├── figs └── pipeline.png ├── __pycache__ ├── load_llff.cpython-38.pyc ├── load_LINEMOD.cpython-38.pyc ├── load_blender.cpython-38.pyc ├── load_deepvoxels.cpython-38.pyc └── run_nerf_helpers.cpython-38.pyc ├── requirements.txt ├── configs └── llff.txt ├── load_blender.py ├── load_LINEMOD.py ├── load_deepvoxels.py ├── README.md ├── run_nerf_helpers.py ├── load_llff.py └── run_nerf.py /figs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsh2000/SS-NeRF/HEAD/figs/pipeline.png -------------------------------------------------------------------------------- /__pycache__/load_llff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsh2000/SS-NeRF/HEAD/__pycache__/load_llff.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/load_LINEMOD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsh2000/SS-NeRF/HEAD/__pycache__/load_LINEMOD.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/load_blender.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsh2000/SS-NeRF/HEAD/__pycache__/load_blender.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/load_deepvoxels.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsh2000/SS-NeRF/HEAD/__pycache__/load_deepvoxels.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/run_nerf_helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zsh2000/SS-NeRF/HEAD/__pycache__/run_nerf_helpers.cpython-38.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8 2 | torchvision>=0.9.1 3 | imageio 4 | imageio-ffmpeg 5 | matplotlib 6 | configargparse 7 | tensorboard>=2.0 8 | tqdm 9 | opencv-python 10 | -------------------------------------------------------------------------------- /configs/llff.txt: -------------------------------------------------------------------------------- 1 | expname = office3_sn_debug 2 | basedir = ./logs 3 | datadir = ./nerf_data/office_4_sample_5 4 | dataset_type = llff 5 | 6 | factor = 1 7 | llffhold = 8 8 | 9 | N_rand = 1024 10 | N_samples = 64 11 | N_importance = 64 12 | 13 | use_viewdirs = True 14 | raw_noise_std = 1e0 15 | 16 | -------------------------------------------------------------------------------- /load_blender.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_blender_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for frame in meta['frames'][::skip]: 57 | fname = os.path.join(basedir, frame['file_path'] + '.png') 58 | imgs.append(imageio.imread(fname)) 59 | poses.append(np.array(frame['transform_matrix'])) 60 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 61 | poses = np.array(poses).astype(np.float32) 62 | counts.append(counts[-1] + imgs.shape[0]) 63 | all_imgs.append(imgs) 64 | all_poses.append(poses) 65 | 66 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 67 | 68 | imgs = np.concatenate(all_imgs, 0) 69 | poses = np.concatenate(all_poses, 0) 70 | 71 | H, W = imgs[0].shape[:2] 72 | camera_angle_x = float(meta['camera_angle_x']) 73 | focal = .5 * W / np.tan(.5 * camera_angle_x) 74 | 75 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 76 | 77 | if half_res: 78 | H = H//2 79 | W = W//2 80 | focal = focal/2. 81 | 82 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 4)) 83 | for i, img in enumerate(imgs): 84 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 85 | imgs = imgs_half_res 86 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 87 | 88 | 89 | return imgs, poses, render_poses, [H, W, focal], i_split 90 | 91 | 92 | -------------------------------------------------------------------------------- /load_LINEMOD.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import imageio 5 | import json 6 | import torch.nn.functional as F 7 | import cv2 8 | 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | 37 | def load_LINEMOD_data(basedir, half_res=False, testskip=1): 38 | splits = ['train', 'val', 'test'] 39 | metas = {} 40 | for s in splits: 41 | with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: 42 | metas[s] = json.load(fp) 43 | 44 | all_imgs = [] 45 | all_poses = [] 46 | counts = [0] 47 | for s in splits: 48 | meta = metas[s] 49 | imgs = [] 50 | poses = [] 51 | if s=='train' or testskip==0: 52 | skip = 1 53 | else: 54 | skip = testskip 55 | 56 | for idx_test, frame in enumerate(meta['frames'][::skip]): 57 | fname = frame['file_path'] 58 | if s == 'test': 59 | print(f"{idx_test}th test frame: {fname}") 60 | imgs.append(imageio.imread(fname)) 61 | poses.append(np.array(frame['transform_matrix'])) 62 | imgs = (np.array(imgs) / 255.).astype(np.float32) # keep all 4 channels (RGBA) 63 | poses = np.array(poses).astype(np.float32) 64 | counts.append(counts[-1] + imgs.shape[0]) 65 | all_imgs.append(imgs) 66 | all_poses.append(poses) 67 | 68 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 69 | 70 | imgs = np.concatenate(all_imgs, 0) 71 | poses = np.concatenate(all_poses, 0) 72 | 73 | H, W = imgs[0].shape[:2] 74 | focal = float(meta['frames'][0]['intrinsic_matrix'][0][0]) 75 | K = meta['frames'][0]['intrinsic_matrix'] 76 | print(f"Focal: {focal}") 77 | 78 | render_poses = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 79 | 80 | if half_res: 81 | H = H//2 82 | W = W//2 83 | focal = focal/2. 84 | 85 | imgs_half_res = np.zeros((imgs.shape[0], H, W, 3)) 86 | for i, img in enumerate(imgs): 87 | imgs_half_res[i] = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) 88 | imgs = imgs_half_res 89 | # imgs = tf.image.resize_area(imgs, [400, 400]).numpy() 90 | 91 | near = np.floor(min(metas['train']['near'], metas['test']['near'])) 92 | far = np.ceil(max(metas['train']['far'], metas['test']['far'])) 93 | return imgs, poses, render_poses, [H, W, focal], K, i_split, near, far 94 | 95 | 96 | -------------------------------------------------------------------------------- /load_deepvoxels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | def load_dv_data(scene='cube', basedir='/data/deepvoxels', testskip=8): 7 | 8 | 9 | def parse_intrinsics(filepath, trgt_sidelength, invert_y=False): 10 | # Get camera intrinsics 11 | with open(filepath, 'r') as file: 12 | f, cx, cy = list(map(float, file.readline().split()))[:3] 13 | grid_barycenter = np.array(list(map(float, file.readline().split()))) 14 | near_plane = float(file.readline()) 15 | scale = float(file.readline()) 16 | height, width = map(float, file.readline().split()) 17 | 18 | try: 19 | world2cam_poses = int(file.readline()) 20 | except ValueError: 21 | world2cam_poses = None 22 | 23 | if world2cam_poses is None: 24 | world2cam_poses = False 25 | 26 | world2cam_poses = bool(world2cam_poses) 27 | 28 | print(cx,cy,f,height,width) 29 | 30 | cx = cx / width * trgt_sidelength 31 | cy = cy / height * trgt_sidelength 32 | f = trgt_sidelength / height * f 33 | 34 | fx = f 35 | if invert_y: 36 | fy = -f 37 | else: 38 | fy = f 39 | 40 | # Build the intrinsic matrices 41 | full_intrinsic = np.array([[fx, 0., cx, 0.], 42 | [0., fy, cy, 0], 43 | [0., 0, 1, 0], 44 | [0, 0, 0, 1]]) 45 | 46 | return full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses 47 | 48 | 49 | def load_pose(filename): 50 | assert os.path.isfile(filename) 51 | nums = open(filename).read().split() 52 | return np.array([float(x) for x in nums]).reshape([4,4]).astype(np.float32) 53 | 54 | 55 | H = 512 56 | W = 512 57 | deepvoxels_base = '{}/train/{}/'.format(basedir, scene) 58 | 59 | full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses = parse_intrinsics(os.path.join(deepvoxels_base, 'intrinsics.txt'), H) 60 | print(full_intrinsic, grid_barycenter, scale, near_plane, world2cam_poses) 61 | focal = full_intrinsic[0,0] 62 | print(H, W, focal) 63 | 64 | 65 | def dir2poses(posedir): 66 | poses = np.stack([load_pose(os.path.join(posedir, f)) for f in sorted(os.listdir(posedir)) if f.endswith('txt')], 0) 67 | transf = np.array([ 68 | [1,0,0,0], 69 | [0,-1,0,0], 70 | [0,0,-1,0], 71 | [0,0,0,1.], 72 | ]) 73 | poses = poses @ transf 74 | poses = poses[:,:3,:4].astype(np.float32) 75 | return poses 76 | 77 | posedir = os.path.join(deepvoxels_base, 'pose') 78 | poses = dir2poses(posedir) 79 | testposes = dir2poses('{}/test/{}/pose'.format(basedir, scene)) 80 | testposes = testposes[::testskip] 81 | valposes = dir2poses('{}/validation/{}/pose'.format(basedir, scene)) 82 | valposes = valposes[::testskip] 83 | 84 | imgfiles = [f for f in sorted(os.listdir(os.path.join(deepvoxels_base, 'rgb'))) if f.endswith('png')] 85 | imgs = np.stack([imageio.imread(os.path.join(deepvoxels_base, 'rgb', f))/255. for f in imgfiles], 0).astype(np.float32) 86 | 87 | 88 | testimgd = '{}/test/{}/rgb'.format(basedir, scene) 89 | imgfiles = [f for f in sorted(os.listdir(testimgd)) if f.endswith('png')] 90 | testimgs = np.stack([imageio.imread(os.path.join(testimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 91 | 92 | valimgd = '{}/validation/{}/rgb'.format(basedir, scene) 93 | imgfiles = [f for f in sorted(os.listdir(valimgd)) if f.endswith('png')] 94 | valimgs = np.stack([imageio.imread(os.path.join(valimgd, f))/255. for f in imgfiles[::testskip]], 0).astype(np.float32) 95 | 96 | all_imgs = [imgs, valimgs, testimgs] 97 | counts = [0] + [x.shape[0] for x in all_imgs] 98 | counts = np.cumsum(counts) 99 | i_split = [np.arange(counts[i], counts[i+1]) for i in range(3)] 100 | 101 | imgs = np.concatenate(all_imgs, 0) 102 | poses = np.concatenate([poses, valposes, testposes], 0) 103 | 104 | render_poses = testposes 105 | 106 | print(poses.shape, imgs.shape) 107 | 108 | return imgs, poses, render_poses, [H,W,focal], i_split 109 | 110 | 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > # [WACV 2023] Beyond RGB: Scene-Property Synthesis with Neural Radiance Fields
2 | > [Paper](https://arxiv.org/abs/2206.04669) 3 | 4 | This repository contains a PyTorch implementation of our paper "Beyond RGB: Scene-Property Synthesis with Neural Radiance Fields". 5 | 6 | ![Pipeline](figs/pipeline.png) 7 | 8 | ## Installation 9 | 10 | #### Tested on a single NVIDIA GeForce RTX 1080 / 2080 GPUs with 11GB memory. 11 | 12 | To install the dependencies, run: 13 | 14 | ``` 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ## Dataset 19 | 20 | The dataset used in the paper can be accessed through https://drive.google.com/file/d/1GilHPjIqsa3Fjtizy1RJIGL2p2beHbhP/view?usp=sharing. It consists of 4 scenes from the Replica dataset, each having 50 frames. 21 | 22 | [^_^]: If you want to render your own scenes in Replica, the directly rendered Replica scene contains RGB images, depth map and 88-class semantic labels. We map the 88-class semantic labels to NYU-13 definitions. The surface normals are derived from depth maps. The edges are 23 | 24 | ## Data Preparation 25 | First organize the data of the scene in the hierarchy shown below (e.g., there are 50 frames in this scene): 26 | 27 | ``` 28 | datadir 29 | - images 30 | - 000.png 31 | - 001.png 32 | ... 33 | - 049.png 34 | - normal 35 | - 000.png 36 | - 001.png 37 | ... 38 | - 049.png 39 | - reshading 40 | - 000.png 41 | - 001.png 42 | ... 43 | - 049.png 44 | - label 45 | - 000.png 46 | - 001.png 47 | ... 48 | - 049.png 49 | - keypoint 50 | - 000.png 51 | - 001.png 52 | ... 53 | - 049.png 54 | - edge 55 | - 000.png 56 | - 001.png 57 | ... 58 | - 049.png 59 | - poses_bounds.npy 60 | ``` 61 | 62 | `datadir` is the name of the direction of the scene, which is specified in the configuration file. `poses_bounds.npy` consists of the poses of all the frames in the feed-forward style scene, 63 | which can be constructed using the guidance in https://github.com/Fyusion/LLFF#using-your-own-poses-without-running-colmap. 64 | 65 | ## Training (Optimizing on a specific scene) 66 | 67 | Run the command: 68 | 69 | ``` 70 | python run_nerf.py --config configs/repaired.txt --no_ndc 71 | ``` 72 | 73 | Tips: If you are running the code using a device which only has approximately 11GB memory (e.g., GeForce RTX 1080 / 2080 GPUs), 74 | we recommend that you could set the argument `i_testset` to 200000 (the total number of iterations), to avoid interruption during training due to limited GPU memory. 75 | 76 | Remember to turn on the flag `--no_ndc` as it will be a more proper setting for feed-forwarding scenes. 77 | 78 | The default training/testing split is the seven views {0, 8, 16, 24, 32, 40, 48} are the hold-out testing views 79 | and the rest are training views for optimization process. You could customize your own training/testing split on [L651](https://github.com/zsh2000/SS-NeRF/blob/856a3b3d12698a710b2b7a6805d878109a7cc692/run_nerf.py#L651) in `run_nerf.py`. 80 | 81 | 82 | ## Testing (Novel view synthesis) 83 | 84 | When using a larger GPU which is capable of rendering novel view results during training, there is no need for separate testing command. 85 | 86 | If you are using the device with small GPU memory as mentioned above in the "tips", you could perform novel view synthesis by 87 | running an additional command, which is the same as training command: 88 | 89 | ``` 90 | python run_nerf.py --config configs/repaired.txt --no_ndc 91 | ``` 92 | 93 | ## Citation 94 | If you find our work useful, please consider citing: 95 | ```BibTeX 96 | @inproceedings{ssnerf-2023, 97 | author = {Zhang, Mingtong and Zheng, Shuhong and Bao, Zhipeng and Hebert, Martial and Wang, Yu-Xiong}, 98 | title = {Beyond RGB: Scene-Property Synthesis with Neural Radiance Fields}, 99 | booktitle = {IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 100 | year = {2023} 101 | } 102 | ``` 103 | 104 | ### Acknowledgement 105 | The codes are largely borrowed from the PyTorch implementation of NeRF: 106 | 107 | https://github.com/yenchenlin/nerf-pytorch 108 | 109 | This work was supported in part by NSF Grant, Toyota Research Institute, NIFA award, the Jump ARCHES endowment through the Health Care 110 | Engineering Systems Center, the National Center for Supercomputing Applications (NCSA) at the University of Illinois 111 | at Urbana-Champaign through the NCSA Fellows program, 112 | and the IBM-Illinois Discovery Accelerator Institute. 113 | -------------------------------------------------------------------------------- /run_nerf_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.autograd.set_detect_anomaly(True) 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | # Misc 9 | img2mse = lambda x, y : torch.mean((x - y) ** 2) 10 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 11 | to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) 12 | 13 | def cosine_similarity_loss(x, y): 14 | cosine_similarity = x*y 15 | loss_pointwise = 1. - torch.sum(cosine_similarity, dim = -1) 16 | return torch.mean(loss_pointwise) 17 | 18 | # Positional encoding (section 5.1) 19 | class Embedder: 20 | def __init__(self, **kwargs): 21 | self.kwargs = kwargs 22 | self.create_embedding_fn() 23 | 24 | def create_embedding_fn(self): 25 | embed_fns = [] 26 | d = self.kwargs['input_dims'] 27 | out_dim = 0 28 | if self.kwargs['include_input']: 29 | embed_fns.append(lambda x : x) 30 | out_dim += d 31 | 32 | max_freq = self.kwargs['max_freq_log2'] 33 | N_freqs = self.kwargs['num_freqs'] 34 | 35 | if self.kwargs['log_sampling']: 36 | freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs) 37 | else: 38 | freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs) 39 | 40 | for freq in freq_bands: 41 | for p_fn in self.kwargs['periodic_fns']: 42 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq)) 43 | out_dim += d 44 | 45 | self.embed_fns = embed_fns 46 | self.out_dim = out_dim 47 | 48 | def embed(self, inputs): 49 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 50 | 51 | 52 | def get_embedder(multires, input_dims, i=0): 53 | if i == -1: 54 | return nn.Identity(), 3 55 | 56 | embed_kwargs = { 57 | 'include_input' : True, 58 | 'input_dims' : input_dims, 59 | 'max_freq_log2' : multires-1, 60 | 'num_freqs' : multires, 61 | 'log_sampling' : True, 62 | 'periodic_fns' : [torch.sin, torch.cos], 63 | } 64 | 65 | embedder_obj = Embedder(**embed_kwargs) 66 | embed = lambda x, eo=embedder_obj : eo.embed(x) 67 | return embed, embedder_obj.out_dim 68 | 69 | 70 | # Model 71 | class NeRF(nn.Module): 72 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, input_ch_poses=12, output_ch=4, skips=[4], use_viewdirs=False): 73 | """ 74 | """ 75 | super(NeRF, self).__init__() 76 | self.D = D 77 | self.W = W 78 | self.input_ch = input_ch 79 | self.input_ch_views = input_ch_views 80 | self.input_ch_poses = input_ch_poses 81 | self.skips = skips 82 | self.use_viewdirs = use_viewdirs 83 | 84 | self.pts_linears = nn.ModuleList( 85 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) 86 | 87 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)]) 88 | self.poses_linears = nn.ModuleList([nn.Linear(input_ch_poses + W, W//2)]) 89 | 90 | if use_viewdirs: 91 | self.feature_linear = nn.Linear(W, W) 92 | self.alpha_linear = nn.Linear(W, 1) 93 | self.rgb_linear = nn.Linear(W//2, 3) 94 | self.seg_linear = nn.Linear(W, 13) 95 | self.kp_linear = nn.Linear(W//2, 1) 96 | self.edge_linear = nn.Linear(W//2, 1) 97 | self.normal_linear = nn.Linear(W//2, 3) 98 | self.reshading_linear = nn.Linear(W//2, 1) 99 | else: 100 | self.output_linear = nn.Linear(W, output_ch) 101 | 102 | def forward(self, x): 103 | input_pts, input_views, input_poses = torch.split(x, [self.input_ch, self.input_ch_views, self.input_ch_poses], dim=-1) 104 | 105 | 106 | h = input_pts 107 | for i, l in enumerate(self.pts_linears): 108 | h = self.pts_linears[i](h) 109 | h = F.relu(h) 110 | if i in self.skips: 111 | h = torch.cat([input_pts, h], -1) 112 | 113 | if self.use_viewdirs: 114 | alpha = self.alpha_linear(h) 115 | seg = self.seg_linear(h) 116 | feature = self.feature_linear(h) 117 | h = torch.cat([feature, input_views], -1) 118 | 119 | h_star = torch.cat([feature, input_poses], -1) 120 | 121 | for i, l in enumerate(self.views_linears): 122 | h = self.views_linears[i](h) 123 | h = F.relu(h) 124 | 125 | for i, l in enumerate(self.poses_linears): 126 | h_star = self.poses_linears[i](h_star) 127 | h_star = F.relu(h_star) 128 | 129 | 130 | 131 | rgb = self.rgb_linear(h) 132 | normal = self.normal_linear(h_star) 133 | reshading = self.reshading_linear(h) 134 | kp = self.kp_linear(h) 135 | edge = self.edge_linear(h) 136 | 137 | outputs = torch.cat([rgb, alpha, seg, normal, reshading, kp, edge], -1) 138 | else: 139 | outputs = self.output_linear(h) 140 | 141 | return outputs 142 | 143 | def load_weights_from_keras(self, weights): 144 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 145 | 146 | # Load pts_linears 147 | for i in range(self.D): 148 | idx_pts_linears = 2 * i 149 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 150 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears+1])) 151 | 152 | # Load feature_linear 153 | idx_feature_linear = 2 * self.D 154 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 155 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear+1])) 156 | 157 | # Load views_linears 158 | idx_views_linears = 2 * self.D + 2 159 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 160 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears+1])) 161 | 162 | # Load rgb_linear 163 | idx_rbg_linear = 2 * self.D + 4 164 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 165 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear+1])) 166 | 167 | # Load alpha_linear 168 | idx_alpha_linear = 2 * self.D + 6 169 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 170 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear+1])) 171 | 172 | 173 | 174 | # Ray helpers 175 | def get_rays(H, W, K, c2w): 176 | i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij' 177 | i = i.t() 178 | j = j.t() 179 | dirs = torch.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -torch.ones_like(i)], -1) 180 | # Rotate ray directions from camera frame to the world frame 181 | rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 182 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 183 | rays_o = c2w[:3,-1].expand(rays_d.shape) 184 | return rays_o, rays_d 185 | 186 | 187 | def get_rays_np(H, W, K, c2w): 188 | i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') 189 | dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1) 190 | # Rotate ray directions from camera frame to the world frame 191 | rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs] 192 | # Translate camera frame's origin to the world frame. It is the origin of all rays. 193 | rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d)) 194 | return rays_o, rays_d 195 | 196 | 197 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 198 | # Shift ray origins to near plane 199 | t = -(near + rays_o[...,2]) / rays_d[...,2] 200 | rays_o = rays_o + t[...,None] * rays_d 201 | 202 | # Projection 203 | o0 = -1./(W/(2.*focal)) * rays_o[...,0] / rays_o[...,2] 204 | o1 = -1./(H/(2.*focal)) * rays_o[...,1] / rays_o[...,2] 205 | o2 = 1. + 2. * near / rays_o[...,2] 206 | 207 | d0 = -1./(W/(2.*focal)) * (rays_d[...,0]/rays_d[...,2] - rays_o[...,0]/rays_o[...,2]) 208 | d1 = -1./(H/(2.*focal)) * (rays_d[...,1]/rays_d[...,2] - rays_o[...,1]/rays_o[...,2]) 209 | d2 = -2. * near / rays_o[...,2] 210 | 211 | rays_o = torch.stack([o0,o1,o2], -1) 212 | rays_d = torch.stack([d0,d1,d2], -1) 213 | 214 | return rays_o, rays_d 215 | 216 | 217 | # Hierarchical sampling (section 5.2) 218 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 219 | # Get pdf 220 | weights = weights + 1e-5 # prevent nans 221 | pdf = weights / torch.sum(weights, -1, keepdim=True) 222 | cdf = torch.cumsum(pdf, -1) 223 | cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins)) 224 | 225 | # Take uniform samples 226 | if det: 227 | u = torch.linspace(0., 1., steps=N_samples) 228 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 229 | else: 230 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples]) 231 | 232 | # Pytest, overwrite u with numpy's fixed random numbers 233 | if pytest: 234 | np.random.seed(0) 235 | new_shape = list(cdf.shape[:-1]) + [N_samples] 236 | if det: 237 | u = np.linspace(0., 1., N_samples) 238 | u = np.broadcast_to(u, new_shape) 239 | else: 240 | u = np.random.rand(*new_shape) 241 | u = torch.Tensor(u) 242 | 243 | # Invert CDF 244 | u = u.contiguous() 245 | inds = torch.searchsorted(cdf, u, right=True) 246 | below = torch.max(torch.zeros_like(inds-1), inds-1) 247 | above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) 248 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 249 | 250 | # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 251 | # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) 252 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 253 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 254 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 255 | 256 | denom = (cdf_g[...,1]-cdf_g[...,0]) 257 | denom = torch.where(denom<1e-5, torch.ones_like(denom), denom) 258 | t = (u-cdf_g[...,0])/denom 259 | samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0]) 260 | 261 | return samples 262 | -------------------------------------------------------------------------------- /load_llff.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, imageio 3 | import cv2 4 | 5 | ########## Slightly modified version of LLFF data loading code 6 | ########## see https://github.com/Fyusion/LLFF for original 7 | 8 | def _minify(basedir, factors=[], resolutions=[]): 9 | needtoload = False 10 | for r in factors: 11 | imgdir = os.path.join(basedir, 'images_{}'.format(r)) 12 | if not os.path.exists(imgdir): 13 | needtoload = True 14 | for r in resolutions: 15 | imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0])) 16 | if not os.path.exists(imgdir): 17 | needtoload = True 18 | if not needtoload: 19 | return 20 | 21 | from shutil import copy 22 | from subprocess import check_output 23 | 24 | imgdir = os.path.join(basedir, 'images') 25 | imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))] 26 | imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])] 27 | imgdir_orig = imgdir 28 | 29 | wd = os.getcwd() 30 | 31 | for r in factors + resolutions: 32 | if isinstance(r, int): 33 | name = 'images_{}'.format(r) 34 | resizearg = '{}%'.format(100./r) 35 | else: 36 | name = 'images_{}x{}'.format(r[1], r[0]) 37 | resizearg = '{}x{}'.format(r[1], r[0]) 38 | imgdir = os.path.join(basedir, name) 39 | if os.path.exists(imgdir): 40 | continue 41 | 42 | print('Minifying', r, basedir) 43 | 44 | os.makedirs(imgdir) 45 | check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True) 46 | 47 | ext = imgs[0].split('.')[-1] 48 | args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)]) 49 | print(args) 50 | os.chdir(imgdir) 51 | check_output(args, shell=True) 52 | os.chdir(wd) 53 | 54 | if ext != 'png': 55 | check_output('rm {}/*.{}'.format(imgdir, ext), shell=True) 56 | print('Removed duplicates') 57 | print('Done') 58 | 59 | 60 | 61 | 62 | def _load_data(basedir, factor=None, width=None, height=None, load_imgs=True): 63 | 64 | poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy')) 65 | poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0]) 66 | bds = poses_arr[:, -2:].transpose([1,0]) 67 | 68 | img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \ 69 | if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0] 70 | sh = imageio.imread(img0).shape 71 | print(sh, "sh") 72 | sfx = '' 73 | 74 | if factor is not None: 75 | sfx = '_{}'.format(factor) 76 | _minify(basedir, factors=[factor]) 77 | factor = factor 78 | elif height is not None: 79 | factor = sh[0] / float(height) 80 | width = int(sh[1] / factor) 81 | _minify(basedir, resolutions=[[height, width]]) 82 | sfx = '_{}x{}'.format(width, height) 83 | elif width is not None: 84 | factor = sh[1] / float(width) 85 | height = int(sh[0] / factor) 86 | _minify(basedir, resolutions=[[height, width]]) 87 | sfx = '_{}x{}'.format(width, height) 88 | else: 89 | factor = 1 90 | 91 | imgdir = os.path.join(basedir, 'images' + sfx) 92 | if not os.path.exists(imgdir): 93 | print( imgdir, 'does not exist, returning' ) 94 | return 95 | 96 | imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')] 97 | if poses.shape[-1] != len(imgfiles): 98 | print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) ) 99 | return 100 | 101 | sh = imageio.imread(imgfiles[0]).shape 102 | poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) 103 | poses[2, 4, :] = poses[2, 4, :] * 1./factor 104 | 105 | if not load_imgs: 106 | return poses, bds 107 | 108 | def imread(f): 109 | if f.endswith('png'): 110 | return imageio.imread(f, ignoregamma=True) 111 | else: 112 | return imageio.imread(f) 113 | 114 | imgs = imgs = [imread(f)[...,:3]/255. for f in imgfiles] 115 | # imgs = imgs = [np.repeat(imread(f)[...,None], 3)/255. for f in imgfiles] 116 | # print(imgs[0].shape,"imgs") 117 | # for f in imgfiles: 118 | # print(imread(f).shape) 119 | 120 | 121 | imgs = np.stack(imgs, -1) 122 | # print(imgs.shape,"imgs") 123 | print('Loaded image data', imgs.shape, poses[:,-1,0]) 124 | 125 | def segread(f): 126 | return cv2.imread(f)[:,:,0:1] 127 | 128 | segdir = os.path.join(basedir, 'label') 129 | 130 | segfiles = [os.path.join(segdir, f) for f in sorted(os.listdir(segdir)) if f.endswith('PNG') or f.endswith('png')] 131 | 132 | segs = segs = [segread(f) for f in segfiles] 133 | 134 | segs = np.stack(segs, -1) 135 | print(segs.shape, "segs") 136 | 137 | 138 | 139 | def normread(f): 140 | if f.endswith('png'): 141 | normal = imageio.imread(f, ignoregamma=True) 142 | else: 143 | normal = imageio.imread(f) 144 | # normal = normal/np.linalg.norm(normal, ord=2, axis=2, keepdims=True) 145 | return normal 146 | 147 | normdir = os.path.join(basedir, 'normal_new') 148 | 149 | normfiles = [os.path.join(normdir, f) for f in sorted(os.listdir(normdir)) if f.endswith('PNG') or f.endswith('png')] 150 | 151 | normals = normals = [normread(f)[...,:3]/255. for f in normfiles] 152 | 153 | normals = np.stack(normals, -1) 154 | 155 | 156 | reshadingdir = os.path.join(basedir, 'reshading') 157 | 158 | reshadingfiles = [os.path.join(reshadingdir, f) for f in sorted(os.listdir(reshadingdir)) if f.endswith('PNG') or f.endswith('png')] 159 | 160 | reshadings = reshadings = [segread(f)/255. for f in reshadingfiles] 161 | 162 | reshadings = np.stack(reshadings, -1) 163 | print(reshadings.shape, "reshadings") 164 | 165 | 166 | kpdir = os.path.join(basedir, 'keypoint') 167 | 168 | kpfiles = [os.path.join(kpdir, f) for f in sorted(os.listdir(kpdir)) if f.endswith('PNG') or f.endswith('png')] 169 | 170 | kps = kps = [segread(f)/255. for f in kpfiles] 171 | 172 | kps = np.stack(kps, -1) 173 | print(kps.shape, "kps") 174 | 175 | 176 | edgedir = os.path.join(basedir, 'edge') 177 | 178 | edgefiles = [os.path.join(edgedir, f) for f in sorted(os.listdir(edgedir)) if f.endswith('PNG') or f.endswith('png')] 179 | 180 | edges = edges = [segread(f)/255. for f in edgefiles] 181 | 182 | edges = np.stack(edges, -1) 183 | print(edges.shape, "edges") 184 | 185 | return poses, bds, imgs, segs, normals, reshadings, kps, edges 186 | 187 | 188 | 189 | 190 | 191 | 192 | def normalize(x): 193 | return x / np.linalg.norm(x) 194 | 195 | def viewmatrix(z, up, pos): 196 | vec2 = normalize(z) 197 | vec1_avg = up 198 | vec0 = normalize(np.cross(vec1_avg, vec2)) 199 | vec1 = normalize(np.cross(vec2, vec0)) 200 | m = np.stack([vec0, vec1, vec2, pos], 1) 201 | return m 202 | 203 | def ptstocam(pts, c2w): 204 | tt = np.matmul(c2w[:3,:3].T, (pts-c2w[:3,3])[...,np.newaxis])[...,0] 205 | return tt 206 | 207 | def poses_avg(poses): 208 | 209 | hwf = poses[0, :3, -1:] 210 | 211 | center = poses[:, :3, 3].mean(0) 212 | vec2 = normalize(poses[:, :3, 2].sum(0)) 213 | up = poses[:, :3, 1].sum(0) 214 | c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) 215 | 216 | return c2w 217 | 218 | 219 | 220 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): 221 | render_poses = [] 222 | rads = np.array(list(rads) + [1.]) 223 | hwf = c2w[:,4:5] 224 | 225 | for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]: 226 | c = np.dot(c2w[:3,:4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta*zrate), 1.]) * rads) 227 | z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.]))) 228 | render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) 229 | return render_poses 230 | 231 | 232 | 233 | def recenter_poses(poses): 234 | 235 | poses_ = poses+0 236 | bottom = np.reshape([0,0,0,1.], [1,4]) 237 | c2w = poses_avg(poses) 238 | c2w = np.concatenate([c2w[:3,:4], bottom], -2) 239 | bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1]) 240 | poses = np.concatenate([poses[:,:3,:4], bottom], -2) 241 | 242 | poses = np.linalg.inv(c2w) @ poses 243 | poses_[:,:3,:4] = poses[:,:3,:4] 244 | poses = poses_ 245 | return poses 246 | 247 | 248 | ##################### 249 | 250 | 251 | def spherify_poses(poses, bds): 252 | 253 | p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1) 254 | 255 | rays_d = poses[:,:3,2:3] 256 | rays_o = poses[:,:3,3:4] 257 | 258 | def min_line_dist(rays_o, rays_d): 259 | A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1]) 260 | b_i = -A_i @ rays_o 261 | pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0)) 262 | return pt_mindist 263 | 264 | pt_mindist = min_line_dist(rays_o, rays_d) 265 | 266 | center = pt_mindist 267 | up = (poses[:,:3,3] - center).mean(0) 268 | 269 | vec0 = normalize(up) 270 | vec1 = normalize(np.cross([.1,.2,.3], vec0)) 271 | vec2 = normalize(np.cross(vec0, vec1)) 272 | pos = center 273 | c2w = np.stack([vec1, vec2, vec0, pos], 1) 274 | 275 | poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4]) 276 | 277 | rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1))) 278 | 279 | sc = 1./rad 280 | poses_reset[:,:3,3] *= sc 281 | bds *= sc 282 | rad *= sc 283 | 284 | centroid = np.mean(poses_reset[:,:3,3], 0) 285 | zh = centroid[2] 286 | radcircle = np.sqrt(rad**2-zh**2) 287 | new_poses = [] 288 | 289 | for th in np.linspace(0.,2.*np.pi, 120): 290 | 291 | camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) 292 | up = np.array([0,0,-1.]) 293 | 294 | vec2 = normalize(camorigin) 295 | vec0 = normalize(np.cross(vec2, up)) 296 | vec1 = normalize(np.cross(vec2, vec0)) 297 | pos = camorigin 298 | p = np.stack([vec0, vec1, vec2, pos], 1) 299 | 300 | new_poses.append(p) 301 | 302 | new_poses = np.stack(new_poses, 0) 303 | 304 | new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1) 305 | poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1) 306 | 307 | return poses_reset, new_poses, bds 308 | 309 | 310 | def load_llff_data(basedir, factor=8, recenter=True, bd_factor=.75, spherify=False, path_zflat=False): 311 | 312 | 313 | poses, bds, imgs, segs, normals, reshadings, kps, edges = _load_data(basedir, factor=factor) # factor=8 downsamples original imgs by 8x 314 | 315 | print(imgs.shape,"imgs") 316 | print('Loaded', basedir, bds.min(), bds.max()) 317 | 318 | # Correct rotation matrix ordering and move variable dim to axis 0 319 | poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) 320 | poses = np.moveaxis(poses, -1, 0).astype(np.float32) 321 | imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) 322 | images = imgs 323 | bds = np.moveaxis(bds, -1, 0).astype(np.float32) 324 | segs = np.moveaxis(segs, -1, 0).astype(np.int) 325 | normals = np.moveaxis(normals, -1, 0).astype(np.float32) 326 | reshadings = np.moveaxis(reshadings, -1, 0).astype(np.float32) 327 | kps = np.moveaxis(kps, -1, 0).astype(np.float32) 328 | edges = np.moveaxis(edges, -1, 0).astype(np.float32) 329 | 330 | # Rescale if bd_factor is provided 331 | # sc = 1. if bd_factor is None else 1./(bds.min() * bd_factor) 332 | sc = 10./bds.max() 333 | poses[:,:3,3] *= sc 334 | bds *= sc 335 | 336 | if recenter: 337 | poses = recenter_poses(poses) 338 | 339 | if spherify: 340 | poses, render_poses, bds = spherify_poses(poses, bds) 341 | 342 | else: 343 | 344 | c2w = poses_avg(poses) 345 | print('recentered', c2w.shape) 346 | print(c2w[:3,:4]) 347 | 348 | ## Get spiral 349 | # Get average pose 350 | up = normalize(poses[:, :3, 1].sum(0)) 351 | 352 | # Find a reasonable "focus depth" for this dataset 353 | close_depth, inf_depth = bds.min()*.9, bds.max()*5. 354 | dt = .75 355 | mean_dz = 1./(((1.-dt)/close_depth + dt/inf_depth)) 356 | focal = mean_dz 357 | 358 | # Get radii for spiral path 359 | shrink_factor = .8 360 | zdelta = close_depth * .2 361 | tt = poses[:,:3,3] # ptstocam(poses[:3,3,:].T, c2w).T 362 | rads = np.percentile(np.abs(tt), 90, 0) 363 | c2w_path = c2w 364 | N_views = 120 365 | N_rots = 2 366 | if path_zflat: 367 | # zloc = np.percentile(tt, 10, 0)[2] 368 | zloc = -close_depth * .1 369 | c2w_path[:3,3] = c2w_path[:3,3] + zloc * c2w_path[:3,2] 370 | rads[2] = 0. 371 | N_rots = 1 372 | N_views/=2 373 | 374 | # Generate poses for spiral path 375 | render_poses = render_path_spiral(c2w_path, up, rads, focal, zdelta, zrate=.5, rots=N_rots, N=N_views) 376 | 377 | 378 | render_poses = np.array(render_poses).astype(np.float32) 379 | 380 | c2w = poses_avg(poses) 381 | print('Data:') 382 | print(poses.shape, images.shape, bds.shape) 383 | 384 | dists = np.sum(np.square(c2w[:3,3] - poses[:,:3,3]), -1) 385 | i_test = np.argmin(dists) 386 | print('HOLDOUT view is', i_test) 387 | 388 | images = images.astype(np.float32) 389 | poses = poses.astype(np.float32) 390 | normals = normals.astype(np.float32) 391 | reshadings = reshadings.astype(np.float32) 392 | kps = kps.astype(np.float32) 393 | edges = edges.astype(np.float32) 394 | 395 | return images, poses, bds, render_poses, i_test, segs, normals, reshadings, kps, edges 396 | 397 | 398 | 399 | -------------------------------------------------------------------------------- /run_nerf.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import imageio 4 | import json 5 | import random 6 | import time 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm, trange 11 | 12 | 13 | from run_nerf_helpers import * 14 | 15 | from load_llff import load_llff_data 16 | from load_deepvoxels import load_dv_data 17 | from load_blender import load_blender_data 18 | from load_LINEMOD import load_LINEMOD_data 19 | 20 | 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | np.random.seed(0) 23 | DEBUG = False 24 | 25 | 26 | 27 | def batchify(fn, chunk): 28 | """Constructs a version of 'fn' that applies to smaller batches. 29 | """ 30 | if chunk is None: 31 | return fn 32 | def ret(inputs): 33 | return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) 34 | return ret 35 | 36 | 37 | def run_network(inputs, viewdirs, poses, fn, embed_fn, embeddirs_fn, embedpose_fn, netchunk=1024*64): 38 | """Prepares inputs and applies network 'fn'. 39 | """ 40 | inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) 41 | embedded = embed_fn(inputs_flat) 42 | if viewdirs is not None: 43 | input_dirs = viewdirs[:,None].expand(inputs.shape) 44 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 45 | embedded_dirs = embeddirs_fn(input_dirs_flat) 46 | 47 | 48 | input_poses = poses[:,None].expand((inputs.shape[0], inputs.shape[1], 12)) 49 | input_poses_flat = torch.reshape(input_poses, [-1, input_poses.shape[-1]]) 50 | embedded_poses = embedpose_fn(input_poses_flat) 51 | embedded = torch.cat([embedded, embedded_dirs, embedded_poses], -1) 52 | 53 | outputs_flat = batchify(fn, netchunk)(embedded) 54 | outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]]) 55 | return outputs 56 | 57 | 58 | def batchify_rays(rays_flat, chunk=1024*32, **kwargs): 59 | """Render rays in smaller minibatches to avoid OOM. 60 | """ 61 | all_ret = {} 62 | for i in range(0, rays_flat.shape[0], chunk): 63 | ret = render_rays(rays_flat[i:i+chunk], **kwargs) 64 | for k in ret: 65 | if k not in all_ret: 66 | all_ret[k] = [] 67 | all_ret[k].append(ret[k]) 68 | 69 | all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret} 70 | return all_ret 71 | 72 | 73 | def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True, 74 | near=0., far=1., 75 | use_viewdirs=False, c2w_staticcam=None, 76 | **kwargs): 77 | """Render rays 78 | Args: 79 | H: int. Height of image in pixels. 80 | W: int. Width of image in pixels. 81 | focal: float. Focal length of pinhole camera. 82 | chunk: int. Maximum number of rays to process simultaneously. Used to 83 | control maximum memory usage. Does not affect final results. 84 | rays: array of shape [2, batch_size, 3]. Ray origin and direction for 85 | each example in batch. 86 | c2w: array of shape [3, 4]. Camera-to-world transformation matrix. 87 | ndc: bool. If True, represent ray origin, direction in NDC coordinates. 88 | near: float or array of shape [batch_size]. Nearest distance for a ray. 89 | far: float or array of shape [batch_size]. Farthest distance for a ray. 90 | use_viewdirs: bool. If True, use viewing direction of a point in space in model. 91 | c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 92 | camera while using other c2w argument for viewing directions. 93 | Returns: 94 | rgb_map: [batch_size, 3]. Predicted RGB values for rays. 95 | disp_map: [batch_size]. Disparity map. Inverse of depth. 96 | acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. 97 | extras: dict with everything returned by render_rays(). 98 | """ 99 | if c2w is not None: 100 | # special case to render full image 101 | rays_o, rays_d = get_rays(H, W, K, c2w) 102 | 103 | # c2w = c2w.repeat(H, W, 1, 1) 104 | pose_1 = c2w[:,0] 105 | pose_2 = c2w[:,1] 106 | pose_3 = c2w[:,2] 107 | pose_4 = c2w[:,3] 108 | 109 | pose_1 = pose_1.repeat(H, W, 1) 110 | pose_2 = pose_2.repeat(H, W, 1) 111 | pose_3 = pose_3.repeat(H, W, 1) 112 | pose_4 = pose_4.repeat(H, W, 1) 113 | 114 | else: 115 | # use provided ray batch 116 | print(rays.shape) 117 | rays_o, rays_d, pose_1, pose_2, pose_3, pose_4 = rays 118 | 119 | if use_viewdirs: 120 | # provide ray directions as input 121 | viewdirs = rays_d 122 | if c2w_staticcam is not None: 123 | # special case to visualize effect of viewdirs 124 | rays_o, rays_d = get_rays(H, W, K, c2w_staticcam) 125 | pose_1 = c2w_staticcam[:,0] 126 | pose_2 = c2w_staticcam[:,1] 127 | pose_3 = c2w_staticcam[:,2] 128 | pose_4 = c2w_staticcam[:,3] 129 | viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) 130 | viewdirs = torch.reshape(viewdirs, [-1,3]).float() 131 | 132 | sh = rays_d.shape # [..., 3] 133 | if ndc: 134 | # for forward facing scenes 135 | rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d) 136 | 137 | # Create ray batch 138 | rays_o = torch.reshape(rays_o, [-1,3]).float() 139 | rays_d = torch.reshape(rays_d, [-1,3]).float() 140 | pose_1 = torch.reshape(pose_1, [-1,3]).float() 141 | pose_2 = torch.reshape(pose_2, [-1,3]).float() 142 | pose_3 = torch.reshape(pose_3, [-1,3]).float() 143 | pose_4 = torch.reshape(pose_4, [-1,3]).float() 144 | 145 | 146 | near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1]) 147 | rays = torch.cat([rays_o, rays_d, pose_1, pose_2, pose_3, pose_4, near, far], -1) 148 | if use_viewdirs: 149 | rays = torch.cat([rays, viewdirs], -1) 150 | 151 | # Render and reshape 152 | all_ret = batchify_rays(rays, chunk, **kwargs) 153 | for k in all_ret: 154 | k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) 155 | all_ret[k] = torch.reshape(all_ret[k], k_sh) 156 | 157 | k_extract = ['rgb_map', 'depth_map', 'acc_map', 'seg_map', 'normal_map', 'reshading_map', 'kp_map', 'edge_map'] 158 | ret_list = [all_ret[k] for k in k_extract] 159 | ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract} 160 | return ret_list + [ret_dict] 161 | 162 | 163 | def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): 164 | 165 | H, W, focal = hwf 166 | 167 | if render_factor!=0: 168 | # Render downsampled for speed 169 | H = H//render_factor 170 | W = W//render_factor 171 | focal = focal/render_factor 172 | 173 | rgbs = [] 174 | depths = [] 175 | segs = [] 176 | normals = [] 177 | reshadings = [] 178 | kps = [] 179 | edges = [] 180 | 181 | t = time.time() 182 | for i, c2w in enumerate(tqdm(render_poses)): 183 | print(i, time.time() - t) 184 | t = time.time() 185 | rgb, depth, acc, seg, normal, reshading, kp, edge, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs) 186 | # print(seg.shape) 187 | rgbs.append(rgb.cpu().numpy()) 188 | depths.append(depth.cpu().numpy()) 189 | segs.append(seg.cpu().numpy()) 190 | normals.append(normal.cpu().numpy()) 191 | reshadings.append(reshading.cpu().numpy()) 192 | kps.append(kp.cpu().numpy()) 193 | edges.append(edge.cpu().numpy()) 194 | 195 | if i==0: 196 | print(rgb.shape, depth.shape) 197 | 198 | """ 199 | if gt_imgs is not None and render_factor==0: 200 | p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) 201 | print(p) 202 | """ 203 | 204 | if savedir is not None: 205 | rgb8 = to8b(rgbs[-1]) 206 | filename = os.path.join(savedir, '{:03d}.png'.format(i)) 207 | depth_filename = os.path.join(savedir, '{:03d}_depth.npy'.format(i)) 208 | seg_filename = os.path.join(savedir, '{:03d}_seg.npy'.format(i)) 209 | normal_filename = os.path.join(savedir, '{:03d}_normal.npy'.format(i)) 210 | reshading_filename = os.path.join(savedir, '{:03d}_reshading.npy'.format(i)) 211 | kp_filename = os.path.join(savedir, '{:03d}_kp.npy'.format(i)) 212 | edge_filename = os.path.join(savedir, '{:03d}_edge.npy'.format(i)) 213 | # seg_save 214 | 215 | imageio.imwrite(filename, rgb8) 216 | np.save(depth_filename, depth.cpu().numpy()) 217 | np.save(seg_filename, seg.cpu().numpy()) 218 | np.save(normal_filename, normal.cpu().numpy()) 219 | np.save(reshading_filename, reshading.cpu().numpy()) 220 | np.save(kp_filename, kp.cpu().numpy()) 221 | np.save(edge_filename, edge.cpu().numpy()) 222 | 223 | 224 | rgbs = np.stack(rgbs, 0) 225 | depths = np.stack(depths, 0) 226 | segs = np.stack(segs, 0) 227 | normals = np.stack(normals, 0) 228 | reshadings = np.stack(reshadings, 0) 229 | kps = np.stack(kps, 0) 230 | edges = np.stack(edges, 0) 231 | 232 | return rgbs, depths, segs, normals, reshadings, kps, edges 233 | 234 | 235 | def create_nerf(args): 236 | """Instantiate NeRF's MLP model. 237 | """ 238 | embed_fn, input_ch = get_embedder(args.multires, 3, args.i_embed) 239 | 240 | input_ch_views = 0 241 | embeddirs_fn = None 242 | if args.use_viewdirs: 243 | embeddirs_fn, input_ch_views = get_embedder(args.multires_views, 3, args.i_embed) 244 | embedpose_fn, input_ch_poses = get_embedder(args.multires_poses, 12, args.i_embed) 245 | output_ch = 5 if args.N_importance > 0 else 4 246 | skips = [4] 247 | 248 | 249 | 250 | model = NeRF(D=args.netdepth, W=args.netwidth, 251 | input_ch=input_ch, output_ch=output_ch, skips=skips, 252 | input_ch_views=input_ch_views, input_ch_poses=input_ch_poses, use_viewdirs=args.use_viewdirs).to(device) 253 | grad_vars = list(model.parameters()) 254 | 255 | model_fine = None 256 | if args.N_importance > 0: 257 | model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine, 258 | input_ch=input_ch, output_ch=output_ch, skips=skips, 259 | input_ch_views=input_ch_views, input_ch_poses=input_ch_poses, use_viewdirs=args.use_viewdirs).to(device) 260 | 261 | 262 | grad_vars += list(model_fine.parameters()) 263 | 264 | network_query_fn = lambda inputs, viewdirs, poses, network_fn : run_network(inputs, viewdirs, poses, network_fn, 265 | embed_fn=embed_fn, 266 | embeddirs_fn=embeddirs_fn, 267 | embedpose_fn=embedpose_fn, 268 | netchunk=args.netchunk) 269 | 270 | # Create optimizer 271 | optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) 272 | 273 | start = 0 274 | basedir = args.basedir 275 | expname = args.expname 276 | 277 | ########################## 278 | 279 | # Load checkpoints 280 | if args.ft_path is not None and args.ft_path!='None': 281 | ckpts = [args.ft_path] 282 | else: 283 | ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f] 284 | 285 | print('Found ckpts', ckpts) 286 | if len(ckpts) > 0 and not args.no_reload: 287 | ckpt_path = ckpts[-1] 288 | print('Reloading from', ckpt_path) 289 | ckpt = torch.load(ckpt_path) 290 | 291 | start = ckpt['global_step'] 292 | optimizer.load_state_dict(ckpt['optimizer_state_dict']) 293 | 294 | # Load model 295 | model.load_state_dict(ckpt['network_fn_state_dict']) 296 | if model_fine is not None: 297 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 298 | 299 | ########################## 300 | 301 | render_kwargs_train = { 302 | 'network_query_fn' : network_query_fn, 303 | 'perturb' : args.perturb, 304 | 'N_importance' : args.N_importance, 305 | 'network_fine' : model_fine, 306 | 'N_samples' : args.N_samples, 307 | 'network_fn' : model, 308 | 'use_viewdirs' : args.use_viewdirs, 309 | 'white_bkgd' : args.white_bkgd, 310 | 'raw_noise_std' : args.raw_noise_std, 311 | } 312 | 313 | # NDC only good for LLFF-style forward facing data 314 | if args.dataset_type != 'llff' or args.no_ndc: 315 | print('Not ndc!') 316 | render_kwargs_train['ndc'] = False 317 | render_kwargs_train['lindisp'] = args.lindisp 318 | 319 | render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train} 320 | render_kwargs_test['perturb'] = False 321 | render_kwargs_test['raw_noise_std'] = 0. 322 | 323 | return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer 324 | 325 | 326 | def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False): 327 | """Transforms model's predictions to semantically meaningful values. 328 | Args: 329 | raw: [num_rays, num_samples along ray, 4]. Prediction from model. 330 | z_vals: [num_rays, num_samples along ray]. Integration time. 331 | rays_d: [num_rays, 3]. Direction of each ray. 332 | Returns: 333 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. 334 | disp_map: [num_rays]. Disparity map. Inverse of depth map. 335 | acc_map: [num_rays]. Sum of weights along each ray. 336 | weights: [num_rays, num_samples]. Weights assigned to each sampled color. 337 | depth_map: [num_rays]. Estimated distance to object. 338 | """ 339 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*dists) 340 | 341 | dists = z_vals[...,1:] - z_vals[...,:-1] 342 | dists = torch.cat([dists, torch.Tensor([1e10]).expand(dists[...,:1].shape)], -1) # [N_rays, N_samples] 343 | 344 | dists = dists * torch.norm(rays_d[...,None,:], dim=-1) 345 | 346 | rgb = torch.sigmoid(raw[...,:3]) # [N_rays, N_samples, 3] 347 | 348 | seg = raw[...,4:4+13] 349 | 350 | 351 | normal = torch.sigmoid(raw[...,17:17+3]) 352 | 353 | reshading = torch.sigmoid(raw[...,20:20+1]) 354 | kp = torch.relu(raw[...,21:21+1]) 355 | 356 | edge = torch.sigmoid(raw[...,22:22+1]) 357 | 358 | noise = 0. 359 | if raw_noise_std > 0.: 360 | noise = torch.randn(raw[...,3].shape) * raw_noise_std 361 | 362 | # Overwrite randomly sampled data if pytest 363 | if pytest: 364 | np.random.seed(0) 365 | noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std 366 | noise = torch.Tensor(noise) 367 | 368 | alpha = raw2alpha(raw[...,3] + noise, dists) # [N_rays, N_samples] 369 | # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) 370 | weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] 371 | rgb_map = torch.sum(weights[...,None] * rgb, -2) # [N_rays, 3] 372 | 373 | seg_map = torch.sum(weights[...,None] * seg, -2) # [N_rays, 1] 374 | 375 | seg_map = torch.nn.functional.softmax(seg_map, dim=1) 376 | 377 | seg_map = torch.clamp(seg_map,0,1) 378 | 379 | normal_map = torch.sum(weights[...,None] * normal, -2) 380 | 381 | reshading_map = torch.sum(weights[...,None] * reshading, -2) 382 | 383 | kp_map = torch.sum(weights[...,None] * kp, -2) 384 | 385 | edge_map = torch.sum(weights[...,None] * edge, -2) 386 | 387 | depth_map = torch.sum(weights * z_vals, -1) 388 | disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), depth_map / torch.sum(weights, -1)) 389 | acc_map = torch.sum(weights, -1) 390 | 391 | if white_bkgd: 392 | rgb_map = rgb_map + (1.-acc_map[...,None]) 393 | 394 | return rgb_map, disp_map, acc_map, weights, depth_map, seg_map, normal_map, reshading_map, kp_map, edge_map 395 | 396 | 397 | def render_rays(ray_batch, 398 | network_fn, 399 | network_query_fn, 400 | N_samples, 401 | retraw=False, 402 | lindisp=False, 403 | perturb=0., 404 | N_importance=0, 405 | network_fine=None, 406 | white_bkgd=False, 407 | raw_noise_std=0., 408 | verbose=False, 409 | pytest=False): 410 | """Volumetric rendering. 411 | Args: 412 | ray_batch: array of shape [batch_size, ...]. All information necessary 413 | for sampling along a ray, including: ray origin, ray direction, min 414 | dist, max dist, and unit-magnitude viewing direction. 415 | network_fn: function. Model for predicting RGB and density at each point 416 | in space. 417 | network_query_fn: function used for passing queries to network_fn. 418 | N_samples: int. Number of different times to sample along each ray. 419 | retraw: bool. If True, include model's raw, unprocessed predictions. 420 | lindisp: bool. If True, sample linearly in inverse depth rather than in depth. 421 | perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified 422 | random points in time. 423 | N_importance: int. Number of additional times to sample along each ray. 424 | These samples are only passed to network_fine. 425 | network_fine: "fine" network with same spec as network_fn. 426 | white_bkgd: bool. If True, assume a white background. 427 | raw_noise_std: ... 428 | verbose: bool. If True, print more debugging info. 429 | Returns: 430 | rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. 431 | disp_map: [num_rays]. Disparity map. 1 / depth. 432 | acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. 433 | raw: [num_rays, num_samples, 4]. Raw predictions from model. 434 | rgb0: See rgb_map. Output for coarse model. 435 | disp0: See disp_map. Output for coarse model. 436 | acc0: See acc_map. Output for coarse model. 437 | z_std: [num_rays]. Standard deviation of distances along ray for each 438 | sample. 439 | """ 440 | N_rays = ray_batch.shape[0] 441 | # print(ray_batch.shape[1], 23) 442 | rays_o, rays_d, pose_1, pose_2, pose_3, pose_4 = ray_batch[:,0:3], ray_batch[:,3:6], ray_batch[:,6:9], ray_batch[:,9:12], ray_batch[:,12:15], ray_batch[:,15:18] # [N_rays, 3] each 443 | viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 20 else None 444 | bounds = torch.reshape(ray_batch[...,18:20], [-1,1,2]) 445 | near, far = bounds[...,0], bounds[...,1] # [-1,1] 446 | 447 | t_vals = torch.linspace(0., 1., steps=N_samples) 448 | if not lindisp: 449 | z_vals = near * (1.-t_vals) + far * (t_vals) 450 | else: 451 | z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) 452 | 453 | z_vals = z_vals.expand([N_rays, N_samples]) 454 | 455 | if perturb > 0.: 456 | # get intervals between samples 457 | mids = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 458 | upper = torch.cat([mids, z_vals[...,-1:]], -1) 459 | lower = torch.cat([z_vals[...,:1], mids], -1) 460 | # stratified samples in those intervals 461 | t_rand = torch.rand(z_vals.shape) 462 | 463 | # Pytest, overwrite u with numpy's fixed random numbers 464 | if pytest: 465 | np.random.seed(0) 466 | t_rand = np.random.rand(*list(z_vals.shape)) 467 | t_rand = torch.Tensor(t_rand) 468 | 469 | z_vals = lower + (upper - lower) * t_rand 470 | 471 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3] 472 | 473 | 474 | pose = torch.cat((pose_1, pose_2, pose_3, pose_4), 1) 475 | raw = network_query_fn(pts, viewdirs, pose, network_fn) 476 | rgb_map, disp_map, acc_map, weights, depth_map, seg_map, normal_map, reshading_map, kp_map, edge_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 477 | 478 | if N_importance > 0: 479 | 480 | rgb_map_0, depth_map_0, acc_map_0, seg_map_0, normal_map_0, reshading_map_0, kp_map_0, edge_map_0 = rgb_map, depth_map, acc_map, seg_map, normal_map, reshading_map, kp_map, edge_map 481 | 482 | z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1]) 483 | z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest) 484 | z_samples = z_samples.detach() 485 | 486 | z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) 487 | pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3] 488 | 489 | run_fn = network_fn if network_fine is None else network_fine 490 | raw = network_query_fn(pts, viewdirs, pose, run_fn) 491 | 492 | rgb_map, disp_map, acc_map, weights, depth_map, seg_map, normal_map, reshading_map, kp_map, edge_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest) 493 | 494 | ret = {'rgb_map' : rgb_map, 'depth_map' : depth_map, 'acc_map' : acc_map, 'seg_map' : seg_map, 'normal_map' : normal_map, 'reshading_map' : reshading_map, 'kp_map' : kp_map, 'edge_map' : edge_map} 495 | if retraw: 496 | ret['raw'] = raw 497 | if N_importance > 0: 498 | ret['rgb0'] = rgb_map_0 499 | ret['depth0'] = depth_map_0 500 | ret['acc0'] = acc_map_0 501 | ret['seg0'] = seg_map_0 502 | ret['normal0'] = normal_map_0 503 | ret['reshading0'] = reshading_map_0 504 | ret['kp0'] = kp_map_0 505 | ret['edge0'] = edge_map_0 506 | ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] 507 | 508 | for k in ret: 509 | if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: 510 | print(f"! [Numerical Error] {k} contains nan or inf.") 511 | 512 | return ret 513 | 514 | 515 | def config_parser(): 516 | 517 | import configargparse 518 | parser = configargparse.ArgumentParser() 519 | parser.add_argument('--config', is_config_file=True, 520 | help='config file path') 521 | parser.add_argument("--expname", type=str, 522 | help='experiment name') 523 | parser.add_argument("--basedir", type=str, default='./logs/', 524 | help='where to store ckpts and logs') 525 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 526 | help='input data directory') 527 | 528 | # training options 529 | parser.add_argument("--netdepth", type=int, default=8, 530 | help='layers in network') 531 | parser.add_argument("--netwidth", type=int, default=256, 532 | help='channels per layer') 533 | parser.add_argument("--netdepth_fine", type=int, default=8, 534 | help='layers in fine network') 535 | parser.add_argument("--netwidth_fine", type=int, default=256, 536 | help='channels per layer in fine network') 537 | parser.add_argument("--N_rand", type=int, default=32*32*4, 538 | help='batch size (number of random rays per gradient step)') 539 | parser.add_argument("--lrate", type=float, default=5e-4, 540 | help='learning rate') 541 | parser.add_argument("--lrate_decay", type=int, default=250, 542 | help='exponential learning rate decay (in 1000 steps)') 543 | parser.add_argument("--chunk", type=int, default=1024*16, # this has been changed from 1024*32 544 | help='number of rays processed in parallel, decrease if running out of memory') 545 | parser.add_argument("--netchunk", type=int, default=1024*32, # this has been changed from 1024*64 546 | help='number of pts sent through network in parallel, decrease if running out of memory') 547 | parser.add_argument("--no_batching", action='store_true', 548 | help='only take random rays from 1 image at a time') 549 | parser.add_argument("--no_reload", action='store_true', 550 | help='do not reload weights from saved ckpt') 551 | parser.add_argument("--ft_path", type=str, default=None, 552 | help='specific weights npy file to reload for coarse network') 553 | 554 | # rendering options 555 | parser.add_argument("--N_samples", type=int, default=64, 556 | help='number of coarse samples per ray') 557 | parser.add_argument("--N_importance", type=int, default=0, 558 | help='number of additional fine samples per ray') 559 | parser.add_argument("--perturb", type=float, default=1., 560 | help='set to 0. for no jitter, 1. for jitter') 561 | parser.add_argument("--use_viewdirs", action='store_true', 562 | help='use full 5D input instead of 3D') 563 | parser.add_argument("--i_embed", type=int, default=0, 564 | help='set 0 for default positional encoding, -1 for none') 565 | parser.add_argument("--multires", type=int, default=10, 566 | help='log2 of max freq for positional encoding (3D location)') 567 | parser.add_argument("--multires_views", type=int, default=4, 568 | help='log2 of max freq for positional encoding (2D direction)') 569 | parser.add_argument("--multires_poses", type=int, default=10, 570 | help='log2 of max freq for positional encoding (4D pose)') 571 | parser.add_argument("--raw_noise_std", type=float, default=0., 572 | help='std dev of noise added to regularize sigma_a output, 1e0 recommended') 573 | 574 | parser.add_argument("--render_only", action='store_true', 575 | help='do not optimize, reload weights and render out render_poses path') 576 | parser.add_argument("--render_test", action='store_true', 577 | help='render the test set instead of render_poses path') 578 | parser.add_argument("--render_factor", type=int, default=0, 579 | help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') 580 | 581 | # training options 582 | parser.add_argument("--precrop_iters", type=int, default=0, 583 | help='number of steps to train on central crops') 584 | parser.add_argument("--precrop_frac", type=float, 585 | default=.5, help='fraction of img taken for central crops') 586 | 587 | # dataset options 588 | parser.add_argument("--dataset_type", type=str, default='llff', 589 | help='options: llff / blender / deepvoxels') 590 | parser.add_argument("--testskip", type=int, default=8, 591 | help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') 592 | 593 | ## deepvoxels flags 594 | parser.add_argument("--shape", type=str, default='greek', 595 | help='options : armchair / cube / greek / vase') 596 | 597 | ## blender flags 598 | parser.add_argument("--white_bkgd", action='store_true', 599 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 600 | parser.add_argument("--half_res", action='store_true', 601 | help='load blender synthetic data at 400x400 instead of 800x800') 602 | 603 | ## llff flags 604 | parser.add_argument("--factor", type=int, default=8, 605 | help='downsample factor for LLFF images') 606 | parser.add_argument("--no_ndc", action='store_true', 607 | help='do not use normalized device coordinates (set for non-forward facing scenes)') 608 | parser.add_argument("--lindisp", action='store_true', 609 | help='sampling linearly in disparity rather than depth') 610 | parser.add_argument("--spherify", action='store_true', 611 | help='set for spherical 360 scenes') 612 | parser.add_argument("--llffhold", type=int, default=8, 613 | help='will take every 1/N images as LLFF test set, paper uses 8') 614 | 615 | # logging/saving options 616 | parser.add_argument("--i_print", type=int, default=100, 617 | help='frequency of console printout and metric loggin') 618 | parser.add_argument("--i_img", type=int, default=500, 619 | help='frequency of tensorboard image logging') 620 | parser.add_argument("--i_weights", type=int, default=10000, 621 | help='frequency of weight ckpt saving') 622 | parser.add_argument("--i_testset", type=int, default=200000, 623 | help='frequency of testset saving') 624 | parser.add_argument("--i_video", type=int, default=50000, 625 | help='frequency of render_poses video saving') 626 | 627 | return parser 628 | 629 | 630 | def train(): 631 | 632 | parser = config_parser() 633 | args = parser.parse_args() 634 | 635 | # Load data 636 | K = None 637 | if args.dataset_type == 'llff': 638 | images, poses, bds, render_poses, i_test, segs, normals, reshadings, kps, edges = load_llff_data(args.datadir, args.factor, 639 | recenter=True, bd_factor=.75, 640 | spherify=args.spherify) 641 | print(images.shape) 642 | hwf = poses[0,:3,-1] 643 | poses = poses[:,:3,:4] 644 | print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir) 645 | if not isinstance(i_test, list): 646 | i_test = [i_test] 647 | 648 | if args.llffhold > 0: 649 | print('Auto LLFF holdout,', args.llffhold) 650 | i_val = np.arange(images.shape[0])[::args.llffhold] 651 | i_train = np.array([i for i in np.arange(int(images.shape[0])) if 652 | (i not in i_test and i not in i_val)]) 653 | i_test = i_val 654 | 655 | print('DEFINING BOUNDS') 656 | if args.no_ndc: 657 | near = np.ndarray.min(bds) * .9 658 | far = np.ndarray.max(bds) * 1. 659 | near = 0.1 660 | far = 10 661 | else: 662 | near = 0. 663 | far = 1. 664 | print('NEAR FAR', near, far) 665 | 666 | elif args.dataset_type == 'blender': 667 | images, poses, render_poses, hwf, i_split = load_blender_data(args.datadir, args.half_res, args.testskip) 668 | print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir) 669 | i_train, i_val, i_test = i_split 670 | 671 | near = 2. 672 | far = 6. 673 | 674 | if args.white_bkgd: 675 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 676 | else: 677 | images = images[...,:3] 678 | 679 | elif args.dataset_type == 'LINEMOD': 680 | images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(args.datadir, args.half_res, args.testskip) 681 | print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}') 682 | print(f'[CHECK HERE] near: {near}, far: {far}.') 683 | i_train, i_val, i_test = i_split 684 | 685 | if args.white_bkgd: 686 | images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:]) 687 | else: 688 | images = images[...,:3] 689 | 690 | elif args.dataset_type == 'deepvoxels': 691 | 692 | images, poses, render_poses, hwf, i_split = load_dv_data(scene=args.shape, 693 | basedir=args.datadir, 694 | testskip=args.testskip) 695 | 696 | print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir) 697 | i_train, i_val, i_test = i_split 698 | 699 | hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1)) 700 | near = hemi_R-1. 701 | far = hemi_R+1. 702 | 703 | else: 704 | print('Unknown dataset type', args.dataset_type, 'exiting') 705 | return 706 | 707 | # Cast intrinsics to right types 708 | H, W, focal = hwf 709 | H, W = int(H), int(W) 710 | hwf = [H, W, focal] 711 | 712 | if K is None: 713 | K = np.array([ 714 | [focal, 0, 0.5*W], 715 | [0, focal, 0.5*H], 716 | [0, 0, 1] 717 | ]) 718 | 719 | if args.render_test: 720 | render_poses = np.array(poses[i_test]) 721 | 722 | # Create log dir and copy the config file 723 | basedir = args.basedir 724 | expname = args.expname 725 | os.makedirs(os.path.join(basedir, expname), exist_ok=True) 726 | f = os.path.join(basedir, expname, 'args.txt') 727 | with open(f, 'w') as file: 728 | for arg in sorted(vars(args)): 729 | attr = getattr(args, arg) 730 | file.write('{} = {}\n'.format(arg, attr)) 731 | if args.config is not None: 732 | f = os.path.join(basedir, expname, 'config.txt') 733 | with open(f, 'w') as file: 734 | file.write(open(args.config, 'r').read()) 735 | 736 | # Create nerf model 737 | render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args) 738 | global_step = start 739 | 740 | bds_dict = { 741 | 'near' : near, 742 | 'far' : far, 743 | } 744 | render_kwargs_train.update(bds_dict) 745 | render_kwargs_test.update(bds_dict) 746 | 747 | # Move testing data to GPU 748 | render_poses = torch.Tensor(render_poses).to(device) 749 | 750 | # Short circuit if only rendering out from trained model 751 | if args.render_only: 752 | print('RENDER ONLY') 753 | with torch.no_grad(): 754 | if args.render_test: 755 | # render_test switches to test poses 756 | images = images[i_test] 757 | else: 758 | # Default is smoother render_poses path 759 | images = None 760 | 761 | testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)) 762 | os.makedirs(testsavedir, exist_ok=True) 763 | print('test poses shape', render_poses.shape) 764 | 765 | rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) 766 | print('Done rendering', testsavedir) 767 | imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) 768 | 769 | return 770 | 771 | # Prepare raybatch tensor if batching random rays 772 | N_rand = args.N_rand 773 | use_batching = not args.no_batching 774 | if use_batching: 775 | # For random ray batching 776 | print('get rays') 777 | rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3] 778 | print('done, concats') 779 | 780 | 781 | # Special modeling for SN 782 | 783 | pose_c1s = poses[:,:3,0] 784 | pose_c1s = pose_c1s[:, np.newaxis, np.newaxis, :] 785 | pose_c1s = np.tile(pose_c1s, (1, H, W, 1)) 786 | 787 | pose_c2s = poses[:,:3,1] 788 | pose_c2s = pose_c2s[:, np.newaxis, np.newaxis, :] 789 | pose_c2s = np.tile(pose_c2s, (1, H, W, 1)) 790 | 791 | pose_c3s = poses[:,:3,2] 792 | pose_c3s = pose_c3s[:, np.newaxis, np.newaxis, :] 793 | pose_c3s = np.tile(pose_c3s, (1, H, W, 1)) 794 | 795 | pose_c4s = poses[:,:3,3] 796 | pose_c4s = pose_c4s[:, np.newaxis, np.newaxis, :] 797 | pose_c4s = np.tile(pose_c4s, (1, H, W, 1)) 798 | 799 | 800 | 801 | 802 | rays_rgb = np.concatenate([rays, pose_c1s[:,None], pose_c2s[:,None], pose_c3s[:,None], pose_c4s[:,None], images[:,None], np.tile(segs[:,None], (1,1,1,3)), normals[:,None], np.tile(reshadings[:,None], (1,1,1,3)), np.tile(kps[:,None], (1,1,1,3)), np.tile(edges[:,None], (1,1,1,3))], 1) # [N, ro+rd+rgb+seg+normal+reshading+kp+edge, H, W, 3] 803 | rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb+seg+normal+reshading+kp+edge, 3] 804 | rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only 805 | rays_rgb = np.reshape(rays_rgb, [-1,12,3]) # [(N-1)*H*W, ro+rd+rgb+seg+normal+reshading+kp+edge, 3] 806 | rays_rgb = rays_rgb.astype(np.float32) 807 | print('shuffle rays') 808 | np.random.shuffle(rays_rgb) 809 | 810 | print('done') 811 | i_batch = 0 812 | 813 | # Move training data to GPU 814 | if use_batching: 815 | images = torch.Tensor(images).to(device) 816 | poses = torch.Tensor(poses).to(device) 817 | if use_batching: 818 | rays_rgb = torch.Tensor(rays_rgb).to(device) 819 | 820 | 821 | N_iters = 200000 + 1 822 | print('Begin') 823 | print('TRAIN views are', i_train) 824 | print('TEST views are', i_test) 825 | print('VAL views are', i_val) 826 | 827 | # Summary writers 828 | # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname)) 829 | 830 | start = start + 1 831 | for i in trange(start, N_iters): 832 | time0 = time.time() 833 | 834 | # Sample random ray batch 835 | if use_batching: 836 | # Random over all images 837 | batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] 838 | batch = torch.transpose(batch, 0, 1) 839 | batch_rays, target_s, target_seg, target_normal, target_reshading, target_kp, target_edge = batch[:2+4], batch[2+4], batch[3+4], batch[4+4], batch[5+4], batch[6+4], batch[7+4] 840 | 841 | i_batch += N_rand 842 | if i_batch >= rays_rgb.shape[0]: 843 | print("Shuffle data after an epoch!") 844 | rand_idx = torch.randperm(rays_rgb.shape[0]) 845 | rays_rgb = rays_rgb[rand_idx] 846 | i_batch = 0 847 | 848 | else: 849 | # Random from one image 850 | img_i = np.random.choice(i_train) 851 | target = images[img_i] 852 | target = torch.Tensor(target).to(device) 853 | 854 | seg_gt = segs[img_i] 855 | seg_gt = torch.Tensor(seg_gt).to(device) 856 | 857 | normal_gt = normals[img_i] 858 | normal_gt = torch.Tensor(normal_gt).to(device) 859 | 860 | reshading_gt = reshadings[img_i] 861 | reshading_gt = torch.Tensor(reshading_gt).to(device) 862 | 863 | kp_gt = kps[img_i] 864 | kp_gt = torch.Tensor(kp_gt).to(device) 865 | 866 | edge_gt = edges[img_i] 867 | edge_gt = torch.Tensor(edge_gt).to(device) 868 | 869 | pose = poses[img_i, :3,:4] 870 | 871 | if N_rand is not None: 872 | rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose)) # (H, W, 3), (H, W, 3) 873 | 874 | if i < args.precrop_iters: 875 | dH = int(H//2 * args.precrop_frac) 876 | dW = int(W//2 * args.precrop_frac) 877 | coords = torch.stack( 878 | torch.meshgrid( 879 | torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 880 | torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) 881 | ), -1) 882 | if i == start: 883 | print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") 884 | else: 885 | coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) 886 | 887 | coords = torch.reshape(coords, [-1,2]) # (H * W, 2) 888 | select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False) # (N_rand,) 889 | select_coords = coords[select_inds].long() # (N_rand, 2) 890 | rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 891 | rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 892 | batch_rays = torch.stack([rays_o, rays_d], 0) 893 | target_s = target[select_coords[:, 0], select_coords[:, 1]] # (N_rand, 3) 894 | target_seg = seg_gt[select_coords[:, 0], select_coords[:, 1]] 895 | target_normal = normal_gt[select_coords[:, 0], select_coords[:, 1]] 896 | target_reshading = reshading_gt[select_coords[:, 0], select_coords[:, 1]] 897 | target_kp = kp_gt[select_coords[:, 0], select_coords[:, 1]] 898 | target_edge = edge_gt[select_coords[:, 0], select_coords[:, 1]] 899 | 900 | ##### Core optimization loop ##### 901 | rgb, disp, acc, seg, normal, reshading, kp, edge, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays, 902 | verbose=i < 10, retraw=True, 903 | **render_kwargs_train) 904 | 905 | optimizer.zero_grad() 906 | 907 | seg_crit = torch.nn.CrossEntropyLoss() 908 | 909 | reshading_crit = torch.nn.L1Loss() 910 | one_hot = target_seg[:, 0].long() 911 | img_loss = img2mse(rgb, target_s) 912 | 913 | seg_loss = seg_crit(seg, one_hot) 914 | 915 | normal_loss = img2mse(normal, target_normal) 916 | reshading_loss = reshading_crit(reshading, target_reshading[:, 0:1]) 917 | kp_loss = reshading_crit(kp, target_kp[:, 0:1]) 918 | edge_loss = reshading_crit(edge, target_edge[:, 0:1]) 919 | 920 | trans = extras['raw'][...,-1] 921 | loss = img_loss + normal_loss + reshading_loss*0.1 + edge_loss*0.4 + kp_loss*2 + seg_loss*0.04 922 | 923 | psnr = mse2psnr(img_loss) 924 | 925 | if 'rgb0' in extras: 926 | img_loss0 = img2mse(extras['rgb0'], target_s) 927 | seg_loss0 = seg_crit(extras['seg0'], one_hot) 928 | normal_loss0 = img2mse(extras['normal0'], target_normal) 929 | reshading_loss0 = reshading_crit(extras['reshading0'], target_reshading[:, 0:1]) 930 | kp_loss0 = reshading_crit(extras['kp0'], target_kp[:, 0:1]) 931 | edge_loss0 = reshading_crit(extras['edge0'], target_edge[:, 0:1]) 932 | 933 | loss = loss + img_loss0 + normal_loss0 + reshading_loss0*0.1 + edge_loss0*0.4 + kp_loss0*2 + seg_loss0*0.04 934 | 935 | psnr0 = mse2psnr(img_loss0) 936 | 937 | loss.backward() 938 | optimizer.step() 939 | 940 | # NOTE: IMPORTANT! 941 | ### update learning rate ### 942 | decay_rate = 0.1 943 | decay_steps = args.lrate_decay * 1000 944 | new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) 945 | for param_group in optimizer.param_groups: 946 | param_group['lr'] = new_lrate 947 | ################################ 948 | 949 | dt = time.time()-time0 950 | 951 | # Rest is logging 952 | 953 | if i%50==0: 954 | print(loss.item(), img_loss.item(), seg_loss.item(), kp_loss.item(), edge_loss.item(), img_loss0.item(), seg_loss0.item(), kp_loss0.item(), edge_loss0.item()) 955 | 956 | 957 | if i%args.i_weights==0: 958 | path = os.path.join(basedir, expname, '{:06d}.tar'.format(i)) 959 | torch.save({ 960 | 'global_step': global_step, 961 | 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), 962 | 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), 963 | 'optimizer_state_dict': optimizer.state_dict(), 964 | }, path) 965 | print('Saved checkpoints at', path) 966 | 967 | if i%args.i_testset==0 and i > 0: 968 | testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i)) 969 | os.makedirs(testsavedir, exist_ok=True) 970 | print('test poses shape', poses[i_test].shape) 971 | with torch.no_grad(): 972 | render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir) 973 | print('Saved test set') 974 | 975 | 976 | 977 | if i%args.i_print==0: 978 | tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") 979 | with open("with_seg.txt", 'a') as file_object: 980 | file_object.write("[TRAIN] Iter: {"+str(i)+"} Loss: {"+str(loss.item())+"} PSNR: {"+str(psnr.item())+"}\n") 981 | 982 | global_step += 1 983 | 984 | 985 | if __name__=='__main__': 986 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 987 | 988 | train() 989 | --------------------------------------------------------------------------------