NeRF源码分析解读(四)

NeRF源码分析解读(四)

上一篇博客中我们已经得到了每张图像下每个像素点对应的光线的 ray_o、 ray_d。有了这两个量可以说我们已经模拟到了一条光线,下面我们继续分析代码,看看空间中的点是如何生成的,MLP 又是怎样输出点的颜色和密度的。
首先我们回顾一下之前博客中分析的代码:

def train():
	# 1、加载数据
	if args.dataset_type == 'llff': ...
	elif args.dataset_type == 'blender': ...
	elif args.dataset_type == 'LINEMOD': ...
	elif args.dataset_type == 'deepvoxels': ...

	# 2、创建 NeRF 网络
	render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)

	# 3、如果以批处理的形式对进行训练,首先生成所有图片的光线
	if use_batching: ...
	

到此我们已经结束了前期的准备工作,包括 NeRF 模型的初始化构建,在批处理的情况下生成所有图片的光线原点以及方向。下面我们就可以开始我们的训练步骤。

def train():
	...
	
	# 4、开始训练
	for i in trange(start, N_iters):
		
		if use_batching:
			...
		else: 
			...

我们首先分析批处理的情况:

    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        if use_batching:
            # 分批加载光线,大小为 N_rand 
            batch = rays_rgb[i_batch : i_batch + N_rand] # [B, 2+1, 3]
            batch = torch.transpose(batch, 0, 1)  # [3, B, 3]
            
            batch_rays, target_s = batch[:2], batch[2]  # [2, B, 3]  [B, 3]  将光线和对应的像素点颜色分离

            i_batch += N_rand
			
			# 经过一定批次的处理后,所有的图片都经过了一次。这时候要对数据打乱,重新再挑选
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0

由于我们在第三步已经生成了所有图片的像素点对应的光线原点和方向,并将光线对应的像素颜色与光线聚合到了一起构成 rays_rgb,因此我们直接对 rays_rgb 分离即可。接下来我们对单张图片的训练进行分析。

    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample random ray batch
        if use_batching: ...
		
		else:
			# 从所有的图像中随机选择一张图像用于训练
            img_i = np.random.choice(i_train)
            target = images[img_i]
            target = torch.Tensor(target).to(device)
            pose = poses[img_i, :3,:4]

            if N_rand is not None:
            	# 生成这张图像中每个像素点对应的光线的原点和方向
                rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)
				# 生成每个像素点的笛卡尔坐标,前 precrop_iters 生成图像中心的像素坐标坐标
                if i < args.precrop_iters:
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 
                            torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
                        ), -1)
                    if i == start:
                        print(f"[Config] Center cropping of size {
      
      2*dH} x {
      
      2*dW} is enabled until iter {
      
      args.precrop_iters}")                
                else:
                	# 生成图像中每个像素的坐标
                    coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)  # (H, W, 2)

                coords = torch.reshape(coords, [-1,2])  # (H * W, 2)
                # 注意,在训练的时候并不是给图像中每个像素都打光线,而是加载一批光线,批大小为 N_rand
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
                # 选择像素坐标
                select_coords = coords[select_inds].long()  # (N_rand, 2)
                # 选择对应的光线
                rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)
                
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0], select_coords[:, 1]]  # (N_rand, 3)

注意: 想必读者已经注意到不论是 use_batching 的情况还是单张图像的情况,每个 epoch 选择的光线的数量是恒定的,即 N_rand 。这么做实际上是为了减少计算的工程量。虽然每次都只随机挑选了一部分像素对应的光线,但是经过多达 200 000 次的训练实际上已经足以把所有的像素对应的光线都挑选一遍了。
我们得到生成的光线以后就可以对光线进行渲染操作了。代码如下:

def train():
	...
	
	# 4、开始训练
	for i in trange(start, N_iters):
		
		if use_batching:
			...
		else: 
			...
		
		rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                                verbose=i < 10, retraw=True,
                                                **render_kwargs_train)

