├── modules ├── utils.py ├── model.py ├── sh.py └── config.py ├── dataloader ├── load_blender.py └── dataset.py ├── README.md └── main.py /modules/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | # Setup 8 | if torch.cuda.is_available(): 9 | device = torch.device("cuda:0") 10 | torch.cuda.set_device(device) 11 | else: 12 | device = torch.device("cpu") 13 | 14 | mse2psnr = lambda x : -10. * torch.log(x) \ 15 | / torch.log(torch.tensor([10.], device=x.device)) 16 | 17 | def safe_path(path): 18 | if os.path.exists(path): 19 | return path 20 | else: 21 | os.mkdir(path) 22 | return path 23 | 24 | def load_mem_data(mem): 25 | poses = mem.pose 26 | R, T = (poses[:, :3, :3]), poses[:, :3, -1] 27 | R, T = R, -(T[: ,None ,:] @ R)[: ,0] 28 | return mem.pts, mem.image, mem.K, R, T, poses, mem.mask 29 | 30 | def get_rays(H, W, K, c2w): 31 | device = c2w.device 32 | i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=device), 33 | torch.linspace(0, H-1, H, device=device)) # pytorch's meshgrid has indexing='ij' 34 | i = i.t() 35 | j = j.t() 36 | dirs = torch.stack([(i-K[0][2])/K[0][0], 37 | -(j-K[1][2])/K[1][1], -torch.ones_like(i, device=device)], -1) 38 | # Rotate ray directions from camera frame to the world frame 39 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 40 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 41 | rays_o = c2w[:3,-1].expand(rays_d.shape) 42 | return rays_o, rays_d 43 | 44 | 45 | def remove_outlier(pts): 46 | pcd = o3d.geometry.PointCloud() 47 | pcd.points = o3d.utility.Vector3dVector(pts) 48 | pcd = pcd.voxel_down_sample(voxel_size=0.010) 49 | cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) 50 | return np.array(pcd.points)[np.array(ind)] 51 | 52 | 53 | def grad_loss(output, gt): 54 | def one_grad(shift): 55 | ox = output[shift:] - output[:-shift] 56 | oy = output[:, shift:] - output[:, :-shift] 57 | gx = gt[shift:] - gt[:-shift] 58 | gy = gt[:, shift:] - gt[:, :-shift] 59 | loss = (ox - gx).abs().mean() + (oy - gy).abs().mean() 60 | return loss 61 | loss = (one_grad(1) + one_grad(2) + one_grad(3)) / 3. 62 | return loss 63 | 64 | 65 | def set_seed(seed=0): 66 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 67 | torch.manual_seed(seed) 68 | torch.cuda.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | np.random.seed(seed) 71 | torch.backends.cudnn.benchmark = False 72 | torch.backends.cudnn.deterministic = True 73 | -------------------------------------------------------------------------------- /dataloader/load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_blender_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for frame in meta['frames'][::skip]: 57 | fname = os.path.join(basedir, frame['file_path'] + '.png') 58 | imgs.append(imageio.imread(fname)) 59 | poses.append(np.array(frame['transform_matrix'])) 60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 61 | poses = np.array(poses).astype(np.float32) 62 | counts.append(counts[-1] + imgs.shape[0]) 63 | all_imgs.append(imgs) 64 | all_poses.append(poses) 65 | 66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 67 | 68 | imgs = np.concatenate(all_imgs, 0) 69 | poses = np.concatenate(all_poses, 0) 70 | 71 | H, W = imgs[0].shape[:2] 72 | camera_angle_x = float(meta['camera_angle_x']) 73 | focal = .5 * W / np.tan(.5 * camera_angle_x) 74 | 75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 76 | 77 | if half_res: 78 | H = H//2 79 | W = W//2 80 | focal = focal/2. 81 | 82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 83 | for i, img in enumerate(imgs): 84 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 85 | imgs = imgs_half_res 86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 87 | 88 | 89 | return imgs, poses, render_poses, [H, W, focal], i_split 90 | 91 | 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # point-radiance 2 | 3 | --- 4 | 5 | This code release accompanies the following paper: 6 | 7 | ### Differentiable Point-Based Radiance Fields for Efficient View Synthesis 8 | Qiang Zhang, Seung-Hwan Baek, Szymon Rusinkiweicz, Felix Heide 9 | 10 | *Siggraph Asia*, 2022 11 | 12 | [PDF](https://arxiv.org/pdf/2205.14330.pdf) | [arXiv](https://arxiv.org/abs/2205.14330) 13 | **Abstract:** 14 | We propose a differentiable rendering algorithm for efficient novel 15 | view synthesis. By departing from volume-based representations 16 | in favor of a learned point representation, we improve on existing 17 | methods more than an order of magnitude in memory and run- 18 | time, both in training and inference. The method begins with a 19 | uniformly-sampled random point cloud and learns per-point posi- 20 | tion and view-dependent appearance, using a differentiable splat- 21 | based renderer to train the model to reproduce a set of input train- 22 | ing images with the given pose. Our method is up to 300 × faster 23 | than NeRF in both training and inference, with only a marginal 24 | sacrifice in quality, while using less than 10 MB of memory for a 25 | static scene. For dynamic scenes, our method trains two orders of 26 | magnitude faster than STNeRF and renders at a near interactive 27 | rate, while maintaining high image quality and temporal coherence 28 | even without imposing any temporal-coherency regularizers. 29 | 30 | 31 | ## Installation 32 | 33 | We recommend using a [`conda`](https://docs.conda.io/en/latest/miniconda.html) environment for this codebase. The following commands will set up a new conda environment with the correct requirements: 34 | 35 | ```bash 36 | # Create and activate new conda env 37 | conda create -n my-conda-env python=3.9 38 | conda activate my-conda-env 39 | 40 | # Install pytorch and related libraries 41 | conda install -y pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.0 -c pytorch 42 | conda install numpy matplotlib tqdm imageio 43 | ``` 44 | Then follow the official [INSTALL.md](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) to install [pytorch3d](https://pytorch3d.org/). 45 | 46 | ## Reproduce 47 | You can train the model on NeRF synthetic dataset within 3 minutes. Here datadir is the dataset folder path. Dataname is the scene name. Basedir is the log folder path. Data_r is the ratio between the used point number and the initialized point number. Splatting_r is the radius for the splatting. 48 | ```bash 49 | python main.py --datadir xxx --dataname hotdog --basedir xxx --data_r 0.012 --splatting_r 0.015 50 | ``` 51 | After around three minutes, you can see the following output (the example is tested on one A100 GPU): 52 | 53 | ``` 54 | Training time: 148.59 s 55 | Rendering quality: 34.70 dB 56 | Rendering speed: 120.01 fps 57 | Model size: 7.32 MB 58 | ``` 59 | 60 | ## Citation 61 | 62 | If you find this work useful for your research, please consider citing: 63 | ``` 64 | @article{zhang2022differentiable, 65 | title={Differentiable Point-Based Radiance Fields for Efficient View Synthesis}, 66 | author={Zhang, Qiang and Baek, Seung-Hwan and Rusinkiewicz, Szymon and Heide, Felix}, 67 | journal={arXiv preprint arXiv:2205.14330}, 68 | year={2022} 69 | } 70 | ``` 71 | -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from pytorch3d.ops import knn_points 7 | from pytorch3d.structures import Pointclouds 8 | from pytorch3d.renderer import ( 9 | PerspectiveCameras, 10 | PointsRasterizationSettings, 11 | PointsRenderer, 12 | PointsRasterizer, 13 | AlphaCompositor 14 | ) 15 | 16 | 17 | from modules.sh import eval_sh 18 | from modules.utils import device, load_mem_data, \ 19 | get_rays, remove_outlier 20 | 21 | 22 | class CoreModel(torch.nn.Module): 23 | def __init__(self, args): 24 | super(CoreModel, self).__init__() 25 | self.raster_n = args.raster_n 26 | self.img_s = args.img_s 27 | self.dataname = args.dataname 28 | self.splatting_r = args.splatting_r 29 | pointcloud, imagesgt, K, R, T, poses,masks = load_mem_data(args.memitem) 30 | 31 | self.R, self.T, self.K = R, T, K 32 | self.poses = torch.tensor(poses).to(device) 33 | self.imagesgt = imagesgt 34 | self.masks = masks 35 | N = int(pointcloud.shape[0] * args.data_r) 36 | ids = np.random.permutation(pointcloud.shape[0])[:N] 37 | pointcloud = pointcloud[ids][:, :6] 38 | print('Initialized point number:{}'.format(pointcloud.shape[0])) 39 | 40 | self.vertsparam = torch.nn.Parameter(torch.Tensor(pointcloud[:, :3])) 41 | self.sh_n, sh_param = 2, [torch.Tensor(pointcloud[:, 3:])] 42 | for i in range((self.sh_n + 1) ** 2): 43 | sh_param.append(torch.rand((pointcloud.shape[0], 3))) 44 | sh_param = torch.cat(sh_param, -1) 45 | self.sh_param = torch.nn.Parameter(sh_param) 46 | self.viewdir = [] 47 | for i in range(self.poses.shape[0]): 48 | rays_o, rays_d = get_rays(self.img_s, self.img_s, torch.tensor(K).to(device), self.poses[i]) 49 | rays_d = torch.nn.functional.normalize(rays_d, dim=2) 50 | self.viewdir.append(rays_d) 51 | 52 | self.raster_settings = PointsRasterizationSettings( 53 | bin_size=23, 54 | image_size=self.img_s, 55 | radius=self.splatting_r, 56 | points_per_pixel=self.raster_n, 57 | ) 58 | self.onlybase = False 59 | 60 | def repeat_pts(self): 61 | self.vertsparam.data = self.vertsparam.data.repeat(2,1) 62 | self.sh_param.data = self.sh_param.data.repeat(2, 1) 63 | if self.vertsparam.grad is not None: 64 | self.vertsparam.grad = self.vertsparam.grad.repeat(2,1) 65 | if self.sh_param.grad is not None: 66 | self.sh_param.grad = self.sh_param.grad.repeat(2, 1) 67 | 68 | def remove_out(self): 69 | pts_all = self.vertsparam.data 70 | pts_in = remove_outlier(pts_all.cpu().data.numpy()) 71 | pts_in = torch.tensor(pts_in).cuda().float() 72 | idx = knn_points(pts_in[None,...], pts_all[None,...], None, None, 1).idx[0,:,0] 73 | self.vertsparam.data = self.vertsparam.data[idx].detach() 74 | self.sh_param.data = self.sh_param.data[idx].detach() 75 | if self.vertsparam.grad is not None: 76 | self.vertsparam.grad = self.vertsparam.grad[idx].detach() 77 | if self.sh_param.grad is not None: 78 | self.sh_param.grad = self.sh_param.grad[idx].detach() 79 | 80 | def forward(self, id): 81 | cameras = PerspectiveCameras(focal_length=self.K[0][0] / self.K[0][2], 82 | device=device, R=-self.R[id:id + 1], T=-self.T[id:id + 1]) 83 | rasterizer = PointsRasterizer(cameras=cameras, raster_settings=self.raster_settings) 84 | renderer = PointsRenderer( 85 | rasterizer=rasterizer, 86 | compositor=AlphaCompositor() 87 | ) 88 | point_cloud = Pointclouds(points=[self.vertsparam], features=[self.sh_param]) 89 | feat = renderer(point_cloud).flip(1) 90 | base, shfeat = feat[..., :3], feat[..., 3:] 91 | shfeat = torch.stack(shfeat.split(3, 3), -1) 92 | if self.onlybase: 93 | image = base 94 | else: 95 | image = base + eval_sh(self.sh_n, shfeat, self.viewdir[id]) 96 | return image.clamp(min=0, max=1) 97 | -------------------------------------------------------------------------------- /modules/sh.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The PlenOctree Authors. 2 | # Redistribution and use in source and binary forms, with or without 3 | # modification, are permitted provided that the following conditions are met: 4 | # 5 | # 1. Redistributions of source code must retain the above copyright notice, 6 | # this list of conditions and the following disclaimer. 7 | # 8 | # 2. Redistributions in binary form must reproduce the above copyright notice, 9 | # this list of conditions and the following disclaimer in the documentation 10 | # and/or other materials provided with the distribution. 11 | # 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 13 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 14 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 15 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 16 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 17 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 18 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 19 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 20 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 21 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 22 | # POSSIBILITY OF SUCH DAMAGE. 23 | 24 | C0 = 0.28209479177387814 25 | C1 = 0.4886025119029199 26 | C2 = [ 27 | 1.0925484305920792, 28 | -1.0925484305920792, 29 | 0.31539156525252005, 30 | -1.0925484305920792, 31 | 0.5462742152960396 32 | ] 33 | C3 = [ 34 | -0.5900435899266435, 35 | 2.890611442640554, 36 | -0.4570457994644658, 37 | 0.3731763325901154, 38 | -0.4570457994644658, 39 | 1.445305721320277, 40 | -0.5900435899266435 41 | ] 42 | C4 = [ 43 | 2.5033429417967046, 44 | -1.7701307697799304, 45 | 0.9461746957575601, 46 | -0.6690465435572892, 47 | 0.10578554691520431, 48 | -0.6690465435572892, 49 | 0.47308734787878004, 50 | -1.7701307697799304, 51 | 0.6258357354491761, 52 | ] 53 | 54 | def eval_sh(deg, sh, dirs): 55 | """ 56 | Evaluate spherical harmonics at unit directions 57 | using hardcoded SH polynomials. 58 | Works with torch/np/jnp. 59 | ... Can be 0 or more batch dimensions. 60 | 61 | Args: 62 | deg: int SH deg. Currently, 0-3 supported 63 | sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] 64 | dirs: jnp.ndarray unit directions [..., 3] 65 | 66 | Returns: 67 | [..., C] 68 | """ 69 | assert deg <= 4 and deg >= 0 70 | assert (deg + 1) ** 2 == sh.shape[-1] 71 | C = sh.shape[-2] 72 | 73 | result = C0 * sh[..., 0] 74 | if deg > 0: 75 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 76 | result = (result - 77 | C1 * y * sh[..., 1] + 78 | C1 * z * sh[..., 2] - 79 | C1 * x * sh[..., 3]) 80 | if deg > 1: 81 | xx, yy, zz = x * x, y * y, z * z 82 | xy, yz, xz = x * y, y * z, x * z 83 | result = (result + 84 | C2[0] * xy * sh[..., 4] + 85 | C2[1] * yz * sh[..., 5] + 86 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 87 | C2[3] * xz * sh[..., 7] + 88 | C2[4] * (xx - yy) * sh[..., 8]) 89 | if deg > 2: 90 | result = (result + 91 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 92 | C3[1] * xy * z * sh[..., 10] + 93 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 94 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 95 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 96 | C3[5] * z * (xx - yy) * sh[..., 14] + 97 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 98 | if deg > 3: 99 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 100 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 101 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 102 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 103 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 104 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 105 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 106 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 107 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 108 | return result 109 | 110 | 111 | if __name__ == '__main__': 112 | import numpy as np 113 | deg = 3 114 | sh = np.random.random((400,400,3,16)) 115 | dirs = np.random.random((400,400,3)) 116 | result = eval_sh(deg,sh,dirs) 117 | print(result.shape) 118 | -------------------------------------------------------------------------------- /dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | import matplotlib.pyplot as plt 8 | 9 | from dataloader.load_blender import load_blender_data 10 | 11 | def safe_path(path): 12 | if os.path.exists(path): 13 | return path 14 | else: 15 | os.mkdir(path) 16 | return path 17 | 18 | 19 | # Ray helpers 20 | def get_rays(H, W, K, c2w): 21 | device = c2w.device 22 | 23 | i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=device), 24 | torch.linspace(0, H-1, H, device=device)) # pytorch's meshgrid has indexing='ij' 25 | i = i.t() 26 | j = j.t() 27 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i, device=device)], -1) 28 | # Rotate ray directions from camera frame to the world frame 29 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 30 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 31 | rays_o = c2w[:3,-1].expand(rays_d.shape) 32 | return rays_o, rays_d 33 | 34 | 35 | # Ray helpers 36 | def get_uvs_from_ray(H, W, K, c2w,pts): 37 | RP = torch.bmm(c2w[:3,:3].T[None,:,:].repeat(pts.shape[0],1,1),pts[:,:,None])[:,:,0] 38 | t = torch.mm(c2w[:3,:3].T,-c2w[:3,-1][:,None]) 39 | pts_local0 = torch.sum((pts-c2w[:3,-1])[..., None, :] * (c2w[:3,:3].T), -1) 40 | pts_local = pts_local0/(-pts_local0[...,-1][...,None]+1e-7) 41 | u = pts_local[...,0]*K[0][0]+K[0][2] 42 | v = -pts_local[...,1]*K[1][1]+K[1][2] 43 | uv = torch.stack((u,v),-1) 44 | return uv,pts_local0 45 | 46 | 47 | def batch_get_uv_from_ray(H,W,K,poses,pts): 48 | RT = (poses[:, :3, :3].transpose(1, 2)) 49 | pts_local = torch.sum((pts[..., None, :] - poses[:, :3, -1])[..., None, :] * RT, -1) 50 | pts_local = pts_local / (-pts_local[..., -1][..., None] + 1e-7) 51 | u = pts_local[..., 0] * K[0][0] + K[0][2] 52 | v = -pts_local[..., 1] * K[1][1] + K[1][2] 53 | uv0 = torch.stack((u, v), -1) 54 | uv0[...,0] = uv0[...,0]/W*2-1 55 | uv0[...,1] = uv0[...,1]/H*2-1 56 | uv0 = uv0.permute(2,0,1,3) 57 | return uv0 58 | 59 | 60 | class MemDataset(object): 61 | def __init__(self,pts,pose,image,mask,K): 62 | self.pts = pts 63 | self.pose = pose 64 | self.mask = mask 65 | self.image = image 66 | self.K = K 67 | 68 | class Data(): 69 | def __init__(self, args): 70 | self.dataname = args.dataname 71 | self.datadir = os.path.join(args.datadir,args.dataname) 72 | self.logpath = safe_path(args.basedir) 73 | # self.initpath = safe_path(os.path.join(self.logpath,'init')) 74 | K = None 75 | if args.dataset_type == 'blender': 76 | images, poses, render_poses, hwf, i_split = load_blender_data(self.datadir, args.half_res, args.testskip) 77 | # print('Loaded blender', images.shape, render_poses.shape, hwf, self.datadir) 78 | masks = images[..., -1:] 79 | near = 2. 80 | far = 6. 81 | if args.white_bkgd: 82 | images = images[..., :3] * images[..., -1:] + (1. - images[..., -1:]) 83 | else: 84 | images = images[..., :3] 85 | else: 86 | print('Unknown dataset type', args.dataset_type, 'exiting') 87 | return 88 | 89 | self.i_split = i_split 90 | self.images = images 91 | self.masks = masks 92 | self.poses = poses 93 | self.render_poses = render_poses 94 | 95 | # Cast intrinsics to right types 96 | H, W, focal = hwf 97 | H, W = int(H), int(W) 98 | hwf = [H, W, focal] 99 | 100 | if K is None: 101 | self.K = np.array([ 102 | [focal, 0, 0.5 * W], 103 | [0, focal, 0.5 * H], 104 | [0, 0, 1] 105 | ]) 106 | else: 107 | self.K = K 108 | 109 | 110 | self.hwf = hwf 111 | self.near = near 112 | self.far = far 113 | 114 | 115 | def genpc(self): 116 | [H, W, focal] = self.hwf 117 | K = torch.tensor(self.K).cuda() 118 | train_n = 100 119 | poses = torch.tensor(self.poses).cuda()[:train_n] 120 | images = torch.tensor(self.images)[:train_n] 121 | 122 | pc,color,N = [],[],400 123 | [xs,ys,zs],[xe,ye,ze] = [-2,-2,-2],[2,2,2] 124 | pts_all = [] 125 | for h_id in tqdm(range(N)): 126 | i, j = torch.meshgrid(torch.linspace(xs, xe, N).cuda(), 127 | torch.linspace(ys, ye, N).cuda()) # pytorch's meshgrid has indexing='ij' 128 | i, j = i.t(), j.t() 129 | pts = torch.stack([i, j, torch.ones_like(i).cuda()], -1) 130 | pts[...,2] = h_id / N * (ze - zs) + zs 131 | pts_all.append(pts.clone()) 132 | uv = batch_get_uv_from_ray(H,W,K,poses,pts) 133 | result = F.grid_sample(images.permute(0, 3, 1, 2).float(), uv).permute(0,2,3,1) 134 | 135 | margin = 0.05 136 | result[(uv[..., 0] >= 1.0) * (uv[..., 0] <= 1.0 + margin)] = 1 137 | result[(uv[..., 0] >= -1.0 - margin) * (uv[..., 0] <= -1.0)] = 1 138 | result[(uv[..., 1] >= 1.0) * (uv[..., 1] <= 1.0 + margin)] = 1 139 | result[(uv[..., 1] >= -1.0 - margin) * (uv[..., 1] <= -1.0)] = 1 140 | result[(uv[..., 0] <= -1.0 - margin) + (uv[..., 0] >= 1.0 + margin)] = 0 141 | result[(uv[..., 1] <= -1.0 - margin) + (uv[..., 1] >= 1.0 + margin)] = 0 142 | 143 | img = ((result>0.).sum(0)[...,0]>train_n-1).float() 144 | pc.append(img) 145 | color.append(result.mean(0)) 146 | pc = torch.stack(pc,-1) 147 | color = torch.stack(color,-1) 148 | r, g, b = color[:, :, 0], color[:, :, 1], color[:, :, 2] 149 | idx = torch.where(pc > 0) 150 | color = torch.stack((r[idx],g[idx],b[idx]),-1) 151 | idx = (idx[1],idx[0],idx[2]) 152 | pts = torch.stack(idx,-1).float()/N 153 | pts[:,0] = pts[:,0]*(xe-xs)+xs 154 | pts[:,1] = pts[:,1]*(ye-ys)+ys 155 | pts[:,2] = pts[:,2]*(ze-zs)+zs 156 | 157 | pts = torch.cat((pts,color),-1).cpu().data.numpy() 158 | print('Initialization, data:{} point:{}'.format(self.dataname,pts.shape)) 159 | item = MemDataset(pts,self.poses,self.images,self.masks,self.K) 160 | return item 161 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import time 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from tqdm import tqdm 9 | from PIL import Image 10 | 11 | from modules.model import CoreModel 12 | from modules.utils import device, mse2psnr, \ 13 | grad_loss, safe_path, set_seed 14 | from modules.config import config_parser 15 | from dataloader.dataset import Data 16 | 17 | 18 | class Trainder(object): 19 | def __init__(self, args): 20 | self.args = args 21 | self.dataname = args.dataname 22 | self.logpath = args.basedir 23 | self.outpath = safe_path(os.path.join(self.logpath, 'output')) 24 | self.weightpath = safe_path(os.path.join(self.logpath, 'weight')) 25 | self.imgpath = safe_path(os.path.join(self.outpath, 'images')) 26 | self.imgpath = safe_path(os.path.join(self.imgpath, '{}'.format(self.dataname))) 27 | self.logfile = os.path.join(self.outpath, 'log_{}.txt'.format(self.dataname)) 28 | self.logfile = open(self.logfile, 'w') 29 | self.model = CoreModel(args).to(device) 30 | self.loss_fn = torch.nn.MSELoss() 31 | self.lr1, self.lr2 = args.lr1, args.lr2 32 | self.lrexp, self.lr_s = args.lrexp, args.lr_s 33 | self.set_optimizer(self.lr1, self.lr2) 34 | self.imagesgt = torch.tensor(self.model.imagesgt).float().to(device) 35 | self.masks = torch.tensor(self.model.masks).float().to(device) 36 | self.imagesgt_train = self.imagesgt 37 | self.imgout_path = safe_path(os.path.join(self.imgpath, 38 | 'v2_{:.3f}_{:.3f}'.format(args.data_r, args.splatting_r))) 39 | self.training_time = 0 40 | print(self.imgout_path) 41 | 42 | def set_onlybase(self): 43 | self.model.onlybase = True 44 | self.set_optimizer(3e-3,self.lr2) 45 | 46 | def remove_onlybase(self): 47 | self.model.onlybase = False 48 | self.set_optimizer(self.lr1,self.lr2) 49 | 50 | def set_optimizer(self, lr1=3e-3, lr2=8e-4): 51 | sh_list = [name for name, params in self.model.named_parameters() if 'sh' in name] 52 | sh_params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] in sh_list, 53 | self.model.named_parameters())))) 54 | other_params = list(map(lambda x: x[1], list(filter(lambda kv: kv[0] not in sh_list, 55 | self.model.named_parameters())))) 56 | optimizer = torch.optim.Adam([ 57 | {'params': sh_params, 'lr': lr1}, 58 | {'params': other_params, 'lr': lr2}]) 59 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, self.lrexp, -1) 60 | self.optimizer, self.lr_scheduler = optimizer, lr_scheduler 61 | return None 62 | 63 | def train(self,epoch_n=30): 64 | self.logfile.write('-----------Stage Segmentation Line-----------') 65 | self.logfile.flush() 66 | max_psnr = 0. 67 | start_time = time.time() 68 | for epoch in range(epoch_n): 69 | loss_all, psnr_all = [], [] 70 | ids = np.random.permutation(100) 71 | for id in tqdm(ids): 72 | images = self.model(id) 73 | loss = self.loss_fn(images[0], self.imagesgt_train[id]) 74 | loss = loss + self.lr_s * grad_loss(images[0], self.imagesgt_train[id]) 75 | self.optimizer.zero_grad() 76 | loss.backward() 77 | self.optimizer.step() 78 | loss_all.append(loss) 79 | psnr_all.append(mse2psnr(loss)) 80 | self.lr_scheduler.step() 81 | loss_e = torch.stack(loss_all).mean().item() 82 | psnr_e = torch.stack(psnr_all).mean().item() 83 | info = '-----train----- epoch:{} loss:{:.3f} psnr:{:.3f}'.format(epoch, loss_e, psnr_e) 84 | print(info) 85 | self.logfile.write(info + '\n') 86 | self.logfile.flush() 87 | psnr_val = self.test(115, 138, False) 88 | if psnr_val > max_psnr: 89 | max_psnr = psnr_val 90 | self.training_time += time.time()-start_time 91 | torch.save(self.model.state_dict(), os.path.join( 92 | self.weightpath,'model_{}.pth'.format(self.dataname))) 93 | 94 | def test(self, start=100, end=115, visual=False): 95 | plt.cla() 96 | plt.clf() 97 | with torch.no_grad(): 98 | loss_all, psnr_all = [], [] 99 | for id in (range(start, end)): 100 | images = self.model(id) 101 | loss = self.loss_fn(images[0], self.imagesgt[id]) 102 | loss_all.append(loss) 103 | psnr_all.append(mse2psnr(loss)) 104 | if visual: 105 | pred = images[0, ..., :3].detach().cpu().data.numpy() 106 | gt = self.imagesgt[id].detach().cpu().data.numpy() 107 | # set background as white for visualization 108 | mask = self.masks[id].cpu().data.numpy() 109 | pred = pred*mask+1-mask 110 | gt = gt*mask+1-mask 111 | img_gt = np.concatenate((pred,gt),1) 112 | img_gt = Image.fromarray((img_gt*255).astype(np.uint8)) 113 | img_gt.save(os.path.join(self.imgout_path, 114 | 'img_{}_{}_{:.2f}.png'.format(self.dataname, id, mse2psnr(loss).item()))) 115 | loss_e = torch.stack(loss_all).mean().item() 116 | psnr_e = torch.stack(psnr_all).mean().item() 117 | info = '-----eval----- loss:{:.3f} psnr:{:.3f}'.format(loss_e, psnr_e) 118 | print(info) 119 | self.logfile.write(info + '\n') 120 | self.logfile.flush() 121 | return psnr_e 122 | 123 | def get_fps_modelsize(self): 124 | start_time = time.time() 125 | for id in (range(0, 138)): 126 | images = self.model(id) 127 | end_time = time.time() 128 | fps = 138 / (end_time - start_time) 129 | model_path = os.path.join( 130 | self.weightpath,'model_{}.pth'.format(self.dataname)) 131 | model_size = os.path.getsize(model_path) 132 | model_size = model_size / float(1024 * 1024) 133 | model_size = round(model_size, 2) 134 | return fps,model_size 135 | 136 | 137 | def solve(args): 138 | trainer = Trainder(args) 139 | trainer.set_onlybase() 140 | trainer.train(epoch_n=20) 141 | trainer.remove_onlybase() 142 | trainer.train() 143 | for i in range(args.refine_n): 144 | trainer.model.remove_out() 145 | trainer.model.repeat_pts() 146 | trainer.set_optimizer(args.lr1, args.lr2) 147 | trainer.train() 148 | trainer.logfile.write('Total Training Time: ' 149 | '{:.2f}s\n'.format(trainer.training_time)) 150 | trainer.logfile.flush() 151 | psnr_e = trainer.test(115, 138, True) 152 | fps,model_size = trainer.get_fps_modelsize() 153 | print('Training time: {:.2f} s'.format(trainer.training_time)) 154 | print('Rendering quality: {:.2f} dB'.format(psnr_e)) 155 | print('Rendering speed: {:.2f} fps'.format(fps)) 156 | print('Model size: {:.2f} MB'.format(model_size)) 157 | 158 | 159 | if __name__ == '__main__': 160 | set_seed(0) 161 | parser = config_parser() 162 | args = parser.parse_args() 163 | dataset = Data(args) 164 | args.memitem = dataset.genpc() 165 | solve(args) 166 | -------------------------------------------------------------------------------- /modules/config.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | 5 | def config_parser(): 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument("--splatting_r", type=float, default=0.015, help='the radius for the splatting') 9 | parser.add_argument("--raster_n", type=int, default=15, help='the point number for soft raterization') 10 | parser.add_argument("--refine_n", type=int, default=2, help='the point number for soft raterization') 11 | parser.add_argument("--data_r", type=float, default=0.012, help='the point number for soft raterization') 12 | parser.add_argument("--step", type=str, default='brdf', help='the running step for the algorithm') 13 | parser.add_argument("--savemodel", type=str, default=None, help='whether to save the model weight or not') 14 | 15 | parser.add_argument("--lr1", type=float, default=22e-3, help='the learning rate for the rgb texture') 16 | parser.add_argument("--lr2", type=float, default=8e-4, help='the learning rate for the point position') 17 | parser.add_argument("--lrexp", type=float, default=0.93, help='the coefficient for the exponential lr decay') 18 | parser.add_argument("--lr_s", type=float, default=0.03, help='the coefficient for the total variance loss') 19 | parser.add_argument("--img_s", type=int, default=400, help='the coefficient for the total variance loss') 20 | parser.add_argument("--memitem", type=object, default=None, help='the coefficient for the total variance loss') 21 | 22 | 23 | parser.add_argument("--expname", type=str, default='pcdata', 24 | help='experiment name') 25 | parser.add_argument("--basedir", type=str, default='../logs/', 26 | help='where to store ckpts and logs') 27 | parser.add_argument("--datadir", type=str, default='/tigress/qz9238/workspace/workspace/data/nerf/nerf_synthetic/', 28 | help='input data directory') 29 | parser.add_argument("--dataname", type=str, default='hotdog', 30 | help='dataset name') 31 | parser.add_argument("--grey", type=int, default=1, 32 | help='whether to set grey or not') 33 | 34 | # training options 35 | parser.add_argument("--netdepth", type=int, default=8, 36 | help='layers in network') 37 | parser.add_argument("--netwidth", type=int, default=256, 38 | help='channels per layer') 39 | parser.add_argument("--netdepth_fine", type=int, default=8, 40 | help='layers in fine network') 41 | parser.add_argument("--netwidth_fine", type=int, default=256, 42 | help='channels per layer in fine network') 43 | parser.add_argument("--N_rand", type=int, default=32*32, 44 | help='batch size (number of random rays per gradient step)') 45 | parser.add_argument("--lrate", type=float, default=5e-4, 46 | help='learning rate') 47 | parser.add_argument("--lrate_decay", type=int, default=500, 48 | help='exponential learning rate decay (in 1000 steps)') 49 | parser.add_argument("--chunk", type=int, default=1024*32, 50 | help='number of rays processed in parallel, decrease if running out of memory') 51 | parser.add_argument("--netchunk", type=int, default=1024*64, 52 | help='number of pts sent through network in parallel, decrease if running out of memory') 53 | parser.add_argument("--no_batching", default=True, 54 | help='only take random rays from 1 image at a time') 55 | parser.add_argument("--no_reload", action='store_true', 56 | help='do not reload weights from saved ckpt') 57 | parser.add_argument("--ft_path", type=str, default=None, 58 | help='specific weights npy file to reload for coarse network') 59 | 60 | # rendering options 61 | parser.add_argument("--N_samples", type=int, default=64, 62 | help='number of coarse samples per ray') 63 | parser.add_argument("--N_importance", type=int, default=128, 64 | help='number of additional fine samples per ray') 65 | parser.add_argument("--perturb", type=float, default=1., 66 | help='set to 0. for no jitter, 1. for jitter') 67 | parser.add_argument("--use_viewdirs", default=True, 68 | help='use full 5D input instead of 3D') 69 | parser.add_argument("--i_embed", type=int, default=0, 70 | help='set 0 for default positional encoding, -1 for none') 71 | parser.add_argument("--multires", type=int, default=10, 72 | help='log2 of max freq for positional encoding (3D location)') 73 | parser.add_argument("--multires_views", type=int, default=4, 74 | help='log2 of max freq for positional encoding (2D direction)') 75 | parser.add_argument("--raw_noise_std", type=float, default=0., 76 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 77 | 78 | parser.add_argument("--render_only", action='store_true', 79 | help='do not optimize, reload weights and render out render_poses path') 80 | parser.add_argument("--render_test", action='store_true', 81 | help='render the test set instead of render_poses path') 82 | parser.add_argument("--render_factor", type=int, default=0, 83 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 84 | 85 | # training options 86 | parser.add_argument("--precrop_iters", type=int, default=500, 87 | help='number of steps to train on central crops') 88 | parser.add_argument("--precrop_frac", type=float, 89 | default=.5, help='fraction of img taken for central crops') 90 | 91 | # dataset options 92 | parser.add_argument("--dataset_type", type=str, default='blender', 93 | help='options: llff / blender / deepvoxels') 94 | parser.add_argument("--testskip", type=int, default=1, 95 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 96 | 97 | ## deepvoxels flags 98 | parser.add_argument("--shape", type=str, default='greek', 99 | help='options : armchair / cube / greek / vase') 100 | 101 | ## blender flags 102 | parser.add_argument("--white_bkgd", action='store_true', 103 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 104 | parser.add_argument("--half_res", default=True, 105 | help='load blender synthetic data at 400x400 instead of 800x800') 106 | 107 | ## llff flags 108 | parser.add_argument("--factor", type=int, default=8, 109 | help='downsample factor for LLFF images') 110 | parser.add_argument("--no_ndc", default=True, help='do not use normalized device coordinates (set for non-forward facing scenes)') 111 | parser.add_argument("--lindisp", action='store_true', 112 | help='sampling linearly in disparity rather than depth') 113 | parser.add_argument("--spherify", action='store_true', 114 | help='set for spherical 360 scenes') 115 | parser.add_argument("--llffhold", type=int, default=8, 116 | help='will take every 1/N images as LLFF test set, paper uses 8') 117 | 118 | # logging/saving options 119 | parser.add_argument("--i_print", type=int, default=100, 120 | help='frequency of console printout and metric loggin') 121 | parser.add_argument("--i_img", type=int, default=500, 122 | help='frequency of tensorboard image logging') 123 | parser.add_argument("--i_weights", type=int, default=500, 124 | help='frequency of weight ckpt saving') 125 | parser.add_argument("--i_testset", type=int, default=5000, 126 | help='frequency of testset saving') 127 | parser.add_argument("--i_video", type=int, default=10000, 128 | help='frequency of render_poses video saving') 129 | 130 | # name_list = ['lego','materials','mic','ficus','drums','chair','ship','hotdog'] 131 | # data_r_mapping = {'hotdog':0.012,'lego':0.08,'materials':0.03,'mic':0.15,'ficus':0.15,'drums':0.25,'chair':0.06,'ship':0.03} 132 | # splatting_mapping = {'hotdog':0.015,'lego':0.010,'materials':0.010,'mic':0.008,'ficus':0.008,'drums':0.008,'chair':0.010,'ship':0.010} 133 | 134 | # args.dataname = name_list[7] 135 | # dataset = Data(args) 136 | # args.memitem = dataset.genpc() 137 | # args.splatting_r = splatting_mapping[args.dataname] 138 | # args.data_r = data_r_mapping[args.dataname] 139 | # solve(args) 140 | 141 | # for i in range(8): 142 | # args.dataname = name_list[i] 143 | # dataset = Data(args) 144 | # args.memitem = dataset.genpc() 145 | # args.splatting_r = splatting_mapping[args.dataname] 146 | # args.data_r = data_r_mapping[args.dataname] 147 | # solve(args) 148 | 149 | return parser 150 | 151 | 152 | 153 | --------------------------------------------------------------------------------