├── models
├── __init__.py
├── quaternion_utils.py
├── sh.py
├── tensorBase.py
└── model.py
├── .gitattributes
├── dataLoader
├── __pycache__
│ ├── llff.cpython-310.pyc
│ ├── llff.cpython-38.pyc
│ ├── nsvf.cpython-310.pyc
│ ├── nsvf.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ ├── blender.cpython-310.pyc
│ ├── blender.cpython-38.pyc
│ ├── __init__.cpython-310.pyc
│ ├── ray_utils.cpython-310.pyc
│ ├── ray_utils.cpython-38.pyc
│ ├── tankstemple.cpython-38.pyc
│ ├── tankstemple.cpython-310.pyc
│ ├── your_own_data.cpython-38.pyc
│ └── your_own_data.cpython-310.pyc
├── __init__.py
├── blender.py
├── your_own_data.py
├── nsvf.py
├── tankstemple.py
├── llff.py
├── ray_utils.py
└── colmap2nerf.py
├── configs
├── truck.txt
├── lego.txt
├── wineholder.txt
└── your_own_data.txt
├── LICENSE
├── README.md
├── utils.py
├── renderer.py
├── opt.py
├── extra
├── compute_metrics.py
└── auto_run_paramsets.py
└── train.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/dataLoader/__pycache__/llff.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/llff.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/llff.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/llff.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/nsvf.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/nsvf.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/nsvf.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/nsvf.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/blender.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/blender.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/blender.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/blender.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/ray_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/ray_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/ray_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/ray_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/tankstemple.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/tankstemple.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/tankstemple.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/tankstemple.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/your_own_data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/your_own_data.cpython-38.pyc
--------------------------------------------------------------------------------
/dataLoader/__pycache__/your_own_data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imkanghan/nrff/HEAD/dataLoader/__pycache__/your_own_data.cpython-310.pyc
--------------------------------------------------------------------------------
/dataLoader/__init__.py:
--------------------------------------------------------------------------------
1 | from .llff import LLFFDataset
2 | from .blender import BlenderDataset
3 | from .nsvf import NSVF
4 | from .tankstemple import TanksTempleDataset
5 | from .your_own_data import YourOwnDataset
6 |
7 |
8 |
9 | dataset_dict = {'blender': BlenderDataset,
10 | 'llff':LLFFDataset,
11 | 'tankstemple':TanksTempleDataset,
12 | 'nsvf':NSVF,
13 | 'own_data':YourOwnDataset}
--------------------------------------------------------------------------------
/configs/truck.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = tankstemple
3 | datadir = ../datasets/TanksAndTemple/Truck
4 | expname = nrff_truck
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 | model_name = NRFF
20 | fea2denseAct = softplus
21 |
22 | L1_weight_inital = 8e-5
23 | L1_weight_rest = 4e-5
24 |
--------------------------------------------------------------------------------
/configs/lego.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = blender
3 | datadir = ../datasets/nerf_synthetic/lego
4 | expname = nrff_lego
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 | model_name = NRFF
20 | fea2denseAct = softplus
21 |
22 | L1_weight_inital = 8e-5
23 | L1_weight_rest = 4e-5
24 | rm_weight_mask_thre = 1e-4
25 |
--------------------------------------------------------------------------------
/configs/wineholder.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = blender
3 | datadir = ./data/Synthetic_NSVF/Wineholder
4 | expname = tensorf_Wineholder_VM
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 | model_name = NRFF
20 |
21 | fea2denseAct = softplus
22 |
23 | L1_weight_inital = 8e-5
24 | L1_weight_rest = 4e-5
25 | rm_weight_mask_thre = 1e-4
26 |
--------------------------------------------------------------------------------
/configs/your_own_data.txt:
--------------------------------------------------------------------------------
1 |
2 | dataset_name = own_data
3 | datadir = ./data/xxx
4 | expname = tensorf_xxx_VM
5 | basedir = ./log
6 |
7 | n_iters = 30000
8 | batch_size = 4096
9 |
10 | N_voxel_init = 2097156 # 128**3
11 | N_voxel_final = 27000000 # 300**3
12 | upsamp_list = [2000,3000,4000,5500,7000]
13 | update_AlphaMask_list = [2000,4000]
14 |
15 | N_vis = 5
16 | vis_every = 10000
17 |
18 | render_test = 1
19 |
20 | n_lamb_sigma = [16,16,16]
21 | n_lamb_sh = [48,48,48]
22 | model_name = TensorVMSplit
23 |
24 |
25 | shadingMode = MLP_Fea
26 | fea2denseAct = softplus
27 |
28 | view_pe = 2
29 | fea_pe = 2
30 |
31 | view_pe = 2
32 | fea_pe = 2
33 |
34 | TV_weight_density = 0.1
35 | TV_weight_app = 0.01
36 |
37 | rm_weight_mask_thre = 1e-4
38 |
39 | ## please uncomment following configuration if hope to training on cp model
40 | #model_name = TensorCP
41 | #n_lamb_sigma = [96]
42 | #n_lamb_sh = [288]
43 | #N_voxel_final = 125000000 # 500**3
44 | #L1_weight_inital = 1e-5
45 | #L1_weight_rest = 1e-5
46 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
20 |
--------------------------------------------------------------------------------
/models/quaternion_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 |
6 | def quaternion_product(p, q):
7 | p_r = p[..., [0]]
8 | p_i = p[..., 1:]
9 | q_r = q[..., [0]]
10 | q_i = q[..., 1:]
11 |
12 | out_r = p_r * q_r - (p_i * q_i).sum(dim=-1)
13 | out_i = p_r * q_i + q_r * p_i + torch.linalg.cross(p_i, q_i, dim=-1)
14 |
15 | return torch.cat([out_r, out_i], dim=-1)
16 |
17 | def quaternion_inverse(p):
18 | p_r = p[..., [0]]
19 | p_i = -p[..., 1:]
20 |
21 | return torch.cat([p_r, p_i], dim=-1)
22 |
23 | def quaternion_rotate(p, q):
24 | q_inv = quaternion_inverse(q)
25 |
26 | qp = quaternion_product(q, p)
27 | out = quaternion_product(qp, q_inv)
28 | return out
29 |
30 | def build_q(vec, angle):
31 | out_r = torch.cos(angle / 2)
32 | out_i = torch.sin(angle / 2) * vec
33 |
34 | return torch.cat([out_r, out_i], dim=-1)
35 |
36 |
37 | def cartesian2quaternion(x):
38 | zeros_ = x.new_zeros([*x.shape[:-1], 1])
39 | return torch.cat([zeros_, x], dim=-1)
40 |
41 |
42 | def spherical2cartesian(theta, phi):
43 | x = torch.cos(phi) * torch.sin(theta)
44 | y = torch.sin(phi) * torch.sin(theta)
45 | z = torch.cos(theta)
46 |
47 | return [x, y, z]
48 |
49 | def init_predefined_omega(n_theta, n_phi):
50 | theta_list = torch.linspace(0, np.pi, n_theta)
51 | phi_list = torch.linspace(0, np.pi*2, n_phi)
52 |
53 | out_omega = []
54 | out_omega_lambda = []
55 | out_omega_mu = []
56 |
57 | for i in range(n_theta):
58 | theta = theta_list[i].view(1, 1)
59 |
60 | for j in range(n_phi):
61 | phi = phi_list[j].view(1, 1)
62 |
63 | omega = spherical2cartesian(theta, phi)
64 | omega = torch.stack(omega, dim=-1).view(1, 3)
65 |
66 | omega_lambda = spherical2cartesian(theta+np.pi/2, phi)
67 | omega_lambda = torch.stack(omega_lambda, dim=-1).view(1, 3)
68 |
69 | p = cartesian2quaternion(omega_lambda)
70 | q = build_q(omega, torch.tensor(np.pi/2).view(1, 1))
71 | omega_mu = quaternion_rotate(p, q)[..., 1:]
72 |
73 | out_omega.append(omega)
74 | out_omega_lambda.append(omega_lambda)
75 | out_omega_mu.append(omega_mu)
76 |
77 |
78 | out_omega = torch.stack(out_omega, dim=0)
79 | out_omega_lambda = torch.stack(out_omega_lambda, dim=0)
80 | out_omega_mu = torch.stack(out_omega_mu, dim=0)
81 |
82 | return out_omega, out_omega_lambda, out_omega_mu
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NRFF
2 |
3 | ### [Project page](https://imkanghan.github.io/projects/NRFF/main) | [Paper](https://arxiv.org/abs/2303.03808)
4 |
5 | This repository is an implementation of the view synthesis method described in the paper "Multiscale Tensor Decomposition and Rendering Equation Encoding for View Synthesis", CVPR 2023.
6 |
7 | [Kang Han](https://imkanghan.github.io/)1, [Wei Xiang](https://scholars.latrobe.edu.au/wxiang)2
8 |
9 | 1James Cook University, 2La Trobe University
10 |
11 | ## Abstract
12 | Rendering novel views from captured multi-view images has made considerable progress since the emergence of the neural radiance field. This paper aims to further advance the quality of view synthesis by proposing a novel approach dubbed the neural radiance feature field (NRFF). We first propose a multiscale tensor decomposition scheme to organize learnable features so as to represent scenes from coarse to fine scales. We demonstrate many benefits of the proposed multiscale representation, including more accurate scene shape and appearance reconstruction, and faster convergence compared with the single-scale representation. Instead of encoding view directions to model view-dependent effects, we further propose to encode the rendering equation in the feature space by employing the anisotropic spherical Gaussian mixture predicted from the proposed multiscale representation. The proposed NRFF improves state-of-the-art rendering results by over 1 dB in PSNR on both the NeRF and NSVF synthetic datasets. A significant improvement has also been observed on the real-world Tanks \& Temples dataset.
13 |
14 | ## Installation
15 |
16 | This implementation is based on [PyTorch](https://pytorch.org/) and [TensoRF](https://github.com/apchenstu/TensoRF). You can create a virtual environment using Anaconda by running
17 |
18 | ```
19 | conda create -n nrff python=3.8
20 | conda activate nrff
21 | pip3 install torch torchvision
22 | pip3 install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia
23 | ```
24 |
25 | ## Dataset
26 | Please download one of the following datasets:
27 |
28 | [NeRF-synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
29 |
30 | [NSVF-synthetic](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip)
31 |
32 | [Tanks & Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)
33 |
34 | ## Training
35 | Specify the path of the data in configs/lego.txt and run
36 | ```
37 | python train.py --config configs/lego.txt
38 | ```
39 |
40 | ## Rendering
41 | ```
42 | python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --render_only 1 --render_test 1
43 | ```
44 |
45 | ## Citation
46 | If you find this code useful, please cite:
47 |
48 | @inproceedings{han2023nrff,
49 | author={Han, Kang and Xiang, Wei},
50 | title={Multiscale Tensor Decomposition and Rendering Equation Encoding for View Synthesis},
51 | booktitle={The IEEE / CVF Computer Vision and Pattern Recognition Conference},
52 | pages={4232--4241},
53 | year={2023}
54 | }
55 |
56 | ## Acknowledgements
57 |
58 | Thanks to the awesome neural rendering repositories of [TensoRF](https://github.com/apchenstu/TensoRF) and [Instand-NGP](https://github.com/NVlabs/instant-ngp).
--------------------------------------------------------------------------------
/dataLoader/blender.py:
--------------------------------------------------------------------------------
1 | import torch,cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 |
10 | from .ray_utils import *
11 |
12 |
13 | class BlenderDataset(Dataset):
14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
15 |
16 | self.N_vis = N_vis
17 | self.root_dir = datadir
18 | self.split = split
19 | self.is_stack = is_stack
20 | self.img_wh = (int(800/downsample),int(800/downsample))
21 | self.define_transforms()
22 |
23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
25 | self.read_meta()
26 | self.define_proj_mat()
27 |
28 | self.white_bg = True
29 | self.near_far = [2.0,6.0]
30 |
31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
33 | self.downsample=downsample
34 |
35 | def read_depth(self, filename):
36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
37 | return depth
38 |
39 | def read_meta(self):
40 |
41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
42 | self.meta = json.load(f)
43 |
44 | w, h = self.img_wh
45 | self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
46 | self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
47 |
48 |
49 | # ray directions for all pixels, same for all images (same H, W, focal)
50 | self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3)
51 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
52 | self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()
53 |
54 | self.image_paths = []
55 | self.poses = []
56 | self.all_rays = []
57 | self.all_rgbs = []
58 | self.all_masks = []
59 | self.all_depth = []
60 | self.downsample=1.0
61 |
62 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
63 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
64 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
65 |
66 | frame = self.meta['frames'][i]
67 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
68 | c2w = torch.FloatTensor(pose)
69 | self.poses += [c2w]
70 |
71 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
72 | self.image_paths += [image_path]
73 | img = Image.open(image_path)
74 |
75 | if self.downsample!=1.0:
76 | img = img.resize(self.img_wh, Image.LANCZOS)
77 | img = self.transform(img) # (4, h, w)
78 | img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
79 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
80 | self.all_rgbs += [img]
81 |
82 |
83 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
84 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
85 |
86 |
87 | self.poses = torch.stack(self.poses)
88 | if not self.is_stack:
89 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
90 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
91 |
92 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
93 | else:
94 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
95 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
96 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
97 |
98 |
99 | def define_transforms(self):
100 | self.transform = T.ToTensor()
101 |
102 | def define_proj_mat(self):
103 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
104 |
105 | def world2ndc(self,points,lindisp=None):
106 | device = points.device
107 | return (points - self.center.to(device)) / self.radius.to(device)
108 |
109 | def __len__(self):
110 | return len(self.all_rgbs)
111 |
112 | def __getitem__(self, idx):
113 |
114 | if self.split == 'train': # use data in the buffers
115 | sample = {'rays': self.all_rays[idx],
116 | 'rgbs': self.all_rgbs[idx]}
117 |
118 | else: # create data for each image separately
119 |
120 | img = self.all_rgbs[idx]
121 | rays = self.all_rays[idx]
122 | mask = self.all_masks[idx] # for quantity evaluation
123 |
124 | sample = {'rays': rays,
125 | 'rgbs': img,
126 | 'mask': mask}
127 | return sample
128 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2,torch
2 | import numpy as np
3 | from PIL import Image
4 | import torchvision.transforms as T
5 | import torch.nn.functional as F
6 | import scipy.signal
7 |
8 | mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
9 |
10 |
11 | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
12 | """
13 | depth: (H, W)
14 | """
15 |
16 | x = np.nan_to_num(depth) # change nan to 0
17 | if minmax is None:
18 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
19 | ma = np.max(x)
20 | else:
21 | mi,ma = minmax
22 |
23 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
24 | x = (255*x).astype(np.uint8)
25 | x_ = cv2.applyColorMap(x, cmap)
26 | return x_, [mi,ma]
27 |
28 | def init_log(log, keys):
29 | for key in keys:
30 | log[key] = torch.tensor([0.0], dtype=float)
31 | return log
32 |
33 | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
34 | """
35 | depth: (H, W)
36 | """
37 | if type(depth) is not np.ndarray:
38 | depth = depth.cpu().numpy()
39 |
40 | x = np.nan_to_num(depth) # change nan to 0
41 | if minmax is None:
42 | mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
43 | ma = np.max(x)
44 | else:
45 | mi,ma = minmax
46 |
47 | x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
48 | x = (255*x).astype(np.uint8)
49 | x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
50 | x_ = T.ToTensor()(x_) # (3, H, W)
51 | return x_, [mi,ma]
52 |
53 | def N_to_reso(n_voxels, bbox):
54 | xyz_min, xyz_max = bbox
55 | dim = len(xyz_min)
56 | voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
57 | return ((xyz_max - xyz_min) / voxel_size).long().tolist()
58 |
59 | def cal_n_samples(reso, step_ratio=0.5):
60 | return int(np.linalg.norm(reso)/step_ratio)
61 |
62 |
63 |
64 |
65 | __LPIPS__ = {}
66 | def init_lpips(net_name, device):
67 | assert net_name in ['alex', 'vgg']
68 | import lpips
69 | print(f'init_lpips: lpips_{net_name}')
70 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
71 |
72 | def rgb_lpips(np_gt, np_im, net_name, device):
73 | if net_name not in __LPIPS__:
74 | __LPIPS__[net_name] = init_lpips(net_name, device)
75 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
76 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
77 | return __LPIPS__[net_name](gt, im, normalize=True).item()
78 |
79 |
80 | def findItem(items, target):
81 | for one in items:
82 | if one[:len(target)]==target:
83 | return one
84 | return None
85 |
86 |
87 | ''' Evaluation metrics (ssim, lpips)
88 | '''
89 | def rgb_ssim(img0, img1, max_val,
90 | filter_size=11,
91 | filter_sigma=1.5,
92 | k1=0.01,
93 | k2=0.03,
94 | return_map=False):
95 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
96 | assert len(img0.shape) == 3
97 | assert img0.shape[-1] == 3
98 | assert img0.shape == img1.shape
99 |
100 | # Construct a 1D Gaussian blur filter.
101 | hw = filter_size // 2
102 | shift = (2 * hw - filter_size + 1) / 2
103 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
104 | filt = np.exp(-0.5 * f_i)
105 | filt /= np.sum(filt)
106 |
107 | # Blur in x and y (faster than the 2D convolution).
108 | def convolve2d(z, f):
109 | return scipy.signal.convolve2d(z, f, mode='valid')
110 |
111 | filt_fn = lambda z: np.stack([
112 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
113 | for i in range(z.shape[-1])], -1)
114 | mu0 = filt_fn(img0)
115 | mu1 = filt_fn(img1)
116 | mu00 = mu0 * mu0
117 | mu11 = mu1 * mu1
118 | mu01 = mu0 * mu1
119 | sigma00 = filt_fn(img0**2) - mu00
120 | sigma11 = filt_fn(img1**2) - mu11
121 | sigma01 = filt_fn(img0 * img1) - mu01
122 |
123 | # Clip the variances and covariances to valid values.
124 | # Variance must be non-negative:
125 | sigma00 = np.maximum(0., sigma00)
126 | sigma11 = np.maximum(0., sigma11)
127 | sigma01 = np.sign(sigma01) * np.minimum(
128 | np.sqrt(sigma00 * sigma11), np.abs(sigma01))
129 | c1 = (k1 * max_val)**2
130 | c2 = (k2 * max_val)**2
131 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
132 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
133 | ssim_map = numer / denom
134 | ssim = np.mean(ssim_map)
135 | return ssim_map if return_map else ssim
136 |
137 |
138 | import torch.nn as nn
139 | class TVLoss(nn.Module):
140 | def __init__(self,TVLoss_weight=1):
141 | super(TVLoss,self).__init__()
142 | self.TVLoss_weight = TVLoss_weight
143 |
144 | def forward(self,x):
145 | batch_size = x.size()[0]
146 | h_x = x.size()[2]
147 | w_x = x.size()[3]
148 | count_h = self._tensor_size(x[:,:,1:,:])
149 | count_w = self._tensor_size(x[:,:,:,1:])
150 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
151 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
152 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
153 |
154 | def _tensor_size(self,t):
155 | return t.size()[1]*t.size()[2]*t.size()[3]
156 |
--------------------------------------------------------------------------------
/dataLoader/your_own_data.py:
--------------------------------------------------------------------------------
1 | import torch,cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 |
10 | from .ray_utils import *
11 |
12 |
13 | class YourOwnDataset(Dataset):
14 | def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
15 |
16 | self.N_vis = N_vis
17 | self.root_dir = datadir
18 | self.split = split
19 | self.is_stack = is_stack
20 | self.downsample = downsample
21 | self.define_transforms()
22 |
23 | self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
24 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
25 | self.read_meta()
26 | self.define_proj_mat()
27 |
28 | self.white_bg = True
29 | self.near_far = [0.1,100.0]
30 |
31 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
32 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
33 | self.downsample=downsample
34 |
35 | def read_depth(self, filename):
36 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
37 | return depth
38 |
39 | def read_meta(self):
40 |
41 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
42 | self.meta = json.load(f)
43 |
44 | w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
45 | self.img_wh = [w,h]
46 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
47 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
48 | self.cx, self.cy = self.meta['cx'],self.meta['cy']
49 |
50 |
51 | # ray directions for all pixels, same for all images (same H, W, focal)
52 | self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
53 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
54 | self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
55 |
56 | self.image_paths = []
57 | self.poses = []
58 | self.all_rays = []
59 | self.all_rgbs = []
60 | self.all_masks = []
61 | self.all_depth = []
62 |
63 |
64 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
65 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
66 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
67 |
68 | frame = self.meta['frames'][i]
69 | pose = np.array(frame['transform_matrix']) @ self.blender2opencv
70 | c2w = torch.FloatTensor(pose)
71 | self.poses += [c2w]
72 |
73 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
74 | self.image_paths += [image_path]
75 | img = Image.open(image_path)
76 |
77 | if self.downsample!=1.0:
78 | img = img.resize(self.img_wh, Image.LANCZOS)
79 | img = self.transform(img) # (4, h, w)
80 | img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
81 | if img.shape[-1]==4:
82 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
83 | self.all_rgbs += [img]
84 |
85 |
86 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
87 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
88 |
89 |
90 | self.poses = torch.stack(self.poses)
91 | if not self.is_stack:
92 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
93 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
94 |
95 | # self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
96 | else:
97 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
98 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
99 | # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
100 |
101 |
102 | def define_transforms(self):
103 | self.transform = T.ToTensor()
104 |
105 | def define_proj_mat(self):
106 | self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
107 |
108 | def world2ndc(self,points,lindisp=None):
109 | device = points.device
110 | return (points - self.center.to(device)) / self.radius.to(device)
111 |
112 | def __len__(self):
113 | return len(self.all_rgbs)
114 |
115 | def __getitem__(self, idx):
116 |
117 | if self.split == 'train': # use data in the buffers
118 | sample = {'rays': self.all_rays[idx],
119 | 'rgbs': self.all_rgbs[idx]}
120 |
121 | else: # create data for each image separately
122 |
123 | img = self.all_rgbs[idx]
124 | rays = self.all_rays[idx]
125 | mask = self.all_masks[idx] # for quantity evaluation
126 |
127 | sample = {'rays': rays,
128 | 'rgbs': img}
129 | return sample
130 |
--------------------------------------------------------------------------------
/models/sh.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ################## sh function ##################
4 | C0 = 0.28209479177387814
5 | C1 = 0.4886025119029199
6 | C2 = [
7 | 1.0925484305920792,
8 | -1.0925484305920792,
9 | 0.31539156525252005,
10 | -1.0925484305920792,
11 | 0.5462742152960396
12 | ]
13 | C3 = [
14 | -0.5900435899266435,
15 | 2.890611442640554,
16 | -0.4570457994644658,
17 | 0.3731763325901154,
18 | -0.4570457994644658,
19 | 1.445305721320277,
20 | -0.5900435899266435
21 | ]
22 | C4 = [
23 | 2.5033429417967046,
24 | -1.7701307697799304,
25 | 0.9461746957575601,
26 | -0.6690465435572892,
27 | 0.10578554691520431,
28 | -0.6690465435572892,
29 | 0.47308734787878004,
30 | -1.7701307697799304,
31 | 0.6258357354491761,
32 | ]
33 |
34 | def eval_sh(deg, sh, dirs):
35 | """
36 | Evaluate spherical harmonics at unit directions
37 | using hardcoded SH polynomials.
38 | Works with torch/np/jnp.
39 | ... Can be 0 or more batch dimensions.
40 | :param deg: int SH max degree. Currently, 0-4 supported
41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
42 | :param dirs: torch.Tensor unit directions (..., 3)
43 | :return: (..., C)
44 | """
45 | assert deg <= 4 and deg >= 0
46 | assert (deg + 1) ** 2 == sh.shape[-1]
47 | C = sh.shape[-2]
48 |
49 | result = C0 * sh[..., 0]
50 | if deg > 0:
51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
52 | result = (result -
53 | C1 * y * sh[..., 1] +
54 | C1 * z * sh[..., 2] -
55 | C1 * x * sh[..., 3])
56 | if deg > 1:
57 | xx, yy, zz = x * x, y * y, z * z
58 | xy, yz, xz = x * y, y * z, x * z
59 | result = (result +
60 | C2[0] * xy * sh[..., 4] +
61 | C2[1] * yz * sh[..., 5] +
62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
63 | C2[3] * xz * sh[..., 7] +
64 | C2[4] * (xx - yy) * sh[..., 8])
65 |
66 | if deg > 2:
67 | result = (result +
68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
69 | C3[1] * xy * z * sh[..., 10] +
70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
73 | C3[5] * z * (xx - yy) * sh[..., 14] +
74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
75 | if deg > 3:
76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
85 | return result
86 |
87 | def eval_sh_bases(deg, dirs):
88 | """
89 | Evaluate spherical harmonics bases at unit directions,
90 | without taking linear combination.
91 | At each point, the final result may the be
92 | obtained through simple multiplication.
93 | :param deg: int SH max degree. Currently, 0-4 supported
94 | :param dirs: torch.Tensor (..., 3) unit directions
95 | :return: torch.Tensor (..., (deg+1) ** 2)
96 | """
97 | assert deg <= 4 and deg >= 0
98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
99 | result[..., 0] = C0
100 | if deg > 0:
101 | x, y, z = dirs.unbind(-1)
102 | result[..., 1] = -C1 * y;
103 | result[..., 2] = C1 * z;
104 | result[..., 3] = -C1 * x;
105 | if deg > 1:
106 | xx, yy, zz = x * x, y * y, z * z
107 | xy, yz, xz = x * y, y * z, x * z
108 | result[..., 4] = C2[0] * xy;
109 | result[..., 5] = C2[1] * yz;
110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
111 | result[..., 7] = C2[3] * xz;
112 | result[..., 8] = C2[4] * (xx - yy);
113 |
114 | if deg > 2:
115 | result[..., 9] = C3[0] * y * (3 * xx - yy);
116 | result[..., 10] = C3[1] * xy * z;
117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
120 | result[..., 14] = C3[5] * z * (xx - yy);
121 | result[..., 15] = C3[6] * x * (xx - 3 * yy);
122 |
123 | if deg > 3:
124 | result[..., 16] = C4[0] * xy * (xx - yy);
125 | result[..., 17] = C4[1] * yz * (3 * xx - yy);
126 | result[..., 18] = C4[2] * xy * (7 * zz - 1);
127 | result[..., 19] = C4[3] * yz * (7 * zz - 3);
128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
129 | result[..., 21] = C4[5] * xz * (7 * zz - 3);
130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy);
132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
133 | return result
134 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | import torch,os,imageio,sys
2 | from tqdm.auto import tqdm
3 | from dataLoader.ray_utils import get_rays
4 | from models.model import NRFF
5 | from utils import *
6 | from dataLoader.ray_utils import ndc_rays_blender
7 | import time
8 |
9 |
10 | def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'):
11 |
12 | rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], []
13 | N_rays_all = rays.shape[0]
14 | for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
15 | rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
16 |
17 | output = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
18 |
19 | rgbs.append(output['rgb_map'])
20 | depth_maps.append(output['depth_map'])
21 |
22 | return torch.cat(rgbs), None, torch.cat(depth_maps), None, output
23 |
24 | @torch.no_grad()
25 | def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
26 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
27 | PSNRs, rgb_maps, depth_maps = [], [], []
28 | ssims,l_alex,l_vgg=[],[],[]
29 | os.makedirs(savePath, exist_ok=True)
30 | os.makedirs(savePath+"/rgbd", exist_ok=True)
31 |
32 | try:
33 | tqdm._instances.clear()
34 | except Exception:
35 | pass
36 |
37 | near_far = test_dataset.near_far
38 | img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1)
39 | idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval))
40 | for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout):
41 |
42 | W, H = test_dataset.img_wh
43 | rays = samples.view(-1,samples.shape[-1])
44 |
45 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples,
46 | ndc_ray=ndc_ray, white_bg = white_bg, device=device)
47 |
48 | rgb_map = rgb_map.clamp(0.0, 1.0)
49 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
50 |
51 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
52 | if len(test_dataset.all_rgbs):
53 | gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3)
54 | loss = torch.mean((rgb_map - gt_rgb) ** 2)
55 | PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
56 |
57 |
58 | if compute_extra_metrics:
59 | ssim = rgb_ssim(rgb_map, gt_rgb, 1)
60 | l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device)
61 | l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device)
62 |
63 | ssims.append(ssim)
64 | l_alex.append(l_a)
65 | l_vgg.append(l_v)
66 |
67 |
68 | rgb_map = (rgb_map.numpy() * 255).astype('uint8')
69 | rgb_maps.append(rgb_map)
70 | depth_maps.append(depth_map)
71 | if savePath is not None:
72 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
73 |
74 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)
75 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)
76 |
77 | if PSNRs:
78 | psnr = np.mean(np.asarray(PSNRs))
79 | if compute_extra_metrics:
80 | ssim = np.mean(np.asarray(ssims))
81 | l_a = np.mean(np.asarray(l_alex))
82 | l_v = np.mean(np.asarray(l_vgg))
83 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
84 | else:
85 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
86 |
87 |
88 | return PSNRs
89 |
90 | @torch.no_grad()
91 | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
92 | white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
93 | PSNRs, rgb_maps, depth_maps = [], [], []
94 | ssims,l_alex,l_vgg=[],[],[]
95 | os.makedirs(savePath, exist_ok=True)
96 | os.makedirs(savePath+"/rgbd", exist_ok=True)
97 |
98 | try:
99 | tqdm._instances.clear()
100 | except Exception:
101 | pass
102 |
103 | near_far = test_dataset.near_far
104 | for idx, c2w in tqdm(enumerate(c2ws)):
105 |
106 | W, H = test_dataset.img_wh
107 |
108 | c2w = torch.FloatTensor(c2w)
109 | rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
110 | if ndc_ray:
111 | rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
112 | rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
113 |
114 | rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples,
115 | ndc_ray=ndc_ray, white_bg = white_bg, device=device)
116 | rgb_map = rgb_map.clamp(0.0, 1.0)
117 |
118 | rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
119 |
120 | depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
121 |
122 | rgb_map = (rgb_map.numpy() * 255).astype('uint8')
123 | # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
124 | rgb_maps.append(rgb_map)
125 | depth_maps.append(depth_map)
126 | if savePath is not None:
127 | imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
128 | rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
129 | imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
130 |
131 | imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
132 | imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)
133 |
134 | if PSNRs:
135 | psnr = np.mean(np.asarray(PSNRs))
136 | if compute_extra_metrics:
137 | ssim = np.mean(np.asarray(ssims))
138 | l_a = np.mean(np.asarray(l_alex))
139 | l_v = np.mean(np.asarray(l_vgg))
140 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
141 | else:
142 | np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
143 |
144 |
145 | return PSNRs
146 |
147 |
--------------------------------------------------------------------------------
/opt.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 |
3 | def config_parser(cmd=None):
4 | parser = configargparse.ArgumentParser()
5 | parser.add_argument('--config', is_config_file=True,
6 | help='config file path')
7 | parser.add_argument("--expname", type=str,
8 | help='experiment name')
9 | parser.add_argument("--basedir", type=str, default='./log',
10 | help='where to store ckpts and logs')
11 | parser.add_argument("--add_timestamp", type=int, default=0,
12 | help='add timestamp to dir')
13 | parser.add_argument("--datadir", type=str, default='./data/llff/fern',
14 | help='input data directory')
15 | parser.add_argument("--progress_refresh_rate", type=int, default=10,
16 | help='how many iterations to show psnrs or iters')
17 |
18 | parser.add_argument('--with_depth', action='store_true')
19 | parser.add_argument('--downsample_train', type=float, default=1.0)
20 | parser.add_argument('--downsample_test', type=float, default=1.0)
21 |
22 | parser.add_argument('--model_name', type=str, default='TensorVMSplit',
23 | choices=['TensorVMSplit', 'TensorCP', 'NRFF'])
24 |
25 | # loader options
26 | parser.add_argument("--batch_size", type=int, default=4096)
27 | parser.add_argument("--n_iters", type=int, default=30000)
28 |
29 | parser.add_argument('--dataset_name', type=str, default='blender',
30 | choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'])
31 |
32 |
33 | # training options
34 | # learning rate
35 | parser.add_argument("--lr_init", type=float, default=0.02,
36 | help='learning rate')
37 | parser.add_argument("--lr_basis", type=float, default=1e-3,
38 | help='learning rate')
39 | parser.add_argument("--lr_decay_iters", type=int, default=-1,
40 | help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters')
41 | parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1,
42 | help='the target decay ratio; after decay_iters inital lr decays to lr*ratio')
43 | parser.add_argument("--lr_upsample_reset", type=int, default=1,
44 | help='reset lr to inital after upsampling')
45 |
46 | # loss
47 | parser.add_argument("--L1_weight_inital", type=float, default=0.0,
48 | help='loss weight')
49 | parser.add_argument("--L1_weight_rest", type=float, default=0,
50 | help='loss weight')
51 | parser.add_argument("--Ortho_weight", type=float, default=0.0,
52 | help='loss weight')
53 | parser.add_argument("--TV_weight_density", type=float, default=0.0,
54 | help='loss weight')
55 | parser.add_argument("--TV_weight_app", type=float, default=0.0,
56 | help='loss weight')
57 |
58 | # model
59 | # volume options
60 | parser.add_argument("--n_lamb_sigma", type=int, action="append")
61 | parser.add_argument("--n_lamb_sh", type=int, action="append")
62 | parser.add_argument("--data_dim_color", type=int, default=27)
63 |
64 | parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001,
65 | help='mask points in ray marching')
66 | parser.add_argument("--alpha_mask_thre", type=float, default=0.00001,
67 | help='threshold for creating alpha mask volume')
68 | parser.add_argument("--distance_scale", type=float, default=25,
69 | help='scaling sampling distance for computation')
70 | parser.add_argument("--density_shift", type=float, default=-10,
71 | help='shift density in softplus; making density = 0 when feature == 0')
72 |
73 | # network decoder
74 | parser.add_argument("--shadingMode", type=str, default="MLP_PE",
75 | help='which shading mode to use')
76 | parser.add_argument("--pos_pe", type=int, default=6,
77 | help='number of pe for pos')
78 | parser.add_argument("--view_pe", type=int, default=6,
79 | help='number of pe for view')
80 | parser.add_argument("--fea_pe", type=int, default=6,
81 | help='number of pe for features')
82 | parser.add_argument("--featureC", type=int, default=128,
83 | help='hidden feature channel in MLP')
84 |
85 |
86 |
87 | parser.add_argument("--ckpt", type=str, default=None,
88 | help='specific weights npy file to reload for coarse network')
89 | parser.add_argument("--render_only", type=int, default=0)
90 | parser.add_argument("--render_test", type=int, default=0)
91 | parser.add_argument("--render_train", type=int, default=0)
92 | parser.add_argument("--render_path", type=int, default=0)
93 | parser.add_argument("--export_mesh", type=int, default=0)
94 |
95 | # rendering options
96 | parser.add_argument('--lindisp', default=False, action="store_true",
97 | help='use disparity depth sampling')
98 | parser.add_argument("--perturb", type=float, default=1.,
99 | help='set to 0. for no jitter, 1. for jitter')
100 | parser.add_argument("--accumulate_decay", type=float, default=0.998)
101 | parser.add_argument("--fea2denseAct", type=str, default='softplus')
102 | parser.add_argument('--ndc_ray', type=int, default=0)
103 | parser.add_argument('--nSamples', type=int, default=1e6,
104 | help='sample point each ray, pass 1e6 if automatic adjust')
105 | parser.add_argument('--step_ratio',type=float,default=0.5)
106 |
107 |
108 | ## blender flags
109 | parser.add_argument("--white_bkgd", action='store_true',
110 | help='set to render synthetic data on a white bkgd (always use for dvoxels)')
111 |
112 |
113 |
114 | parser.add_argument('--N_voxel_init',
115 | type=int,
116 | default=100**3)
117 | parser.add_argument('--N_voxel_final',
118 | type=int,
119 | default=300**3)
120 | parser.add_argument("--upsamp_list", type=int, action="append")
121 | parser.add_argument("--update_AlphaMask_list", type=int, action="append")
122 |
123 | parser.add_argument('--idx_view',
124 | type=int,
125 | default=0)
126 | # logging/saving options
127 | parser.add_argument("--N_vis", type=int, default=5,
128 | help='N images to vis')
129 | parser.add_argument("--vis_every", type=int, default=10000,
130 | help='frequency of visualize the image')
131 | if cmd is not None:
132 | return parser.parse_args(cmd)
133 | else:
134 | return parser.parse_args()
135 |
--------------------------------------------------------------------------------
/dataLoader/nsvf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import os
5 | from PIL import Image
6 | from torchvision import transforms as T
7 |
8 | from .ray_utils import *
9 |
10 | trans_t = lambda t : torch.Tensor([
11 | [1,0,0,0],
12 | [0,1,0,0],
13 | [0,0,1,t],
14 | [0,0,0,1]]).float()
15 |
16 | rot_phi = lambda phi : torch.Tensor([
17 | [1,0,0,0],
18 | [0,np.cos(phi),-np.sin(phi),0],
19 | [0,np.sin(phi), np.cos(phi),0],
20 | [0,0,0,1]]).float()
21 |
22 | rot_theta = lambda th : torch.Tensor([
23 | [np.cos(th),0,-np.sin(th),0],
24 | [0,1,0,0],
25 | [np.sin(th),0, np.cos(th),0],
26 | [0,0,0,1]]).float()
27 |
28 |
29 | def pose_spherical(theta, phi, radius):
30 | c2w = trans_t(radius)
31 | c2w = rot_phi(phi/180.*np.pi) @ c2w
32 | c2w = rot_theta(theta/180.*np.pi) @ c2w
33 | c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
34 | return c2w
35 |
36 | class NSVF(Dataset):
37 | """NSVF Generic Dataset."""
38 | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
39 | self.root_dir = datadir
40 | self.split = split
41 | self.is_stack = is_stack
42 | self.downsample = downsample
43 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
44 | self.define_transforms()
45 |
46 | self.white_bg = True
47 | self.near_far = [0.5,6.0]
48 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
49 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
50 | self.read_meta()
51 | self.define_proj_mat()
52 |
53 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
54 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
55 |
56 | def bbox2corners(self):
57 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
58 | for i in range(3):
59 | corners[i,[0,1],i] = corners[i,[1,0],i]
60 | return corners.view(-1,3)
61 |
62 |
63 | def read_meta(self):
64 | with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
65 | focal = float(f.readline().split()[0])
66 | self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
67 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
68 |
69 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
70 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
71 |
72 | if self.split == 'train':
73 | pose_files = [x for x in pose_files if x.startswith('0_')]
74 | img_files = [x for x in img_files if x.startswith('0_')]
75 | elif self.split == 'val':
76 | pose_files = [x for x in pose_files if x.startswith('1_')]
77 | img_files = [x for x in img_files if x.startswith('1_')]
78 | elif self.split == 'test':
79 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
80 | test_img_files = [x for x in img_files if x.startswith('2_')]
81 | if len(test_pose_files) == 0:
82 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
83 | test_img_files = [x for x in img_files if x.startswith('1_')]
84 | pose_files = test_pose_files
85 | img_files = test_img_files
86 |
87 | # ray directions for all pixels, same for all images (same H, W, focal)
88 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
89 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
90 |
91 |
92 | self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
93 |
94 | self.poses = []
95 | self.all_rays = []
96 | self.all_rgbs = []
97 |
98 | assert len(img_files) == len(pose_files)
99 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
100 | image_path = os.path.join(self.root_dir, 'rgb', img_fname)
101 | img = Image.open(image_path)
102 | if self.downsample!=1.0:
103 | img = img.resize(self.img_wh, Image.LANCZOS)
104 | img = self.transform(img) # (4, h, w)
105 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
106 | if img.shape[-1]==4:
107 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
108 | self.all_rgbs += [img]
109 |
110 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
111 | c2w = torch.FloatTensor(c2w)
112 | self.poses.append(c2w) # C2W
113 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
114 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
115 |
116 | # w2c = torch.inverse(c2w)
117 | #
118 |
119 | self.poses = torch.stack(self.poses)
120 | if 'train' == self.split:
121 | if self.is_stack:
122 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
123 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
124 | else:
125 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
126 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
127 | else:
128 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
129 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
130 |
131 |
132 | def define_transforms(self):
133 | self.transform = T.ToTensor()
134 |
135 | def define_proj_mat(self):
136 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
137 |
138 | def world2ndc(self, points):
139 | device = points.device
140 | return (points - self.center.to(device)) / self.radius.to(device)
141 |
142 | def __len__(self):
143 | if self.split == 'train':
144 | return len(self.all_rays)
145 | return len(self.all_rgbs)
146 |
147 | def __getitem__(self, idx):
148 |
149 | if self.split == 'train': # use data in the buffers
150 | sample = {'rays': self.all_rays[idx],
151 | 'rgbs': self.all_rgbs[idx]}
152 |
153 | else: # create data for each image separately
154 |
155 | img = self.all_rgbs[idx]
156 | rays = self.all_rays[idx]
157 |
158 | sample = {'rays': rays,
159 | 'rgbs': img}
160 | return sample
--------------------------------------------------------------------------------
/extra/compute_metrics.py:
--------------------------------------------------------------------------------
1 | import os, math
2 | import numpy as np
3 | import scipy.signal
4 | from typing import List, Optional
5 | from PIL import Image
6 | import os
7 | import torch
8 | import configargparse
9 |
10 | __LPIPS__ = {}
11 | def init_lpips(net_name, device):
12 | assert net_name in ['alex', 'vgg']
13 | import lpips
14 | print(f'init_lpips: lpips_{net_name}')
15 | return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
16 |
17 | def rgb_lpips(np_gt, np_im, net_name, device):
18 | if net_name not in __LPIPS__:
19 | __LPIPS__[net_name] = init_lpips(net_name, device)
20 | gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
21 | im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
22 | return __LPIPS__[net_name](gt, im, normalize=True).item()
23 |
24 |
25 | def findItem(items, target):
26 | for one in items:
27 | if one[:len(target)]==target:
28 | return one
29 | return None
30 |
31 |
32 | ''' Evaluation metrics (ssim, lpips)
33 | '''
34 | def rgb_ssim(img0, img1, max_val,
35 | filter_size=11,
36 | filter_sigma=1.5,
37 | k1=0.01,
38 | k2=0.03,
39 | return_map=False):
40 | # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
41 | assert len(img0.shape) == 3
42 | assert img0.shape[-1] == 3
43 | assert img0.shape == img1.shape
44 |
45 | # Construct a 1D Gaussian blur filter.
46 | hw = filter_size // 2
47 | shift = (2 * hw - filter_size + 1) / 2
48 | f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
49 | filt = np.exp(-0.5 * f_i)
50 | filt /= np.sum(filt)
51 |
52 | # Blur in x and y (faster than the 2D convolution).
53 | def convolve2d(z, f):
54 | return scipy.signal.convolve2d(z, f, mode='valid')
55 |
56 | filt_fn = lambda z: np.stack([
57 | convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
58 | for i in range(z.shape[-1])], -1)
59 | mu0 = filt_fn(img0)
60 | mu1 = filt_fn(img1)
61 | mu00 = mu0 * mu0
62 | mu11 = mu1 * mu1
63 | mu01 = mu0 * mu1
64 | sigma00 = filt_fn(img0**2) - mu00
65 | sigma11 = filt_fn(img1**2) - mu11
66 | sigma01 = filt_fn(img0 * img1) - mu01
67 |
68 | # Clip the variances and covariances to valid values.
69 | # Variance must be non-negative:
70 | sigma00 = np.maximum(0., sigma00)
71 | sigma11 = np.maximum(0., sigma11)
72 | sigma01 = np.sign(sigma01) * np.minimum(
73 | np.sqrt(sigma00 * sigma11), np.abs(sigma01))
74 | c1 = (k1 * max_val)**2
75 | c2 = (k2 * max_val)**2
76 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
77 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
78 | ssim_map = numer / denom
79 | ssim = np.mean(ssim_map)
80 | return ssim_map if return_map else ssim
81 |
82 |
83 | if __name__ == '__main__':
84 |
85 | parser = configargparse.ArgumentParser()
86 | parser.add_argument("--exp", type=str, help="folder of exps")
87 | parser.add_argument("--paramStr", type=str, help="str of params")
88 | args = parser.parse_args()
89 |
90 |
91 | # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']#
92 | # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic"
93 | # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp
94 |
95 | # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']#
96 | # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/"
97 | # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp
98 | paramStr = args.paramStr
99 | fileNum = 200
100 |
101 |
102 | expitems = os.listdir(expFolder)
103 | finalFolder = f'{expFolder}/finals/{paramStr}'
104 | outFile = f'{finalFolder}/{paramStr}_metrics.txt'
105 | os.makedirs(finalFolder, exist_ok=True)
106 |
107 | expitems.sort(reverse=True)
108 |
109 |
110 | with open(outFile, 'w') as f:
111 | all_psnr = []
112 | all_ssim = []
113 | all_alex = []
114 | all_vgg = []
115 | for dataname in datanames:
116 |
117 |
118 | gtstr = gtFolder+"/"+dataname+"/test/r_%d.png"
119 | expname = findItem(expitems, f'{paramStr}-{dataname}')
120 | print("expname: ", expname)
121 | if expname is None:
122 | print("no ",dataname, "exists")
123 | continue
124 | resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png"
125 | metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt'
126 | video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4'
127 |
128 | exist_metric=False
129 | if os.path.isfile(metric_file):
130 | metrics = np.loadtxt(metric_file)
131 | print(metrics, metrics.tolist())
132 | if metrics.size == 4:
133 | psnr, ssim, l_a, l_v = metrics.tolist()
134 | exist_metric = True
135 | os.system(f"cp {video_file} {finalFolder}/")
136 |
137 | if not exist_metric:
138 | psnrs = []
139 | ssims = []
140 | l_alex = []
141 | l_vgg = []
142 | for i in range(fileNum):
143 | gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0
144 | gtmask = gt[...,[3]]
145 | gt = gt[...,:3]
146 | gt = gt*gtmask + (1-gtmask)
147 | img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0
148 | # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max())
149 |
150 |
151 | psnr = -10. * np.log10(np.mean(np.square(img - gt)))
152 | ssim = rgb_ssim(img, gt, 1)
153 | lpips_alex = rgb_lpips(gt, img, 'alex','cuda')
154 | lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda')
155 |
156 | print(i, psnr, ssim, lpips_alex, lpips_vgg)
157 | psnrs.append(psnr)
158 | ssims.append(ssim)
159 | l_alex.append(lpips_alex)
160 | l_vgg.append(lpips_vgg)
161 | psnr = np.mean(np.array(psnrs))
162 | ssim = np.mean(np.array(ssims))
163 | l_a = np.mean(np.array(l_alex))
164 | l_v = np.mean(np.array(l_vgg))
165 |
166 | rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
167 | print(rS)
168 | f.write(rS+"\n")
169 |
170 | all_psnr.append(psnr)
171 | all_ssim.append(ssim)
172 | all_alex.append(l_a)
173 | all_vgg.append(l_v)
174 |
175 | psnr = np.mean(np.array(all_psnr))
176 | ssim = np.mean(np.array(all_ssim))
177 | l_a = np.mean(np.array(all_alex))
178 | l_v = np.mean(np.array(all_vgg))
179 |
180 | rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
181 | print(rS)
182 | f.write(rS+"\n")
--------------------------------------------------------------------------------
/extra/auto_run_paramsets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import threading, queue
3 | import numpy as np
4 | import time
5 |
6 |
7 | def getFolderLocker(logFolder):
8 | while True:
9 | try:
10 | os.makedirs(logFolder+"/lockFolder")
11 | break
12 | except:
13 | time.sleep(0.01)
14 |
15 | def releaseFolderLocker(logFolder):
16 | os.removedirs(logFolder+"/lockFolder")
17 |
18 | def getStopFolder(logFolder):
19 | return os.path.isdir(logFolder+"/stopFolder")
20 |
21 |
22 | def get_param_str(key, val):
23 | if key == 'data_name':
24 | return f'--datadir {datafolder}/{val} '
25 | else:
26 | return f'--{key} {val} '
27 |
28 | def get_param_list(param_dict):
29 | param_keys = list(param_dict.keys())
30 | param_modes = len(param_keys)
31 | param_nums = [len(param_dict[key]) for key in param_keys]
32 |
33 | param_ids = np.zeros(param_nums+[param_modes], dtype=int)
34 | for i in range(param_modes):
35 | broad_tuple = np.ones(param_modes, dtype=int).tolist()
36 | broad_tuple[i] = param_nums[i]
37 | broad_tuple = tuple(broad_tuple)
38 | print(broad_tuple)
39 | param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple)
40 | param_ids = param_ids.reshape(-1, param_modes)
41 | # print(param_ids)
42 | print(len(param_ids))
43 |
44 | params = []
45 | expnames = []
46 | for i in range(param_ids.shape[0]):
47 | one = ""
48 | name = ""
49 | param_id = param_ids[i]
50 | for j in range(param_modes):
51 | key = param_keys[j]
52 | val = param_dict[key][param_id[j]]
53 | if type(key) is tuple:
54 | assert len(key) == len(val)
55 | for k in range(len(key)):
56 | one += get_param_str(key[k], val[k])
57 | name += f'{val[k]},'
58 | name=name[:-1]+'-'
59 | else:
60 | one += get_param_str(key, val)
61 | name += f'{val}-'
62 | params.append(one)
63 | name=name.replace(' ','')
64 | print(name)
65 | expnames.append(name[:-1])
66 | # print(params)
67 | return params, expnames
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 | if __name__ == '__main__':
76 |
77 |
78 |
79 | # nerf
80 | expFolder = "nerf/"
81 | # parameters to iterate, use tuple to couple multiple parameters
82 | datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/'
83 | param_dict = {
84 | 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'],
85 | 'data_dim_color': [13, 27, 54]
86 | }
87 |
88 | # n_iters = 30000
89 | # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
90 | # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \
91 | # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\
92 | # f'--expname {data_name} --batch_size {batch_size} ' \
93 | # f'--n_iters {n_iters} ' \
94 | # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\
95 | # f'--N_vis {5} ' \
96 | # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \
97 | # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \
98 | # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \
99 | # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \
100 | # f'--render_test 1 '
101 | # print(cmd)
102 | # os.system(cmd)
103 |
104 | # nsvf
105 | # expFolder = "nsvf_0227/"
106 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/'
107 | # param_dict = {
108 | # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
109 | # 'shadingMode': ['SH'],
110 | # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")],
111 | # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)],
112 | # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)],
113 | # ('n_iters','N_voxel_final'): [(30000,300**3)],
114 | # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)],
115 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")]
116 | #
117 | # }
118 |
119 | # tankstemple
120 | # expFolder = "tankstemple_0304/"
121 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/'
122 | # param_dict = {
123 | # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'],
124 | # 'shadingMode': ['MLP_Fea'],
125 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")],
126 | # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)],
127 | # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)],
128 | # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)],
129 | # ('n_iters','N_voxel_final'): [(15000,300**3)],
130 | # ('dataset_name','N_vis') : [("tankstemple",5)],
131 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")]
132 | # }
133 |
134 | # llff
135 | # expFolder = "real_iconic/"
136 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/'
137 | # List = os.listdir(datafolder)
138 | # param_dict = {
139 | # 'data_name': List,
140 | # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)],
141 | # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")],
142 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
143 | # ('n_iters','N_voxel_final'): [(25000,640**3)],
144 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)],
145 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
146 | # }
147 |
148 | # expFolder = "llff/"
149 | # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data'
150 | # param_dict = {
151 | # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'
152 | # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")],
153 | # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)],
154 | # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
155 | # ('n_iters','N_voxel_final'): [(25000,640**3)],
156 | # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)],
157 | # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
158 | # }
159 |
160 | #setting available gpus
161 | gpus_que = queue.Queue(3)
162 | for i in [1,2,3]:
163 | gpus_que.put(i)
164 |
165 | os.makedirs(f"log/{expFolder}", exist_ok=True)
166 |
167 | def run_program(gpu, expname, param):
168 | cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \
169 | f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \
170 | f'{param}' \
171 | f'> "log/{expFolder}{expname}/{expname}.txt"'
172 | print(cmd)
173 | os.system(cmd)
174 | gpus_que.put(gpu)
175 |
176 | params, expnames = get_param_list(param_dict)
177 |
178 |
179 | logFolder=f"log/{expFolder}"
180 | os.makedirs(logFolder, exist_ok=True)
181 |
182 | ths = []
183 | for i in range(len(params)):
184 |
185 | if getStopFolder(logFolder):
186 | break
187 |
188 |
189 | targetFolder = f"log/{expFolder}{expnames[i]}"
190 | gpu = gpus_que.get()
191 | getFolderLocker(logFolder)
192 | if os.path.isdir(targetFolder):
193 | releaseFolderLocker(logFolder)
194 | gpus_que.put(gpu)
195 | continue
196 | else:
197 | os.makedirs(targetFolder, exist_ok=True)
198 | print("making",targetFolder, "running",expnames[i], params[i])
199 | releaseFolderLocker(logFolder)
200 |
201 |
202 | t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True)
203 | t.start()
204 | ths.append(t)
205 |
206 | for th in ths:
207 | th.join()
--------------------------------------------------------------------------------
/dataLoader/tankstemple.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from tqdm import tqdm
4 | import os
5 | from PIL import Image
6 | from torchvision import transforms as T
7 |
8 | from .ray_utils import *
9 |
10 |
11 | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
12 | if axis == 'z':
13 | return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
14 | elif axis == 'y':
15 | return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
16 | else:
17 | return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]
18 |
19 |
20 | def cross(x, y, axis=0):
21 | T = torch if isinstance(x, torch.Tensor) else np
22 | return T.cross(x, y, axis)
23 |
24 |
25 | def normalize(x, axis=-1, order=2):
26 | if isinstance(x, torch.Tensor):
27 | l2 = x.norm(p=order, dim=axis, keepdim=True)
28 | return x / (l2 + 1e-8), l2
29 |
30 | else:
31 | l2 = np.linalg.norm(x, order, axis)
32 | l2 = np.expand_dims(l2, axis)
33 | l2[l2 == 0] = 1
34 | return x / l2,
35 |
36 |
37 | def cat(x, axis=1):
38 | if isinstance(x[0], torch.Tensor):
39 | return torch.cat(x, dim=axis)
40 | return np.concatenate(x, axis=axis)
41 |
42 |
43 | def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
44 | """
45 | This function takes a vector 'camera_position' which specifies the location
46 | of the camera in world coordinates and two vectors `at` and `up` which
47 | indicate the position of the object and the up directions of the world
48 | coordinate system respectively. The object is assumed to be centered at
49 | the origin.
50 | The output is a rotation matrix representing the transformation
51 | from world coordinates -> view coordinates.
52 | Input:
53 | camera_position: 3
54 | at: 1 x 3 or N x 3 (0, 0, 0) in default
55 | up: 1 x 3 or N x 3 (0, 1, 0) in default
56 | """
57 |
58 | if at is None:
59 | at = torch.zeros_like(camera_position)
60 | else:
61 | at = torch.tensor(at).type_as(camera_position)
62 | if up is None:
63 | up = torch.zeros_like(camera_position)
64 | up[2] = -1
65 | else:
66 | up = torch.tensor(up).type_as(camera_position)
67 |
68 | z_axis = normalize(at - camera_position)[0]
69 | x_axis = normalize(cross(up, z_axis))[0]
70 | y_axis = normalize(cross(z_axis, x_axis))[0]
71 |
72 | R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
73 | return R
74 |
75 |
76 | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
77 | c2ws = []
78 | for t in range(frames):
79 | c2w = torch.eye(4)
80 | cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
81 | cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
82 | c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
83 | c2ws.append(c2w)
84 | return torch.stack(c2ws)
85 |
86 | class TanksTempleDataset(Dataset):
87 | """NSVF Generic Dataset."""
88 | def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False):
89 | self.root_dir = datadir
90 | self.split = split
91 | self.is_stack = is_stack
92 | self.downsample = downsample
93 | self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
94 | self.define_transforms()
95 |
96 | self.white_bg = True
97 | self.near_far = [0.01,6.0]
98 | self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2
99 |
100 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
101 | self.read_meta()
102 | self.define_proj_mat()
103 |
104 | self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
105 | self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
106 |
107 | def bbox2corners(self):
108 | corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
109 | for i in range(3):
110 | corners[i,[0,1],i] = corners[i,[1,0],i]
111 | return corners.view(-1,3)
112 |
113 |
114 | def read_meta(self):
115 |
116 | self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
117 | self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1)
118 | pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
119 | img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
120 |
121 | if self.split == 'train':
122 | pose_files = [x for x in pose_files if x.startswith('0_')]
123 | img_files = [x for x in img_files if x.startswith('0_')]
124 | elif self.split == 'val':
125 | pose_files = [x for x in pose_files if x.startswith('1_')]
126 | img_files = [x for x in img_files if x.startswith('1_')]
127 | elif self.split == 'test':
128 | test_pose_files = [x for x in pose_files if x.startswith('2_')]
129 | test_img_files = [x for x in img_files if x.startswith('2_')]
130 | if len(test_pose_files) == 0:
131 | test_pose_files = [x for x in pose_files if x.startswith('1_')]
132 | test_img_files = [x for x in img_files if x.startswith('1_')]
133 | pose_files = test_pose_files
134 | img_files = test_img_files
135 |
136 | # ray directions for all pixels, same for all images (same H, W, focal)
137 | self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
138 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
139 |
140 |
141 |
142 | self.poses = []
143 | self.all_rays = []
144 | self.all_rgbs = []
145 |
146 | assert len(img_files) == len(pose_files)
147 | for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
148 | image_path = os.path.join(self.root_dir, 'rgb', img_fname)
149 | img = Image.open(image_path)
150 | if self.downsample!=1.0:
151 | img = img.resize(self.img_wh, Image.LANCZOS)
152 | img = self.transform(img) # (4, h, w)
153 | img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
154 | if img.shape[-1]==4:
155 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
156 | self.all_rgbs.append(img)
157 |
158 |
159 | c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans
160 | c2w = torch.FloatTensor(c2w)
161 | self.poses.append(c2w) # C2W
162 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
163 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
164 |
165 | self.poses = torch.stack(self.poses)
166 |
167 | center = torch.mean(self.scene_bbox, dim=0)
168 | radius = torch.norm(self.scene_bbox[1]-center)*1.2
169 | up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
170 | pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
171 | self.render_path = gen_path(pos_gen, up=up,frames=200)
172 | self.render_path[:, :3, 3] += center
173 |
174 |
175 |
176 | if 'train' == self.split:
177 | if self.is_stack:
178 | self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
179 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
180 | else:
181 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
182 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
183 | else:
184 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
185 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
186 |
187 |
188 | def define_transforms(self):
189 | self.transform = T.ToTensor()
190 |
191 | def define_proj_mat(self):
192 | self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
193 |
194 | def world2ndc(self, points):
195 | device = points.device
196 | return (points - self.center.to(device)) / self.radius.to(device)
197 |
198 | def __len__(self):
199 | if self.split == 'train':
200 | return len(self.all_rays)
201 | return len(self.all_rgbs)
202 |
203 | def __getitem__(self, idx):
204 |
205 | if self.split == 'train': # use data in the buffers
206 | sample = {'rays': self.all_rays[idx],
207 | 'rgbs': self.all_rgbs[idx]}
208 |
209 | else: # create data for each image separately
210 |
211 | img = self.all_rgbs[idx]
212 | rays = self.all_rays[idx]
213 |
214 | sample = {'rays': rays,
215 | 'rgbs': img}
216 | return sample
--------------------------------------------------------------------------------
/dataLoader/llff.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import glob
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | from torchvision import transforms as T
8 |
9 | from .ray_utils import *
10 |
11 |
12 | def normalize(v):
13 | """Normalize a vector."""
14 | return v / np.linalg.norm(v)
15 |
16 |
17 | def average_poses(poses):
18 | """
19 | Calculate the average pose, which is then used to center all poses
20 | using @center_poses. Its computation is as follows:
21 | 1. Compute the center: the average of pose centers.
22 | 2. Compute the z axis: the normalized average z axis.
23 | 3. Compute axis y': the average y axis.
24 | 4. Compute x' = y' cross product z, then normalize it as the x axis.
25 | 5. Compute the y axis: z cross product x.
26 |
27 | Note that at step 3, we cannot directly use y' as y axis since it's
28 | not necessarily orthogonal to z axis. We need to pass from x to y.
29 | Inputs:
30 | poses: (N_images, 3, 4)
31 | Outputs:
32 | pose_avg: (3, 4) the average pose
33 | """
34 | # 1. Compute the center
35 | center = poses[..., 3].mean(0) # (3)
36 |
37 | # 2. Compute the z axis
38 | z = normalize(poses[..., 2].mean(0)) # (3)
39 |
40 | # 3. Compute axis y' (no need to normalize as it's not the final output)
41 | y_ = poses[..., 1].mean(0) # (3)
42 |
43 | # 4. Compute the x axis
44 | x = normalize(np.cross(z, y_)) # (3)
45 |
46 | # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
47 | y = np.cross(x, z) # (3)
48 |
49 | pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
50 |
51 | return pose_avg
52 |
53 |
54 | def center_poses(poses, blender2opencv):
55 | """
56 | Center the poses so that we can use NDC.
57 | See https://github.com/bmild/nerf/issues/34
58 | Inputs:
59 | poses: (N_images, 3, 4)
60 | Outputs:
61 | poses_centered: (N_images, 3, 4) the centered poses
62 | pose_avg: (3, 4) the average pose
63 | """
64 | poses = poses @ blender2opencv
65 | pose_avg = average_poses(poses) # (3, 4)
66 | pose_avg_homo = np.eye(4)
67 | pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
68 | pose_avg_homo = pose_avg_homo
69 | # by simply adding 0, 0, 0, 1 as the last row
70 | last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
71 | poses_homo = \
72 | np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
73 |
74 | poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
75 | # poses_centered = poses_centered @ blender2opencv
76 | poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
77 |
78 | return poses_centered, pose_avg_homo
79 |
80 |
81 | def viewmatrix(z, up, pos):
82 | vec2 = normalize(z)
83 | vec1_avg = up
84 | vec0 = normalize(np.cross(vec1_avg, vec2))
85 | vec1 = normalize(np.cross(vec2, vec0))
86 | m = np.eye(4)
87 | m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
88 | return m
89 |
90 |
91 | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
92 | render_poses = []
93 | rads = np.array(list(rads) + [1.])
94 |
95 | for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
96 | c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
97 | z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
98 | render_poses.append(viewmatrix(z, up, c))
99 | return render_poses
100 |
101 |
102 | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
103 | # center pose
104 | c2w = average_poses(c2ws_all)
105 |
106 | # Get average pose
107 | up = normalize(c2ws_all[:, :3, 1].sum(0))
108 |
109 | # Find a reasonable "focus depth" for this dataset
110 | dt = 0.75
111 | close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
112 | focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
113 |
114 | # Get radii for spiral path
115 | zdelta = near_fars.min() * .2
116 | tt = c2ws_all[:, :3, 3]
117 | rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
118 | render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
119 | return np.stack(render_poses)
120 |
121 |
122 | class LLFFDataset(Dataset):
123 | def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):
124 | """
125 | spheric_poses: whether the images are taken in a spheric inward-facing manner
126 | default: False (forward-facing)
127 | val_num: number of val images (used for multigpu training, validate same image for all gpus)
128 | """
129 |
130 | self.root_dir = datadir
131 | self.split = split
132 | self.hold_every = hold_every
133 | self.is_stack = is_stack
134 | self.downsample = downsample
135 | self.define_transforms()
136 |
137 | self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
138 | self.read_meta()
139 | self.white_bg = False
140 |
141 | # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
142 | self.near_far = [0.0, 1.0]
143 | self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
144 | # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
145 | self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
146 | self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
147 |
148 | def read_meta(self):
149 |
150 |
151 | poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
152 | self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
153 | # load full resolution image then resize
154 | if self.split in ['train', 'test']:
155 | assert len(poses_bounds) == len(self.image_paths), \
156 | 'Mismatch between number of images and number of poses! Please rerun COLMAP!'
157 |
158 | poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
159 | self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
160 | hwf = poses[:, :, -1]
161 |
162 | # Step 1: rescale focal length according to training resolution
163 | H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
164 | self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
165 | self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
166 |
167 | # Step 2: correct poses
168 | # Original poses has rotation in form "down right back", change to "right up back"
169 | # See https://github.com/bmild/nerf/issues/34
170 | poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
171 | # (N_images, 3, 4) exclude H, W, focal
172 | self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
173 |
174 | # Step 3: correct scale so that the nearest depth is at a little more than 1.0
175 | # See https://github.com/bmild/nerf/issues/34
176 | near_original = self.near_fars.min()
177 | scale_factor = near_original * 0.75 # 0.75 is the default parameter
178 | # the nearest depth is at 1/0.75=1.33
179 | self.near_fars /= scale_factor
180 | self.poses[..., 3] /= scale_factor
181 |
182 | # build rendering path
183 | N_views, N_rots = 120, 2
184 | tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
185 | up = normalize(self.poses[:, :3, 1].sum(0))
186 | rads = np.percentile(np.abs(tt), 90, 0)
187 |
188 | self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
189 |
190 | # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
191 | # val_idx = np.argmin(distances_from_center) # choose val image as the closest to
192 | # center image
193 |
194 | # ray directions for all pixels, same for all images (same H, W, focal)
195 | W, H = self.img_wh
196 | self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
197 |
198 | average_pose = average_poses(self.poses)
199 | dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
200 | i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
201 | img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
202 |
203 | # use first N_images-1 to train, the LAST is val
204 | self.all_rays = []
205 | self.all_rgbs = []
206 | for i in img_list:
207 | image_path = self.image_paths[i]
208 | c2w = torch.FloatTensor(self.poses[i])
209 |
210 | img = Image.open(image_path).convert('RGB')
211 | if self.downsample != 1.0:
212 | img = img.resize(self.img_wh, Image.LANCZOS)
213 | img = self.transform(img) # (3, h, w)
214 |
215 | img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
216 | self.all_rgbs += [img]
217 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
218 | rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
219 | # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
220 |
221 | self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
222 |
223 | if not self.is_stack:
224 | self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
225 | self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
226 | else:
227 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3)
228 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
229 |
230 |
231 | def define_transforms(self):
232 | self.transform = T.ToTensor()
233 |
234 | def __len__(self):
235 | return len(self.all_rgbs)
236 |
237 | def __getitem__(self, idx):
238 |
239 | sample = {'rays': self.all_rays[idx],
240 | 'rgbs': self.all_rgbs[idx]}
241 |
242 | return sample
--------------------------------------------------------------------------------
/dataLoader/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch, re
2 | import numpy as np
3 | from torch import searchsorted
4 | from kornia import create_meshgrid
5 |
6 |
7 | # from utils import index_point_feature
8 |
9 | def depth2dist(z_vals, cos_angle):
10 | # z_vals: [N_ray N_sample]
11 | device = z_vals.device
12 | dists = z_vals[..., 1:] - z_vals[..., :-1]
13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
14 | dists = dists * cos_angle.unsqueeze(-1)
15 | return dists
16 |
17 |
18 | def ndc2dist(ndc_pts, cos_angle):
19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples]
21 | return dists
22 |
23 |
24 | def get_ray_directions(H, W, focal, center=None):
25 | """
26 | Get ray directions for all pixels in camera coordinate.
27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
28 | ray-tracing-generating-camera-rays/standard-coordinate-systems
29 | Inputs:
30 | H, W, focal: image height, width and focal length
31 | Outputs:
32 | directions: (H, W, 3), the direction of the rays in camera coordinate
33 | """
34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
35 |
36 | i, j = grid.unbind(-1)
37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
38 | # see https://github.com/bmild/nerf/issues/24
39 | cent = center if center is not None else [W / 2, H / 2]
40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
41 |
42 | return directions
43 |
44 |
45 | def get_ray_directions_blender(H, W, focal, center=None):
46 | """
47 | Get ray directions for all pixels in camera coordinate.
48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
49 | ray-tracing-generating-camera-rays/standard-coordinate-systems
50 | Inputs:
51 | H, W, focal: image height, width and focal length
52 | Outputs:
53 | directions: (H, W, 3), the direction of the rays in camera coordinate
54 | """
55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
56 | i, j = grid.unbind(-1)
57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
58 | # see https://github.com/bmild/nerf/issues/24
59 | cent = center if center is not None else [W / 2, H / 2]
60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
61 | -1) # (H, W, 3)
62 |
63 | return directions
64 |
65 |
66 | def get_rays(directions, c2w):
67 | """
68 | Get ray origin and normalized directions in world coordinate for all pixels in one image.
69 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
70 | ray-tracing-generating-camera-rays/standard-coordinate-systems
71 | Inputs:
72 | directions: (H, W, 3) precomputed ray directions in camera coordinate
73 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
74 | Outputs:
75 | rays_o: (H*W, 3), the origin of the rays in world coordinate
76 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
77 | """
78 | # Rotate ray directions from camera coordinate to the world coordinate
79 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
80 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
81 | # The origin of all rays is the camera origin in world coordinate
82 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
83 |
84 | rays_d = rays_d.view(-1, 3)
85 | rays_o = rays_o.view(-1, 3)
86 |
87 | return rays_o, rays_d
88 |
89 |
90 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
91 | # Shift ray origins to near plane
92 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
93 | rays_o = rays_o + t[..., None] * rays_d
94 |
95 | # Projection
96 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
97 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
98 | o2 = 1. + 2. * near / rays_o[..., 2]
99 |
100 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
101 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
102 | d2 = -2. * near / rays_o[..., 2]
103 |
104 | rays_o = torch.stack([o0, o1, o2], -1)
105 | rays_d = torch.stack([d0, d1, d2], -1)
106 |
107 | return rays_o, rays_d
108 |
109 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
110 | # Shift ray origins to near plane
111 | t = (near - rays_o[..., 2]) / rays_d[..., 2]
112 | rays_o = rays_o + t[..., None] * rays_d
113 |
114 | # Projection
115 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
116 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
117 | o2 = 1. - 2. * near / rays_o[..., 2]
118 |
119 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
120 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
121 | d2 = 2. * near / rays_o[..., 2]
122 |
123 | rays_o = torch.stack([o0, o1, o2], -1)
124 | rays_d = torch.stack([d0, d1, d2], -1)
125 |
126 | return rays_o, rays_d
127 |
128 | # Hierarchical sampling (section 5.2)
129 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
130 | device = weights.device
131 | # Get pdf
132 | weights = weights + 1e-5 # prevent nans
133 | pdf = weights / torch.sum(weights, -1, keepdim=True)
134 | cdf = torch.cumsum(pdf, -1)
135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
136 |
137 | # Take uniform samples
138 | if det:
139 | u = torch.linspace(0., 1., steps=N_samples, device=device)
140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
141 | else:
142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)
143 |
144 | # Pytest, overwrite u with numpy's fixed random numbers
145 | if pytest:
146 | np.random.seed(0)
147 | new_shape = list(cdf.shape[:-1]) + [N_samples]
148 | if det:
149 | u = np.linspace(0., 1., N_samples)
150 | u = np.broadcast_to(u, new_shape)
151 | else:
152 | u = np.random.rand(*new_shape)
153 | u = torch.Tensor(u)
154 |
155 | # Invert CDF
156 | u = u.contiguous()
157 | inds = searchsorted(cdf.detach(), u, right=True)
158 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
159 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
160 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
161 |
162 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
163 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
164 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
165 |
166 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
167 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
168 | t = (u - cdf_g[..., 0]) / denom
169 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
170 |
171 | return samples
172 |
173 |
174 | def dda(rays_o, rays_d, bbox_3D):
175 | inv_ray_d = 1.0 / (rays_d + 1e-6)
176 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3
177 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
178 | t = torch.stack((t_min, t_max)) # 2 N_rays 3
179 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
180 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
181 | return t_min, t_max
182 |
183 |
184 | def ray_marcher(rays,
185 | N_samples=64,
186 | lindisp=False,
187 | perturb=0,
188 | bbox_3D=None):
189 | """
190 | sample points along the rays
191 | Inputs:
192 | rays: ()
193 |
194 | Returns:
195 |
196 | """
197 |
198 | # Decompose the inputs
199 | N_rays = rays.shape[0]
200 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
201 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
202 |
203 | if bbox_3D is not None:
204 | # cal aabb boundles
205 | near, far = dda(rays_o, rays_d, bbox_3D)
206 |
207 | # Sample depth points
208 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
209 | if not lindisp: # use linear sampling in depth space
210 | z_vals = near * (1 - z_steps) + far * z_steps
211 | else: # use linear sampling in disparity space
212 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
213 |
214 | z_vals = z_vals.expand(N_rays, N_samples)
215 |
216 | if perturb > 0: # perturb sampling depths (z_vals)
217 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points
218 | # get intervals between samples
219 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
220 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
221 |
222 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
223 | z_vals = lower + (upper - lower) * perturb_rand
224 |
225 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \
226 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
227 |
228 | return xyz_coarse_sampled, rays_o, rays_d, z_vals
229 |
230 |
231 | def read_pfm(filename):
232 | file = open(filename, 'rb')
233 | color = None
234 | width = None
235 | height = None
236 | scale = None
237 | endian = None
238 |
239 | header = file.readline().decode('utf-8').rstrip()
240 | if header == 'PF':
241 | color = True
242 | elif header == 'Pf':
243 | color = False
244 | else:
245 | raise Exception('Not a PFM file.')
246 |
247 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
248 | if dim_match:
249 | width, height = map(int, dim_match.groups())
250 | else:
251 | raise Exception('Malformed PFM header.')
252 |
253 | scale = float(file.readline().rstrip())
254 | if scale < 0: # little-endian
255 | endian = '<'
256 | scale = -scale
257 | else:
258 | endian = '>' # big-endian
259 |
260 | data = np.fromfile(file, endian + 'f')
261 | shape = (height, width, 3) if color else (height, width)
262 |
263 | data = np.reshape(data, shape)
264 | data = np.flipud(data)
265 | file.close()
266 | return data, scale
267 |
268 |
269 | def ndc_bbox(all_rays):
270 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
271 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
272 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
273 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
274 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
275 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))
--------------------------------------------------------------------------------
/dataLoader/colmap2nerf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | import argparse
12 | import os
13 | from pathlib import Path, PurePosixPath
14 |
15 | import numpy as np
16 | import json
17 | import sys
18 | import math
19 | import cv2
20 | import os
21 | import shutil
22 |
23 | def parse_args():
24 | parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")
25 |
26 | parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
27 | parser.add_argument("--video_fps", default=2)
28 | parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
29 | parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
30 | parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
31 | parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
32 | parser.add_argument("--images", default="images", help="input path to the images")
33 | parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
34 | parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
35 | parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
36 | parser.add_argument("--out", default="transforms.json", help="output path")
37 | args = parser.parse_args()
38 | return args
39 |
40 | def do_system(arg):
41 | print(f"==== running: {arg}")
42 | err = os.system(arg)
43 | if err:
44 | print("FATAL: command failed")
45 | sys.exit(err)
46 |
47 | def run_ffmpeg(args):
48 | if not os.path.isabs(args.images):
49 | args.images = os.path.join(os.path.dirname(args.video_in), args.images)
50 | images = args.images
51 | video = args.video_in
52 | fps = float(args.video_fps) or 1.0
53 | print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
54 | if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
55 | sys.exit(1)
56 | try:
57 | shutil.rmtree(images)
58 | except:
59 | pass
60 | do_system(f"mkdir {images}")
61 |
62 | time_slice_value = ""
63 | time_slice = args.time_slice
64 | if time_slice:
65 | start, end = time_slice.split(",")
66 | time_slice_value = f",select='between(t\,{start}\,{end})'"
67 | do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")
68 |
69 | def run_colmap(args):
70 | db=args.colmap_db
71 | images=args.images
72 | db_noext=str(Path(db).with_suffix(""))
73 |
74 | if args.text=="text":
75 | args.text=db_noext+"_text"
76 | text=args.text
77 | sparse=db_noext+"_sparse"
78 | print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
79 | if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
80 | sys.exit(1)
81 | if os.path.exists(db):
82 | os.remove(db)
83 | do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
84 | do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}")
85 | try:
86 | shutil.rmtree(sparse)
87 | except:
88 | pass
89 | do_system(f"mkdir {sparse}")
90 | do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
91 | do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
92 | try:
93 | shutil.rmtree(text)
94 | except:
95 | pass
96 | do_system(f"mkdir {text}")
97 | do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")
98 |
99 | def variance_of_laplacian(image):
100 | return cv2.Laplacian(image, cv2.CV_64F).var()
101 |
102 | def sharpness(imagePath):
103 | image = cv2.imread(imagePath)
104 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
105 | fm = variance_of_laplacian(gray)
106 | return fm
107 |
108 | def qvec2rotmat(qvec):
109 | return np.array([
110 | [
111 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
114 | ], [
115 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
116 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
117 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
118 | ], [
119 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
120 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
121 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
122 | ]
123 | ])
124 |
125 | def rotmat(a, b):
126 | a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
127 | v = np.cross(a, b)
128 | c = np.dot(a, b)
129 | s = np.linalg.norm(v)
130 | kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
131 | return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
132 |
133 | def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
134 | da = da / np.linalg.norm(da)
135 | db = db / np.linalg.norm(db)
136 | c = np.cross(da, db)
137 | denom = np.linalg.norm(c)**2
138 | t = ob - oa
139 | ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
140 | tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
141 | if ta > 0:
142 | ta = 0
143 | if tb > 0:
144 | tb = 0
145 | return (oa+ta*da+ob+tb*db) * 0.5, denom
146 |
147 | if __name__ == "__main__":
148 | args = parse_args()
149 | if args.video_in != "":
150 | run_ffmpeg(args)
151 | if args.run_colmap:
152 | run_colmap(args)
153 | AABB_SCALE = int(args.aabb_scale)
154 | SKIP_EARLY = int(args.skip_early)
155 | IMAGE_FOLDER = args.images
156 | TEXT_FOLDER = args.text
157 | OUT_PATH = args.out
158 | print(f"outputting to {OUT_PATH}...")
159 | with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
160 | angle_x = math.pi / 2
161 | for line in f:
162 | # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
163 | # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
164 | # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
165 | if line[0] == "#":
166 | continue
167 | els = line.split(" ")
168 | w = float(els[2])
169 | h = float(els[3])
170 | fl_x = float(els[4])
171 | fl_y = float(els[4])
172 | k1 = 0
173 | k2 = 0
174 | p1 = 0
175 | p2 = 0
176 | cx = w / 2
177 | cy = h / 2
178 | if els[1] == "SIMPLE_PINHOLE":
179 | cx = float(els[5])
180 | cy = float(els[6])
181 | elif els[1] == "PINHOLE":
182 | fl_y = float(els[5])
183 | cx = float(els[6])
184 | cy = float(els[7])
185 | elif els[1] == "SIMPLE_RADIAL":
186 | cx = float(els[5])
187 | cy = float(els[6])
188 | k1 = float(els[7])
189 | elif els[1] == "RADIAL":
190 | cx = float(els[5])
191 | cy = float(els[6])
192 | k1 = float(els[7])
193 | k2 = float(els[8])
194 | elif els[1] == "OPENCV":
195 | fl_y = float(els[5])
196 | cx = float(els[6])
197 | cy = float(els[7])
198 | k1 = float(els[8])
199 | k2 = float(els[9])
200 | p1 = float(els[10])
201 | p2 = float(els[11])
202 | else:
203 | print("unknown camera model ", els[1])
204 | # fl = 0.5 * w / tan(0.5 * angle_x);
205 | angle_x = math.atan(w / (fl_x * 2)) * 2
206 | angle_y = math.atan(h / (fl_y * 2)) * 2
207 | fovx = angle_x * 180 / math.pi
208 | fovy = angle_y * 180 / math.pi
209 |
210 | print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")
211 |
212 | with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
213 | i = 0
214 | bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
215 | out = {
216 | "camera_angle_x": angle_x,
217 | "camera_angle_y": angle_y,
218 | "fl_x": fl_x,
219 | "fl_y": fl_y,
220 | "k1": k1,
221 | "k2": k2,
222 | "p1": p1,
223 | "p2": p2,
224 | "cx": cx,
225 | "cy": cy,
226 | "w": w,
227 | "h": h,
228 | "aabb_scale": AABB_SCALE,
229 | "frames": [],
230 | }
231 |
232 | up = np.zeros(3)
233 | for line in f:
234 | line = line.strip()
235 | if line[0] == "#":
236 | continue
237 | i = i + 1
238 | if i < SKIP_EARLY*2:
239 | continue
240 | if i % 2 == 1:
241 | elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
242 | #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
243 | # why is this requireing a relitive path while using ^
244 | image_rel = os.path.relpath(IMAGE_FOLDER)
245 | name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
246 | b=sharpness(name)
247 | print(name, "sharpness=",b)
248 | image_id = int(elems[0])
249 | qvec = np.array(tuple(map(float, elems[1:5])))
250 | tvec = np.array(tuple(map(float, elems[5:8])))
251 | R = qvec2rotmat(-qvec)
252 | t = tvec.reshape([3,1])
253 | m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
254 | c2w = np.linalg.inv(m)
255 | c2w[0:3,2] *= -1 # flip the y and z axis
256 | c2w[0:3,1] *= -1
257 | c2w = c2w[[1,0,2,3],:] # swap y and z
258 | c2w[2,:] *= -1 # flip whole world upside down
259 |
260 | up += c2w[0:3,1]
261 |
262 | frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
263 | out["frames"].append(frame)
264 | nframes = len(out["frames"])
265 | up = up / np.linalg.norm(up)
266 | print("up vector was", up)
267 | R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
268 | R = np.pad(R,[0,1])
269 | R[-1, -1] = 1
270 |
271 |
272 | for f in out["frames"]:
273 | f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
274 |
275 | # find a central point they are all looking at
276 | print("computing center of attention...")
277 | totw = 0.0
278 | totp = np.array([0.0, 0.0, 0.0])
279 | for f in out["frames"]:
280 | mf = f["transform_matrix"][0:3,:]
281 | for g in out["frames"]:
282 | mg = g["transform_matrix"][0:3,:]
283 | p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
284 | if w > 0.01:
285 | totp += p*w
286 | totw += w
287 | totp /= totw
288 | print(totp) # the cameras are looking at totp
289 | for f in out["frames"]:
290 | f["transform_matrix"][0:3,3] -= totp
291 |
292 | avglen = 0.
293 | for f in out["frames"]:
294 | avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
295 | avglen /= nframes
296 | print("avg camera distance from origin", avglen)
297 | for f in out["frames"]:
298 | f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
299 |
300 | for f in out["frames"]:
301 | f["transform_matrix"] = f["transform_matrix"].tolist()
302 | print(nframes,"frames")
303 | print(f"writing {OUT_PATH}")
304 | with open(OUT_PATH, "w") as outfile:
305 | json.dump(out, outfile, indent=2)
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from tqdm.auto import tqdm
4 | from opt import config_parser
5 |
6 | from renderer import *
7 | from utils import *
8 | import datetime
9 |
10 | from dataLoader import dataset_dict
11 | import sys
12 |
13 | import time
14 |
15 |
16 |
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 |
19 | renderer = OctreeRender_trilinear_fast
20 |
21 |
22 | class SimpleSampler:
23 | def __init__(self, total, batch):
24 | self.total = total
25 | self.batch = batch
26 | self.curr = total
27 | self.ids = None
28 |
29 | def nextids(self):
30 | self.curr+=self.batch
31 | if self.curr + self.batch > self.total:
32 | self.ids = torch.LongTensor(np.random.permutation(self.total))
33 | self.curr = 0
34 | return self.ids[self.curr:self.curr+self.batch]
35 |
36 | @torch.no_grad()
37 | def render_test(args):
38 | # init dataset
39 | dataset = dataset_dict[args.dataset_name]
40 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
41 | white_bg = test_dataset.white_bg
42 | ndc_ray = args.ndc_ray
43 |
44 | if not os.path.exists(args.ckpt):
45 | print('the ckpt path does not exists!!')
46 | return
47 |
48 | tensorf = torch.load(args.ckpt, map_location=device)
49 |
50 | logfolder = os.path.dirname(args.ckpt)
51 | if args.render_train:
52 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
53 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
54 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
55 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
56 | print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
57 |
58 | if args.render_test:
59 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
60 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
61 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
62 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
63 |
64 | if args.render_path:
65 | c2ws = test_dataset.render_path
66 | os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
67 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
68 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
69 |
70 | def reconstruction(args):
71 |
72 | # init dataset
73 | dataset = dataset_dict[args.dataset_name]
74 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
75 | test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
76 | white_bg = train_dataset.white_bg
77 | near_far = train_dataset.near_far
78 | ndc_ray = args.ndc_ray
79 |
80 | # init resolution
81 | upsamp_list = args.upsamp_list
82 | update_AlphaMask_list = args.update_AlphaMask_list
83 | n_lamb_sigma = args.n_lamb_sigma
84 | n_lamb_sh = args.n_lamb_sh
85 |
86 |
87 | if args.add_timestamp:
88 | logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
89 | else:
90 | logfolder = f'{args.basedir}/{args.expname}'
91 |
92 |
93 | # init log file
94 | os.makedirs(logfolder, exist_ok=True)
95 | os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
96 | os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
97 | os.makedirs(f'{logfolder}/rgba', exist_ok=True)
98 |
99 |
100 | # init parameters
101 | aabb = train_dataset.scene_bbox.to(device)
102 | reso_cur = N_to_reso(args.N_voxel_init, aabb)
103 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
104 |
105 |
106 | if args.ckpt is not None:
107 | tensorf = torch.load(args.ckpt, map_location=device)
108 | else:
109 | tensorf = eval(args.model_name)(aabb, reso_cur, device,
110 | density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
111 | shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
112 | pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct)
113 |
114 | grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
115 | if args.lr_decay_iters > 0:
116 | lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
117 | else:
118 | args.lr_decay_iters = args.n_iters
119 | lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
120 |
121 | print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
122 |
123 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
124 |
125 |
126 | #linear in logrithmic space
127 | N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]
128 |
129 |
130 | torch.cuda.empty_cache()
131 | PSNRs,PSNRs_test = [],[0]
132 | batch_size = 2048
133 |
134 | allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
135 | if not args.ndc_ray:
136 | allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)
137 | trainingSampler = SimpleSampler(allrays.shape[0], batch_size)
138 |
139 | L1_reg_weight = args.L1_weight_inital
140 | print("initial L1_reg_weight", L1_reg_weight)
141 |
142 | pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
143 | for iteration in pbar:
144 |
145 | ray_idx = trainingSampler.nextids()
146 | rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)
147 |
148 | #rgb_map, alphas_map, depth_map, weights, uncertainty
149 | rgb_map, alphas_map, depth_map, weights, others = renderer(rays_train,
150 | tensorf,
151 | chunk=batch_size,
152 | N_samples=nSamples,
153 | white_bg = white_bg,
154 | ndc_ray=ndc_ray,
155 | device=device,
156 | is_train=True)
157 |
158 | mse_loss = torch.mean((rgb_map - rgb_train) ** 2)
159 | total_loss = mse_loss
160 |
161 | if others['normals'] is not None:
162 | Ro = torch.sum(others['normals'] * others['valid_viewdirs'], dim=-1)
163 | Ro = F.relu(Ro).pow(2) * others['valid_weights']
164 | Ro = Ro.mean()
165 | total_loss += 0.3 * Ro
166 |
167 | if L1_reg_weight > 0:
168 | loss_reg_L1 = tensorf.density_L1()
169 | total_loss += L1_reg_weight*loss_reg_L1
170 |
171 | optimizer.zero_grad()
172 | total_loss.backward()
173 | optimizer.step()
174 |
175 | mse_loss = mse_loss.detach().item()
176 | PSNRs.append(-10.0 * np.log(mse_loss) / np.log(10.0))
177 |
178 |
179 | for param_group in optimizer.param_groups:
180 | param_group['lr'] = param_group['lr'] * lr_factor
181 |
182 | # Print the current values of the losses.
183 | if iteration % args.progress_refresh_rate == 0:
184 | pbar.set_description(
185 | f'Iteration {iteration:05d}:'
186 | + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
187 | + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
188 | + f' mse = {mse_loss:.6f}'
189 | )
190 | PSNRs = []
191 |
192 |
193 | if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
194 | PSNRs_test = evaluation(test_dataset,tensorf,
195 | args,
196 | renderer,
197 | f'{logfolder}/imgs_vis/',
198 | N_vis=args.N_vis,
199 | prtx=f'{iteration:06d}_',
200 | N_samples=nSamples,
201 | white_bg = white_bg,
202 | ndc_ray=ndc_ray,
203 | compute_extra_metrics=False)
204 |
205 |
206 | if iteration in update_AlphaMask_list:
207 |
208 | # reso_cur = N_to_reso(250**3, tensorf.aabb)
209 | if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
210 | reso_mask = reso_cur
211 | new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
212 | if iteration == update_AlphaMask_list[0]:
213 | tensorf.shrink(new_aabb)
214 | L1_reg_weight = args.L1_weight_rest
215 | print("continuing L1_reg_weight", L1_reg_weight)
216 |
217 |
218 | if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
219 | # filter rays outside the bbox
220 | allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)
221 |
222 | batch_size = args.batch_size
223 | trainingSampler = SimpleSampler(allrgbs.shape[0], batch_size)
224 | print(f'Update batch size to {args.batch_size}')
225 |
226 |
227 | if iteration in upsamp_list:
228 | n_voxels = N_voxel_list.pop(0)
229 | reso_cur = N_to_reso(n_voxels, tensorf.aabb)
230 | nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
231 | tensorf.upsample_volume_grid(reso_cur)
232 |
233 | if args.lr_upsample_reset:
234 | lr_scale = 1 #0.1 ** (iteration / args.n_iters)
235 | else:
236 | lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
237 | grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
238 | optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
239 |
240 |
241 | if iteration == upsamp_list[0]:
242 | batch_size = args.batch_size
243 | trainingSampler = SimpleSampler(allrgbs.shape[0], batch_size)
244 | print(f'Update batch size to {batch_size}')
245 |
246 |
247 | torch.save(tensorf, f'{logfolder}/{args.expname}.pt')
248 |
249 |
250 | if args.render_train:
251 | os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
252 | train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
253 | PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
254 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
255 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
256 |
257 | if args.render_test:
258 | os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
259 | PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
260 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
261 | print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
262 |
263 | if args.render_path:
264 | c2ws = test_dataset.render_path
265 | print('========>',c2ws.shape)
266 | os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
267 | evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
268 | N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
269 |
270 |
271 | if __name__ == '__main__':
272 |
273 | torch.set_default_dtype(torch.float32)
274 | torch.manual_seed(20211202)
275 | np.random.seed(20211202)
276 |
277 | args = config_parser()
278 |
279 | if args.export_mesh:
280 | export_mesh(args)
281 |
282 | if args.render_only and (args.render_test or args.render_path):
283 | render_test(args)
284 | else:
285 | reconstruction(args)
286 |
287 |
--------------------------------------------------------------------------------
/models/tensorBase.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import torch.nn.functional as F
4 | # from .sh import eval_sh_bases
5 | import numpy as np
6 | import time
7 |
8 |
9 | def positional_encoding(positions, freqs):
10 |
11 | freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,)
12 | pts = (positions[..., None] * freq_bands).reshape(
13 | positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF)
14 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
15 | return pts
16 |
17 | def raw2alpha(sigma, dist):
18 | # sigma, dist [N_rays, N_samples]
19 | alpha = 1. - torch.exp(-sigma*dist)
20 |
21 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
22 |
23 | weights = alpha * T[:, :-1] # [N_rays, N_samples]
24 | return alpha, weights, T[:,-1:]
25 |
26 |
27 | class AlphaGridMask(torch.nn.Module):
28 | def __init__(self, device, aabb, alpha_volume):
29 | super(AlphaGridMask, self).__init__()
30 | self.device = device
31 |
32 | self.aabb=aabb.to(self.device)
33 | self.aabbSize = self.aabb[1] - self.aabb[0]
34 | self.invgridSize = 1.0/self.aabbSize * 2
35 | self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:])
36 | self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device)
37 |
38 | def sample_alpha(self, xyz_sampled):
39 | xyz_sampled = self.normalize_coord(xyz_sampled)
40 | alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1)
41 |
42 | return alpha_vals
43 |
44 | def normalize_coord(self, xyz_sampled):
45 | return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1
46 |
47 |
48 | class TensorBase(torch.nn.Module):
49 | def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27,
50 | shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0],
51 | density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,
52 | pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0,
53 | fea2denseAct = 'softplus'):
54 | super(TensorBase, self).__init__()
55 |
56 | self.density_n_comp = density_n_comp
57 | self.app_n_comp = appearance_n_comp
58 | self.app_dim = app_dim
59 | self.aabb = aabb
60 | self.alphaMask = alphaMask
61 | self.device=device
62 |
63 | self.density_shift = density_shift
64 | self.alphaMask_thres = alphaMask_thres
65 | self.distance_scale = distance_scale
66 | self.rayMarch_weight_thres = rayMarch_weight_thres
67 | self.fea2denseAct = fea2denseAct
68 |
69 | self.near_far = near_far
70 | self.step_ratio = step_ratio
71 |
72 |
73 | self.update_stepSize(gridSize)
74 |
75 | self.matMode = [[0,1], [0,2], [1,2]]
76 | self.vecMode = [2, 1, 0]
77 | self.comp_w = [1,1,1]
78 |
79 | self.init_svd_volume(gridSize[0], device)
80 |
81 | self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC
82 |
83 | def update_stepSize(self, gridSize):
84 | self.aabbSize = self.aabb[1] - self.aabb[0]
85 | self.invaabbSize = 2.0/self.aabbSize
86 | self.gridSize= torch.LongTensor(gridSize).to(self.device)
87 | self.units=self.aabbSize / (self.gridSize-1)
88 | self.stepSize=torch.mean(self.units)*self.step_ratio
89 | self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
90 | self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1
91 |
92 | def init_svd_volume(self, res, device):
93 | pass
94 |
95 | def compute_features(self, xyz_sampled):
96 | pass
97 |
98 | def compute_densityfeature(self, xyz_sampled):
99 | pass
100 |
101 | def compute_appfeature(self, xyz_sampled):
102 | pass
103 |
104 | def normalize_coord(self, xyz_sampled):
105 | return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1
106 |
107 | def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001):
108 | pass
109 |
110 | def get_kwargs(self):
111 | return {
112 | 'aabb': self.aabb,
113 | 'gridSize':self.gridSize.tolist(),
114 | 'density_n_comp': self.density_n_comp,
115 | 'appearance_n_comp': self.app_n_comp,
116 | 'app_dim': self.app_dim,
117 |
118 | 'density_shift': self.density_shift,
119 | 'alphaMask_thres': self.alphaMask_thres,
120 | 'distance_scale': self.distance_scale,
121 | 'rayMarch_weight_thres': self.rayMarch_weight_thres,
122 | 'fea2denseAct': self.fea2denseAct,
123 |
124 | 'near_far': self.near_far,
125 | 'step_ratio': self.step_ratio,
126 |
127 | 'shadingMode': self.shadingMode,
128 | 'pos_pe': self.pos_pe,
129 | 'view_pe': self.view_pe,
130 | 'fea_pe': self.fea_pe,
131 | 'featureC': self.featureC
132 | }
133 |
134 | def save(self, path):
135 | kwargs = self.get_kwargs()
136 | ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
137 | if self.alphaMask is not None:
138 | alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
139 | ckpt.update({'alphaMask.shape':alpha_volume.shape})
140 | ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))})
141 | ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
142 | torch.save(ckpt, path)
143 |
144 | def load(self, ckpt):
145 | if 'alphaMask.aabb' in ckpt.keys():
146 | length = np.prod(ckpt['alphaMask.shape'])
147 | alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
148 | self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device))
149 | self.load_state_dict(ckpt['state_dict'])
150 |
151 |
152 | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
153 | N_samples = N_samples if N_samples > 0 else self.nSamples
154 | near, far = self.near_far
155 | interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
156 | if is_train:
157 | interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)
158 |
159 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
160 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
161 | return rays_pts, interpx, ~mask_outbbox
162 |
163 | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
164 | N_samples = N_samples if N_samples>0 else self.nSamples
165 | stepsize = self.stepSize
166 | near, far = self.near_far
167 | vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)
168 | rate_a = (self.aabb[1] - rays_o) / vec
169 | rate_b = (self.aabb[0] - rays_o) / vec
170 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
171 |
172 | rng = torch.arange(N_samples)[None].float()
173 | if is_train:
174 | rng = rng.repeat(rays_d.shape[-2],1)
175 | rng += torch.rand_like(rng[:,[0]])
176 | step = stepsize * rng.to(rays_o.device)
177 | interpx = (t_min[...,None] + step)
178 |
179 | rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
180 | mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1)
181 |
182 | return rays_pts, interpx, ~mask_outbbox
183 |
184 |
185 | def shrink(self, new_aabb, voxel_size):
186 | pass
187 |
188 | @torch.no_grad()
189 | def getDenseAlpha(self,gridSize=None):
190 | gridSize = self.gridSize if gridSize is None else gridSize
191 |
192 | samples = torch.stack(torch.meshgrid(
193 | torch.linspace(0, 1, gridSize[0]),
194 | torch.linspace(0, 1, gridSize[1]),
195 | torch.linspace(0, 1, gridSize[2]),
196 | ), -1).to(self.device)
197 | dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples
198 |
199 | alpha = torch.zeros_like(dense_xyz[...,0])
200 | for i in range(gridSize[0]):
201 | alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))
202 | return alpha, dense_xyz
203 |
204 | @torch.no_grad()
205 | def updateAlphaMask(self, gridSize=(200,200,200)):
206 |
207 | alpha, dense_xyz = self.getDenseAlpha(gridSize)
208 | dense_xyz = dense_xyz.transpose(0,2).contiguous()
209 | alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]
210 | total_voxels = gridSize[0] * gridSize[1] * gridSize[2]
211 |
212 | ks = 3
213 | alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
214 | alpha[alpha>=self.alphaMask_thres] = 1
215 | alpha[alpha0.5]
220 |
221 | xyz_min = valid_xyz.amin(0)
222 | xyz_max = valid_xyz.amax(0)
223 |
224 | new_aabb = torch.stack((xyz_min, xyz_max))
225 |
226 | total = torch.sum(alpha)
227 | print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100))
228 | return new_aabb
229 |
230 | @torch.no_grad()
231 | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False):
232 | print('========> filtering rays ...')
233 | tt = time.time()
234 |
235 | N = torch.tensor(all_rays.shape[:-1]).prod()
236 |
237 | mask_filtered = []
238 | idx_chunks = torch.split(torch.arange(N), chunk)
239 | for idx_chunk in idx_chunks:
240 | rays_chunk = all_rays[idx_chunk].to(self.device)
241 |
242 | rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
243 | if bbox_only:
244 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
245 | rate_a = (self.aabb[1] - rays_o) / vec
246 | rate_b = (self.aabb[0] - rays_o) / vec
247 | t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far)
248 | t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far)
249 | mask_inbbox = t_max > t_min
250 |
251 | else:
252 | xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
253 | mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
254 |
255 | mask_filtered.append(mask_inbbox.cpu())
256 |
257 | mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])
258 |
259 | print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
260 | return all_rays[mask_filtered], all_rgbs[mask_filtered]
261 |
262 |
263 | def feature2density(self, density_features):
264 | if self.fea2denseAct == "softplus":
265 | return F.softplus(density_features+self.density_shift)
266 | elif self.fea2denseAct == "relu":
267 | return F.relu(density_features)
268 |
269 |
270 | def compute_alpha(self, xyz_locs, length=1):
271 |
272 | if self.alphaMask is not None:
273 | alphas = self.alphaMask.sample_alpha(xyz_locs)
274 | alpha_mask = alphas > 0
275 | else:
276 | alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool)
277 |
278 |
279 | sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
280 |
281 | if alpha_mask.any():
282 | xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
283 | sigma_feature = self.compute_densityfeature(xyz_sampled)
284 | validsigma = self.feature2density(sigma_feature)
285 | sigma[alpha_mask] = validsigma
286 |
287 |
288 | alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1])
289 |
290 | return alpha
291 |
292 |
293 | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
294 |
295 | # sample points
296 | viewdirs = rays_chunk[:, 3:6]
297 | if ndc_ray:
298 | xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
299 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
300 | rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
301 | dists = dists * rays_norm
302 | viewdirs = viewdirs / rays_norm
303 | else:
304 | xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
305 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
306 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
307 |
308 | if self.alphaMask is not None:
309 | alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
310 | alpha_mask = alphas > 0
311 | ray_invalid = ~ray_valid
312 | ray_invalid[ray_valid] |= (~alpha_mask)
313 | ray_valid = ~ray_invalid
314 |
315 |
316 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
317 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
318 |
319 | if ray_valid.any():
320 | xyz_sampled = self.normalize_coord(xyz_sampled)
321 | sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
322 |
323 | validsigma = self.feature2density(sigma_feature)
324 | sigma[ray_valid] = validsigma
325 |
326 |
327 | alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
328 |
329 | app_mask = weight > self.rayMarch_weight_thres
330 |
331 | normals = None
332 | valid_viewdirs = None
333 | valid_weights = None
334 | if app_mask.any():
335 | valid_viewdirs = viewdirs[app_mask]
336 | valid_weights = weight[app_mask]
337 |
338 | app_features = self.compute_appfeature(xyz_sampled[app_mask])
339 | valid_rgbs, normals = self.rendering_net(valid_viewdirs, app_features)
340 | rgb[app_mask] = valid_rgbs
341 |
342 |
343 | acc_map = torch.sum(weight, -1)
344 | rgb_map = torch.sum(weight[..., None] * rgb, -2)
345 |
346 | if white_bg or (is_train and torch.rand((1,))<0.5):
347 | rgb_map = rgb_map + (1. - acc_map[..., None])
348 |
349 |
350 | rgb_map = rgb_map.clamp(0,1)
351 |
352 | with torch.no_grad():
353 | depth_map = torch.sum(weight * z_vals, -1)
354 | depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
355 |
356 | return {'rgb_map':rgb_map,
357 | 'depth_map':depth_map,
358 | 'normals':normals,
359 | 'valid_viewdirs':valid_viewdirs,
360 | 'valid_weights':valid_weights
361 | }
362 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | from torch.functional import align_tensors
2 | from .tensorBase import *
3 | from .quaternion_utils import *
4 | from utils import N_to_reso
5 |
6 | import numpy as np
7 | import math
8 |
9 |
10 | class TensorDecomposition(torch.nn.Module):
11 | def __init__(self, grid_size, num_features, scale, device, reduce_sum=False):
12 | super(TensorDecomposition, self).__init__()
13 | self.grid_size = torch.tensor(grid_size)
14 | self.num_voxels = grid_size[0] * grid_size[1] * grid_size[2]
15 | self.reduce_sum = reduce_sum
16 |
17 | X, Y, Z = grid_size
18 | self.plane_xy = torch.nn.Parameter(scale * torch.randn((1, num_features, Y, X), device=device))
19 | self.plane_yz = torch.nn.Parameter(scale * torch.randn((1, num_features, Z, Y), device=device))
20 | self.plane_xz = torch.nn.Parameter(scale * torch.randn((1, num_features, Z, X), device=device))
21 |
22 | self.line_z = torch.nn.Parameter(scale * torch.randn((1, num_features, Z, 1), device=device))
23 | self.line_x = torch.nn.Parameter(scale * torch.randn((1, num_features, X, 1), device=device))
24 | self.line_y = torch.nn.Parameter(scale * torch.randn((1, num_features, Y, 1), device=device))
25 |
26 | def forward(self, coords_plane, coords_line):
27 | feature_xy = F.grid_sample(self.plane_xy, coords_plane[0], mode='bilinear', align_corners=True)
28 | feature_yz = F.grid_sample(self.plane_yz, coords_plane[1], mode='bilinear', align_corners=True)
29 | feature_xz = F.grid_sample(self.plane_xz, coords_plane[2], mode='bilinear', align_corners=True)
30 |
31 | feature_x = F.grid_sample(self.line_x, coords_line[0], mode='bilinear', align_corners=True)
32 | feature_y = F.grid_sample(self.line_y, coords_line[1], mode='bilinear', align_corners=True)
33 | feature_z = F.grid_sample(self.line_z, coords_line[2], mode='bilinear', align_corners=True)
34 |
35 | out_x = feature_yz * feature_x
36 | out_y = feature_xz * feature_y
37 | out_z = feature_xy * feature_z
38 |
39 | _, C, N, _ = out_x.size()
40 | if self.reduce_sum:
41 | output = out_x.sum(dim=(0, 1, 3)) + out_y.sum(dim=(0, 1, 3)) + out_z.sum(dim=(0, 1, 3))
42 | else:
43 | output = [out_x.view(-1, N).T, out_y.view(-1, N).T, out_z.view(-1, N).T]
44 |
45 | return output
46 |
47 | def L1loss(self):
48 | loss = torch.abs(self.plane_xy).mean() + torch.abs(self.plane_yz).mean() + torch.abs(self.plane_xz).mean()
49 | loss += torch.abs(self.line_x).mean() + torch.abs(self.line_y).mean() + torch.abs(self.line_z).mean()
50 | loss = loss / 6
51 |
52 | return loss
53 |
54 | def TV_loss(self):
55 | loss = self.TV_loss_com(self.plane_xy)
56 | loss += self.TV_loss_com(self.plane_yz)
57 | loss += self.TV_loss_com(self.plane_xz)
58 | loss = loss / 6
59 |
60 | return loss
61 |
62 | def TV_loss_com(self, x):
63 | loss = (x[:, :, 1:] - x[:, :, :-1]).pow(2).mean() + (x[:, :, :, 1:] - x[:, :, :, :-1]).pow(2).mean()
64 | return loss
65 |
66 |
67 | def shrink(self, bound):
68 | # bound [3, 2]
69 | x, y, z = bound[0], bound[1], bound[2]
70 | self.plane_xy = torch.nn.Parameter(self.plane_xy.data[:, :, y[0]:y[1], x[0]:x[1]])
71 | self.plane_yz = torch.nn.Parameter(self.plane_yz.data[:, :, z[0]:z[1], y[0]:y[1]])
72 | self.plane_xz = torch.nn.Parameter(self.plane_xz.data[:, :, z[0]:z[1], x[0]:x[1]])
73 |
74 | self.line_x = torch.nn.Parameter(self.line_x.data[:, :, x[0]:x[1]])
75 | self.line_y = torch.nn.Parameter(self.line_y.data[:, :, y[0]:y[1]])
76 | self.line_z = torch.nn.Parameter(self.line_z.data[:, :, z[0]:z[1]])
77 |
78 | self.grid_size = bound[:, 1] - bound[:, 0]
79 |
80 |
81 | def upsample(self, aabb):
82 | target_res = N_to_reso(self.num_voxels, aabb)
83 |
84 |
85 | self.grid_size = torch.tensor(target_res)
86 |
87 | self.plane_xy = torch.nn.Parameter(F.interpolate(self.plane_xy.data,
88 | size=(target_res[1], target_res[0]), mode='bilinear', align_corners=True))
89 | self.plane_yz = torch.nn.Parameter(F.interpolate(self.plane_yz.data,
90 | size=(target_res[2], target_res[1]), mode='bilinear', align_corners=True))
91 | self.plane_xz = torch.nn.Parameter(F.interpolate(self.plane_xz.data,
92 | size=(target_res[2], target_res[0]), mode='bilinear', align_corners=True))
93 |
94 |
95 | self.line_x = torch.nn.Parameter(F.interpolate(self.line_x.data,
96 | size=(target_res[0], 1), mode='bilinear', align_corners=True))
97 | self.line_y = torch.nn.Parameter(F.interpolate(self.line_y.data,
98 | size=(target_res[1], 1), mode='bilinear', align_corners=True))
99 | self.line_z = torch.nn.Parameter(F.interpolate(self.line_z.data,
100 | size=(target_res[2], 1), mode='bilinear', align_corners=True))
101 |
102 | class MultiscaleTensorDecom(torch.nn.Module):
103 | def __init__(self, num_levels, num_features, base_resolution, max_resolution, device, reduce_sum=False, scale=0.1):
104 | super(MultiscaleTensorDecom, self).__init__()
105 | self.reduce_sum = reduce_sum
106 |
107 | tensors = []
108 | if num_levels == 1:
109 | factor = 1
110 | else:
111 | factor = math.exp( (math.log(max_resolution) - math.log(base_resolution)) / (num_levels-1) )
112 |
113 | for i in range(num_levels):
114 | level_resolution = int(base_resolution * factor**i)
115 | level_grid = (level_resolution, level_resolution, level_resolution)
116 | tensors.append(TensorDecomposition(level_grid, num_features, scale, device, reduce_sum=reduce_sum))
117 |
118 | self.tensors = torch.nn.ModuleList(tensors)
119 |
120 | def coords_split(self, pts, dim=2, z_vals=None):
121 | N, D = pts.size()
122 | pts = pts.view(1, N, 1, D)
123 |
124 | out_plane = []
125 | if dim == 2:
126 | out_plane.append(pts[..., [0, 1]])
127 | out_plane.append(pts[..., [1, 2]])
128 | out_plane.append(pts[..., [0, 2]])
129 | elif dim == 3:
130 | out_plane.append(pts[..., [0, 1, 2]][:, :, None])
131 | out_plane.append(pts[..., [1, 2, 0]][:, :, None])
132 | out_plane.append(pts[..., [0, 2, 1]][:, :, None])
133 |
134 | if z_vals is None:
135 | coord_x = pts.new_zeros(1, N, 1, 1)
136 | else:
137 | coord_x = z_vals.view(1, N, 1, 1)
138 | out_line = []
139 | out_line.append(torch.cat((coord_x, pts[..., [0]]), dim=-1))
140 | out_line.append(torch.cat((coord_x, pts[..., [1]]), dim=-1))
141 | out_line.append(torch.cat((coord_x, pts[..., [2]]), dim=-1))
142 |
143 | return out_plane, out_line
144 |
145 | def L1loss(self):
146 | loss = 0.
147 | for tensor in self.tensors:
148 | loss += tensor.L1loss()
149 |
150 | return loss / len(self.tensors)
151 |
152 | def shrink(self, aabb, new_aabb):
153 | aabb_size = aabb[1] - aabb[0]
154 | xyz_min, xyz_max = new_aabb
155 |
156 | for tensor in self.tensors:
157 | grid_size = tensor.grid_size
158 | units = aabb_size / (grid_size - 1)
159 | t_l, b_r = (xyz_min - aabb[0]) / units, (xyz_max - aabb[0]) / units
160 |
161 | t_l, b_r = torch.floor(t_l).long(), torch.ceil(b_r).long()
162 | b_r = torch.stack([b_r, grid_size]).amin(0)
163 |
164 | bound = torch.stack((t_l, b_r), dim=-1)
165 | tensor.shrink(bound)
166 |
167 | def upsample(self, aabb):
168 | for tensor in self.tensors:
169 | tensor.upsample(aabb)
170 |
171 | def forward(self, pts):
172 | coords_plane, coords_line = self.coords_split(pts)
173 |
174 | if self.reduce_sum:
175 | output = pts.new_zeros(pts.size(0))
176 | else:
177 | output = []
178 |
179 | for level_tensor in self.tensors:
180 | output += level_tensor(coords_plane, coords_line)
181 |
182 | return output
183 |
184 | class RenderingEquationEncoding(torch.nn.Module):
185 | def __init__(self, num_theta, num_phi, device):
186 | super(RenderingEquationEncoding, self).__init__()
187 |
188 | self.num_theta = num_theta
189 | self.num_phi = num_phi
190 |
191 | omega, omega_la, omega_mu = init_predefined_omega(num_theta, num_phi)
192 | self.omega = omega.view(1, num_theta, num_phi, 3).to(device)
193 | self.omega_la = omega_la.view(1, num_theta, num_phi, 3).to(device)
194 | self.omega_mu = omega_mu.view(1, num_theta, num_phi, 3).to(device)
195 |
196 | def forward(self, omega_o, a, la, mu):
197 | Smooth = F.relu((omega_o[:, None, None] * self.omega).sum(dim=-1, keepdim=True)) # N, num_theta, num_phi, 1
198 |
199 | la = F.softplus(la - 1)
200 | mu = F.softplus(mu - 1)
201 | exp_input = -la * (self.omega_la * omega_o[:, None, None]).sum(dim=-1, keepdim=True).pow(2) -mu * (self.omega_mu * omega_o[:, None, None]).sum(dim=-1, keepdim=True).pow(2)
202 | out = a * Smooth * torch.exp(exp_input)
203 |
204 | return out
205 |
206 | class RenderingNet(torch.nn.Module):
207 | def __init__(self, num_theta = 8, num_phi=16, data_dim_color=192, featureC=256, device='cpu'):
208 | super(RenderingNet, self).__init__()
209 |
210 | self.ch_cd = 3
211 | self.ch_s = 3
212 | self.ch_normal = 3
213 | self.ch_bottleneck = 128
214 |
215 | self.num_theta = 8
216 | self.num_phi = 16
217 | self.num_asg = self.num_theta * self.num_phi
218 |
219 | self.ch_asg_feature = 128
220 | self.ch_per_theta = self.ch_asg_feature // self.num_theta
221 |
222 | self.ch_a = 2
223 | self.ch_la = 1
224 | self.ch_mu = 1
225 | self.ch_per_asg = self.ch_a + self.ch_la + self.ch_mu
226 |
227 | self.ch_normal_dot_viewdir = 1
228 |
229 |
230 | self.ree_function = RenderingEquationEncoding(num_theta, num_phi, device)
231 |
232 | self.spatial_mlp = torch.nn.Sequential(
233 | torch.nn.Linear(data_dim_color, featureC),
234 | torch.nn.GELU(),
235 | torch.nn.Linear(featureC, featureC),
236 | torch.nn.GELU(),
237 | torch.nn.Linear(featureC, self.ch_cd + self.ch_s + self.ch_bottleneck + self.ch_normal + self.ch_asg_feature)).to(device)
238 |
239 | self.asg_mlp = torch.nn.Sequential(torch.nn.Linear(self.ch_per_theta, self.num_phi * self.ch_per_asg)).to(device)
240 |
241 | self.directional_mlp = torch.nn.Sequential(
242 | torch.nn.Linear(self.ch_bottleneck + self.num_asg * self.ch_a + self.ch_normal_dot_viewdir, featureC),
243 | torch.nn.GELU(),
244 | torch.nn.Linear(featureC, featureC),
245 | torch.nn.GELU(),
246 | torch.nn.Linear(featureC, featureC),
247 | torch.nn.GELU(),
248 | torch.nn.Linear(featureC, featureC),
249 | torch.nn.GELU(),
250 | torch.nn.Linear(featureC, featureC),
251 | torch.nn.GELU(),
252 | torch.nn.Linear(featureC, 3)).to(device)
253 |
254 |
255 | def spatial_mlp_forward(self, x):
256 | out = self.spatial_mlp(x)
257 | sections = [self.ch_cd, self.ch_s, self.ch_normal, self.ch_bottleneck, self.ch_asg_feature]
258 | diffuse_color, tint, normals, bottleneck, asg_features = torch.split(out, sections, dim=-1)
259 | normals = -F.normalize(normals, dim=1)
260 | return diffuse_color, tint, normals, bottleneck, asg_features
261 |
262 | def asg_mlp_forward(self, asg_feature):
263 | N = asg_feature.size(0)
264 | asg_feature = asg_feature.view(N, self.num_theta, -1)
265 | asg_params = self.asg_mlp(asg_feature)
266 | asg_params = asg_params.view(N, self.num_theta, self.num_phi, -1)
267 |
268 | a, la, mu = torch.split(asg_params, [self.ch_a, self.ch_la, self.ch_mu], dim=-1)
269 | return a, la, mu
270 |
271 | def directional_mlp_forward(self, x):
272 | out = self.directional_mlp(x)
273 | return out
274 |
275 | def reflect(self, viewdir, normal):
276 | out = 2 * (viewdir * normal).sum(dim=-1, keepdim=True) * normal - viewdir
277 | return out
278 |
279 | def forward(self, viewdir, feature):
280 | diffuse_color, tint, normal, bottleneck, asg_feature = self.spatial_mlp_forward(feature)
281 | refdir = self.reflect(-viewdir, normal)
282 |
283 | a, la, mu = self.asg_mlp_forward(asg_feature)
284 | ree = self.ree_function(refdir, a, la, mu) # N, num_theta, num_phi, ch_per_asg
285 | ree = ree.view(ree.size(0), -1)
286 |
287 | normal_dot_viewdir = ((-viewdir) * normal).sum(dim=-1, keepdim=True)
288 | dir_mlp_input = torch.cat([bottleneck, ree, normal_dot_viewdir], dim=-1)
289 | specular_color = self.directional_mlp_forward(dir_mlp_input)
290 |
291 | raw_rgb = diffuse_color + tint * specular_color
292 | rgb = torch.sigmoid(raw_rgb)
293 |
294 | return rgb, normal
295 |
296 |
297 | ########################################################################################
298 |
299 | class NRFF(TensorBase):
300 | def __init__(self, aabb, gridSize, device, **kargs):
301 | super(NRFF, self).__init__(aabb, gridSize, device, **kargs)
302 |
303 | self.rendering_net = RenderingNet(8, 16, device=device)
304 | self.init_feature_field(device)
305 |
306 | def init_feature_field(self, device):
307 | self.density_field = MultiscaleTensorDecom(num_levels=16, num_features=2, base_resolution=16, max_resolution=512, device=device, reduce_sum=True)
308 | self.appearance_field = MultiscaleTensorDecom(num_levels=16, num_features=4, base_resolution=16, max_resolution=512, device=device)
309 |
310 | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
311 | grad_vars = []
312 |
313 | grad_vars += [{'params': self.density_field.parameters(), 'lr': lr_init_spatialxyz}]
314 | grad_vars += [{'params': self.appearance_field.parameters(), 'lr': lr_init_spatialxyz}]
315 | grad_vars += [{'params': self.rendering_net.parameters(), 'lr':lr_init_network}]
316 |
317 | return grad_vars
318 |
319 |
320 | def density_L1(self):
321 | return self.density_field.L1loss()
322 |
323 | def compute_densityfeature(self, pts):
324 | output = self.density_field(pts)
325 | return output
326 |
327 | def compute_appfeature(self, pts):
328 | app_feature = self.appearance_field(pts)
329 | app_feature = torch.cat(app_feature, dim=-1)
330 | return app_feature
331 |
332 | @torch.no_grad()
333 | def shrink(self, new_aabb):
334 | self.train_aabb = new_aabb
335 |
336 | self.density_field.shrink(self.aabb.cpu(), new_aabb.cpu())
337 | self.appearance_field.shrink(self.aabb.cpu(), new_aabb.cpu())
338 |
339 | xyz_min, xyz_max = new_aabb
340 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
341 |
342 |
343 | t_l, b_r = torch.floor(t_l).long(), torch.ceil(b_r).long()
344 | b_r = torch.stack([b_r, self.gridSize]).amin(0)
345 |
346 | if not torch.all(self.alphaMask.gridSize == self.gridSize):
347 | t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
348 | correct_aabb = torch.zeros_like(new_aabb)
349 | correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
350 | correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
351 | new_aabb = correct_aabb
352 |
353 | newSize = b_r - t_l
354 | self.aabb = new_aabb
355 |
356 | self.density_field.upsample(new_aabb.cpu())
357 | self.appearance_field.upsample(new_aabb.cpu())
358 |
359 | self.update_stepSize((newSize[0], newSize[1], newSize[2]))
360 |
361 |
362 | @torch.no_grad()
363 | def upsample_volume_grid(self, res_target):
364 | self.update_stepSize(res_target)
365 |
--------------------------------------------------------------------------------