Nets
Canonical Space MLP
import torch
import torch.nn as nn
from core.utils.network_util import initseq
class CanonicalMLP(nn.Module):
def __init__(self, mlp_depth=8, mlp_width=256,
input_ch=3, skips=None,
**_):
super(CanonicalMLP, self).__init__()
if skips is None:
skips = [4]
self.mlp_depth = mlp_depth
self.mlp_width = mlp_width
self.input_ch = input_ch
pts_block_mlps = [nn.Linear(input_ch, mlp_width), nn.ReLU()]
layers_to_cat_input = []
for i in range(mlp_depth-1):
if i in skips:
layers_to_cat_input.append(len(pts_block_mlps))
pts_block_mlps += [nn.Linear(mlp_width + input_ch, mlp_width),
nn.ReLU()]
else:
pts_block_mlps += [nn.Linear(mlp_width, mlp_width), nn.ReLU()]
self.layers_to_cat_input = layers_to_cat_input
self.pts_linears = nn.ModuleList(pts_block_mlps)
initseq(self.pts_linears)
# output: rgb + sigma (density)
self.output_linear = nn.Sequential(nn.Linear(mlp_width, 4))
initseq(self.output_linear)
def forward(self, pos_embed, **_):
h = pos_embed
for i, _ in enumerate(self.pts_linears):
if i in self.layers_to_cat_input:
h = torch.cat([pos_embed, h], dim=-1)
h = self.pts_linears[i](h)
outputs = self.output_linear(h)
return outputs
Embedders
hannw_fourier
import numpy as np
import torch
import torch.nn as nn
from configs import cfg
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
# get hann window weights
kick_in_iter = torch.tensor(cfg.non_rigid_motion_mlp.kick_in_iter,
dtype=torch.float32)
t = torch.clamp(self.kwargs['iter_val'] - kick_in_iter, min=0.)
N = cfg.non_rigid_motion_mlp.full_band_iter - kick_in_iter
m = N_freqs
alpha = m * t / N
for freq_idx, freq in enumerate(freq_bands):
w = (1. - torch.cos(np.pi * torch.clamp(alpha - freq_idx,
min=0., max=1.))) / 2.
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq, w=w: w * p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, iter_val, is_identity=0):
if is_identity == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : False,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'periodic_fns' : [torch.sin, torch.cos],
'iter_val': iter_val
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
fourier
import torch
import torch.nn as nn
class Embedder:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self):
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs):
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(multires, i=0):
if i == -1:
return nn.Identity(), 3
embed_kwargs = {
'include_input' : True,
'input_dims' : 3,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed, embedder_obj.out_dim
mweight_vol_decoder
deconv_vol_decoder.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.utils.network_util import ConvDecoder3D
class MotionWeightVolumeDecoder(nn.Module):
def __init__(self, embedding_size=256, volume_size=32, total_bones=24):
super(MotionWeightVolumeDecoder, self).__init__()
self.total_bones = total_bones
self.volume_size = volume_size
self.const_embedding = nn.Parameter(
torch.randn(embedding_size), requires_grad=True
)
self.decoder = ConvDecoder3D(
embedding_size=embedding_size,
volume_size=volume_size,
voxel_channels=total_bones+1)
def forward(self,
motion_weights_priors,
**_):
embedding = self.const_embedding[None, ...]
decoded_weights = F.softmax(self.decoder(embedding) + \
torch.log(motion_weights_priors),
dim=1)
return decoded_weights
non_rigid_motion_mlps
mlp_offset.py
import torch
import torch.nn as nn
from core.utils.network_util import initseq
class NonRigidMotionMLP(nn.Module):
def __init__(self,
pos_embed_size=3,
condition_code_size=69,
mlp_width=128,
mlp_depth=6,
skips=None):
super(NonRigidMotionMLP, self).__init__()
self.skips = [4] if skips is None else skips
block_mlps = [nn.Linear(pos_embed_size+condition_code_size,
mlp_width), nn.ReLU()]
layers_to_cat_inputs = []
for i in range(1, mlp_depth):
if i in self.skips:
layers_to_cat_inputs.append(len(block_mlps))
block_mlps += [nn.Linear(mlp_width+pos_embed_size, mlp_width),
nn.ReLU()]
else:
block_mlps += [nn.Linear(mlp_width, mlp_width), nn.ReLU()]
block_mlps += [nn.Linear(mlp_width, 3)]
self.block_mlps = nn.ModuleList(block_mlps)
initseq(self.block_mlps)
self.layers_to_cat_inputs = layers_to_cat_inputs
# init the weights of the last layer as very small value
# -- at the beginning, we hope non-rigid offsets are zeros
init_val = 1e-5
last_layer = self.block_mlps[-1]
last_layer.weight.data.uniform_(-init_val, init_val)
last_layer.bias.data.zero_()
def forward(self, pos_embed, pos_xyz, condition_code, viewdirs=None, **_):
h = torch.cat([condition_code, pos_embed], dim=-1)
if viewdirs is not None:
h = torch.cat([h, viewdirs], dim=-1)
for i in range(len(self.block_mlps)):
if i in self.layers_to_cat_inputs:
h = torch.cat([h, pos_embed], dim=-1)
h = self.block_mlps[i](h)
trans = h
result = {
'xyz': pos_xyz + trans,
'offsets': trans
}
return result
mlp_delta_body_pose.py
import torch
import torch.nn as nn
from core.utils.network_util import initseq
class NonRigidMotionMLP(nn.Module):
def init(self,
pos_embed_size=3,
condition_code_size=69,
mlp_width=128,
mlp_depth=6,
skips=None):
super(NonRigidMotionMLP, self).init()
self.skips = [4] if skips is None else skips
block_mlps = [nn.Linear(pos_embed_size+condition_code_size,
mlp_width), nn.ReLU()]
layers_to_cat_inputs = []
for i in range(1, mlp_depth):
if i in self.skips:
layers_to_cat_inputs.append(len(block_mlps))
block_mlps += [nn.Linear(mlp_width+pos_embed_size, mlp_width),
nn.ReLU()]
else:
block_mlps += [nn.Linear(mlp_width, mlp_width), nn.ReLU()]
block_mlps += [nn.Linear(mlp_width, 3)]
self.block_mlps = nn.ModuleList(block_mlps)
initseq(self.block_mlps)
self.layers_to_cat_inputs = layers_to_cat_inputs
# init the weights of the last layer as very small value
# -- at the beginning, we hope non-rigid offsets are zeros
init_val = 1e-5
last_layer = self.block_mlps[-1]
last_layer.weight.data.uniform_(-init_val, init_val)
last_layer.bias.data.zero_()
def forward(self, pos_embed, pos_xyz, condition_code, viewdirs=None, **_):
h = torch.cat([condition_code, pos_embed], dim=-1)
if viewdirs is not None:
h = torch.cat([h, viewdirs], dim=-1)
for i in range(len(self.block_mlps)):
if i in self.layers_to_cat_inputs:
h = torch.cat([h, pos_embed], dim=-1)
h = self.block_mlps[i](h)
trans = h
result = {
'xyz': pos_xyz + trans,
'offsets': trans
}
return result
component_factory.py
import torch.nn as nn
from core.utils.network_util import initseq, RodriguesModule
from configs import cfg
class BodyPoseRefiner(nn.Module):
def init(self,
embedding_size=69,
mlp_width=256,
mlp_depth=4,
**_):
super(BodyPoseRefiner, self).init()
block_mlps = [nn.Linear(embedding_size, mlp_width), nn.ReLU()]
for _ in range(0, mlp_depth-1):
block_mlps += [nn.Linear(mlp_width, mlp_width), nn.ReLU()]
self.total_bones = cfg.total_bones - 1
block_mlps += [nn.Linear(mlp_width, 3 * self.total_bones)]
self.block_mlps = nn.Sequential(*block_mlps)
initseq(self.block_mlps)
# init the weights of the last layer as very small value
# -- at the beginning, we hope the rotation matrix can be identity
init_val = 1e-5
last_layer = self.block_mlps[-1]
last_layer.weight.data.uniform_(-init_val, init_val)
last_layer.bias.data.zero_()
self.rodriguez = RodriguesModule()
def forward(self, pose_input):
rvec = self.block_mlps(pose_input).view(-1, 3)
Rs = self.rodriguez(rvec).view(-1, self.total_bones, 3, 3)
return {
"Rs": Rs
}
network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from core.utils.network_util import MotionBasisComputer
from core.nets.human_nerf.component_factory import \
load_positional_embedder, \
load_canonical_mlp, \
load_mweight_vol_decoder, \
load_pose_decoder, \
load_non_rigid_motion_mlp
from configs import cfg
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()
# motion basis computer
self.motion_basis_computer = MotionBasisComputer(
total_bones=cfg.total_bones)
# motion weight volume
self.mweight_vol_decoder = load_mweight_vol_decoder(cfg.mweight_volume.module)(
embedding_size=cfg.mweight_volume.embedding_size,
volume_size=cfg.mweight_volume.volume_size,
total_bones=cfg.total_bones
)
# non-rigid motion st positional encoding
self.get_non_rigid_embedder = \
load_positional_embedder(cfg.non_rigid_embedder.module)
# non-rigid motion MLP
_, non_rigid_pos_embed_size = \
self.get_non_rigid_embedder(cfg.non_rigid_motion_mlp.multires,
cfg.non_rigid_motion_mlp.i_embed)
self.non_rigid_mlp = \
load_non_rigid_motion_mlp(cfg.non_rigid_motion_mlp.module)(
pos_embed_size=non_rigid_pos_embed_size,
condition_code_size=cfg.non_rigid_motion_mlp.condition_code_size,
mlp_width=cfg.non_rigid_motion_mlp.mlp_width,
mlp_depth=cfg.non_rigid_motion_mlp.mlp_depth,
skips=cfg.non_rigid_motion_mlp.skips)
self.non_rigid_mlp = \
nn.DataParallel(
self.non_rigid_mlp,
device_ids=cfg.secondary_gpus,
output_device=cfg.secondary_gpus[0])
# canonical positional encoding
get_embedder = load_positional_embedder(cfg.embedder.module)
cnl_pos_embed_fn, cnl_pos_embed_size = \
get_embedder(cfg.canonical_mlp.multires,
cfg.canonical_mlp.i_embed)
self.pos_embed_fn = cnl_pos_embed_fn
# canonical mlp
skips = [4]
self.cnl_mlp = \
load_canonical_mlp(cfg.canonical_mlp.module)(
input_ch=cnl_pos_embed_size,
mlp_depth=cfg.canonical_mlp.mlp_depth,
mlp_width=cfg.canonical_mlp.mlp_width,
skips=skips)
self.cnl_mlp = \
nn.DataParallel(
self.cnl_mlp,
device_ids=cfg.secondary_gpus,
output_device=cfg.primary_gpus[0])
# pose decoder MLP
self.pose_decoder = \
load_pose_decoder(cfg.pose_decoder.module)(
embedding_size=cfg.pose_decoder.embedding_size,
mlp_width=cfg.pose_decoder.mlp_width,
mlp_depth=cfg.pose_decoder.mlp_depth)
def deploy_mlps_to_secondary_gpus(self):
self.cnl_mlp = self.cnl_mlp.to(cfg.secondary_gpus[0])
if self.non_rigid_mlp:
self.non_rigid_mlp = self.non_rigid_mlp.to(cfg.secondary_gpus[0])
return self
def _query_mlp(
self,
pos_xyz,
pos_embed_fn,
non_rigid_pos_embed_fn,
non_rigid_mlp_input):
# (N_rays, N_samples, 3) --> (N_rays x N_samples, 3)
pos_flat = torch.reshape(pos_xyz, [-1, pos_xyz.shape[-1]])
chunk = cfg.netchunk_per_gpu*len(cfg.secondary_gpus)
result = self._apply_mlp_kernals(
pos_flat=pos_flat,
pos_embed_fn=pos_embed_fn,
non_rigid_mlp_input=non_rigid_mlp_input,
non_rigid_pos_embed_fn=non_rigid_pos_embed_fn,
chunk=chunk)
output = {
}
raws_flat = result['raws']
output['raws'] = torch.reshape(
raws_flat,
list(pos_xyz.shape[:-1]) + [raws_flat.shape[-1]])
return output
@staticmethod
def _expand_input(input_data, total_elem):
assert input_data.shape[0] == 1
input_size = input_data.shape[1]
return input_data.expand((total_elem, input_size))
def _apply_mlp_kernals(
self,
pos_flat,
pos_embed_fn,
non_rigid_mlp_input,
non_rigid_pos_embed_fn,
chunk):
raws = []
# iterate ray samples by trunks
for i in range(0, pos_flat.shape[0], chunk):
start = i
end = i + chunk
if end > pos_flat.shape[0]:
end = pos_flat.shape[0]
total_elem = end - start
xyz = pos_flat[start:end]
if not cfg.ignore_non_rigid_motions:
non_rigid_embed_xyz = non_rigid_pos_embed_fn(xyz)
result = self.non_rigid_mlp(
pos_embed=non_rigid_embed_xyz,
pos_xyz=xyz,
condition_code=self._expand_input(non_rigid_mlp_input, total_elem)
)
xyz = result['xyz']
xyz_embedded = pos_embed_fn(xyz)
raws += [self.cnl_mlp(
pos_embed=xyz_embedded)]
output = {
}
output['raws'] = torch.cat(raws, dim=0).to(cfg.primary_gpus[0])
return output
def _batchify_rays(self, rays_flat, **kwargs):
all_ret = {
}
for i in range(0, rays_flat.shape[0], cfg.chunk):
ret = self._render_rays(rays_flat[i:i+cfg.chunk], **kwargs)
for k in ret:
if k not in all_ret:
all_ret[k] = []
all_ret[k].append(ret[k])
all_ret = {
k : torch.cat(all_ret[k], 0) for k in all_ret}
return all_ret
@staticmethod
def _raw2outputs(raw, raw_mask, z_vals, rays_d, bgcolor=None):
def _raw2alpha(raw, dists, act_fn=F.relu):
return 1.0 - torch.exp(-act_fn(raw)*dists)
dists = z_vals[...,1:] - z_vals[...,:-1]
infinity_dists = torch.Tensor([1e10])
infinity_dists = infinity_dists.expand(dists[...,:1].shape).to(dists)
dists = torch.cat([dists, infinity_dists], dim=-1)
dists = dists * torch.norm(rays_d[...,None,:], dim=-1)
rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3]
alpha = _raw2alpha(raw[...,3], dists) # [N_rays, N_samples]
alpha = alpha * raw_mask[:, :, 0]
weights = alpha * torch.cumprod(
torch.cat([torch.ones((alpha.shape[0], 1)).to(alpha),
1.-alpha + 1e-10], dim=-1), dim=-1)[:, :-1]
rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3]
depth_map = torch.sum(weights * z_vals, -1)
acc_map = torch.sum(weights, -1)
rgb_map = rgb_map + (1.-acc_map[...,None]) * bgcolor[None, :]/255.
return rgb_map, acc_map, weights, depth_map
@staticmethod
def _sample_motion_fields(
pts,
motion_scale_Rs,
motion_Ts,
motion_weights_vol,
cnl_bbox_min_xyz, cnl_bbox_scale_xyz,
output_list):
orig_shape = list(pts.shape)
pts = pts.reshape(-1, 3) # [N_rays x N_samples, 3]
# remove BG channel
motion_weights = motion_weights_vol[:-1]
weights_list = []
for i in range(motion_weights.size(0)):
pos = torch.matmul(motion_scale_Rs[i, :, :], pts.T).T + motion_Ts[i, :]
pos = (pos - cnl_bbox_min_xyz[None, :]) \
* cnl_bbox_scale_xyz[None, :] - 1.0
weights = F.grid_sample(input=motion_weights[None, i:i+1, :, :, :],
grid=pos[None, None, None, :, :],
padding_mode='zeros', align_corners=True)
weights = weights[0, 0, 0, 0, :, None]
weights_list.append(weights)
backwarp_motion_weights = torch.cat(weights_list, dim=-1)
total_bases = backwarp_motion_weights.shape[-1]
backwarp_motion_weights_sum = torch.sum(backwarp_motion_weights,
dim=-1, keepdim=True)
weighted_motion_fields = []
for i in range(total_bases):
pos = torch.matmul(motion_scale_Rs[i, :, :], pts.T).T + motion_Ts[i, :]
weighted_pos = backwarp_motion_weights[:, i:i+1] * pos
weighted_motion_fields.append(weighted_pos)
x_skel = torch.sum(
torch.stack(weighted_motion_fields, dim=0), dim=0
) / backwarp_motion_weights_sum.clamp(min=0.0001)
fg_likelihood_mask = backwarp_motion_weights_sum
x_skel = x_skel.reshape(orig_shape[:2]+[3])
backwarp_motion_weights = \
backwarp_motion_weights.reshape(orig_shape[:2]+[total_bases])
fg_likelihood_mask = fg_likelihood_mask.reshape(orig_shape[:2]+[1])
results = {
}
if 'x_skel' in output_list: # [N_rays x N_samples, 3]
results['x_skel'] = x_skel
if 'fg_likelihood_mask' in output_list: # [N_rays x N_samples, 1]
results['fg_likelihood_mask'] = fg_likelihood_mask
return results
@staticmethod
def _unpack_ray_batch(ray_batch):
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6]
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
near, far = bounds[...,0], bounds[...,1]
return rays_o, rays_d, near, far
@staticmethod
def _get_samples_along_ray(N_rays, near, far):
t_vals = torch.linspace(0., 1., steps=cfg.N_samples).to(near)
z_vals = near * (1.-t_vals) + far * (t_vals)
return z_vals.expand([N_rays, cfg.N_samples])
@staticmethod
def _stratified_sampling(z_vals):
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
t_rand = torch.rand(z_vals.shape).to(z_vals)
z_vals = lower + (upper - lower) * t_rand
return z_vals
def _render_rays(
self,
ray_batch,
motion_scale_Rs,
motion_Ts,
motion_weights_vol,
cnl_bbox_min_xyz,
cnl_bbox_scale_xyz,
pos_embed_fn,
non_rigid_pos_embed_fn,
non_rigid_mlp_input=None,
bgcolor=None,
**_):
N_rays = ray_batch.shape[0]
rays_o, rays_d, near, far = self._unpack_ray_batch(ray_batch)
z_vals = self._get_samples_along_ray(N_rays, near, far)
if cfg.perturb > 0.:
z_vals = self._stratified_sampling(z_vals)
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]
mv_output = self._sample_motion_fields(
pts=pts,
motion_scale_Rs=motion_scale_Rs[0],
motion_Ts=motion_Ts[0],
motion_weights_vol=motion_weights_vol,
cnl_bbox_min_xyz=cnl_bbox_min_xyz,
cnl_bbox_scale_xyz=cnl_bbox_scale_xyz,
output_list=['x_skel', 'fg_likelihood_mask'])
pts_mask = mv_output['fg_likelihood_mask']
cnl_pts = mv_output['x_skel']
query_result = self._query_mlp(
pos_xyz=cnl_pts,
non_rigid_mlp_input=non_rigid_mlp_input,
pos_embed_fn=pos_embed_fn,
non_rigid_pos_embed_fn=non_rigid_pos_embed_fn)
raw = query_result['raws']
rgb_map, acc_map, _, depth_map = \
self._raw2outputs(raw, pts_mask, z_vals, rays_d, bgcolor)
return {
'rgb' : rgb_map,
'alpha' : acc_map,
'depth': depth_map}
def _get_motion_base(self, dst_Rs, dst_Ts, cnl_gtfms):
motion_scale_Rs, motion_Ts = self.motion_basis_computer(
dst_Rs, dst_Ts, cnl_gtfms)
return motion_scale_Rs, motion_Ts
@staticmethod
def _multiply_corrected_Rs(Rs, correct_Rs):
total_bones = cfg.total_bones - 1
return torch.matmul(Rs.reshape(-1, 3, 3),
correct_Rs.reshape(-1, 3, 3)).reshape(-1, total_bones, 3, 3)
def forward(self,
rays,
dst_Rs, dst_Ts, cnl_gtfms,
motion_weights_priors,
dst_posevec=None,
near=None, far=None,
iter_val=1e7,
**kwargs):
dst_Rs=dst_Rs[None, ...]
dst_Ts=dst_Ts[None, ...]
dst_posevec=dst_posevec[None, ...]
cnl_gtfms=cnl_gtfms[None, ...]
motion_weights_priors=motion_weights_priors[None, ...]
# correct body pose
if iter_val >= cfg.pose_decoder.get('kick_in_iter', 0):
pose_out = self.pose_decoder(dst_posevec)
refined_Rs = pose_out['Rs']
refined_Ts = pose_out.get('Ts', None)
dst_Rs_no_root = dst_Rs[:, 1:, ...]
dst_Rs_no_root = self._multiply_corrected_Rs(
dst_Rs_no_root,
refined_Rs)
dst_Rs = torch.cat(
[dst_Rs[:, 0:1, ...], dst_Rs_no_root], dim=1)
if refined_Ts is not None:
dst_Ts = dst_Ts + refined_Ts
non_rigid_pos_embed_fn, _ = \
self.get_non_rigid_embedder(
multires=cfg.non_rigid_motion_mlp.multires,
is_identity=cfg.non_rigid_motion_mlp.i_embed,
iter_val=iter_val,)
if iter_val < cfg.non_rigid_motion_mlp.kick_in_iter:
# mask-out non_rigid_mlp_input
non_rigid_mlp_input = torch.zeros_like(dst_posevec) * dst_posevec
else:
non_rigid_mlp_input = dst_posevec
kwargs.update({
"pos_embed_fn": self.pos_embed_fn,
"non_rigid_pos_embed_fn": non_rigid_pos_embed_fn,
"non_rigid_mlp_input": non_rigid_mlp_input
})
motion_scale_Rs, motion_Ts = self._get_motion_base(
dst_Rs=dst_Rs,
dst_Ts=dst_Ts,
cnl_gtfms=cnl_gtfms)
motion_weights_vol = self.mweight_vol_decoder(
motion_weights_priors=motion_weights_priors)
motion_weights_vol=motion_weights_vol[0] # remove batch dimension
kwargs.update({
'motion_scale_Rs': motion_scale_Rs,
'motion_Ts': motion_Ts,
'motion_weights_vol': motion_weights_vol
})
rays_o, rays_d = rays
rays_shape = rays_d.shape
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
packed_ray_infos = torch.cat([rays_o, rays_d, near, far], -1)
all_ret = self._batchify_rays(packed_ray_infos, **kwargs)
for k in all_ret:
k_shape = list(rays_shape[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_shape)
return all_ret
create_network.py
import imp
from configs import cfg
def _query_network():
module = cfg.network_module
module_path = module.replace(".", "/") + ".py"
network = imp.load_source(module, module_path).Network
return network
def create_network():
network = _query_network()
network = network()
return network
optimizer.py
import torch.optim as optim
from configs import cfg
_optimizers = {
'adam': optim.Adam
}
def get_customized_lr_names():
return [k[3:] for k in cfg.train.keys() if k.startswith('lr_')]
def get_optimizer(network):
optimizer = _optimizers[cfg.train.optimizer]
cus_lr_names = get_customized_lr_names()
params = []
print('\n\n********** learnable parameters **********\n')
for key, value in network.named_parameters():
if not value.requires_grad:
continue
is_assigned_lr = False
for lr_name in cus_lr_names:
if lr_name in key:
params += [{
"params": [value],
"lr": cfg.train[f'lr_{
lr_name}'],
"name": lr_name}]
print(f"{
key}: lr = {
cfg.train[f'lr_{
lr_name}']}")
is_assigned_lr = True
if not is_assigned_lr:
params += [{
"params": [value],
"name": key}]
print(f"{
key}: lr = {
cfg.train.lr}")
print('\n******************************************\n\n')
if cfg.train.optimizer == 'adam':
optimizer = optimizer(params, lr=cfg.train.lr, betas=(0.9, 0.999))
else:
assert False, "Unsupported Optimizer."
return optimizer
trainer.py
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from third_parties.lpips import LPIPS
from core.train import create_lr_updater
from core.data import create_dataloader
from core.utils.network_util import set_requires_grad
from core.utils.train_util import cpu_data_to_gpu, Timer
from core.utils.image_util import tile_images, to_8b_image
from configs import cfg
img2mse = lambda x, y : torch.mean((x - y) ** 2)
img2l1 = lambda x, y : torch.mean(torch.abs(x-y))
to8b = lambda x : (255.*np.clip(x,0.,1.)).astype(np.uint8)
EXCLUDE_KEYS_TO_GPU = ['frame_name', 'img_width', 'img_height']
def _unpack_imgs(rgbs, patch_masks, bgcolor, targets, div_indices):
N_patch = len(div_indices) - 1
assert patch_masks.shape[0] == N_patch
assert targets.shape[0] == N_patch
patch_imgs = bgcolor.expand(targets.shape).clone() # (N_patch, H, W, 3)
for i in range(N_patch):
patch_imgs[i, patch_masks[i]] = rgbs[div_indices[i]:div_indices[i+1]]
return patch_imgs
def scale_for_lpips(image_tensor):
return image_tensor * 2. - 1.
class Trainer(object):
def __init__(self, network, optimizer):
print('\n********** Init Trainer ***********')
network = network.cuda().deploy_mlps_to_secondary_gpus()
self.network = network
self.optimizer = optimizer
self.update_lr = create_lr_updater()
if cfg.resume and Trainer.ckpt_exists(cfg.load_net):
self.load_ckpt(f'{
cfg.load_net}')
else:
self.iter = 0
self.save_ckpt('init')
self.iter = 1
self.timer = Timer()
if "lpips" in cfg.train.lossweights.keys():
self.lpips = LPIPS(net='vgg')
set_requires_grad(self.lpips, requires_grad=False)
self.lpips = nn.DataParallel(self.lpips).cuda()
print("Load Progress Dataset ...")
self.prog_dataloader = create_dataloader(data_type='progress')
print('************************************')
@staticmethod
def get_ckpt_path(name):
return os.path.join(cfg.logdir, f'{
name}.tar')
@staticmethod
def ckpt_exists(name):
return os.path.exists(Trainer.get_ckpt_path(name))
######################################################3
## Training
def get_img_rebuild_loss(self, loss_names, rgb, target):
losses = {
}
if "mse" in loss_names:
losses["mse"] = img2mse(rgb, target)
if "l1" in loss_names:
losses["l1"] = img2l1(rgb, target)
if "lpips" in loss_names:
lpips_loss = self.lpips(scale_for_lpips(rgb.permute(0, 3, 1, 2)),
scale_for_lpips(target.permute(0, 3, 1, 2)))
losses["lpips"] = torch.mean(lpips_loss)
return losses
def get_loss(self, net_output,
patch_masks, bgcolor, targets, div_indices):
lossweights = cfg.train.lossweights
loss_names = list(lossweights.keys())
rgb = net_output['rgb']
losses = self.get_img_rebuild_loss(
loss_names,
_unpack_imgs(rgb, patch_masks, bgcolor,
targets, div_indices),
targets)
train_losses = [
weight * losses[k] for k, weight in lossweights.items()
]
return sum(train_losses), \
{
loss_names[i]: train_losses[i] for i in range(len(loss_names))}
def train_begin(self, train_dataloader):
assert train_dataloader.batch_size == 1
self.network.train()
cfg.perturb = cfg.train.perturb
def train_end(self):
pass
def train(self, epoch, train_dataloader):
self.train_begin(train_dataloader=train_dataloader)
self.timer.begin()
for batch_idx, batch in enumerate(train_dataloader):
if self.iter > cfg.train.maxiter:
break
self.optimizer.zero_grad()
# only access the first batch as we process one image one time
for k, v in batch.items():
batch[k] = v[0]
batch['iter_val'] = torch.full((1,), self.iter)
data = cpu_data_to_gpu(
batch, exclude_keys=EXCLUDE_KEYS_TO_GPU)
net_output = self.network(**data)
train_loss, loss_dict = self.get_loss(
net_output=net_output,
patch_masks=data['patch_masks'],
bgcolor=data['bgcolor'] / 255.,
targets=data['target_patches'],
div_indices=data['patch_div_indices'])
train_loss.backward()
self.optimizer.step()
if self.iter % cfg.train.log_interval == 0:
loss_str = f"Loss: {
train_loss.item():.4f} ["
for k, v in loss_dict.items():
loss_str += f"{
k}: {
v.item():.4f} "
loss_str += "]"
log_str = 'Epoch: {} [Iter {}, {}/{} ({:.0f}%), {}] {}'
log_str = log_str.format(
epoch, self.iter,
batch_idx * cfg.train.batch_size, len(train_dataloader.dataset),
100. * batch_idx / len(train_dataloader),
self.timer.log(),
loss_str)
print(log_str)
is_reload_model = False
if self.iter in [100, 300, 1000, 2500] or \
self.iter % cfg.progress.dump_interval == 0:
is_reload_model = self.progress()
if not is_reload_model:
if self.iter % cfg.train.save_checkpt_interval == 0:
self.save_ckpt('latest')
if cfg.save_all:
if self.iter % cfg.train.save_model_interval == 0:
self.save_ckpt(f'iter_{
self.iter}')
self.update_lr(self.optimizer, self.iter)
self.iter += 1
def finalize(self):
self.save_ckpt('latest')
######################################################3
## Progress
def progress_begin(self):
self.network.eval()
cfg.perturb = 0.
def progress_end(self):
self.network.train()
cfg.perturb = cfg.train.perturb
def progress(self):
self.progress_begin()
print('Evaluate Progress Images ...')
images = []
is_empty_img = False
for _, batch in enumerate(tqdm(self.prog_dataloader)):
# only access the first batch as we process one image one time
for k, v in batch.items():
batch[k] = v[0]
width = batch['img_width']
height = batch['img_height']
ray_mask = batch['ray_mask']
rendered = np.full(
(height * width, 3), np.array(cfg.bgcolor)/255.,
dtype='float32')
truth = np.full(
(height * width, 3), np.array(cfg.bgcolor)/255.,
dtype='float32')
batch['iter_val'] = torch.full((1,), self.iter)
data = cpu_data_to_gpu(
batch, exclude_keys=EXCLUDE_KEYS_TO_GPU + ['target_rgbs'])
with torch.no_grad():
net_output = self.network(**data)
rgb = net_output['rgb'].data.to("cpu").numpy()
target_rgbs = batch['target_rgbs']
rendered[ray_mask] = rgb
truth[ray_mask] = target_rgbs
truth = to_8b_image(truth.reshape((height, width, -1)))
rendered = to_8b_image(rendered.reshape((height, width, -1)))
images.append(np.concatenate([rendered, truth], axis=1))
# check if we create empty images (only at the begining of training)
if self.iter <= 5000 and \
np.allclose(rendered, np.array(cfg.bgcolor), atol=5.):
is_empty_img = True
break
tiled_image = tile_images(images)
Image.fromarray(tiled_image).save(
os.path.join(cfg.logdir, "prog_{:06}.jpg".format(self.iter)))
if is_empty_img:
print("Produce empty images; reload the init model.")
self.load_ckpt('init')
self.progress_end()
return is_empty_img
######################################################3
## Utils
def save_ckpt(self, name):
path = Trainer.get_ckpt_path(name)
print(f"Save checkpoint to {
path} ...")
torch.save({
'iter': self.iter,
'network': self.network.state_dict(),
'optimizer': self.optimizer.state_dict()
}, path)
def load_ckpt(self, name):
path = Trainer.get_ckpt_path(name)
print(f"Load checkpoint from {
path} ...")
ckpt = torch.load(path, map_location='cuda:0')
self.iter = ckpt['iter'] + 1
self.network.load_state_dict(ckpt['network'], strict=False)
self.optimizer.load_state_dict(ckpt['optimizer'])