├── models ├── __init__.py ├── quaternion_utils.py ├── sh.py ├── tensorBase.py └── model.py ├── .gitattributes ├── dataLoader ├── __pycache__ │ ├── llff.cpython-310.pyc │ ├── llff.cpython-38.pyc │ ├── nsvf.cpython-310.pyc │ ├── nsvf.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── blender.cpython-310.pyc │ ├── blender.cpython-38.pyc │ ├── __init__.cpython-310.pyc │ ├── ray_utils.cpython-310.pyc │ ├── ray_utils.cpython-38.pyc │ ├── tankstemple.cpython-38.pyc │ ├── tankstemple.cpython-310.pyc │ ├── your_own_data.cpython-38.pyc │ └── your_own_data.cpython-310.pyc ├── __init__.py ├── blender.py ├── your_own_data.py ├── nsvf.py ├── tankstemple.py ├── llff.py ├── ray_utils.py └── colmap2nerf.py ├── configs ├── truck.txt ├── lego.txt ├── wineholder.txt └── your_own_data.txt ├── LICENSE ├── README.md ├── utils.py ├── renderer.py ├── opt.py ├── extra ├── compute_metrics.py └── auto_run_paramsets.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /dataLoader/__pycache__/llff.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/llff.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/llff.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/llff.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/nsvf.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/nsvf.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/nsvf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/nsvf.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/blender.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/blender.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/blender.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/blender.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/ray_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/ray_utils.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/ray_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/ray_utils.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/tankstemple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/tankstemple.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/tankstemple.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/tankstemple.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/your_own_data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/your_own_data.cpython-38.pyc -------------------------------------------------------------------------------- /dataLoader/__pycache__/your_own_data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/your_own_data.cpython-310.pyc -------------------------------------------------------------------------------- /dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from .llff import LLFFDataset 2 | from .blender import BlenderDataset 3 | from .nsvf import NSVF 4 | from .tankstemple import TanksTempleDataset 5 | from .your_own_data import YourOwnDataset 6 | 7 | 8 | 9 | dataset_dict = {'blender': BlenderDataset, 10 | 'llff':LLFFDataset, 11 | 'tankstemple':TanksTempleDataset, 12 | 'nsvf':NSVF, 13 | 'own_data':YourOwnDataset} -------------------------------------------------------------------------------- /configs/truck.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = tankstemple 3 | datadir = ../datasets/TanksAndTemple/Truck 4 | expname = nrff_truck 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | model_name = NRFF 20 | fea2denseAct = softplus 21 | 22 | L1_weight_inital = 8e-5 23 | L1_weight_rest = 4e-5 24 | -------------------------------------------------------------------------------- /configs/lego.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ../datasets/nerf_synthetic/lego 4 | expname = nrff_lego 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | model_name = NRFF 20 | fea2denseAct = softplus 21 | 22 | L1_weight_inital = 8e-5 23 | L1_weight_rest = 4e-5 24 | rm_weight_mask_thre = 1e-4 25 | -------------------------------------------------------------------------------- /configs/wineholder.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = blender 3 | datadir = ./data/Synthetic_NSVF/Wineholder 4 | expname = tensorf_Wineholder_VM 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | model_name = NRFF 20 | 21 | fea2denseAct = softplus 22 | 23 | L1_weight_inital = 8e-5 24 | L1_weight_rest = 4e-5 25 | rm_weight_mask_thre = 1e-4 26 | -------------------------------------------------------------------------------- /configs/your_own_data.txt: -------------------------------------------------------------------------------- 1 | 2 | dataset_name = own_data 3 | datadir = ./data/xxx 4 | expname = tensorf_xxx_VM 5 | basedir = ./log 6 | 7 | n_iters = 30000 8 | batch_size = 4096 9 | 10 | N_voxel_init = 2097156 # 128**3 11 | N_voxel_final = 27000000 # 300**3 12 | upsamp_list = [2000,3000,4000,5500,7000] 13 | update_AlphaMask_list = [2000,4000] 14 | 15 | N_vis = 5 16 | vis_every = 10000 17 | 18 | render_test = 1 19 | 20 | n_lamb_sigma = [16,16,16] 21 | n_lamb_sh = [48,48,48] 22 | model_name = TensorVMSplit 23 | 24 | 25 | shadingMode = MLP_Fea 26 | fea2denseAct = softplus 27 | 28 | view_pe = 2 29 | fea_pe = 2 30 | 31 | view_pe = 2 32 | fea_pe = 2 33 | 34 | TV_weight_density = 0.1 35 | TV_weight_app = 0.01 36 | 37 | rm_weight_mask_thre = 1e-4 38 | 39 | ## please uncomment following configuration if hope to training on cp model 40 | #model_name = TensorCP 41 | #n_lamb_sigma = [96] 42 | #n_lamb_sh = [288] 43 | #N_voxel_final = 125000000 # 500**3 44 | #L1_weight_inital = 1e-5 45 | #L1_weight_rest = 1e-5 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /models/quaternion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | 6 | def quaternion_product(p, q): 7 | p_r = p[..., [0]] 8 | p_i = p[..., 1:] 9 | q_r = q[..., [0]] 10 | q_i = q[..., 1:] 11 | 12 | out_r = p_r * q_r - (p_i * q_i).sum(dim=-1) 13 | out_i = p_r * q_i + q_r * p_i + torch.linalg.cross(p_i, q_i, dim=-1) 14 | 15 | return torch.cat([out_r, out_i], dim=-1) 16 | 17 | def quaternion_inverse(p): 18 | p_r = p[..., [0]] 19 | p_i = -p[..., 1:] 20 | 21 | return torch.cat([p_r, p_i], dim=-1) 22 | 23 | def quaternion_rotate(p, q): 24 | q_inv = quaternion_inverse(q) 25 | 26 | qp = quaternion_product(q, p) 27 | out = quaternion_product(qp, q_inv) 28 | return out 29 | 30 | def build_q(vec, angle): 31 | out_r = torch.cos(angle / 2) 32 | out_i = torch.sin(angle / 2) * vec 33 | 34 | return torch.cat([out_r, out_i], dim=-1) 35 | 36 | 37 | def cartesian2quaternion(x): 38 | zeros_ = x.new_zeros([*x.shape[:-1], 1]) 39 | return torch.cat([zeros_, x], dim=-1) 40 | 41 | 42 | def spherical2cartesian(theta, phi): 43 | x = torch.cos(phi) * torch.sin(theta) 44 | y = torch.sin(phi) * torch.sin(theta) 45 | z = torch.cos(theta) 46 | 47 | return [x, y, z] 48 | 49 | def init_predefined_omega(n_theta, n_phi): 50 | theta_list = torch.linspace(0, np.pi, n_theta) 51 | phi_list = torch.linspace(0, np.pi*2, n_phi) 52 | 53 | out_omega = [] 54 | out_omega_lambda = [] 55 | out_omega_mu = [] 56 | 57 | for i in range(n_theta): 58 | theta = theta_list[i].view(1, 1) 59 | 60 | for j in range(n_phi): 61 | phi = phi_list[j].view(1, 1) 62 | 63 | omega = spherical2cartesian(theta, phi) 64 | omega = torch.stack(omega, dim=-1).view(1, 3) 65 | 66 | omega_lambda = spherical2cartesian(theta+np.pi/2, phi) 67 | omega_lambda = torch.stack(omega_lambda, dim=-1).view(1, 3) 68 | 69 | p = cartesian2quaternion(omega_lambda) 70 | q = build_q(omega, torch.tensor(np.pi/2).view(1, 1)) 71 | omega_mu = quaternion_rotate(p, q)[..., 1:] 72 | 73 | out_omega.append(omega) 74 | out_omega_lambda.append(omega_lambda) 75 | out_omega_mu.append(omega_mu) 76 | 77 | 78 | out_omega = torch.stack(out_omega, dim=0) 79 | out_omega_lambda = torch.stack(out_omega_lambda, dim=0) 80 | out_omega_mu = torch.stack(out_omega_mu, dim=0) 81 | 82 | return out_omega, out_omega_lambda, out_omega_mu 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NRFF 2 | 3 | ### [Project page](https://imkanghan.github.io/projects/NRFF/main) | [Paper](https://arxiv.org/abs/2303.03808) 4 | 5 | This repository is an implementation of the view synthesis method described in the paper "Multiscale Tensor Decomposition and Rendering Equation Encoding for View Synthesis", CVPR 2023. 6 | 7 | [Kang Han](https://imkanghan.github.io/)1, [Wei Xiang](https://scholars.latrobe.edu.au/wxiang)2 8 | 9 | 1James Cook University, 2La Trobe University 10 | 11 | ## Abstract 12 | Rendering novel views from captured multi-view images has made considerable progress since the emergence of the neural radiance field. This paper aims to further advance the quality of view synthesis by proposing a novel approach dubbed the neural radiance feature field (NRFF). We first propose a multiscale tensor decomposition scheme to organize learnable features so as to represent scenes from coarse to fine scales. We demonstrate many benefits of the proposed multiscale representation, including more accurate scene shape and appearance reconstruction, and faster convergence compared with the single-scale representation. Instead of encoding view directions to model view-dependent effects, we further propose to encode the rendering equation in the feature space by employing the anisotropic spherical Gaussian mixture predicted from the proposed multiscale representation. The proposed NRFF improves state-of-the-art rendering results by over 1 dB in PSNR on both the NeRF and NSVF synthetic datasets. A significant improvement has also been observed on the real-world Tanks \& Temples dataset. 13 | 14 | ## Installation 15 | 16 | This implementation is based on [PyTorch](https://pytorch.org/) and [TensoRF](https://github.com/apchenstu/TensoRF). You can create a virtual environment using Anaconda by running 17 | 18 | ``` 19 | conda create -n nrff python=3.8 20 | conda activate nrff 21 | pip3 install torch torchvision 22 | pip3 install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia 23 | ``` 24 | 25 | ## Dataset 26 | Please download one of the following datasets: 27 | 28 | [NeRF-synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 29 | 30 | [NSVF-synthetic](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip) 31 | 32 | [Tanks & Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip) 33 | 34 | ## Training 35 | Specify the path of the data in configs/lego.txt and run 36 | ``` 37 | python train.py --config configs/lego.txt 38 | ``` 39 | 40 | ## Rendering 41 | ``` 42 | python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --render_only 1 --render_test 1 43 | ``` 44 | 45 | ## Citation 46 | If you find this code useful, please cite: 47 | 48 | @inproceedings{han2023nrff, 49 | author={Han, Kang and Xiang, Wei}, 50 | title={Multiscale Tensor Decomposition and Rendering Equation Encoding for View Synthesis}, 51 | booktitle={The IEEE / CVF Computer Vision and Pattern Recognition Conference}, 52 | pages={4232--4241}, 53 | year={2023} 54 | } 55 | 56 | ## Acknowledgements 57 | 58 | Thanks to the awesome neural rendering repositories of [TensoRF](https://github.com/apchenstu/TensoRF) and [Instand-NGP](https://github.com/NVlabs/instant-ngp). -------------------------------------------------------------------------------- /dataLoader/blender.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class BlenderDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.img_wh = (int(800/downsample),int(800/downsample)) 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [2.0,6.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = self.img_wh 45 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 46 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh 47 | 48 | 49 | # ray directions for all pixels, same for all images (same H, W, focal) 50 | self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3) 51 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 52 | self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float() 53 | 54 | self.image_paths = [] 55 | self.poses = [] 56 | self.all_rays = [] 57 | self.all_rgbs = [] 58 | self.all_masks = [] 59 | self.all_depth = [] 60 | self.downsample=1.0 61 | 62 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 63 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 64 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 65 | 66 | frame = self.meta['frames'][i] 67 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 68 | c2w = torch.FloatTensor(pose) 69 | self.poses += [c2w] 70 | 71 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 72 | self.image_paths += [image_path] 73 | img = Image.open(image_path) 74 | 75 | if self.downsample!=1.0: 76 | img = img.resize(self.img_wh, Image.LANCZOS) 77 | img = self.transform(img) # (4, h, w) 78 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA 79 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 80 | self.all_rgbs += [img] 81 | 82 | 83 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 84 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 85 | 86 | 87 | self.poses = torch.stack(self.poses) 88 | if not self.is_stack: 89 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 90 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 91 | 92 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 93 | else: 94 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 95 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 96 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 97 | 98 | 99 | def define_transforms(self): 100 | self.transform = T.ToTensor() 101 | 102 | def define_proj_mat(self): 103 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 104 | 105 | def world2ndc(self,points,lindisp=None): 106 | device = points.device 107 | return (points - self.center.to(device)) / self.radius.to(device) 108 | 109 | def __len__(self): 110 | return len(self.all_rgbs) 111 | 112 | def __getitem__(self, idx): 113 | 114 | if self.split == 'train': # use data in the buffers 115 | sample = {'rays': self.all_rays[idx], 116 | 'rgbs': self.all_rgbs[idx]} 117 | 118 | else: # create data for each image separately 119 | 120 | img = self.all_rgbs[idx] 121 | rays = self.all_rays[idx] 122 | mask = self.all_masks[idx] # for quantity evaluation 123 | 124 | sample = {'rays': rays, 125 | 'rgbs': img, 126 | 'mask': mask} 127 | return sample 128 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2,torch 2 | import numpy as np 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | import torch.nn.functional as F 6 | import scipy.signal 7 | 8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.])) 9 | 10 | 11 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): 12 | """ 13 | depth: (H, W) 14 | """ 15 | 16 | x = np.nan_to_num(depth) # change nan to 0 17 | if minmax is None: 18 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 19 | ma = np.max(x) 20 | else: 21 | mi,ma = minmax 22 | 23 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 24 | x = (255*x).astype(np.uint8) 25 | x_ = cv2.applyColorMap(x, cmap) 26 | return x_, [mi,ma] 27 | 28 | def init_log(log, keys): 29 | for key in keys: 30 | log[key] = torch.tensor([0.0], dtype=float) 31 | return log 32 | 33 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): 34 | """ 35 | depth: (H, W) 36 | """ 37 | if type(depth) is not np.ndarray: 38 | depth = depth.cpu().numpy() 39 | 40 | x = np.nan_to_num(depth) # change nan to 0 41 | if minmax is None: 42 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background) 43 | ma = np.max(x) 44 | else: 45 | mi,ma = minmax 46 | 47 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1 48 | x = (255*x).astype(np.uint8) 49 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 50 | x_ = T.ToTensor()(x_) # (3, H, W) 51 | return x_, [mi,ma] 52 | 53 | def N_to_reso(n_voxels, bbox): 54 | xyz_min, xyz_max = bbox 55 | dim = len(xyz_min) 56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim) 57 | return ((xyz_max - xyz_min) / voxel_size).long().tolist() 58 | 59 | def cal_n_samples(reso, step_ratio=0.5): 60 | return int(np.linalg.norm(reso)/step_ratio) 61 | 62 | 63 | 64 | 65 | __LPIPS__ = {} 66 | def init_lpips(net_name, device): 67 | assert net_name in ['alex', 'vgg'] 68 | import lpips 69 | print(f'init_lpips: lpips_{net_name}') 70 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 71 | 72 | def rgb_lpips(np_gt, np_im, net_name, device): 73 | if net_name not in __LPIPS__: 74 | __LPIPS__[net_name] = init_lpips(net_name, device) 75 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 76 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 77 | return __LPIPS__[net_name](gt, im, normalize=True).item() 78 | 79 | 80 | def findItem(items, target): 81 | for one in items: 82 | if one[:len(target)]==target: 83 | return one 84 | return None 85 | 86 | 87 | ''' Evaluation metrics (ssim, lpips) 88 | ''' 89 | def rgb_ssim(img0, img1, max_val, 90 | filter_size=11, 91 | filter_sigma=1.5, 92 | k1=0.01, 93 | k2=0.03, 94 | return_map=False): 95 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 96 | assert len(img0.shape) == 3 97 | assert img0.shape[-1] == 3 98 | assert img0.shape == img1.shape 99 | 100 | # Construct a 1D Gaussian blur filter. 101 | hw = filter_size // 2 102 | shift = (2 * hw - filter_size + 1) / 2 103 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 104 | filt = np.exp(-0.5 * f_i) 105 | filt /= np.sum(filt) 106 | 107 | # Blur in x and y (faster than the 2D convolution). 108 | def convolve2d(z, f): 109 | return scipy.signal.convolve2d(z, f, mode='valid') 110 | 111 | filt_fn = lambda z: np.stack([ 112 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 113 | for i in range(z.shape[-1])], -1) 114 | mu0 = filt_fn(img0) 115 | mu1 = filt_fn(img1) 116 | mu00 = mu0 * mu0 117 | mu11 = mu1 * mu1 118 | mu01 = mu0 * mu1 119 | sigma00 = filt_fn(img0**2) - mu00 120 | sigma11 = filt_fn(img1**2) - mu11 121 | sigma01 = filt_fn(img0 * img1) - mu01 122 | 123 | # Clip the variances and covariances to valid values. 124 | # Variance must be non-negative: 125 | sigma00 = np.maximum(0., sigma00) 126 | sigma11 = np.maximum(0., sigma11) 127 | sigma01 = np.sign(sigma01) * np.minimum( 128 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 129 | c1 = (k1 * max_val)**2 130 | c2 = (k2 * max_val)**2 131 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 132 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 133 | ssim_map = numer / denom 134 | ssim = np.mean(ssim_map) 135 | return ssim_map if return_map else ssim 136 | 137 | 138 | import torch.nn as nn 139 | class TVLoss(nn.Module): 140 | def __init__(self,TVLoss_weight=1): 141 | super(TVLoss,self).__init__() 142 | self.TVLoss_weight = TVLoss_weight 143 | 144 | def forward(self,x): 145 | batch_size = x.size()[0] 146 | h_x = x.size()[2] 147 | w_x = x.size()[3] 148 | count_h = self._tensor_size(x[:,:,1:,:]) 149 | count_w = self._tensor_size(x[:,:,:,1:]) 150 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 151 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 152 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 153 | 154 | def _tensor_size(self,t): 155 | return t.size()[1]*t.size()[2]*t.size()[3] 156 | -------------------------------------------------------------------------------- /dataLoader/your_own_data.py: -------------------------------------------------------------------------------- 1 | import torch,cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class YourOwnDataset(Dataset): 14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1): 15 | 16 | self.N_vis = N_vis 17 | self.root_dir = datadir 18 | self.split = split 19 | self.is_stack = is_stack 20 | self.downsample = downsample 21 | self.define_transforms() 22 | 23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]]) 24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 25 | self.read_meta() 26 | self.define_proj_mat() 27 | 28 | self.white_bg = True 29 | self.near_far = [0.1,100.0] 30 | 31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 33 | self.downsample=downsample 34 | 35 | def read_depth(self, filename): 36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800) 37 | return depth 38 | 39 | def read_meta(self): 40 | 41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 42 | self.meta = json.load(f) 43 | 44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample) 45 | self.img_wh = [w,h] 46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length 48 | self.cx, self.cy = self.meta['cx'],self.meta['cy'] 49 | 50 | 51 | # ray directions for all pixels, same for all images (same H, W, focal) 52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) 53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float() 55 | 56 | self.image_paths = [] 57 | self.poses = [] 58 | self.all_rays = [] 59 | self.all_rgbs = [] 60 | self.all_masks = [] 61 | self.all_depth = [] 62 | 63 | 64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:# 67 | 68 | frame = self.meta['frames'][i] 69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv 70 | c2w = torch.FloatTensor(pose) 71 | self.poses += [c2w] 72 | 73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png") 74 | self.image_paths += [image_path] 75 | img = Image.open(image_path) 76 | 77 | if self.downsample!=1.0: 78 | img = img.resize(self.img_wh, Image.LANCZOS) 79 | img = self.transform(img) # (4, h, w) 80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA 81 | if img.shape[-1]==4: 82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 83 | self.all_rgbs += [img] 84 | 85 | 86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 88 | 89 | 90 | self.poses = torch.stack(self.poses) 91 | if not self.is_stack: 92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 94 | 95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3) 96 | else: 97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3) 100 | 101 | 102 | def define_transforms(self): 103 | self.transform = T.ToTensor() 104 | 105 | def define_proj_mat(self): 106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3] 107 | 108 | def world2ndc(self,points,lindisp=None): 109 | device = points.device 110 | return (points - self.center.to(device)) / self.radius.to(device) 111 | 112 | def __len__(self): 113 | return len(self.all_rgbs) 114 | 115 | def __getitem__(self, idx): 116 | 117 | if self.split == 'train': # use data in the buffers 118 | sample = {'rays': self.all_rays[idx], 119 | 'rgbs': self.all_rgbs[idx]} 120 | 121 | else: # create data for each image separately 122 | 123 | img = self.all_rgbs[idx] 124 | rays = self.all_rays[idx] 125 | mask = self.all_masks[idx] # for quantity evaluation 126 | 127 | sample = {'rays': rays, 128 | 'rgbs': img} 129 | return sample 130 | -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################## sh function ################## 4 | C0 = 0.28209479177387814 5 | C1 = 0.4886025119029199 6 | C2 = [ 7 | 1.0925484305920792, 8 | -1.0925484305920792, 9 | 0.31539156525252005, 10 | -1.0925484305920792, 11 | 0.5462742152960396 12 | ] 13 | C3 = [ 14 | -0.5900435899266435, 15 | 2.890611442640554, 16 | -0.4570457994644658, 17 | 0.3731763325901154, 18 | -0.4570457994644658, 19 | 1.445305721320277, 20 | -0.5900435899266435 21 | ] 22 | C4 = [ 23 | 2.5033429417967046, 24 | -1.7701307697799304, 25 | 0.9461746957575601, 26 | -0.6690465435572892, 27 | 0.10578554691520431, 28 | -0.6690465435572892, 29 | 0.47308734787878004, 30 | -1.7701307697799304, 31 | 0.6258357354491761, 32 | ] 33 | 34 | def eval_sh(deg, sh, dirs): 35 | """ 36 | Evaluate spherical harmonics at unit directions 37 | using hardcoded SH polynomials. 38 | Works with torch/np/jnp. 39 | ... Can be 0 or more batch dimensions. 40 | :param deg: int SH max degree. Currently, 0-4 supported 41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) 42 | :param dirs: torch.Tensor unit directions (..., 3) 43 | :return: (..., C) 44 | """ 45 | assert deg <= 4 and deg >= 0 46 | assert (deg + 1) ** 2 == sh.shape[-1] 47 | C = sh.shape[-2] 48 | 49 | result = C0 * sh[..., 0] 50 | if deg > 0: 51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 52 | result = (result - 53 | C1 * y * sh[..., 1] + 54 | C1 * z * sh[..., 2] - 55 | C1 * x * sh[..., 3]) 56 | if deg > 1: 57 | xx, yy, zz = x * x, y * y, z * z 58 | xy, yz, xz = x * y, y * z, x * z 59 | result = (result + 60 | C2[0] * xy * sh[..., 4] + 61 | C2[1] * yz * sh[..., 5] + 62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 63 | C2[3] * xz * sh[..., 7] + 64 | C2[4] * (xx - yy) * sh[..., 8]) 65 | 66 | if deg > 2: 67 | result = (result + 68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 69 | C3[1] * xy * z * sh[..., 10] + 70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 73 | C3[5] * z * (xx - yy) * sh[..., 14] + 74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 75 | if deg > 3: 76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 85 | return result 86 | 87 | def eval_sh_bases(deg, dirs): 88 | """ 89 | Evaluate spherical harmonics bases at unit directions, 90 | without taking linear combination. 91 | At each point, the final result may the be 92 | obtained through simple multiplication. 93 | :param deg: int SH max degree. Currently, 0-4 supported 94 | :param dirs: torch.Tensor (..., 3) unit directions 95 | :return: torch.Tensor (..., (deg+1) ** 2) 96 | """ 97 | assert deg <= 4 and deg >= 0 98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) 99 | result[..., 0] = C0 100 | if deg > 0: 101 | x, y, z = dirs.unbind(-1) 102 | result[..., 1] = -C1 * y; 103 | result[..., 2] = C1 * z; 104 | result[..., 3] = -C1 * x; 105 | if deg > 1: 106 | xx, yy, zz = x * x, y * y, z * z 107 | xy, yz, xz = x * y, y * z, x * z 108 | result[..., 4] = C2[0] * xy; 109 | result[..., 5] = C2[1] * yz; 110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy); 111 | result[..., 7] = C2[3] * xz; 112 | result[..., 8] = C2[4] * (xx - yy); 113 | 114 | if deg > 2: 115 | result[..., 9] = C3[0] * y * (3 * xx - yy); 116 | result[..., 10] = C3[1] * xy * z; 117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy); 118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); 119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy); 120 | result[..., 14] = C3[5] * z * (xx - yy); 121 | result[..., 15] = C3[6] * x * (xx - 3 * yy); 122 | 123 | if deg > 3: 124 | result[..., 16] = C4[0] * xy * (xx - yy); 125 | result[..., 17] = C4[1] * yz * (3 * xx - yy); 126 | result[..., 18] = C4[2] * xy * (7 * zz - 1); 127 | result[..., 19] = C4[3] * yz * (7 * zz - 3); 128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); 129 | result[..., 21] = C4[5] * xz * (7 * zz - 3); 130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); 131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy); 132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); 133 | return result 134 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | import torch,os,imageio,sys 2 | from tqdm.auto import tqdm 3 | from dataLoader.ray_utils import get_rays 4 | from models.model import NRFF 5 | from utils import * 6 | from dataLoader.ray_utils import ndc_rays_blender 7 | import time 8 | 9 | 10 | def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'): 11 | 12 | rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], [] 13 | N_rays_all = rays.shape[0] 14 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): 15 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) 16 | 17 | output = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples) 18 | 19 | rgbs.append(output['rgb_map']) 20 | depth_maps.append(output['depth_map']) 21 | 22 | return torch.cat(rgbs), None, torch.cat(depth_maps), None, output 23 | 24 | @torch.no_grad() 25 | def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 26 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 27 | PSNRs, rgb_maps, depth_maps = [], [], [] 28 | ssims,l_alex,l_vgg=[],[],[] 29 | os.makedirs(savePath, exist_ok=True) 30 | os.makedirs(savePath+"/rgbd", exist_ok=True) 31 | 32 | try: 33 | tqdm._instances.clear() 34 | except Exception: 35 | pass 36 | 37 | near_far = test_dataset.near_far 38 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1) 39 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval)) 40 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout): 41 | 42 | W, H = test_dataset.img_wh 43 | rays = samples.view(-1,samples.shape[-1]) 44 | 45 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples, 46 | ndc_ray=ndc_ray, white_bg = white_bg, device=device) 47 | 48 | rgb_map = rgb_map.clamp(0.0, 1.0) 49 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() 50 | 51 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 52 | if len(test_dataset.all_rgbs): 53 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3) 54 | loss = torch.mean((rgb_map - gt_rgb) ** 2) 55 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0)) 56 | 57 | 58 | if compute_extra_metrics: 59 | ssim = rgb_ssim(rgb_map, gt_rgb, 1) 60 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device) 61 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device) 62 | 63 | ssims.append(ssim) 64 | l_alex.append(l_a) 65 | l_vgg.append(l_v) 66 | 67 | 68 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 69 | rgb_maps.append(rgb_map) 70 | depth_maps.append(depth_map) 71 | if savePath is not None: 72 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 73 | 74 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10) 75 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10) 76 | 77 | if PSNRs: 78 | psnr = np.mean(np.asarray(PSNRs)) 79 | if compute_extra_metrics: 80 | ssim = np.mean(np.asarray(ssims)) 81 | l_a = np.mean(np.asarray(l_alex)) 82 | l_v = np.mean(np.asarray(l_vgg)) 83 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) 84 | else: 85 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) 86 | 87 | 88 | return PSNRs 89 | 90 | @torch.no_grad() 91 | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1, 92 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'): 93 | PSNRs, rgb_maps, depth_maps = [], [], [] 94 | ssims,l_alex,l_vgg=[],[],[] 95 | os.makedirs(savePath, exist_ok=True) 96 | os.makedirs(savePath+"/rgbd", exist_ok=True) 97 | 98 | try: 99 | tqdm._instances.clear() 100 | except Exception: 101 | pass 102 | 103 | near_far = test_dataset.near_far 104 | for idx, c2w in tqdm(enumerate(c2ws)): 105 | 106 | W, H = test_dataset.img_wh 107 | 108 | c2w = torch.FloatTensor(c2w) 109 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3) 110 | if ndc_ray: 111 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d) 112 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6) 113 | 114 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples, 115 | ndc_ray=ndc_ray, white_bg = white_bg, device=device) 116 | rgb_map = rgb_map.clamp(0.0, 1.0) 117 | 118 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu() 119 | 120 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far) 121 | 122 | rgb_map = (rgb_map.numpy() * 255).astype('uint8') 123 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 124 | rgb_maps.append(rgb_map) 125 | depth_maps.append(depth_map) 126 | if savePath is not None: 127 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map) 128 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1) 129 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map) 130 | 131 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8) 132 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8) 133 | 134 | if PSNRs: 135 | psnr = np.mean(np.asarray(PSNRs)) 136 | if compute_extra_metrics: 137 | ssim = np.mean(np.asarray(ssims)) 138 | l_a = np.mean(np.asarray(l_alex)) 139 | l_v = np.mean(np.asarray(l_vgg)) 140 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v])) 141 | else: 142 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr])) 143 | 144 | 145 | return PSNRs 146 | 147 | -------------------------------------------------------------------------------- /opt.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | 3 | def config_parser(cmd=None): 4 | parser = configargparse.ArgumentParser() 5 | parser.add_argument('--config', is_config_file=True, 6 | help='config file path') 7 | parser.add_argument("--expname", type=str, 8 | help='experiment name') 9 | parser.add_argument("--basedir", type=str, default='./log', 10 | help='where to store ckpts and logs') 11 | parser.add_argument("--add_timestamp", type=int, default=0, 12 | help='add timestamp to dir') 13 | parser.add_argument("--datadir", type=str, default='./data/llff/fern', 14 | help='input data directory') 15 | parser.add_argument("--progress_refresh_rate", type=int, default=10, 16 | help='how many iterations to show psnrs or iters') 17 | 18 | parser.add_argument('--with_depth', action='store_true') 19 | parser.add_argument('--downsample_train', type=float, default=1.0) 20 | parser.add_argument('--downsample_test', type=float, default=1.0) 21 | 22 | parser.add_argument('--model_name', type=str, default='TensorVMSplit', 23 | choices=['TensorVMSplit', 'TensorCP', 'NRFF']) 24 | 25 | # loader options 26 | parser.add_argument("--batch_size", type=int, default=4096) 27 | parser.add_argument("--n_iters", type=int, default=30000) 28 | 29 | parser.add_argument('--dataset_name', type=str, default='blender', 30 | choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']) 31 | 32 | 33 | # training options 34 | # learning rate 35 | parser.add_argument("--lr_init", type=float, default=0.02, 36 | help='learning rate') 37 | parser.add_argument("--lr_basis", type=float, default=1e-3, 38 | help='learning rate') 39 | parser.add_argument("--lr_decay_iters", type=int, default=-1, 40 | help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters') 41 | parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1, 42 | help='the target decay ratio; after decay_iters inital lr decays to lr*ratio') 43 | parser.add_argument("--lr_upsample_reset", type=int, default=1, 44 | help='reset lr to inital after upsampling') 45 | 46 | # loss 47 | parser.add_argument("--L1_weight_inital", type=float, default=0.0, 48 | help='loss weight') 49 | parser.add_argument("--L1_weight_rest", type=float, default=0, 50 | help='loss weight') 51 | parser.add_argument("--Ortho_weight", type=float, default=0.0, 52 | help='loss weight') 53 | parser.add_argument("--TV_weight_density", type=float, default=0.0, 54 | help='loss weight') 55 | parser.add_argument("--TV_weight_app", type=float, default=0.0, 56 | help='loss weight') 57 | 58 | # model 59 | # volume options 60 | parser.add_argument("--n_lamb_sigma", type=int, action="append") 61 | parser.add_argument("--n_lamb_sh", type=int, action="append") 62 | parser.add_argument("--data_dim_color", type=int, default=27) 63 | 64 | parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001, 65 | help='mask points in ray marching') 66 | parser.add_argument("--alpha_mask_thre", type=float, default=0.00001, 67 | help='threshold for creating alpha mask volume') 68 | parser.add_argument("--distance_scale", type=float, default=25, 69 | help='scaling sampling distance for computation') 70 | parser.add_argument("--density_shift", type=float, default=-10, 71 | help='shift density in softplus; making density = 0 when feature == 0') 72 | 73 | # network decoder 74 | parser.add_argument("--shadingMode", type=str, default="MLP_PE", 75 | help='which shading mode to use') 76 | parser.add_argument("--pos_pe", type=int, default=6, 77 | help='number of pe for pos') 78 | parser.add_argument("--view_pe", type=int, default=6, 79 | help='number of pe for view') 80 | parser.add_argument("--fea_pe", type=int, default=6, 81 | help='number of pe for features') 82 | parser.add_argument("--featureC", type=int, default=128, 83 | help='hidden feature channel in MLP') 84 | 85 | 86 | 87 | parser.add_argument("--ckpt", type=str, default=None, 88 | help='specific weights npy file to reload for coarse network') 89 | parser.add_argument("--render_only", type=int, default=0) 90 | parser.add_argument("--render_test", type=int, default=0) 91 | parser.add_argument("--render_train", type=int, default=0) 92 | parser.add_argument("--render_path", type=int, default=0) 93 | parser.add_argument("--export_mesh", type=int, default=0) 94 | 95 | # rendering options 96 | parser.add_argument('--lindisp', default=False, action="store_true", 97 | help='use disparity depth sampling') 98 | parser.add_argument("--perturb", type=float, default=1., 99 | help='set to 0. for no jitter, 1. for jitter') 100 | parser.add_argument("--accumulate_decay", type=float, default=0.998) 101 | parser.add_argument("--fea2denseAct", type=str, default='softplus') 102 | parser.add_argument('--ndc_ray', type=int, default=0) 103 | parser.add_argument('--nSamples', type=int, default=1e6, 104 | help='sample point each ray, pass 1e6 if automatic adjust') 105 | parser.add_argument('--step_ratio',type=float,default=0.5) 106 | 107 | 108 | ## blender flags 109 | parser.add_argument("--white_bkgd", action='store_true', 110 | help='set to render synthetic data on a white bkgd (always use for dvoxels)') 111 | 112 | 113 | 114 | parser.add_argument('--N_voxel_init', 115 | type=int, 116 | default=100**3) 117 | parser.add_argument('--N_voxel_final', 118 | type=int, 119 | default=300**3) 120 | parser.add_argument("--upsamp_list", type=int, action="append") 121 | parser.add_argument("--update_AlphaMask_list", type=int, action="append") 122 | 123 | parser.add_argument('--idx_view', 124 | type=int, 125 | default=0) 126 | # logging/saving options 127 | parser.add_argument("--N_vis", type=int, default=5, 128 | help='N images to vis') 129 | parser.add_argument("--vis_every", type=int, default=10000, 130 | help='frequency of visualize the image') 131 | if cmd is not None: 132 | return parser.parse_args(cmd) 133 | else: 134 | return parser.parse_args() 135 | -------------------------------------------------------------------------------- /dataLoader/nsvf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | trans_t = lambda t : torch.Tensor([ 11 | [1,0,0,0], 12 | [0,1,0,0], 13 | [0,0,1,t], 14 | [0,0,0,1]]).float() 15 | 16 | rot_phi = lambda phi : torch.Tensor([ 17 | [1,0,0,0], 18 | [0,np.cos(phi),-np.sin(phi),0], 19 | [0,np.sin(phi), np.cos(phi),0], 20 | [0,0,0,1]]).float() 21 | 22 | rot_theta = lambda th : torch.Tensor([ 23 | [np.cos(th),0,-np.sin(th),0], 24 | [0,1,0,0], 25 | [np.sin(th),0, np.cos(th),0], 26 | [0,0,0,1]]).float() 27 | 28 | 29 | def pose_spherical(theta, phi, radius): 30 | c2w = trans_t(radius) 31 | c2w = rot_phi(phi/180.*np.pi) @ c2w 32 | c2w = rot_theta(theta/180.*np.pi) @ c2w 33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w 34 | return c2w 35 | 36 | class NSVF(Dataset): 37 | """NSVF Generic Dataset.""" 38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False): 39 | self.root_dir = datadir 40 | self.split = split 41 | self.is_stack = is_stack 42 | self.downsample = downsample 43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 44 | self.define_transforms() 45 | 46 | self.white_bg = True 47 | self.near_far = [0.5,6.0] 48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3) 49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 50 | self.read_meta() 51 | self.define_proj_mat() 52 | 53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 55 | 56 | def bbox2corners(self): 57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 58 | for i in range(3): 59 | corners[i,[0,1],i] = corners[i,[1,0],i] 60 | return corners.view(-1,3) 61 | 62 | 63 | def read_meta(self): 64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f: 65 | focal = float(f.readline().split()[0]) 66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]]) 67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1) 68 | 69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 71 | 72 | if self.split == 'train': 73 | pose_files = [x for x in pose_files if x.startswith('0_')] 74 | img_files = [x for x in img_files if x.startswith('0_')] 75 | elif self.split == 'val': 76 | pose_files = [x for x in pose_files if x.startswith('1_')] 77 | img_files = [x for x in img_files if x.startswith('1_')] 78 | elif self.split == 'test': 79 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 80 | test_img_files = [x for x in img_files if x.startswith('2_')] 81 | if len(test_pose_files) == 0: 82 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 83 | test_img_files = [x for x in img_files if x.startswith('1_')] 84 | pose_files = test_pose_files 85 | img_files = test_img_files 86 | 87 | # ray directions for all pixels, same for all images (same H, W, focal) 88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 90 | 91 | 92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0) 93 | 94 | self.poses = [] 95 | self.all_rays = [] 96 | self.all_rgbs = [] 97 | 98 | assert len(img_files) == len(pose_files) 99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 101 | img = Image.open(image_path) 102 | if self.downsample!=1.0: 103 | img = img.resize(self.img_wh, Image.LANCZOS) 104 | img = self.transform(img) # (4, h, w) 105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 106 | if img.shape[-1]==4: 107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 108 | self.all_rgbs += [img] 109 | 110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv 111 | c2w = torch.FloatTensor(c2w) 112 | self.poses.append(c2w) # C2W 113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 115 | 116 | # w2c = torch.inverse(c2w) 117 | # 118 | 119 | self.poses = torch.stack(self.poses) 120 | if 'train' == self.split: 121 | if self.is_stack: 122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 124 | else: 125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 127 | else: 128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 130 | 131 | 132 | def define_transforms(self): 133 | self.transform = T.ToTensor() 134 | 135 | def define_proj_mat(self): 136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 137 | 138 | def world2ndc(self, points): 139 | device = points.device 140 | return (points - self.center.to(device)) / self.radius.to(device) 141 | 142 | def __len__(self): 143 | if self.split == 'train': 144 | return len(self.all_rays) 145 | return len(self.all_rgbs) 146 | 147 | def __getitem__(self, idx): 148 | 149 | if self.split == 'train': # use data in the buffers 150 | sample = {'rays': self.all_rays[idx], 151 | 'rgbs': self.all_rgbs[idx]} 152 | 153 | else: # create data for each image separately 154 | 155 | img = self.all_rgbs[idx] 156 | rays = self.all_rays[idx] 157 | 158 | sample = {'rays': rays, 159 | 'rgbs': img} 160 | return sample -------------------------------------------------------------------------------- /extra/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import os, math 2 | import numpy as np 3 | import scipy.signal 4 | from typing import List, Optional 5 | from PIL import Image 6 | import os 7 | import torch 8 | import configargparse 9 | 10 | __LPIPS__ = {} 11 | def init_lpips(net_name, device): 12 | assert net_name in ['alex', 'vgg'] 13 | import lpips 14 | print(f'init_lpips: lpips_{net_name}') 15 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device) 16 | 17 | def rgb_lpips(np_gt, np_im, net_name, device): 18 | if net_name not in __LPIPS__: 19 | __LPIPS__[net_name] = init_lpips(net_name, device) 20 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device) 21 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device) 22 | return __LPIPS__[net_name](gt, im, normalize=True).item() 23 | 24 | 25 | def findItem(items, target): 26 | for one in items: 27 | if one[:len(target)]==target: 28 | return one 29 | return None 30 | 31 | 32 | ''' Evaluation metrics (ssim, lpips) 33 | ''' 34 | def rgb_ssim(img0, img1, max_val, 35 | filter_size=11, 36 | filter_sigma=1.5, 37 | k1=0.01, 38 | k2=0.03, 39 | return_map=False): 40 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58 41 | assert len(img0.shape) == 3 42 | assert img0.shape[-1] == 3 43 | assert img0.shape == img1.shape 44 | 45 | # Construct a 1D Gaussian blur filter. 46 | hw = filter_size // 2 47 | shift = (2 * hw - filter_size + 1) / 2 48 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2 49 | filt = np.exp(-0.5 * f_i) 50 | filt /= np.sum(filt) 51 | 52 | # Blur in x and y (faster than the 2D convolution). 53 | def convolve2d(z, f): 54 | return scipy.signal.convolve2d(z, f, mode='valid') 55 | 56 | filt_fn = lambda z: np.stack([ 57 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :]) 58 | for i in range(z.shape[-1])], -1) 59 | mu0 = filt_fn(img0) 60 | mu1 = filt_fn(img1) 61 | mu00 = mu0 * mu0 62 | mu11 = mu1 * mu1 63 | mu01 = mu0 * mu1 64 | sigma00 = filt_fn(img0**2) - mu00 65 | sigma11 = filt_fn(img1**2) - mu11 66 | sigma01 = filt_fn(img0 * img1) - mu01 67 | 68 | # Clip the variances and covariances to valid values. 69 | # Variance must be non-negative: 70 | sigma00 = np.maximum(0., sigma00) 71 | sigma11 = np.maximum(0., sigma11) 72 | sigma01 = np.sign(sigma01) * np.minimum( 73 | np.sqrt(sigma00 * sigma11), np.abs(sigma01)) 74 | c1 = (k1 * max_val)**2 75 | c2 = (k2 * max_val)**2 76 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 77 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 78 | ssim_map = numer / denom 79 | ssim = np.mean(ssim_map) 80 | return ssim_map if return_map else ssim 81 | 82 | 83 | if __name__ == '__main__': 84 | 85 | parser = configargparse.ArgumentParser() 86 | parser.add_argument("--exp", type=str, help="folder of exps") 87 | parser.add_argument("--paramStr", type=str, help="str of params") 88 | args = parser.parse_args() 89 | 90 | 91 | # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']# 92 | # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic" 93 | # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp 94 | 95 | # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']# 96 | # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/" 97 | # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp 98 | paramStr = args.paramStr 99 | fileNum = 200 100 | 101 | 102 | expitems = os.listdir(expFolder) 103 | finalFolder = f'{expFolder}/finals/{paramStr}' 104 | outFile = f'{finalFolder}/{paramStr}_metrics.txt' 105 | os.makedirs(finalFolder, exist_ok=True) 106 | 107 | expitems.sort(reverse=True) 108 | 109 | 110 | with open(outFile, 'w') as f: 111 | all_psnr = [] 112 | all_ssim = [] 113 | all_alex = [] 114 | all_vgg = [] 115 | for dataname in datanames: 116 | 117 | 118 | gtstr = gtFolder+"/"+dataname+"/test/r_%d.png" 119 | expname = findItem(expitems, f'{paramStr}-{dataname}') 120 | print("expname: ", expname) 121 | if expname is None: 122 | print("no ",dataname, "exists") 123 | continue 124 | resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png" 125 | metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt' 126 | video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4' 127 | 128 | exist_metric=False 129 | if os.path.isfile(metric_file): 130 | metrics = np.loadtxt(metric_file) 131 | print(metrics, metrics.tolist()) 132 | if metrics.size == 4: 133 | psnr, ssim, l_a, l_v = metrics.tolist() 134 | exist_metric = True 135 | os.system(f"cp {video_file} {finalFolder}/") 136 | 137 | if not exist_metric: 138 | psnrs = [] 139 | ssims = [] 140 | l_alex = [] 141 | l_vgg = [] 142 | for i in range(fileNum): 143 | gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0 144 | gtmask = gt[...,[3]] 145 | gt = gt[...,:3] 146 | gt = gt*gtmask + (1-gtmask) 147 | img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0 148 | # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max()) 149 | 150 | 151 | psnr = -10. * np.log10(np.mean(np.square(img - gt))) 152 | ssim = rgb_ssim(img, gt, 1) 153 | lpips_alex = rgb_lpips(gt, img, 'alex','cuda') 154 | lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda') 155 | 156 | print(i, psnr, ssim, lpips_alex, lpips_vgg) 157 | psnrs.append(psnr) 158 | ssims.append(ssim) 159 | l_alex.append(lpips_alex) 160 | l_vgg.append(lpips_vgg) 161 | psnr = np.mean(np.array(psnrs)) 162 | ssim = np.mean(np.array(ssims)) 163 | l_a = np.mean(np.array(l_alex)) 164 | l_v = np.mean(np.array(l_vgg)) 165 | 166 | rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' 167 | print(rS) 168 | f.write(rS+"\n") 169 | 170 | all_psnr.append(psnr) 171 | all_ssim.append(ssim) 172 | all_alex.append(l_a) 173 | all_vgg.append(l_v) 174 | 175 | psnr = np.mean(np.array(all_psnr)) 176 | ssim = np.mean(np.array(all_ssim)) 177 | l_a = np.mean(np.array(all_alex)) 178 | l_v = np.mean(np.array(all_vgg)) 179 | 180 | rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}' 181 | print(rS) 182 | f.write(rS+"\n") -------------------------------------------------------------------------------- /extra/auto_run_paramsets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading, queue 3 | import numpy as np 4 | import time 5 | 6 | 7 | def getFolderLocker(logFolder): 8 | while True: 9 | try: 10 | os.makedirs(logFolder+"/lockFolder") 11 | break 12 | except: 13 | time.sleep(0.01) 14 | 15 | def releaseFolderLocker(logFolder): 16 | os.removedirs(logFolder+"/lockFolder") 17 | 18 | def getStopFolder(logFolder): 19 | return os.path.isdir(logFolder+"/stopFolder") 20 | 21 | 22 | def get_param_str(key, val): 23 | if key == 'data_name': 24 | return f'--datadir {datafolder}/{val} ' 25 | else: 26 | return f'--{key} {val} ' 27 | 28 | def get_param_list(param_dict): 29 | param_keys = list(param_dict.keys()) 30 | param_modes = len(param_keys) 31 | param_nums = [len(param_dict[key]) for key in param_keys] 32 | 33 | param_ids = np.zeros(param_nums+[param_modes], dtype=int) 34 | for i in range(param_modes): 35 | broad_tuple = np.ones(param_modes, dtype=int).tolist() 36 | broad_tuple[i] = param_nums[i] 37 | broad_tuple = tuple(broad_tuple) 38 | print(broad_tuple) 39 | param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple) 40 | param_ids = param_ids.reshape(-1, param_modes) 41 | # print(param_ids) 42 | print(len(param_ids)) 43 | 44 | params = [] 45 | expnames = [] 46 | for i in range(param_ids.shape[0]): 47 | one = "" 48 | name = "" 49 | param_id = param_ids[i] 50 | for j in range(param_modes): 51 | key = param_keys[j] 52 | val = param_dict[key][param_id[j]] 53 | if type(key) is tuple: 54 | assert len(key) == len(val) 55 | for k in range(len(key)): 56 | one += get_param_str(key[k], val[k]) 57 | name += f'{val[k]},' 58 | name=name[:-1]+'-' 59 | else: 60 | one += get_param_str(key, val) 61 | name += f'{val}-' 62 | params.append(one) 63 | name=name.replace(' ','') 64 | print(name) 65 | expnames.append(name[:-1]) 66 | # print(params) 67 | return params, expnames 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | if __name__ == '__main__': 76 | 77 | 78 | 79 | # nerf 80 | expFolder = "nerf/" 81 | # parameters to iterate, use tuple to couple multiple parameters 82 | datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/' 83 | param_dict = { 84 | 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'], 85 | 'data_dim_color': [13, 27, 54] 86 | } 87 | 88 | # n_iters = 30000 89 | # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' 90 | # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \ 91 | # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\ 92 | # f'--expname {data_name} --batch_size {batch_size} ' \ 93 | # f'--n_iters {n_iters} ' \ 94 | # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\ 95 | # f'--N_vis {5} ' \ 96 | # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \ 97 | # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \ 98 | # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \ 99 | # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \ 100 | # f'--render_test 1 ' 101 | # print(cmd) 102 | # os.system(cmd) 103 | 104 | # nsvf 105 | # expFolder = "nsvf_0227/" 106 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/' 107 | # param_dict = { 108 | # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder' 109 | # 'shadingMode': ['SH'], 110 | # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")], 111 | # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)], 112 | # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)], 113 | # ('n_iters','N_voxel_final'): [(30000,300**3)], 114 | # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)], 115 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")] 116 | # 117 | # } 118 | 119 | # tankstemple 120 | # expFolder = "tankstemple_0304/" 121 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/' 122 | # param_dict = { 123 | # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'], 124 | # 'shadingMode': ['MLP_Fea'], 125 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")], 126 | # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)], 127 | # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)], 128 | # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)], 129 | # ('n_iters','N_voxel_final'): [(15000,300**3)], 130 | # ('dataset_name','N_vis') : [("tankstemple",5)], 131 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")] 132 | # } 133 | 134 | # llff 135 | # expFolder = "real_iconic/" 136 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/' 137 | # List = os.listdir(datafolder) 138 | # param_dict = { 139 | # 'data_name': List, 140 | # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)], 141 | # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")], 142 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], 143 | # ('n_iters','N_voxel_final'): [(25000,640**3)], 144 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)], 145 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], 146 | # } 147 | 148 | # expFolder = "llff/" 149 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data' 150 | # param_dict = { 151 | # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids' 152 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")], 153 | # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)], 154 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)], 155 | # ('n_iters','N_voxel_final'): [(25000,640**3)], 156 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)], 157 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")], 158 | # } 159 | 160 | #setting available gpus 161 | gpus_que = queue.Queue(3) 162 | for i in [1,2,3]: 163 | gpus_que.put(i) 164 | 165 | os.makedirs(f"log/{expFolder}", exist_ok=True) 166 | 167 | def run_program(gpu, expname, param): 168 | cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \ 169 | f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \ 170 | f'{param}' \ 171 | f'> "log/{expFolder}{expname}/{expname}.txt"' 172 | print(cmd) 173 | os.system(cmd) 174 | gpus_que.put(gpu) 175 | 176 | params, expnames = get_param_list(param_dict) 177 | 178 | 179 | logFolder=f"log/{expFolder}" 180 | os.makedirs(logFolder, exist_ok=True) 181 | 182 | ths = [] 183 | for i in range(len(params)): 184 | 185 | if getStopFolder(logFolder): 186 | break 187 | 188 | 189 | targetFolder = f"log/{expFolder}{expnames[i]}" 190 | gpu = gpus_que.get() 191 | getFolderLocker(logFolder) 192 | if os.path.isdir(targetFolder): 193 | releaseFolderLocker(logFolder) 194 | gpus_que.put(gpu) 195 | continue 196 | else: 197 | os.makedirs(targetFolder, exist_ok=True) 198 | print("making",targetFolder, "running",expnames[i], params[i]) 199 | releaseFolderLocker(logFolder) 200 | 201 | 202 | t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True) 203 | t.start() 204 | ths.append(t) 205 | 206 | for th in ths: 207 | th.join() -------------------------------------------------------------------------------- /dataLoader/tankstemple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import os 5 | from PIL import Image 6 | from torchvision import transforms as T 7 | 8 | from .ray_utils import * 9 | 10 | 11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1): 12 | if axis == 'z': 13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h] 14 | elif axis == 'y': 15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)] 16 | else: 17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)] 18 | 19 | 20 | def cross(x, y, axis=0): 21 | T = torch if isinstance(x, torch.Tensor) else np 22 | return T.cross(x, y, axis) 23 | 24 | 25 | def normalize(x, axis=-1, order=2): 26 | if isinstance(x, torch.Tensor): 27 | l2 = x.norm(p=order, dim=axis, keepdim=True) 28 | return x / (l2 + 1e-8), l2 29 | 30 | else: 31 | l2 = np.linalg.norm(x, order, axis) 32 | l2 = np.expand_dims(l2, axis) 33 | l2[l2 == 0] = 1 34 | return x / l2, 35 | 36 | 37 | def cat(x, axis=1): 38 | if isinstance(x[0], torch.Tensor): 39 | return torch.cat(x, dim=axis) 40 | return np.concatenate(x, axis=axis) 41 | 42 | 43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False): 44 | """ 45 | This function takes a vector 'camera_position' which specifies the location 46 | of the camera in world coordinates and two vectors `at` and `up` which 47 | indicate the position of the object and the up directions of the world 48 | coordinate system respectively. The object is assumed to be centered at 49 | the origin. 50 | The output is a rotation matrix representing the transformation 51 | from world coordinates -> view coordinates. 52 | Input: 53 | camera_position: 3 54 | at: 1 x 3 or N x 3 (0, 0, 0) in default 55 | up: 1 x 3 or N x 3 (0, 1, 0) in default 56 | """ 57 | 58 | if at is None: 59 | at = torch.zeros_like(camera_position) 60 | else: 61 | at = torch.tensor(at).type_as(camera_position) 62 | if up is None: 63 | up = torch.zeros_like(camera_position) 64 | up[2] = -1 65 | else: 66 | up = torch.tensor(up).type_as(camera_position) 67 | 68 | z_axis = normalize(at - camera_position)[0] 69 | x_axis = normalize(cross(up, z_axis))[0] 70 | y_axis = normalize(cross(z_axis, x_axis))[0] 71 | 72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1) 73 | return R 74 | 75 | 76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180): 77 | c2ws = [] 78 | for t in range(frames): 79 | c2w = torch.eye(4) 80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi)) 81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True) 82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot 83 | c2ws.append(c2w) 84 | return torch.stack(c2ws) 85 | 86 | class TanksTempleDataset(Dataset): 87 | """NSVF Generic Dataset.""" 88 | def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False): 89 | self.root_dir = datadir 90 | self.split = split 91 | self.is_stack = is_stack 92 | self.downsample = downsample 93 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample)) 94 | self.define_transforms() 95 | 96 | self.white_bg = True 97 | self.near_far = [0.01,6.0] 98 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2 99 | 100 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 101 | self.read_meta() 102 | self.define_proj_mat() 103 | 104 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3) 105 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 106 | 107 | def bbox2corners(self): 108 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1) 109 | for i in range(3): 110 | corners[i,[0,1],i] = corners[i,[1,0],i] 111 | return corners.view(-1,3) 112 | 113 | 114 | def read_meta(self): 115 | 116 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt")) 117 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1) 118 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose'))) 119 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb'))) 120 | 121 | if self.split == 'train': 122 | pose_files = [x for x in pose_files if x.startswith('0_')] 123 | img_files = [x for x in img_files if x.startswith('0_')] 124 | elif self.split == 'val': 125 | pose_files = [x for x in pose_files if x.startswith('1_')] 126 | img_files = [x for x in img_files if x.startswith('1_')] 127 | elif self.split == 'test': 128 | test_pose_files = [x for x in pose_files if x.startswith('2_')] 129 | test_img_files = [x for x in img_files if x.startswith('2_')] 130 | if len(test_pose_files) == 0: 131 | test_pose_files = [x for x in pose_files if x.startswith('1_')] 132 | test_img_files = [x for x in img_files if x.startswith('1_')] 133 | pose_files = test_pose_files 134 | img_files = test_img_files 135 | 136 | # ray directions for all pixels, same for all images (same H, W, focal) 137 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3) 138 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 139 | 140 | 141 | 142 | self.poses = [] 143 | self.all_rays = [] 144 | self.all_rgbs = [] 145 | 146 | assert len(img_files) == len(pose_files) 147 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'): 148 | image_path = os.path.join(self.root_dir, 'rgb', img_fname) 149 | img = Image.open(image_path) 150 | if self.downsample!=1.0: 151 | img = img.resize(self.img_wh, Image.LANCZOS) 152 | img = self.transform(img) # (4, h, w) 153 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA 154 | if img.shape[-1]==4: 155 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 156 | self.all_rgbs.append(img) 157 | 158 | 159 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans 160 | c2w = torch.FloatTensor(c2w) 161 | self.poses.append(c2w) # C2W 162 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 163 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8) 164 | 165 | self.poses = torch.stack(self.poses) 166 | 167 | center = torch.mean(self.scene_bbox, dim=0) 168 | radius = torch.norm(self.scene_bbox[1]-center)*1.2 169 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist() 170 | pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y') 171 | self.render_path = gen_path(pos_gen, up=up,frames=200) 172 | self.render_path[:, :3, 3] += center 173 | 174 | 175 | 176 | if 'train' == self.split: 177 | if self.is_stack: 178 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3) 179 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3) 180 | else: 181 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 182 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3) 183 | else: 184 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3) 185 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 186 | 187 | 188 | def define_transforms(self): 189 | self.transform = T.ToTensor() 190 | 191 | def define_proj_mat(self): 192 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3] 193 | 194 | def world2ndc(self, points): 195 | device = points.device 196 | return (points - self.center.to(device)) / self.radius.to(device) 197 | 198 | def __len__(self): 199 | if self.split == 'train': 200 | return len(self.all_rays) 201 | return len(self.all_rgbs) 202 | 203 | def __getitem__(self, idx): 204 | 205 | if self.split == 'train': # use data in the buffers 206 | sample = {'rays': self.all_rays[idx], 207 | 'rgbs': self.all_rgbs[idx]} 208 | 209 | else: # create data for each image separately 210 | 211 | img = self.all_rgbs[idx] 212 | rays = self.all_rays[idx] 213 | 214 | sample = {'rays': rays, 215 | 'rgbs': img} 216 | return sample -------------------------------------------------------------------------------- /dataLoader/llff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import glob 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .ray_utils import * 10 | 11 | 12 | def normalize(v): 13 | """Normalize a vector.""" 14 | return v / np.linalg.norm(v) 15 | 16 | 17 | def average_poses(poses): 18 | """ 19 | Calculate the average pose, which is then used to center all poses 20 | using @center_poses. Its computation is as follows: 21 | 1. Compute the center: the average of pose centers. 22 | 2. Compute the z axis: the normalized average z axis. 23 | 3. Compute axis y': the average y axis. 24 | 4. Compute x' = y' cross product z, then normalize it as the x axis. 25 | 5. Compute the y axis: z cross product x. 26 | 27 | Note that at step 3, we cannot directly use y' as y axis since it's 28 | not necessarily orthogonal to z axis. We need to pass from x to y. 29 | Inputs: 30 | poses: (N_images, 3, 4) 31 | Outputs: 32 | pose_avg: (3, 4) the average pose 33 | """ 34 | # 1. Compute the center 35 | center = poses[..., 3].mean(0) # (3) 36 | 37 | # 2. Compute the z axis 38 | z = normalize(poses[..., 2].mean(0)) # (3) 39 | 40 | # 3. Compute axis y' (no need to normalize as it's not the final output) 41 | y_ = poses[..., 1].mean(0) # (3) 42 | 43 | # 4. Compute the x axis 44 | x = normalize(np.cross(z, y_)) # (3) 45 | 46 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1) 47 | y = np.cross(x, z) # (3) 48 | 49 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4) 50 | 51 | return pose_avg 52 | 53 | 54 | def center_poses(poses, blender2opencv): 55 | """ 56 | Center the poses so that we can use NDC. 57 | See https://github.com/bmild/nerf/issues/34 58 | Inputs: 59 | poses: (N_images, 3, 4) 60 | Outputs: 61 | poses_centered: (N_images, 3, 4) the centered poses 62 | pose_avg: (3, 4) the average pose 63 | """ 64 | poses = poses @ blender2opencv 65 | pose_avg = average_poses(poses) # (3, 4) 66 | pose_avg_homo = np.eye(4) 67 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation 68 | pose_avg_homo = pose_avg_homo 69 | # by simply adding 0, 0, 0, 1 as the last row 70 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4) 71 | poses_homo = \ 72 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate 73 | 74 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4) 75 | # poses_centered = poses_centered @ blender2opencv 76 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4) 77 | 78 | return poses_centered, pose_avg_homo 79 | 80 | 81 | def viewmatrix(z, up, pos): 82 | vec2 = normalize(z) 83 | vec1_avg = up 84 | vec0 = normalize(np.cross(vec1_avg, vec2)) 85 | vec1 = normalize(np.cross(vec2, vec0)) 86 | m = np.eye(4) 87 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1) 88 | return m 89 | 90 | 91 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120): 92 | render_poses = [] 93 | rads = np.array(list(rads) + [1.]) 94 | 95 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]: 96 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads) 97 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.]))) 98 | render_poses.append(viewmatrix(z, up, c)) 99 | return render_poses 100 | 101 | 102 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120): 103 | # center pose 104 | c2w = average_poses(c2ws_all) 105 | 106 | # Get average pose 107 | up = normalize(c2ws_all[:, :3, 1].sum(0)) 108 | 109 | # Find a reasonable "focus depth" for this dataset 110 | dt = 0.75 111 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0 112 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth)) 113 | 114 | # Get radii for spiral path 115 | zdelta = near_fars.min() * .2 116 | tt = c2ws_all[:, :3, 3] 117 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale 118 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views) 119 | return np.stack(render_poses) 120 | 121 | 122 | class LLFFDataset(Dataset): 123 | def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8): 124 | """ 125 | spheric_poses: whether the images are taken in a spheric inward-facing manner 126 | default: False (forward-facing) 127 | val_num: number of val images (used for multigpu training, validate same image for all gpus) 128 | """ 129 | 130 | self.root_dir = datadir 131 | self.split = split 132 | self.hold_every = hold_every 133 | self.is_stack = is_stack 134 | self.downsample = downsample 135 | self.define_transforms() 136 | 137 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 138 | self.read_meta() 139 | self.white_bg = False 140 | 141 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])] 142 | self.near_far = [0.0, 1.0] 143 | self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]) 144 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]]) 145 | self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3) 146 | self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3) 147 | 148 | def read_meta(self): 149 | 150 | 151 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17) 152 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*'))) 153 | # load full resolution image then resize 154 | if self.split in ['train', 'test']: 155 | assert len(poses_bounds) == len(self.image_paths), \ 156 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!' 157 | 158 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5) 159 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2) 160 | hwf = poses[:, :, -1] 161 | 162 | # Step 1: rescale focal length according to training resolution 163 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images 164 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)]) 165 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H] 166 | 167 | # Step 2: correct poses 168 | # Original poses has rotation in form "down right back", change to "right up back" 169 | # See https://github.com/bmild/nerf/issues/34 170 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1) 171 | # (N_images, 3, 4) exclude H, W, focal 172 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv) 173 | 174 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0 175 | # See https://github.com/bmild/nerf/issues/34 176 | near_original = self.near_fars.min() 177 | scale_factor = near_original * 0.75 # 0.75 is the default parameter 178 | # the nearest depth is at 1/0.75=1.33 179 | self.near_fars /= scale_factor 180 | self.poses[..., 3] /= scale_factor 181 | 182 | # build rendering path 183 | N_views, N_rots = 120, 2 184 | tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T 185 | up = normalize(self.poses[:, :3, 1].sum(0)) 186 | rads = np.percentile(np.abs(tt), 90, 0) 187 | 188 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views) 189 | 190 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1) 191 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to 192 | # center image 193 | 194 | # ray directions for all pixels, same for all images (same H, W, focal) 195 | W, H = self.img_wh 196 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3) 197 | 198 | average_pose = average_poses(self.poses) 199 | dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1) 200 | i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)] 201 | img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test)) 202 | 203 | # use first N_images-1 to train, the LAST is val 204 | self.all_rays = [] 205 | self.all_rgbs = [] 206 | for i in img_list: 207 | image_path = self.image_paths[i] 208 | c2w = torch.FloatTensor(self.poses[i]) 209 | 210 | img = Image.open(image_path).convert('RGB') 211 | if self.downsample != 1.0: 212 | img = img.resize(self.img_wh, Image.LANCZOS) 213 | img = self.transform(img) # (3, h, w) 214 | 215 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB 216 | self.all_rgbs += [img] 217 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 218 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d) 219 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 220 | 221 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6) 222 | 223 | if not self.is_stack: 224 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3) 225 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3) 226 | else: 227 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3) 228 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 229 | 230 | 231 | def define_transforms(self): 232 | self.transform = T.ToTensor() 233 | 234 | def __len__(self): 235 | return len(self.all_rgbs) 236 | 237 | def __getitem__(self, idx): 238 | 239 | sample = {'rays': self.all_rays[idx], 240 | 'rgbs': self.all_rgbs[idx]} 241 | 242 | return sample -------------------------------------------------------------------------------- /dataLoader/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch, re 2 | import numpy as np 3 | from torch import searchsorted 4 | from kornia import create_meshgrid 5 | 6 | 7 | # from utils import index_point_feature 8 | 9 | def depth2dist(z_vals, cos_angle): 10 | # z_vals: [N_ray N_sample] 11 | device = z_vals.device 12 | dists = z_vals[..., 1:] - z_vals[..., :-1] 13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 14 | dists = dists * cos_angle.unsqueeze(-1) 15 | return dists 16 | 17 | 18 | def ndc2dist(ndc_pts, cos_angle): 19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) 20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] 21 | return dists 22 | 23 | 24 | def get_ray_directions(H, W, focal, center=None): 25 | """ 26 | Get ray directions for all pixels in camera coordinate. 27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 28 | ray-tracing-generating-camera-rays/standard-coordinate-systems 29 | Inputs: 30 | H, W, focal: image height, width and focal length 31 | Outputs: 32 | directions: (H, W, 3), the direction of the rays in camera coordinate 33 | """ 34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 35 | 36 | i, j = grid.unbind(-1) 37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 38 | # see https://github.com/bmild/nerf/issues/24 39 | cent = center if center is not None else [W / 2, H / 2] 40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) 41 | 42 | return directions 43 | 44 | 45 | def get_ray_directions_blender(H, W, focal, center=None): 46 | """ 47 | Get ray directions for all pixels in camera coordinate. 48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 49 | ray-tracing-generating-camera-rays/standard-coordinate-systems 50 | Inputs: 51 | H, W, focal: image height, width and focal length 52 | Outputs: 53 | directions: (H, W, 3), the direction of the rays in camera coordinate 54 | """ 55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 56 | i, j = grid.unbind(-1) 57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 58 | # see https://github.com/bmild/nerf/issues/24 59 | cent = center if center is not None else [W / 2, H / 2] 60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], 61 | -1) # (H, W, 3) 62 | 63 | return directions 64 | 65 | 66 | def get_rays(directions, c2w): 67 | """ 68 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 69 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 70 | ray-tracing-generating-camera-rays/standard-coordinate-systems 71 | Inputs: 72 | directions: (H, W, 3) precomputed ray directions in camera coordinate 73 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 74 | Outputs: 75 | rays_o: (H*W, 3), the origin of the rays in world coordinate 76 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 77 | """ 78 | # Rotate ray directions from camera coordinate to the world coordinate 79 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 80 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 81 | # The origin of all rays is the camera origin in world coordinate 82 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 83 | 84 | rays_d = rays_d.view(-1, 3) 85 | rays_o = rays_o.view(-1, 3) 86 | 87 | return rays_o, rays_d 88 | 89 | 90 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d): 91 | # Shift ray origins to near plane 92 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 93 | rays_o = rays_o + t[..., None] * rays_d 94 | 95 | # Projection 96 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 97 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 98 | o2 = 1. + 2. * near / rays_o[..., 2] 99 | 100 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 101 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 102 | d2 = -2. * near / rays_o[..., 2] 103 | 104 | rays_o = torch.stack([o0, o1, o2], -1) 105 | rays_d = torch.stack([d0, d1, d2], -1) 106 | 107 | return rays_o, rays_d 108 | 109 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 110 | # Shift ray origins to near plane 111 | t = (near - rays_o[..., 2]) / rays_d[..., 2] 112 | rays_o = rays_o + t[..., None] * rays_d 113 | 114 | # Projection 115 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 116 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 117 | o2 = 1. - 2. * near / rays_o[..., 2] 118 | 119 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 120 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 121 | d2 = 2. * near / rays_o[..., 2] 122 | 123 | rays_o = torch.stack([o0, o1, o2], -1) 124 | rays_d = torch.stack([d0, d1, d2], -1) 125 | 126 | return rays_o, rays_d 127 | 128 | # Hierarchical sampling (section 5.2) 129 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 130 | device = weights.device 131 | # Get pdf 132 | weights = weights + 1e-5 # prevent nans 133 | pdf = weights / torch.sum(weights, -1, keepdim=True) 134 | cdf = torch.cumsum(pdf, -1) 135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 136 | 137 | # Take uniform samples 138 | if det: 139 | u = torch.linspace(0., 1., steps=N_samples, device=device) 140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 141 | else: 142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device) 143 | 144 | # Pytest, overwrite u with numpy's fixed random numbers 145 | if pytest: 146 | np.random.seed(0) 147 | new_shape = list(cdf.shape[:-1]) + [N_samples] 148 | if det: 149 | u = np.linspace(0., 1., N_samples) 150 | u = np.broadcast_to(u, new_shape) 151 | else: 152 | u = np.random.rand(*new_shape) 153 | u = torch.Tensor(u) 154 | 155 | # Invert CDF 156 | u = u.contiguous() 157 | inds = searchsorted(cdf.detach(), u, right=True) 158 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 159 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 160 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 161 | 162 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 163 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 164 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 165 | 166 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 167 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 168 | t = (u - cdf_g[..., 0]) / denom 169 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 170 | 171 | return samples 172 | 173 | 174 | def dda(rays_o, rays_d, bbox_3D): 175 | inv_ray_d = 1.0 / (rays_d + 1e-6) 176 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3 177 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d 178 | t = torch.stack((t_min, t_max)) # 2 N_rays 3 179 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0] 180 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0] 181 | return t_min, t_max 182 | 183 | 184 | def ray_marcher(rays, 185 | N_samples=64, 186 | lindisp=False, 187 | perturb=0, 188 | bbox_3D=None): 189 | """ 190 | sample points along the rays 191 | Inputs: 192 | rays: () 193 | 194 | Returns: 195 | 196 | """ 197 | 198 | # Decompose the inputs 199 | N_rays = rays.shape[0] 200 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 201 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 202 | 203 | if bbox_3D is not None: 204 | # cal aabb boundles 205 | near, far = dda(rays_o, rays_d, bbox_3D) 206 | 207 | # Sample depth points 208 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) 209 | if not lindisp: # use linear sampling in depth space 210 | z_vals = near * (1 - z_steps) + far * z_steps 211 | else: # use linear sampling in disparity space 212 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) 213 | 214 | z_vals = z_vals.expand(N_rays, N_samples) 215 | 216 | if perturb > 0: # perturb sampling depths (z_vals) 217 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points 218 | # get intervals between samples 219 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 220 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 221 | 222 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) 223 | z_vals = lower + (upper - lower) * perturb_rand 224 | 225 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 226 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 227 | 228 | return xyz_coarse_sampled, rays_o, rays_d, z_vals 229 | 230 | 231 | def read_pfm(filename): 232 | file = open(filename, 'rb') 233 | color = None 234 | width = None 235 | height = None 236 | scale = None 237 | endian = None 238 | 239 | header = file.readline().decode('utf-8').rstrip() 240 | if header == 'PF': 241 | color = True 242 | elif header == 'Pf': 243 | color = False 244 | else: 245 | raise Exception('Not a PFM file.') 246 | 247 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 248 | if dim_match: 249 | width, height = map(int, dim_match.groups()) 250 | else: 251 | raise Exception('Malformed PFM header.') 252 | 253 | scale = float(file.readline().rstrip()) 254 | if scale < 0: # little-endian 255 | endian = '<' 256 | scale = -scale 257 | else: 258 | endian = '>' # big-endian 259 | 260 | data = np.fromfile(file, endian + 'f') 261 | shape = (height, width, 3) if color else (height, width) 262 | 263 | data = np.reshape(data, shape) 264 | data = np.flipud(data) 265 | file.close() 266 | return data, scale 267 | 268 | 269 | def ndc_bbox(all_rays): 270 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0] 271 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0] 272 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0] 273 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0] 274 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}') 275 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max))) -------------------------------------------------------------------------------- /dataLoader/colmap2nerf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | import argparse 12 | import os 13 | from pathlib import Path, PurePosixPath 14 | 15 | import numpy as np 16 | import json 17 | import sys 18 | import math 19 | import cv2 20 | import os 21 | import shutil 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place") 25 | 26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also") 27 | parser.add_argument("--video_fps", default=2) 28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video") 29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder") 30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images") 31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename") 32 | parser.add_argument("--images", default="images", help="input path to the images") 33 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)") 34 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16") 35 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start") 36 | parser.add_argument("--out", default="transforms.json", help="output path") 37 | args = parser.parse_args() 38 | return args 39 | 40 | def do_system(arg): 41 | print(f"==== running: {arg}") 42 | err = os.system(arg) 43 | if err: 44 | print("FATAL: command failed") 45 | sys.exit(err) 46 | 47 | def run_ffmpeg(args): 48 | if not os.path.isabs(args.images): 49 | args.images = os.path.join(os.path.dirname(args.video_in), args.images) 50 | images = args.images 51 | video = args.video_in 52 | fps = float(args.video_fps) or 1.0 53 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.") 54 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 55 | sys.exit(1) 56 | try: 57 | shutil.rmtree(images) 58 | except: 59 | pass 60 | do_system(f"mkdir {images}") 61 | 62 | time_slice_value = "" 63 | time_slice = args.time_slice 64 | if time_slice: 65 | start, end = time_slice.split(",") 66 | time_slice_value = f",select='between(t\,{start}\,{end})'" 67 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg") 68 | 69 | def run_colmap(args): 70 | db=args.colmap_db 71 | images=args.images 72 | db_noext=str(Path(db).with_suffix("")) 73 | 74 | if args.text=="text": 75 | args.text=db_noext+"_text" 76 | text=args.text 77 | sparse=db_noext+"_sparse" 78 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}") 79 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y": 80 | sys.exit(1) 81 | if os.path.exists(db): 82 | os.remove(db) 83 | do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}") 84 | do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}") 85 | try: 86 | shutil.rmtree(sparse) 87 | except: 88 | pass 89 | do_system(f"mkdir {sparse}") 90 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}") 91 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1") 92 | try: 93 | shutil.rmtree(text) 94 | except: 95 | pass 96 | do_system(f"mkdir {text}") 97 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT") 98 | 99 | def variance_of_laplacian(image): 100 | return cv2.Laplacian(image, cv2.CV_64F).var() 101 | 102 | def sharpness(imagePath): 103 | image = cv2.imread(imagePath) 104 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 105 | fm = variance_of_laplacian(gray) 106 | return fm 107 | 108 | def qvec2rotmat(qvec): 109 | return np.array([ 110 | [ 111 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 114 | ], [ 115 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 116 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 117 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 118 | ], [ 119 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 120 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 121 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 122 | ] 123 | ]) 124 | 125 | def rotmat(a, b): 126 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b) 127 | v = np.cross(a, b) 128 | c = np.dot(a, b) 129 | s = np.linalg.norm(v) 130 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 131 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10)) 132 | 133 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel 134 | da = da / np.linalg.norm(da) 135 | db = db / np.linalg.norm(db) 136 | c = np.cross(da, db) 137 | denom = np.linalg.norm(c)**2 138 | t = ob - oa 139 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10) 140 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10) 141 | if ta > 0: 142 | ta = 0 143 | if tb > 0: 144 | tb = 0 145 | return (oa+ta*da+ob+tb*db) * 0.5, denom 146 | 147 | if __name__ == "__main__": 148 | args = parse_args() 149 | if args.video_in != "": 150 | run_ffmpeg(args) 151 | if args.run_colmap: 152 | run_colmap(args) 153 | AABB_SCALE = int(args.aabb_scale) 154 | SKIP_EARLY = int(args.skip_early) 155 | IMAGE_FOLDER = args.images 156 | TEXT_FOLDER = args.text 157 | OUT_PATH = args.out 158 | print(f"outputting to {OUT_PATH}...") 159 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f: 160 | angle_x = math.pi / 2 161 | for line in f: 162 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691 163 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224 164 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443 165 | if line[0] == "#": 166 | continue 167 | els = line.split(" ") 168 | w = float(els[2]) 169 | h = float(els[3]) 170 | fl_x = float(els[4]) 171 | fl_y = float(els[4]) 172 | k1 = 0 173 | k2 = 0 174 | p1 = 0 175 | p2 = 0 176 | cx = w / 2 177 | cy = h / 2 178 | if els[1] == "SIMPLE_PINHOLE": 179 | cx = float(els[5]) 180 | cy = float(els[6]) 181 | elif els[1] == "PINHOLE": 182 | fl_y = float(els[5]) 183 | cx = float(els[6]) 184 | cy = float(els[7]) 185 | elif els[1] == "SIMPLE_RADIAL": 186 | cx = float(els[5]) 187 | cy = float(els[6]) 188 | k1 = float(els[7]) 189 | elif els[1] == "RADIAL": 190 | cx = float(els[5]) 191 | cy = float(els[6]) 192 | k1 = float(els[7]) 193 | k2 = float(els[8]) 194 | elif els[1] == "OPENCV": 195 | fl_y = float(els[5]) 196 | cx = float(els[6]) 197 | cy = float(els[7]) 198 | k1 = float(els[8]) 199 | k2 = float(els[9]) 200 | p1 = float(els[10]) 201 | p2 = float(els[11]) 202 | else: 203 | print("unknown camera model ", els[1]) 204 | # fl = 0.5 * w / tan(0.5 * angle_x); 205 | angle_x = math.atan(w / (fl_x * 2)) * 2 206 | angle_y = math.atan(h / (fl_y * 2)) * 2 207 | fovx = angle_x * 180 / math.pi 208 | fovy = angle_y * 180 / math.pi 209 | 210 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ") 211 | 212 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f: 213 | i = 0 214 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4]) 215 | out = { 216 | "camera_angle_x": angle_x, 217 | "camera_angle_y": angle_y, 218 | "fl_x": fl_x, 219 | "fl_y": fl_y, 220 | "k1": k1, 221 | "k2": k2, 222 | "p1": p1, 223 | "p2": p2, 224 | "cx": cx, 225 | "cy": cy, 226 | "w": w, 227 | "h": h, 228 | "aabb_scale": AABB_SCALE, 229 | "frames": [], 230 | } 231 | 232 | up = np.zeros(3) 233 | for line in f: 234 | line = line.strip() 235 | if line[0] == "#": 236 | continue 237 | i = i + 1 238 | if i < SKIP_EARLY*2: 239 | continue 240 | if i % 2 == 1: 241 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces) 242 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9]))) 243 | # why is this requireing a relitive path while using ^ 244 | image_rel = os.path.relpath(IMAGE_FOLDER) 245 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}") 246 | b=sharpness(name) 247 | print(name, "sharpness=",b) 248 | image_id = int(elems[0]) 249 | qvec = np.array(tuple(map(float, elems[1:5]))) 250 | tvec = np.array(tuple(map(float, elems[5:8]))) 251 | R = qvec2rotmat(-qvec) 252 | t = tvec.reshape([3,1]) 253 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0) 254 | c2w = np.linalg.inv(m) 255 | c2w[0:3,2] *= -1 # flip the y and z axis 256 | c2w[0:3,1] *= -1 257 | c2w = c2w[[1,0,2,3],:] # swap y and z 258 | c2w[2,:] *= -1 # flip whole world upside down 259 | 260 | up += c2w[0:3,1] 261 | 262 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w} 263 | out["frames"].append(frame) 264 | nframes = len(out["frames"]) 265 | up = up / np.linalg.norm(up) 266 | print("up vector was", up) 267 | R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1] 268 | R = np.pad(R,[0,1]) 269 | R[-1, -1] = 1 270 | 271 | 272 | for f in out["frames"]: 273 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis 274 | 275 | # find a central point they are all looking at 276 | print("computing center of attention...") 277 | totw = 0.0 278 | totp = np.array([0.0, 0.0, 0.0]) 279 | for f in out["frames"]: 280 | mf = f["transform_matrix"][0:3,:] 281 | for g in out["frames"]: 282 | mg = g["transform_matrix"][0:3,:] 283 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2]) 284 | if w > 0.01: 285 | totp += p*w 286 | totw += w 287 | totp /= totw 288 | print(totp) # the cameras are looking at totp 289 | for f in out["frames"]: 290 | f["transform_matrix"][0:3,3] -= totp 291 | 292 | avglen = 0. 293 | for f in out["frames"]: 294 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3]) 295 | avglen /= nframes 296 | print("avg camera distance from origin", avglen) 297 | for f in out["frames"]: 298 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized" 299 | 300 | for f in out["frames"]: 301 | f["transform_matrix"] = f["transform_matrix"].tolist() 302 | print(nframes,"frames") 303 | print(f"writing {OUT_PATH}") 304 | with open(OUT_PATH, "w") as outfile: 305 | json.dump(out, outfile, indent=2) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from tqdm.auto import tqdm 4 | from opt import config_parser 5 | 6 | from renderer import * 7 | from utils import * 8 | import datetime 9 | 10 | from dataLoader import dataset_dict 11 | import sys 12 | 13 | import time 14 | 15 | 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | renderer = OctreeRender_trilinear_fast 20 | 21 | 22 | class SimpleSampler: 23 | def __init__(self, total, batch): 24 | self.total = total 25 | self.batch = batch 26 | self.curr = total 27 | self.ids = None 28 | 29 | def nextids(self): 30 | self.curr+=self.batch 31 | if self.curr + self.batch > self.total: 32 | self.ids = torch.LongTensor(np.random.permutation(self.total)) 33 | self.curr = 0 34 | return self.ids[self.curr:self.curr+self.batch] 35 | 36 | @torch.no_grad() 37 | def render_test(args): 38 | # init dataset 39 | dataset = dataset_dict[args.dataset_name] 40 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) 41 | white_bg = test_dataset.white_bg 42 | ndc_ray = args.ndc_ray 43 | 44 | if not os.path.exists(args.ckpt): 45 | print('the ckpt path does not exists!!') 46 | return 47 | 48 | tensorf = torch.load(args.ckpt, map_location=device) 49 | 50 | logfolder = os.path.dirname(args.ckpt) 51 | if args.render_train: 52 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 53 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) 54 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', 55 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 56 | print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================') 57 | 58 | if args.render_test: 59 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) 60 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/', 61 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 62 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 63 | 64 | if args.render_path: 65 | c2ws = test_dataset.render_path 66 | os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True) 67 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/', 68 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 69 | 70 | def reconstruction(args): 71 | 72 | # init dataset 73 | dataset = dataset_dict[args.dataset_name] 74 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False) 75 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) 76 | white_bg = train_dataset.white_bg 77 | near_far = train_dataset.near_far 78 | ndc_ray = args.ndc_ray 79 | 80 | # init resolution 81 | upsamp_list = args.upsamp_list 82 | update_AlphaMask_list = args.update_AlphaMask_list 83 | n_lamb_sigma = args.n_lamb_sigma 84 | n_lamb_sh = args.n_lamb_sh 85 | 86 | 87 | if args.add_timestamp: 88 | logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}' 89 | else: 90 | logfolder = f'{args.basedir}/{args.expname}' 91 | 92 | 93 | # init log file 94 | os.makedirs(logfolder, exist_ok=True) 95 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True) 96 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True) 97 | os.makedirs(f'{logfolder}/rgba', exist_ok=True) 98 | 99 | 100 | # init parameters 101 | aabb = train_dataset.scene_bbox.to(device) 102 | reso_cur = N_to_reso(args.N_voxel_init, aabb) 103 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) 104 | 105 | 106 | if args.ckpt is not None: 107 | tensorf = torch.load(args.ckpt, map_location=device) 108 | else: 109 | tensorf = eval(args.model_name)(aabb, reso_cur, device, 110 | density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far, 111 | shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale, 112 | pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct) 113 | 114 | grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis) 115 | if args.lr_decay_iters > 0: 116 | lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters) 117 | else: 118 | args.lr_decay_iters = args.n_iters 119 | lr_factor = args.lr_decay_target_ratio**(1/args.n_iters) 120 | 121 | print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters) 122 | 123 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99)) 124 | 125 | 126 | #linear in logrithmic space 127 | N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:] 128 | 129 | 130 | torch.cuda.empty_cache() 131 | PSNRs,PSNRs_test = [],[0] 132 | batch_size = 2048 133 | 134 | allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs 135 | if not args.ndc_ray: 136 | allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True) 137 | trainingSampler = SimpleSampler(allrays.shape[0], batch_size) 138 | 139 | L1_reg_weight = args.L1_weight_inital 140 | print("initial L1_reg_weight", L1_reg_weight) 141 | 142 | pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout) 143 | for iteration in pbar: 144 | 145 | ray_idx = trainingSampler.nextids() 146 | rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device) 147 | 148 | #rgb_map, alphas_map, depth_map, weights, uncertainty 149 | rgb_map, alphas_map, depth_map, weights, others = renderer(rays_train, 150 | tensorf, 151 | chunk=batch_size, 152 | N_samples=nSamples, 153 | white_bg = white_bg, 154 | ndc_ray=ndc_ray, 155 | device=device, 156 | is_train=True) 157 | 158 | mse_loss = torch.mean((rgb_map - rgb_train) ** 2) 159 | total_loss = mse_loss 160 | 161 | if others['normals'] is not None: 162 | Ro = torch.sum(others['normals'] * others['valid_viewdirs'], dim=-1) 163 | Ro = F.relu(Ro).pow(2) * others['valid_weights'] 164 | Ro = Ro.mean() 165 | total_loss += 0.3 * Ro 166 | 167 | if L1_reg_weight > 0: 168 | loss_reg_L1 = tensorf.density_L1() 169 | total_loss += L1_reg_weight*loss_reg_L1 170 | 171 | optimizer.zero_grad() 172 | total_loss.backward() 173 | optimizer.step() 174 | 175 | mse_loss = mse_loss.detach().item() 176 | PSNRs.append(-10.0 * np.log(mse_loss) / np.log(10.0)) 177 | 178 | 179 | for param_group in optimizer.param_groups: 180 | param_group['lr'] = param_group['lr'] * lr_factor 181 | 182 | # Print the current values of the losses. 183 | if iteration % args.progress_refresh_rate == 0: 184 | pbar.set_description( 185 | f'Iteration {iteration:05d}:' 186 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}' 187 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}' 188 | + f' mse = {mse_loss:.6f}' 189 | ) 190 | PSNRs = [] 191 | 192 | 193 | if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0: 194 | PSNRs_test = evaluation(test_dataset,tensorf, 195 | args, 196 | renderer, 197 | f'{logfolder}/imgs_vis/', 198 | N_vis=args.N_vis, 199 | prtx=f'{iteration:06d}_', 200 | N_samples=nSamples, 201 | white_bg = white_bg, 202 | ndc_ray=ndc_ray, 203 | compute_extra_metrics=False) 204 | 205 | 206 | if iteration in update_AlphaMask_list: 207 | 208 | # reso_cur = N_to_reso(250**3, tensorf.aabb) 209 | if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution 210 | reso_mask = reso_cur 211 | new_aabb = tensorf.updateAlphaMask(tuple(reso_mask)) 212 | if iteration == update_AlphaMask_list[0]: 213 | tensorf.shrink(new_aabb) 214 | L1_reg_weight = args.L1_weight_rest 215 | print("continuing L1_reg_weight", L1_reg_weight) 216 | 217 | 218 | if not args.ndc_ray and iteration == update_AlphaMask_list[1]: 219 | # filter rays outside the bbox 220 | allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs) 221 | 222 | batch_size = args.batch_size 223 | trainingSampler = SimpleSampler(allrgbs.shape[0], batch_size) 224 | print(f'Update batch size to {args.batch_size}') 225 | 226 | 227 | if iteration in upsamp_list: 228 | n_voxels = N_voxel_list.pop(0) 229 | reso_cur = N_to_reso(n_voxels, tensorf.aabb) 230 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio)) 231 | tensorf.upsample_volume_grid(reso_cur) 232 | 233 | if args.lr_upsample_reset: 234 | lr_scale = 1 #0.1 ** (iteration / args.n_iters) 235 | else: 236 | lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters) 237 | grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale) 238 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99)) 239 | 240 | 241 | if iteration == upsamp_list[0]: 242 | batch_size = args.batch_size 243 | trainingSampler = SimpleSampler(allrgbs.shape[0], batch_size) 244 | print(f'Update batch size to {batch_size}') 245 | 246 | 247 | torch.save(tensorf, f'{logfolder}/{args.expname}.pt') 248 | 249 | 250 | if args.render_train: 251 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True) 252 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True) 253 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/', 254 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 255 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 256 | 257 | if args.render_test: 258 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True) 259 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/', 260 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 261 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================') 262 | 263 | if args.render_path: 264 | c2ws = test_dataset.render_path 265 | print('========>',c2ws.shape) 266 | os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True) 267 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/', 268 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) 269 | 270 | 271 | if __name__ == '__main__': 272 | 273 | torch.set_default_dtype(torch.float32) 274 | torch.manual_seed(20211202) 275 | np.random.seed(20211202) 276 | 277 | args = config_parser() 278 | 279 | if args.export_mesh: 280 | export_mesh(args) 281 | 282 | if args.render_only and (args.render_test or args.render_path): 283 | render_test(args) 284 | else: 285 | reconstruction(args) 286 | 287 | -------------------------------------------------------------------------------- /models/tensorBase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | # from .sh import eval_sh_bases 5 | import numpy as np 6 | import time 7 | 8 | 9 | def positional_encoding(positions, freqs): 10 | 11 | freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) 12 | pts = (positions[..., None] * freq_bands).reshape( 13 | positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) 14 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) 15 | return pts 16 | 17 | def raw2alpha(sigma, dist): 18 | # sigma, dist [N_rays, N_samples] 19 | alpha = 1. - torch.exp(-sigma*dist) 20 | 21 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) 22 | 23 | weights = alpha * T[:, :-1] # [N_rays, N_samples] 24 | return alpha, weights, T[:,-1:] 25 | 26 | 27 | class AlphaGridMask(torch.nn.Module): 28 | def __init__(self, device, aabb, alpha_volume): 29 | super(AlphaGridMask, self).__init__() 30 | self.device = device 31 | 32 | self.aabb=aabb.to(self.device) 33 | self.aabbSize = self.aabb[1] - self.aabb[0] 34 | self.invgridSize = 1.0/self.aabbSize * 2 35 | self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:]) 36 | self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device) 37 | 38 | def sample_alpha(self, xyz_sampled): 39 | xyz_sampled = self.normalize_coord(xyz_sampled) 40 | alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1) 41 | 42 | return alpha_vals 43 | 44 | def normalize_coord(self, xyz_sampled): 45 | return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1 46 | 47 | 48 | class TensorBase(torch.nn.Module): 49 | def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27, 50 | shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], 51 | density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001, 52 | pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0, 53 | fea2denseAct = 'softplus'): 54 | super(TensorBase, self).__init__() 55 | 56 | self.density_n_comp = density_n_comp 57 | self.app_n_comp = appearance_n_comp 58 | self.app_dim = app_dim 59 | self.aabb = aabb 60 | self.alphaMask = alphaMask 61 | self.device=device 62 | 63 | self.density_shift = density_shift 64 | self.alphaMask_thres = alphaMask_thres 65 | self.distance_scale = distance_scale 66 | self.rayMarch_weight_thres = rayMarch_weight_thres 67 | self.fea2denseAct = fea2denseAct 68 | 69 | self.near_far = near_far 70 | self.step_ratio = step_ratio 71 | 72 | 73 | self.update_stepSize(gridSize) 74 | 75 | self.matMode = [[0,1], [0,2], [1,2]] 76 | self.vecMode = [2, 1, 0] 77 | self.comp_w = [1,1,1] 78 | 79 | self.init_svd_volume(gridSize[0], device) 80 | 81 | self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC 82 | 83 | def update_stepSize(self, gridSize): 84 | self.aabbSize = self.aabb[1] - self.aabb[0] 85 | self.invaabbSize = 2.0/self.aabbSize 86 | self.gridSize= torch.LongTensor(gridSize).to(self.device) 87 | self.units=self.aabbSize / (self.gridSize-1) 88 | self.stepSize=torch.mean(self.units)*self.step_ratio 89 | self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) 90 | self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 91 | 92 | def init_svd_volume(self, res, device): 93 | pass 94 | 95 | def compute_features(self, xyz_sampled): 96 | pass 97 | 98 | def compute_densityfeature(self, xyz_sampled): 99 | pass 100 | 101 | def compute_appfeature(self, xyz_sampled): 102 | pass 103 | 104 | def normalize_coord(self, xyz_sampled): 105 | return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1 106 | 107 | def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001): 108 | pass 109 | 110 | def get_kwargs(self): 111 | return { 112 | 'aabb': self.aabb, 113 | 'gridSize':self.gridSize.tolist(), 114 | 'density_n_comp': self.density_n_comp, 115 | 'appearance_n_comp': self.app_n_comp, 116 | 'app_dim': self.app_dim, 117 | 118 | 'density_shift': self.density_shift, 119 | 'alphaMask_thres': self.alphaMask_thres, 120 | 'distance_scale': self.distance_scale, 121 | 'rayMarch_weight_thres': self.rayMarch_weight_thres, 122 | 'fea2denseAct': self.fea2denseAct, 123 | 124 | 'near_far': self.near_far, 125 | 'step_ratio': self.step_ratio, 126 | 127 | 'shadingMode': self.shadingMode, 128 | 'pos_pe': self.pos_pe, 129 | 'view_pe': self.view_pe, 130 | 'fea_pe': self.fea_pe, 131 | 'featureC': self.featureC 132 | } 133 | 134 | def save(self, path): 135 | kwargs = self.get_kwargs() 136 | ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()} 137 | if self.alphaMask is not None: 138 | alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy() 139 | ckpt.update({'alphaMask.shape':alpha_volume.shape}) 140 | ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))}) 141 | ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()}) 142 | torch.save(ckpt, path) 143 | 144 | def load(self, ckpt): 145 | if 'alphaMask.aabb' in ckpt.keys(): 146 | length = np.prod(ckpt['alphaMask.shape']) 147 | alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape'])) 148 | self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device)) 149 | self.load_state_dict(ckpt['state_dict']) 150 | 151 | 152 | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): 153 | N_samples = N_samples if N_samples > 0 else self.nSamples 154 | near, far = self.near_far 155 | interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) 156 | if is_train: 157 | interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) 158 | 159 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] 160 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) 161 | return rays_pts, interpx, ~mask_outbbox 162 | 163 | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): 164 | N_samples = N_samples if N_samples>0 else self.nSamples 165 | stepsize = self.stepSize 166 | near, far = self.near_far 167 | vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d) 168 | rate_a = (self.aabb[1] - rays_o) / vec 169 | rate_b = (self.aabb[0] - rays_o) / vec 170 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) 171 | 172 | rng = torch.arange(N_samples)[None].float() 173 | if is_train: 174 | rng = rng.repeat(rays_d.shape[-2],1) 175 | rng += torch.rand_like(rng[:,[0]]) 176 | step = stepsize * rng.to(rays_o.device) 177 | interpx = (t_min[...,None] + step) 178 | 179 | rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] 180 | mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1) 181 | 182 | return rays_pts, interpx, ~mask_outbbox 183 | 184 | 185 | def shrink(self, new_aabb, voxel_size): 186 | pass 187 | 188 | @torch.no_grad() 189 | def getDenseAlpha(self,gridSize=None): 190 | gridSize = self.gridSize if gridSize is None else gridSize 191 | 192 | samples = torch.stack(torch.meshgrid( 193 | torch.linspace(0, 1, gridSize[0]), 194 | torch.linspace(0, 1, gridSize[1]), 195 | torch.linspace(0, 1, gridSize[2]), 196 | ), -1).to(self.device) 197 | dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples 198 | 199 | alpha = torch.zeros_like(dense_xyz[...,0]) 200 | for i in range(gridSize[0]): 201 | alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2])) 202 | return alpha, dense_xyz 203 | 204 | @torch.no_grad() 205 | def updateAlphaMask(self, gridSize=(200,200,200)): 206 | 207 | alpha, dense_xyz = self.getDenseAlpha(gridSize) 208 | dense_xyz = dense_xyz.transpose(0,2).contiguous() 209 | alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None] 210 | total_voxels = gridSize[0] * gridSize[1] * gridSize[2] 211 | 212 | ks = 3 213 | alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1]) 214 | alpha[alpha>=self.alphaMask_thres] = 1 215 | alpha[alpha0.5] 220 | 221 | xyz_min = valid_xyz.amin(0) 222 | xyz_max = valid_xyz.amax(0) 223 | 224 | new_aabb = torch.stack((xyz_min, xyz_max)) 225 | 226 | total = torch.sum(alpha) 227 | print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100)) 228 | return new_aabb 229 | 230 | @torch.no_grad() 231 | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False): 232 | print('========> filtering rays ...') 233 | tt = time.time() 234 | 235 | N = torch.tensor(all_rays.shape[:-1]).prod() 236 | 237 | mask_filtered = [] 238 | idx_chunks = torch.split(torch.arange(N), chunk) 239 | for idx_chunk in idx_chunks: 240 | rays_chunk = all_rays[idx_chunk].to(self.device) 241 | 242 | rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6] 243 | if bbox_only: 244 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d) 245 | rate_a = (self.aabb[1] - rays_o) / vec 246 | rate_b = (self.aabb[0] - rays_o) / vec 247 | t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far) 248 | t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far) 249 | mask_inbbox = t_max > t_min 250 | 251 | else: 252 | xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False) 253 | mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1) 254 | 255 | mask_filtered.append(mask_inbbox.cpu()) 256 | 257 | mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1]) 258 | 259 | print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}') 260 | return all_rays[mask_filtered], all_rgbs[mask_filtered] 261 | 262 | 263 | def feature2density(self, density_features): 264 | if self.fea2denseAct == "softplus": 265 | return F.softplus(density_features+self.density_shift) 266 | elif self.fea2denseAct == "relu": 267 | return F.relu(density_features) 268 | 269 | 270 | def compute_alpha(self, xyz_locs, length=1): 271 | 272 | if self.alphaMask is not None: 273 | alphas = self.alphaMask.sample_alpha(xyz_locs) 274 | alpha_mask = alphas > 0 275 | else: 276 | alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool) 277 | 278 | 279 | sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) 280 | 281 | if alpha_mask.any(): 282 | xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) 283 | sigma_feature = self.compute_densityfeature(xyz_sampled) 284 | validsigma = self.feature2density(sigma_feature) 285 | sigma[alpha_mask] = validsigma 286 | 287 | 288 | alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1]) 289 | 290 | return alpha 291 | 292 | 293 | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): 294 | 295 | # sample points 296 | viewdirs = rays_chunk[:, 3:6] 297 | if ndc_ray: 298 | xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) 299 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 300 | rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) 301 | dists = dists * rays_norm 302 | viewdirs = viewdirs / rays_norm 303 | else: 304 | xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) 305 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 306 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape) 307 | 308 | if self.alphaMask is not None: 309 | alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) 310 | alpha_mask = alphas > 0 311 | ray_invalid = ~ray_valid 312 | ray_invalid[ray_valid] |= (~alpha_mask) 313 | ray_valid = ~ray_invalid 314 | 315 | 316 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) 317 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) 318 | 319 | if ray_valid.any(): 320 | xyz_sampled = self.normalize_coord(xyz_sampled) 321 | sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) 322 | 323 | validsigma = self.feature2density(sigma_feature) 324 | sigma[ray_valid] = validsigma 325 | 326 | 327 | alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) 328 | 329 | app_mask = weight > self.rayMarch_weight_thres 330 | 331 | normals = None 332 | valid_viewdirs = None 333 | valid_weights = None 334 | if app_mask.any(): 335 | valid_viewdirs = viewdirs[app_mask] 336 | valid_weights = weight[app_mask] 337 | 338 | app_features = self.compute_appfeature(xyz_sampled[app_mask]) 339 | valid_rgbs, normals = self.rendering_net(valid_viewdirs, app_features) 340 | rgb[app_mask] = valid_rgbs 341 | 342 | 343 | acc_map = torch.sum(weight, -1) 344 | rgb_map = torch.sum(weight[..., None] * rgb, -2) 345 | 346 | if white_bg or (is_train and torch.rand((1,))<0.5): 347 | rgb_map = rgb_map + (1. - acc_map[..., None]) 348 | 349 | 350 | rgb_map = rgb_map.clamp(0,1) 351 | 352 | with torch.no_grad(): 353 | depth_map = torch.sum(weight * z_vals, -1) 354 | depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1] 355 | 356 | return {'rgb_map':rgb_map, 357 | 'depth_map':depth_map, 358 | 'normals':normals, 359 | 'valid_viewdirs':valid_viewdirs, 360 | 'valid_weights':valid_weights 361 | } 362 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from torch.functional import align_tensors 2 | from .tensorBase import * 3 | from .quaternion_utils import * 4 | from utils import N_to_reso 5 | 6 | import numpy as np 7 | import math 8 | 9 | 10 | class TensorDecomposition(torch.nn.Module): 11 | def __init__(self, grid_size, num_features, scale, device, reduce_sum=False): 12 | super(TensorDecomposition, self).__init__() 13 | self.grid_size = torch.tensor(grid_size) 14 | self.num_voxels = grid_size[0] * grid_size[1] * grid_size[2] 15 | self.reduce_sum = reduce_sum 16 | 17 | X, Y, Z = grid_size 18 | self.plane_xy = torch.nn.Parameter(scale * torch.randn((1, num_features, Y, X), device=device)) 19 | self.plane_yz = torch.nn.Parameter(scale * torch.randn((1, num_features, Z, Y), device=device)) 20 | self.plane_xz = torch.nn.Parameter(scale * torch.randn((1, num_features, Z, X), device=device)) 21 | 22 | self.line_z = torch.nn.Parameter(scale * torch.randn((1, num_features, Z, 1), device=device)) 23 | self.line_x = torch.nn.Parameter(scale * torch.randn((1, num_features, X, 1), device=device)) 24 | self.line_y = torch.nn.Parameter(scale * torch.randn((1, num_features, Y, 1), device=device)) 25 | 26 | def forward(self, coords_plane, coords_line): 27 | feature_xy = F.grid_sample(self.plane_xy, coords_plane[0], mode='bilinear', align_corners=True) 28 | feature_yz = F.grid_sample(self.plane_yz, coords_plane[1], mode='bilinear', align_corners=True) 29 | feature_xz = F.grid_sample(self.plane_xz, coords_plane[2], mode='bilinear', align_corners=True) 30 | 31 | feature_x = F.grid_sample(self.line_x, coords_line[0], mode='bilinear', align_corners=True) 32 | feature_y = F.grid_sample(self.line_y, coords_line[1], mode='bilinear', align_corners=True) 33 | feature_z = F.grid_sample(self.line_z, coords_line[2], mode='bilinear', align_corners=True) 34 | 35 | out_x = feature_yz * feature_x 36 | out_y = feature_xz * feature_y 37 | out_z = feature_xy * feature_z 38 | 39 | _, C, N, _ = out_x.size() 40 | if self.reduce_sum: 41 | output = out_x.sum(dim=(0, 1, 3)) + out_y.sum(dim=(0, 1, 3)) + out_z.sum(dim=(0, 1, 3)) 42 | else: 43 | output = [out_x.view(-1, N).T, out_y.view(-1, N).T, out_z.view(-1, N).T] 44 | 45 | return output 46 | 47 | def L1loss(self): 48 | loss = torch.abs(self.plane_xy).mean() + torch.abs(self.plane_yz).mean() + torch.abs(self.plane_xz).mean() 49 | loss += torch.abs(self.line_x).mean() + torch.abs(self.line_y).mean() + torch.abs(self.line_z).mean() 50 | loss = loss / 6 51 | 52 | return loss 53 | 54 | def TV_loss(self): 55 | loss = self.TV_loss_com(self.plane_xy) 56 | loss += self.TV_loss_com(self.plane_yz) 57 | loss += self.TV_loss_com(self.plane_xz) 58 | loss = loss / 6 59 | 60 | return loss 61 | 62 | def TV_loss_com(self, x): 63 | loss = (x[:, :, 1:] - x[:, :, :-1]).pow(2).mean() + (x[:, :, :, 1:] - x[:, :, :, :-1]).pow(2).mean() 64 | return loss 65 | 66 | 67 | def shrink(self, bound): 68 | # bound [3, 2] 69 | x, y, z = bound[0], bound[1], bound[2] 70 | self.plane_xy = torch.nn.Parameter(self.plane_xy.data[:, :, y[0]:y[1], x[0]:x[1]]) 71 | self.plane_yz = torch.nn.Parameter(self.plane_yz.data[:, :, z[0]:z[1], y[0]:y[1]]) 72 | self.plane_xz = torch.nn.Parameter(self.plane_xz.data[:, :, z[0]:z[1], x[0]:x[1]]) 73 | 74 | self.line_x = torch.nn.Parameter(self.line_x.data[:, :, x[0]:x[1]]) 75 | self.line_y = torch.nn.Parameter(self.line_y.data[:, :, y[0]:y[1]]) 76 | self.line_z = torch.nn.Parameter(self.line_z.data[:, :, z[0]:z[1]]) 77 | 78 | self.grid_size = bound[:, 1] - bound[:, 0] 79 | 80 | 81 | def upsample(self, aabb): 82 | target_res = N_to_reso(self.num_voxels, aabb) 83 | 84 | 85 | self.grid_size = torch.tensor(target_res) 86 | 87 | self.plane_xy = torch.nn.Parameter(F.interpolate(self.plane_xy.data, 88 | size=(target_res[1], target_res[0]), mode='bilinear', align_corners=True)) 89 | self.plane_yz = torch.nn.Parameter(F.interpolate(self.plane_yz.data, 90 | size=(target_res[2], target_res[1]), mode='bilinear', align_corners=True)) 91 | self.plane_xz = torch.nn.Parameter(F.interpolate(self.plane_xz.data, 92 | size=(target_res[2], target_res[0]), mode='bilinear', align_corners=True)) 93 | 94 | 95 | self.line_x = torch.nn.Parameter(F.interpolate(self.line_x.data, 96 | size=(target_res[0], 1), mode='bilinear', align_corners=True)) 97 | self.line_y = torch.nn.Parameter(F.interpolate(self.line_y.data, 98 | size=(target_res[1], 1), mode='bilinear', align_corners=True)) 99 | self.line_z = torch.nn.Parameter(F.interpolate(self.line_z.data, 100 | size=(target_res[2], 1), mode='bilinear', align_corners=True)) 101 | 102 | class MultiscaleTensorDecom(torch.nn.Module): 103 | def __init__(self, num_levels, num_features, base_resolution, max_resolution, device, reduce_sum=False, scale=0.1): 104 | super(MultiscaleTensorDecom, self).__init__() 105 | self.reduce_sum = reduce_sum 106 | 107 | tensors = [] 108 | if num_levels == 1: 109 | factor = 1 110 | else: 111 | factor = math.exp( (math.log(max_resolution) - math.log(base_resolution)) / (num_levels-1) ) 112 | 113 | for i in range(num_levels): 114 | level_resolution = int(base_resolution * factor**i) 115 | level_grid = (level_resolution, level_resolution, level_resolution) 116 | tensors.append(TensorDecomposition(level_grid, num_features, scale, device, reduce_sum=reduce_sum)) 117 | 118 | self.tensors = torch.nn.ModuleList(tensors) 119 | 120 | def coords_split(self, pts, dim=2, z_vals=None): 121 | N, D = pts.size() 122 | pts = pts.view(1, N, 1, D) 123 | 124 | out_plane = [] 125 | if dim == 2: 126 | out_plane.append(pts[..., [0, 1]]) 127 | out_plane.append(pts[..., [1, 2]]) 128 | out_plane.append(pts[..., [0, 2]]) 129 | elif dim == 3: 130 | out_plane.append(pts[..., [0, 1, 2]][:, :, None]) 131 | out_plane.append(pts[..., [1, 2, 0]][:, :, None]) 132 | out_plane.append(pts[..., [0, 2, 1]][:, :, None]) 133 | 134 | if z_vals is None: 135 | coord_x = pts.new_zeros(1, N, 1, 1) 136 | else: 137 | coord_x = z_vals.view(1, N, 1, 1) 138 | out_line = [] 139 | out_line.append(torch.cat((coord_x, pts[..., [0]]), dim=-1)) 140 | out_line.append(torch.cat((coord_x, pts[..., [1]]), dim=-1)) 141 | out_line.append(torch.cat((coord_x, pts[..., [2]]), dim=-1)) 142 | 143 | return out_plane, out_line 144 | 145 | def L1loss(self): 146 | loss = 0. 147 | for tensor in self.tensors: 148 | loss += tensor.L1loss() 149 | 150 | return loss / len(self.tensors) 151 | 152 | def shrink(self, aabb, new_aabb): 153 | aabb_size = aabb[1] - aabb[0] 154 | xyz_min, xyz_max = new_aabb 155 | 156 | for tensor in self.tensors: 157 | grid_size = tensor.grid_size 158 | units = aabb_size / (grid_size - 1) 159 | t_l, b_r = (xyz_min - aabb[0]) / units, (xyz_max - aabb[0]) / units 160 | 161 | t_l, b_r = torch.floor(t_l).long(), torch.ceil(b_r).long() 162 | b_r = torch.stack([b_r, grid_size]).amin(0) 163 | 164 | bound = torch.stack((t_l, b_r), dim=-1) 165 | tensor.shrink(bound) 166 | 167 | def upsample(self, aabb): 168 | for tensor in self.tensors: 169 | tensor.upsample(aabb) 170 | 171 | def forward(self, pts): 172 | coords_plane, coords_line = self.coords_split(pts) 173 | 174 | if self.reduce_sum: 175 | output = pts.new_zeros(pts.size(0)) 176 | else: 177 | output = [] 178 | 179 | for level_tensor in self.tensors: 180 | output += level_tensor(coords_plane, coords_line) 181 | 182 | return output 183 | 184 | class RenderingEquationEncoding(torch.nn.Module): 185 | def __init__(self, num_theta, num_phi, device): 186 | super(RenderingEquationEncoding, self).__init__() 187 | 188 | self.num_theta = num_theta 189 | self.num_phi = num_phi 190 | 191 | omega, omega_la, omega_mu = init_predefined_omega(num_theta, num_phi) 192 | self.omega = omega.view(1, num_theta, num_phi, 3).to(device) 193 | self.omega_la = omega_la.view(1, num_theta, num_phi, 3).to(device) 194 | self.omega_mu = omega_mu.view(1, num_theta, num_phi, 3).to(device) 195 | 196 | def forward(self, omega_o, a, la, mu): 197 | Smooth = F.relu((omega_o[:, None, None] * self.omega).sum(dim=-1, keepdim=True)) # N, num_theta, num_phi, 1 198 | 199 | la = F.softplus(la - 1) 200 | mu = F.softplus(mu - 1) 201 | exp_input = -la * (self.omega_la * omega_o[:, None, None]).sum(dim=-1, keepdim=True).pow(2) -mu * (self.omega_mu * omega_o[:, None, None]).sum(dim=-1, keepdim=True).pow(2) 202 | out = a * Smooth * torch.exp(exp_input) 203 | 204 | return out 205 | 206 | class RenderingNet(torch.nn.Module): 207 | def __init__(self, num_theta = 8, num_phi=16, data_dim_color=192, featureC=256, device='cpu'): 208 | super(RenderingNet, self).__init__() 209 | 210 | self.ch_cd = 3 211 | self.ch_s = 3 212 | self.ch_normal = 3 213 | self.ch_bottleneck = 128 214 | 215 | self.num_theta = 8 216 | self.num_phi = 16 217 | self.num_asg = self.num_theta * self.num_phi 218 | 219 | self.ch_asg_feature = 128 220 | self.ch_per_theta = self.ch_asg_feature // self.num_theta 221 | 222 | self.ch_a = 2 223 | self.ch_la = 1 224 | self.ch_mu = 1 225 | self.ch_per_asg = self.ch_a + self.ch_la + self.ch_mu 226 | 227 | self.ch_normal_dot_viewdir = 1 228 | 229 | 230 | self.ree_function = RenderingEquationEncoding(num_theta, num_phi, device) 231 | 232 | self.spatial_mlp = torch.nn.Sequential( 233 | torch.nn.Linear(data_dim_color, featureC), 234 | torch.nn.GELU(), 235 | torch.nn.Linear(featureC, featureC), 236 | torch.nn.GELU(), 237 | torch.nn.Linear(featureC, self.ch_cd + self.ch_s + self.ch_bottleneck + self.ch_normal + self.ch_asg_feature)).to(device) 238 | 239 | self.asg_mlp = torch.nn.Sequential(torch.nn.Linear(self.ch_per_theta, self.num_phi * self.ch_per_asg)).to(device) 240 | 241 | self.directional_mlp = torch.nn.Sequential( 242 | torch.nn.Linear(self.ch_bottleneck + self.num_asg * self.ch_a + self.ch_normal_dot_viewdir, featureC), 243 | torch.nn.GELU(), 244 | torch.nn.Linear(featureC, featureC), 245 | torch.nn.GELU(), 246 | torch.nn.Linear(featureC, featureC), 247 | torch.nn.GELU(), 248 | torch.nn.Linear(featureC, featureC), 249 | torch.nn.GELU(), 250 | torch.nn.Linear(featureC, featureC), 251 | torch.nn.GELU(), 252 | torch.nn.Linear(featureC, 3)).to(device) 253 | 254 | 255 | def spatial_mlp_forward(self, x): 256 | out = self.spatial_mlp(x) 257 | sections = [self.ch_cd, self.ch_s, self.ch_normal, self.ch_bottleneck, self.ch_asg_feature] 258 | diffuse_color, tint, normals, bottleneck, asg_features = torch.split(out, sections, dim=-1) 259 | normals = -F.normalize(normals, dim=1) 260 | return diffuse_color, tint, normals, bottleneck, asg_features 261 | 262 | def asg_mlp_forward(self, asg_feature): 263 | N = asg_feature.size(0) 264 | asg_feature = asg_feature.view(N, self.num_theta, -1) 265 | asg_params = self.asg_mlp(asg_feature) 266 | asg_params = asg_params.view(N, self.num_theta, self.num_phi, -1) 267 | 268 | a, la, mu = torch.split(asg_params, [self.ch_a, self.ch_la, self.ch_mu], dim=-1) 269 | return a, la, mu 270 | 271 | def directional_mlp_forward(self, x): 272 | out = self.directional_mlp(x) 273 | return out 274 | 275 | def reflect(self, viewdir, normal): 276 | out = 2 * (viewdir * normal).sum(dim=-1, keepdim=True) * normal - viewdir 277 | return out 278 | 279 | def forward(self, viewdir, feature): 280 | diffuse_color, tint, normal, bottleneck, asg_feature = self.spatial_mlp_forward(feature) 281 | refdir = self.reflect(-viewdir, normal) 282 | 283 | a, la, mu = self.asg_mlp_forward(asg_feature) 284 | ree = self.ree_function(refdir, a, la, mu) # N, num_theta, num_phi, ch_per_asg 285 | ree = ree.view(ree.size(0), -1) 286 | 287 | normal_dot_viewdir = ((-viewdir) * normal).sum(dim=-1, keepdim=True) 288 | dir_mlp_input = torch.cat([bottleneck, ree, normal_dot_viewdir], dim=-1) 289 | specular_color = self.directional_mlp_forward(dir_mlp_input) 290 | 291 | raw_rgb = diffuse_color + tint * specular_color 292 | rgb = torch.sigmoid(raw_rgb) 293 | 294 | return rgb, normal 295 | 296 | 297 | ######################################################################################## 298 | 299 | class NRFF(TensorBase): 300 | def __init__(self, aabb, gridSize, device, **kargs): 301 | super(NRFF, self).__init__(aabb, gridSize, device, **kargs) 302 | 303 | self.rendering_net = RenderingNet(8, 16, device=device) 304 | self.init_feature_field(device) 305 | 306 | def init_feature_field(self, device): 307 | self.density_field = MultiscaleTensorDecom(num_levels=16, num_features=2, base_resolution=16, max_resolution=512, device=device, reduce_sum=True) 308 | self.appearance_field = MultiscaleTensorDecom(num_levels=16, num_features=4, base_resolution=16, max_resolution=512, device=device) 309 | 310 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): 311 | grad_vars = [] 312 | 313 | grad_vars += [{'params': self.density_field.parameters(), 'lr': lr_init_spatialxyz}] 314 | grad_vars += [{'params': self.appearance_field.parameters(), 'lr': lr_init_spatialxyz}] 315 | grad_vars += [{'params': self.rendering_net.parameters(), 'lr':lr_init_network}] 316 | 317 | return grad_vars 318 | 319 | 320 | def density_L1(self): 321 | return self.density_field.L1loss() 322 | 323 | def compute_densityfeature(self, pts): 324 | output = self.density_field(pts) 325 | return output 326 | 327 | def compute_appfeature(self, pts): 328 | app_feature = self.appearance_field(pts) 329 | app_feature = torch.cat(app_feature, dim=-1) 330 | return app_feature 331 | 332 | @torch.no_grad() 333 | def shrink(self, new_aabb): 334 | self.train_aabb = new_aabb 335 | 336 | self.density_field.shrink(self.aabb.cpu(), new_aabb.cpu()) 337 | self.appearance_field.shrink(self.aabb.cpu(), new_aabb.cpu()) 338 | 339 | xyz_min, xyz_max = new_aabb 340 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units 341 | 342 | 343 | t_l, b_r = torch.floor(t_l).long(), torch.ceil(b_r).long() 344 | b_r = torch.stack([b_r, self.gridSize]).amin(0) 345 | 346 | if not torch.all(self.alphaMask.gridSize == self.gridSize): 347 | t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1) 348 | correct_aabb = torch.zeros_like(new_aabb) 349 | correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1] 350 | correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1] 351 | new_aabb = correct_aabb 352 | 353 | newSize = b_r - t_l 354 | self.aabb = new_aabb 355 | 356 | self.density_field.upsample(new_aabb.cpu()) 357 | self.appearance_field.upsample(new_aabb.cpu()) 358 | 359 | self.update_stepSize((newSize[0], newSize[1], newSize[2])) 360 | 361 | 362 | @torch.no_grad() 363 | def upsample_volume_grid(self, res_target): 364 | self.update_stepSize(res_target) 365 | --------------------------------------------------------------------------------