我们对 render() 函数进行分析。首先我们分析 render() 函数的各项参数

chunk : 并行处理的光线的数量
batch_rays : 经由上一的步骤我们挑选出的光线
render_kwargs_train : 初始化模型代码 create_nerf() 返回的字典数据,具体内容请查看 NeRF 源码分析解读(二)
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
	...
	# 如果使用视图方向,根据光线的 ray_d 计算单位方向作为 view_dirs
	if use_viewdirs:
		viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
        
	...
	
	# 生成光线的远近端,用于确定边界框,并将其聚合到 rays 中
	near, far = near * torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])  # [-1,]
    rays = torch.cat([rays_o, rays_d, near, far], -1)  # [-1, 8]
    # 视图方向聚合到光线中
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)  # [-1, 11]
	# 开始并行计算光线属性
	all_ret = batchify_rays(rays, chunk, **kwargs)
	

我们经过上述步骤计算得到了光线的 ray_o、ray_d、near、far、viewdirs 并投入到批处理程序 batchify_rays() 中。接下来我们对 batchify_rays() 进行分析。

def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    all_ret = {
    
    }
    for i in range(0, rays_flat.shape[0], chunk):
    	# 渲染光线
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        ...

事实上我们可以看到,主要起作用的代码为render_rays()注意 以下代码涉及到光线上采样点的位置生成,请读者从数学上思考 pts 的生成过程。本文给出实例印证

def render_rays():
	# 从 ray 中分离出 rays_o, rays_d, viewdirs, near, far
	...
	
	# 确定空间中一个坐标的 Z 轴位置
	t_vals = torch.linspace(0., 1., steps=N_samples)  # 在 0-1 内生成 N_samples 个等差点
	# 根据参数确定不同的采样方式,从而确定 Z 轴在边界框内的的具体位置
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

	...
	# 生成光线上每个采样点的位置
	pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]  # [N_rays, N_samples, 3]
	# 将光线上的每个点投入到 MLP 网络 network_fn 中前向传播得到每个点对应的 (RGB,A)
	raw = network_query_fn(pts, viewdirs, network_fn)

对于 z_vals 的计算。举一个例子,我们想要生成 a-b 之间的 n 个均匀的数组,我们改如何计算呢?按照我们习惯的计算方法应该将将 a - b 之间的长度分成 n-1 等份,每份距离为 d ,然后依次将 d 加到 a 上即可,表述为数学形式即为:
d = b − a n − 1 a i = a + i d = a + i ⋅ b − a n − 1 = a ( 1 − i n − 1 ) + b i n − 1 ( i = 0 , . . . , n − 1 ) \begin{aligned} d &= \frac{b-a}{n-1} \\ a_i &= a + id \\ &= a + i \cdot \frac{b-a}{n-1} \\ &= a(1-\frac{i}{n-1}) + b\frac{i}{n-1} \quad (i = 0,...,n-1) \end{aligned} dai=n1ba=a+id=a+in1ba=a(1n1i)+bn1i(i=0,...,n1)这里的 i n − 1 ( i = 0 , . . . , n − 1 ) \frac{i}{n-1} \quad (i = 0,...,n-1) n1i(i=0,...,n1) 可以表述成 0-1 内的 n 个等差点。对应于代码中

z_vals = near * (1.-t_vals) + far * (t_vals)

有了 Z 轴距离,自然而然的我们想到空间中的点的位置可以表述为原点加方向和距离的乘积。

pts = rays_o + rays_d * z_vals

把位置坐标 pts 输入到网络 network_fn 中输出每个点对应的 RGB A ,聚合到变量 raw 中。


到此我们分析了光线的选择性投射,以及空间中点的位置的生成,并给出了点对应的颜色和密度。接下来我们可以根据论文中提到的体积渲染计算光线对应的最终颜色。具体的代码分析见下一片博客NeRF源码分析解读(五)

猜你喜欢

转载自blog.csdn.net/qq_41071191/article/details/126052795