├── media └── overview.png ├── requirements.txt ├── models ├── __pycache__ │ ├── dataset.cpython-37.pyc │ ├── fields.cpython-310.pyc │ ├── fields.cpython-37.pyc │ ├── dataset.cpython-310.pyc │ ├── embedder.cpython-310.pyc │ ├── embedder.cpython-37.pyc │ ├── ray_utils.cpython-310.pyc │ ├── renderer.cpython-310.pyc │ ├── renderer.cpython-37.pyc │ └── blender_swap.cpython-310.pyc ├── embedder.py ├── sh.py ├── scannet_blender.py ├── dataset.py ├── fields.py ├── nerf2neus.py ├── ray_utils.py ├── tensorf2neus.py ├── blender_swap.py ├── tensorBase.py ├── renderer.py └── tensoRF.py ├── confs ├── replica.conf └── blendswap.conf ├── README.md ├── mesh_metrics.py ├── cull_mesh.py └── exp_runner.py /media/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/media/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open3d 2 | pytorch3d 3 | trimesh 4 | matplotlib 5 | opencv-python 6 | tqdm 7 | scikit-image -------------------------------------------------------------------------------- /models/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/fields.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/fields.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/fields.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/fields.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/embedder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/embedder.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/embedder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/embedder.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/ray_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/ray_utils.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/renderer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/renderer.cpython-310.pyc -------------------------------------------------------------------------------- /models/__pycache__/renderer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/renderer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/blender_swap.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/blender_swap.cpython-310.pyc -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | 'include_input': True, 42 | 'input_dims': input_dims, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | -------------------------------------------------------------------------------- /confs/replica.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = log/room0_v2 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | type = Blender 11 | data_dir = data/Replica 12 | scene = room0 13 | } 14 | 15 | train { 16 | learning_rate = 5e-4 17 | learning_rate_alpha = 0.05 18 | end_iter = 400000 19 | 20 | batch_size = 512 21 | validate_resolution_level = 1 22 | warm_up_end = 5000 23 | anneal_end = 50000 24 | use_white_bkgd = False 25 | 26 | save_freq = 10000 27 | val_freq = 2500 28 | val_mesh_freq = 5000 29 | report_freq = 100 30 | 31 | igr_weight = 0.1 32 | mask_weight = 0.0 33 | } 34 | 35 | model { 36 | nerf { 37 | D = 8, 38 | d_in = 4, 39 | d_in_view = 3, 40 | W = 256, 41 | multires = 10, 42 | multires_view = 4, 43 | output_ch = 4, 44 | skips=[4], 45 | use_viewdirs=True 46 | } 47 | 48 | sdf_network { 49 | d_out = 257 50 | d_in = 3 51 | d_hidden = 256 52 | n_layers = 8 53 | skip_in = [4] 54 | multires = 6 55 | bias = 0.5 56 | scale = 3.0 57 | geometric_init = True 58 | weight_norm = True 59 | } 60 | 61 | variance_network { 62 | init_val = 0.3 63 | } 64 | 65 | rendering_network { 66 | d_feature = 256 67 | mode = idr 68 | d_in = 9 69 | d_out = 3 70 | d_hidden = 256 71 | n_layers = 4 72 | weight_norm = True 73 | multires_view = 4 74 | squeeze_out = True 75 | } 76 | 77 | neus_renderer { 78 | n_samples = 64 79 | n_importance = 64 80 | n_outside = 0 81 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 82 | perturb = 1.0 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /confs/blendswap.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = log/breakfast_room 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | type = Blender 11 | data_dir = /data/blendswap 12 | scene = breakfast_room 13 | } 14 | 15 | train { 16 | learning_rate = 5e-4 17 | learning_rate_alpha = 0.05 18 | end_iter = 300000 19 | 20 | batch_size = 512 21 | validate_resolution_level = 1 22 | warm_up_end = 5000 23 | anneal_end = 50000 24 | use_white_bkgd = False 25 | 26 | save_freq = 10000 27 | val_freq = 2500 28 | val_mesh_freq = 5000 29 | report_freq = 100 30 | 31 | igr_weight = 0.1 32 | mask_weight = 0.0 33 | } 34 | 35 | model { 36 | nerf { 37 | D = 8, 38 | d_in = 4, 39 | d_in_view = 3, 40 | W = 256, 41 | multires = 10, 42 | multires_view = 4, 43 | output_ch = 4, 44 | skips=[4], 45 | use_viewdirs=True 46 | } 47 | 48 | sdf_network { 49 | d_out = 257 50 | d_in = 3 51 | d_hidden = 256 52 | n_layers = 8 53 | skip_in = [4] 54 | multires = 6 55 | bias = 0.5 56 | scale = 3.0 57 | geometric_init = True 58 | weight_norm = True 59 | } 60 | 61 | variance_network { 62 | init_val = 0.3 63 | } 64 | 65 | rendering_network { 66 | d_feature = 256 67 | mode = idr 68 | d_in = 9 69 | d_out = 3 70 | d_hidden = 256 71 | n_layers = 4 72 | weight_norm = True 73 | multires_view = 4 74 | squeeze_out = True 75 | } 76 | 77 | neus_renderer { 78 | n_samples = 64 79 | n_importance = 64 80 | n_outside = 0 81 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling 82 | perturb = 1.0 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

NeRFPrior: Learning Neural Radiance Field as a Prior for Indoor Scene Reconstruction

3 | 4 | 5 |

Paper | Project Page

6 |
7 |
8 |

9 | 10 |

11 | 12 | In this paper, we introduce NeRFPrior. Given multi-view images of a scene as input, we first train a grid-based NeRF to obtain the density field and color field as priors. We then learn a signed distance function by imposing a multi-view consistency constraint using volume rendering. For each sampled point on the ray, we query the prior density and prior color as additional supervision of the predicted density and color, respectively. To improve the smoothness and completeness of textureless areas in the scene, we propose a depth consistency loss, which forces surface points in the same textureless plane to have similar depths. 13 | 14 | 15 | # Preprocessed Datasets 16 | 17 | Our preprocessed ScanNet and Replica datasets are provided in [This link](https://huggingface.co/datasets/zParquet/MonoInstance/tree/main). 18 | 19 | 20 | # Setup 21 | 22 | ## Installation 23 | 24 | Clone the repository and create an anaconda environment using 25 | ```shell 26 | git clone https://github.com/wen-yuan-zhang/NeRFPrior.git 27 | cd NeRFPrior 28 | 29 | conda create -n nerfprior python=3.10 30 | conda activate nerfprior 31 | 32 | conda install pytorch=1.13.0 torchvision=0.14.0 cudatoolkit=11.7 -c pytorch 33 | conda install cudatoolkit-dev=11.7 -c conda-forge 34 | 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | You should also clone TensoRF to obtain NeRF prior. 39 | ```shell 40 | git clone https://github.com/apchenstu/TensoRF.git 41 | ``` 42 | 43 | 44 | # Training 45 | 46 | Firstly, train TensoRF on the given scene to obtain the ```.th``` checkpoint for further neural implicit surface reconstruction. 47 | 48 | To train BlendSwap dataset, use 49 | ```shell 50 | CUDA_VISIBLE_DEVICES=1 python exp_runner.py --conf confs/blendswap.conf 51 | ``` 52 | To train Replica dataset, use 53 | ```shell 54 | CUDA_VISIBLE_DEVICES=1 python exp_runner.py --conf confs/replica.conf 55 | ``` 56 | 57 | 58 | 59 | 60 | # Evaluation 61 | 62 | To evaluate the reconstructed meshes first use ```cull_mesh.py``` to cull the meshes according to view frustums. Then use ```mesh_metrics.py``` to specify the path to meshes and evaluate the metrics. 63 | ```shell 64 | python cull_mesh.py 65 | python mesh_metrics.py 66 | ``` 67 | 68 | 69 | 70 | 71 | # Acknowledgements 72 | 73 | This project is built upon [NeuS](https://lingjie0206.github.io/papers/NeuS/) and [TensoRF](https://apchenstu.github.io/TensoRF/). We thank all the authors for their great repos. 74 | 75 | 76 | # Citation 77 | 78 | If you find our code or paper useful, please consider citing 79 | ```bibtex 80 | @inproceedings{zhang2025nerfprior, 81 | title={{NeRFPrior}: Learning neural radiance field as a prior for indoor scene reconstruction}, 82 | author={Zhang, Wenyuan and Jia, Emily Yue-ting and Zhou, Junsheng and Ma, Baorui and Shi, Kanle and Liu, Yu-Shen and Han, Zhizhong}, 83 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 84 | pages={11317--11327}, 85 | year={2025} 86 | } 87 | ``` 88 | 89 | -------------------------------------------------------------------------------- /models/sh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ################## sh function ################## 4 | C0 = 0.28209479177387814 5 | C1 = 0.4886025119029199 6 | C2 = [ 7 | 1.0925484305920792, 8 | -1.0925484305920792, 9 | 0.31539156525252005, 10 | -1.0925484305920792, 11 | 0.5462742152960396 12 | ] 13 | C3 = [ 14 | -0.5900435899266435, 15 | 2.890611442640554, 16 | -0.4570457994644658, 17 | 0.3731763325901154, 18 | -0.4570457994644658, 19 | 1.445305721320277, 20 | -0.5900435899266435 21 | ] 22 | C4 = [ 23 | 2.5033429417967046, 24 | -1.7701307697799304, 25 | 0.9461746957575601, 26 | -0.6690465435572892, 27 | 0.10578554691520431, 28 | -0.6690465435572892, 29 | 0.47308734787878004, 30 | -1.7701307697799304, 31 | 0.6258357354491761, 32 | ] 33 | 34 | def eval_sh(deg, sh, dirs): 35 | """ 36 | Evaluate spherical harmonics at unit directions 37 | using hardcoded SH polynomials. 38 | Works with torch/np/jnp. 39 | ... Can be 0 or more batch dimensions. 40 | :param deg: int SH max degree. Currently, 0-4 supported 41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) 42 | :param dirs: torch.Tensor unit directions (..., 3) 43 | :return: (..., C) 44 | """ 45 | assert deg <= 4 and deg >= 0 46 | assert (deg + 1) ** 2 == sh.shape[-1] 47 | C = sh.shape[-2] 48 | 49 | result = C0 * sh[..., 0] 50 | if deg > 0: 51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] 52 | result = (result - 53 | C1 * y * sh[..., 1] + 54 | C1 * z * sh[..., 2] - 55 | C1 * x * sh[..., 3]) 56 | if deg > 1: 57 | xx, yy, zz = x * x, y * y, z * z 58 | xy, yz, xz = x * y, y * z, x * z 59 | result = (result + 60 | C2[0] * xy * sh[..., 4] + 61 | C2[1] * yz * sh[..., 5] + 62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + 63 | C2[3] * xz * sh[..., 7] + 64 | C2[4] * (xx - yy) * sh[..., 8]) 65 | 66 | if deg > 2: 67 | result = (result + 68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] + 69 | C3[1] * xy * z * sh[..., 10] + 70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + 71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + 72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + 73 | C3[5] * z * (xx - yy) * sh[..., 14] + 74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15]) 75 | if deg > 3: 76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + 77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] + 78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] + 79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] + 80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + 81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] + 82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + 83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + 84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) 85 | return result 86 | 87 | def eval_sh_bases(deg, dirs): 88 | """ 89 | Evaluate spherical harmonics bases at unit directions, 90 | without taking linear combination. 91 | At each point, the final result may the be 92 | obtained through simple multiplication. 93 | :param deg: int SH max degree. Currently, 0-4 supported 94 | :param dirs: torch.Tensor (..., 3) unit directions 95 | :return: torch.Tensor (..., (deg+1) ** 2) 96 | """ 97 | assert deg <= 4 and deg >= 0 98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) 99 | result[..., 0] = C0 100 | if deg > 0: 101 | x, y, z = dirs.unbind(-1) 102 | result[..., 1] = -C1 * y; 103 | result[..., 2] = C1 * z; 104 | result[..., 3] = -C1 * x; 105 | if deg > 1: 106 | xx, yy, zz = x * x, y * y, z * z 107 | xy, yz, xz = x * y, y * z, x * z 108 | result[..., 4] = C2[0] * xy; 109 | result[..., 5] = C2[1] * yz; 110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy); 111 | result[..., 7] = C2[3] * xz; 112 | result[..., 8] = C2[4] * (xx - yy); 113 | 114 | if deg > 2: 115 | result[..., 9] = C3[0] * y * (3 * xx - yy); 116 | result[..., 10] = C3[1] * xy * z; 117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy); 118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); 119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy); 120 | result[..., 14] = C3[5] * z * (xx - yy); 121 | result[..., 15] = C3[6] * x * (xx - 3 * yy); 122 | 123 | if deg > 3: 124 | result[..., 16] = C4[0] * xy * (xx - yy); 125 | result[..., 17] = C4[1] * yz * (3 * xx - yy); 126 | result[..., 18] = C4[2] * xy * (7 * zz - 1); 127 | result[..., 19] = C4[3] * yz * (7 * zz - 3); 128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); 129 | result[..., 21] = C4[5] * xz * (7 * zz - 3); 130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); 131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy); 132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); 133 | return result 134 | -------------------------------------------------------------------------------- /mesh_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import trimesh 3 | from scipy.spatial import cKDTree 4 | import numpy as np 5 | import open3d as o3d 6 | # from matplotlib import pyplot as plt 7 | 8 | 9 | def compute_iou(mesh_pred, mesh_target): 10 | res = 0.05 11 | v_pred = mesh_pred.voxelized(pitch=res) 12 | v_target = mesh_target.voxelized(pitch=res) 13 | v_target_mesh = v_target.as_boxes() 14 | v_pred_mesh = v_pred.as_boxes() 15 | 16 | v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred.points) 17 | v_target_filled = set(tuple(np.round(x, 4)) for x in v_target.points) 18 | inter = v_pred_filled.intersection(v_target_filled) 19 | union = v_pred_filled.union(v_target_filled) 20 | iou = len(inter) / len(union) 21 | return iou, v_target_mesh, v_pred_mesh 22 | 23 | 24 | def get_colored_pcd(pcd, metric): 25 | cmap = plt.cm.get_cmap("jet") 26 | color = cmap(metric / 0.10)[..., :3] 27 | pcd_o3d = o3d.geometry.PointCloud() 28 | pcd_o3d.points = o3d.utility.Vector3dVector(pcd) 29 | pcd_o3d.colors = o3d.utility.Vector3dVector(color) 30 | return pcd_o3d 31 | 32 | 33 | def compute_metrics(mesh_pred, mesh_target): 34 | # mesh_pred = trimesh.load_mesh(path_pred) 35 | # mesh_target = trimesh.load_mesh(path_target) 36 | area_pred = int(mesh_pred.area * 1e4) 37 | area_tgt = int(mesh_target.area * 1e4) 38 | print("pred: {}, target: {}".format(area_pred, area_tgt)) 39 | 40 | iou, v_gt, v_pred = compute_iou(mesh_pred, mesh_target) 41 | 42 | pointcloud_pred, idx = mesh_pred.sample(area_pred, return_index=True) 43 | pointcloud_pred = pointcloud_pred.astype(np.float32) 44 | normals_pred = mesh_pred.face_normals[idx] 45 | 46 | pointcloud_tgt, idx = mesh_target.sample(area_tgt, return_index=True) 47 | pointcloud_tgt = pointcloud_tgt.astype(np.float32) 48 | normals_tgt = mesh_target.face_normals[idx] 49 | 50 | thresholds = np.array([0.05]) 51 | 52 | # for every point in gt compute the min distance to points in pred 53 | completeness, completeness_normals = distance_p2p( 54 | pointcloud_tgt, normals_tgt, pointcloud_pred, normals_pred 55 | ) 56 | recall = get_threshold_percentage(completeness, thresholds) 57 | completeness2 = completeness ** 2 58 | 59 | # color gt_point_cloud using completion 60 | # com_mesh = get_colored_pcd(pointcloud_tgt, completeness) 61 | 62 | completeness = completeness.mean() 63 | completeness2 = completeness2.mean() 64 | completeness_normals = completeness_normals.mean() 65 | 66 | # Accuracy: how far are th points of the predicted pointcloud 67 | # from the target pointcloud 68 | accuracy, accuracy_normals = distance_p2p( 69 | pointcloud_pred, normals_pred, pointcloud_tgt, normals_tgt 70 | ) 71 | precision = get_threshold_percentage(accuracy, thresholds) 72 | accuracy2 = accuracy ** 2 73 | 74 | # color pred_point_cloud using completion 75 | # acc_mesh = get_colored_pcd(pointcloud_pred, accuracy) 76 | 77 | accuracy = accuracy.mean() 78 | accuracy2 = accuracy2.mean() 79 | accuracy_normals = accuracy_normals.mean() 80 | 81 | # Chamfer distance 82 | chamferL2 = 0.5 * (completeness2 + accuracy2) 83 | normals_correctness = ( 84 | 0.5 * completeness_normals + 0.5 * accuracy_normals 85 | ) 86 | chamferL1 = 0.5 * (completeness + accuracy) 87 | 88 | # F-Score 89 | F = [ 90 | 2 * precision[i] * recall[i] / (precision[i] + recall[i]) 91 | for i in range(len(precision)) 92 | ] 93 | rst = { 94 | "IoU": iou, 95 | # "Acc": accuracy, 96 | # "Comp": completeness, 97 | "C-L1": chamferL1, 98 | "NC": normals_correctness, 99 | 'precision': precision[0], 100 | 'recall': recall[0], 101 | "F-score": F[0] 102 | } 103 | 104 | return rst 105 | 106 | 107 | def distance_p2p(points_src, normals_src, points_tgt, normals_tgt): 108 | """ Computes minimal distances of each point in points_src to points_tgt. 109 | Args: 110 | points_src (numpy array): source points 111 | normals_src (numpy array): source normals 112 | points_tgt (numpy array): target points 113 | normals_tgt (numpy array): target normals 114 | """ 115 | kdtree = cKDTree(points_tgt) 116 | dist, idx = kdtree.query(points_src) 117 | 118 | if normals_src is not None and normals_tgt is not None: 119 | normals_src = \ 120 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True) 121 | normals_tgt = \ 122 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True) 123 | 124 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1) 125 | # Handle normals that point into wrong direction gracefully 126 | # (mostly due to mehtod not caring about this in generation) 127 | normals_dot_product = np.abs(normals_dot_product) 128 | else: 129 | normals_dot_product = np.array( 130 | [np.nan] * points_src.shape[0], dtype=np.float32) 131 | return dist, normals_dot_product 132 | 133 | 134 | def get_threshold_percentage(dist, thresholds): 135 | """ Evaluates a point cloud. 136 | Args: 137 | dist (numpy array): calculated distance 138 | thresholds (numpy array): threshold values for the F-score calculation 139 | """ 140 | in_threshold = [ 141 | (dist <= t).astype(np.float32).mean() for t in thresholds 142 | ] 143 | return in_threshold 144 | 145 | 146 | def save_meshes(meshes, mesh_dir, save_name): 147 | for key in meshes: 148 | mesh = meshes[key] 149 | if isinstance(mesh, o3d.geometry.PointCloud): 150 | o3d.io.write_point_cloud(os.path.join(mesh_dir, "{}_{}.ply".format(key, save_name)), mesh) 151 | elif isinstance(mesh, trimesh.Trimesh): 152 | mesh.export(os.path.join(mesh_dir, "{}_{}.ply".format(key, save_name))) 153 | 154 | 155 | if __name__ == '__main__': 156 | gt_mesh_name = 'out_meshes/room0/gt_cropped_culled_transformed.ply' 157 | ours_mesh_name = 'out_meshes/room0/neus_cropped_culled_transformed.ply' 158 | 159 | gt_mesh = trimesh.load_mesh(gt_mesh_name) 160 | # neuralrgbd_mesh = trimesh.load_mesh(neuralrgbd_mesh_name) 161 | ours_mesh = trimesh.load(ours_mesh_name) 162 | 163 | # iou1, target_mesh, pred_mesh = compute_iou(neuralrgbd_mesh, gt_mesh) 164 | # iou2, target_mesh, pred_mesh = compute_iou(ours_mesh, gt_mesh) 165 | # print(iou1) 166 | # print(iou2) 167 | 168 | # ret = compute_metrics(neuralrgbd_mesh, gt_mesh) 169 | # print(ret) 170 | ret = compute_metrics(ours_mesh, gt_mesh) 171 | print(ret) 172 | -------------------------------------------------------------------------------- /models/scannet_blender.py: -------------------------------------------------------------------------------- 1 | from cv2 import cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms as T 9 | 10 | import torch 11 | from kornia import create_meshgrid 12 | 13 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" 14 | 15 | def get_ray_directions(H, W, focal, center=None): 16 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 17 | 18 | i, j = grid.unbind(-1) 19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 20 | # see https://github.com/bmild/nerf/issues/24 21 | cent = center if center is not None else [W / 2, H / 2] 22 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) 23 | 24 | return directions 25 | 26 | def get_rays(directions, c2w): 27 | # Rotate ray directions from camera coordinate to the world coordinate 28 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 29 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 30 | # The origin of all rays is the camera origin in world coordinate 31 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 32 | 33 | rays_d = rays_d.view(-1, 3) 34 | rays_o = rays_o.view(-1, 3) 35 | 36 | return rays_o, rays_d 37 | 38 | class ScanNetDataset(Dataset): 39 | def __init__(self, conf, split='train', N_vis=-1): 40 | self.device = torch.device('cuda') 41 | self.N_vis = N_vis 42 | self.root_dir = conf.get_string('data_dir') 43 | scene = conf.get_string('scene') 44 | 45 | self.split = split 46 | self.is_stack = False 47 | self.downsample = 1.0 48 | self.transform = T.ToTensor() 49 | 50 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 51 | 52 | self.poses = [] 53 | self.all_rays = [] 54 | self.all_rgbs = [] 55 | self.all_depth = [] 56 | self.directions = [] 57 | self.read_meta(scene) 58 | self.n_images = len(self.all_rgbs) 59 | self.white_bg = True 60 | 61 | # for Neus exp_runner 62 | self.pose_all = self.poses 63 | self.object_bbox_min = np.array([-1.01, -1.01, -1.01]) # only used in extract 64 | self.object_bbox_max = np.array([ 1.01, 1.01, 1.01]) 65 | self.object_bbox_max = np.array([0.6, 0.6, 0.6]) 66 | self.object_bbox_min = np.array([-0.6, -0.6, -0.6]) 67 | self.scale_mats_np = [np.eye(4)] 68 | 69 | def read_meta(self, scene): 70 | root_dir = os.path.join(self.root_dir, scene) 71 | with open(os.path.join(root_dir, f"transforms_{self.split}.json"), 'r') as f: 72 | self.meta = json.load(f) 73 | 74 | w, h = int(self.meta['w'] / self.downsample), int(self.meta['h'] / self.downsample) 75 | self.img_wh = [w, h] 76 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 77 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length 78 | self.cx, self.cy = self.meta['cx'], self.meta['cy'] 79 | 80 | # ray directions for all pixels, same for all images (same H, W, focal) 81 | direction = get_ray_directions(h, w, [self.focal_x, self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) 82 | self.directions.append(direction) 83 | # self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) # TODO 84 | self.intrinsics = torch.tensor([[self.focal_x, 0, self.cx], [0, self.focal_y, self.cy], [0, 0, 1]]).float() 85 | 86 | idxs = list(range(0, len(self.meta['frames']))) 87 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:# 88 | frame = self.meta['frames'][i] 89 | pose = np.array(frame['transform_matrix']) 90 | pose = pose @ self.blender2opencv 91 | c2w = torch.FloatTensor(pose) 92 | self.poses.append(c2w) 93 | image_path = os.path.join(root_dir, f"{frame['file_path']}") 94 | image_id = int(os.path.basename(image_path)[:-4]) 95 | img = Image.open(image_path) 96 | img = self.transform(img) # (4, h, w) 97 | img = img.permute(1, 2, 0) # (h*w, 4) RGBA 98 | if img.shape[-1] == 4: 99 | img = img[..., :3] * img[..., -1:] + (1 - img[..., -1:]) # blend A to RGB 100 | self.all_rgbs.append(img) 101 | 102 | 103 | self.poses = torch.stack(self.poses) #(N, 4, 4) 104 | 105 | rays_o = self.poses[:, :3, 3] 106 | 107 | # self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 6) 108 | # self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 109 | 110 | self.h, self.w = h, w 111 | 112 | # def define_proj_mat(self): 113 | # self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3] 114 | 115 | def __len__(self): 116 | return len(self.all_rgbs) 117 | 118 | # 多个场景版本的get_random_rays 119 | def gen_random_rays_at(self, img_idx, batch_size): 120 | scene_id = 0 121 | pixels_x = torch.randint(low=0, high=self.w, size=[batch_size]) 122 | pixels_y = torch.randint(low=0, high=self.h, size=[batch_size]) 123 | color = self.all_rgbs[img_idx] # [h, w, 3] 124 | color = color[(pixels_y, pixels_x)].cpu() # [batch_size, 3] 125 | # depth = self.all_depth[img_idx].cuda()[(pixels_y, pixels_x)] # [batch_size,] 126 | # depth = depth.unsqueeze(-1).cpu() 127 | depth = torch.ones(batch_size, 1).cpu() 128 | mask = torch.ones_like(color, dtype=torch.float) 129 | directions = self.directions[scene_id] 130 | batch_direction = directions[(pixels_y, pixels_x)] 131 | rays_o, rays_d = get_rays(batch_direction, self.poses[img_idx]) # both (batch_size, 3) 132 | rand_rays = torch.cat([rays_o, rays_d], -1) 133 | return torch.cat([rand_rays, color, mask[:, :1]], dim=-1).to(self.device) 134 | 135 | def near_far_from_sphere(self, rays_o, rays_d): 136 | near = torch.zeros(rays_o.shape[0], 1).cuda() 137 | far = torch.ones(rays_o.shape[0], 1).cuda() * 2 138 | return near, far 139 | 140 | def gen_rays_at(self, img_idx, resolution_level=1): 141 | tx = torch.linspace(0, self.w-resolution_level, self.w // resolution_level) 142 | ty = torch.linspace(0, self.h-resolution_level, self.h // resolution_level) 143 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 144 | directions = self.directions[0] 145 | batch_direction = directions[(pixels_y.long(), pixels_x.long())] 146 | rays_o, rays_d = get_rays(batch_direction, self.poses[img_idx]) # both (batch_size, 3) 147 | rays_o = rays_o.reshape(self.w//resolution_level, self.h//resolution_level, 3).to(self.device).transpose(0,1) 148 | rays_d = rays_d.reshape(self.w//resolution_level, self.h//resolution_level, 3).to(self.device).transpose(0,1) 149 | return rays_o, rays_d 150 | 151 | def image_at(self, idx, resolution_level): 152 | img = self.all_rgbs[idx].cpu().numpy() * 255 153 | img = cv2.resize(img, (self.w//resolution_level, self.h//resolution_level)) 154 | return img[:,:,[2,1,0]] 155 | 156 | def get_scene_id(self, img_idx): 157 | return 0 158 | 159 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 160 | # only used in novel view synthesis 161 | raise NotImplementedError() 162 | 163 | 164 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from glob import glob 7 | # from icecream import ic 8 | from scipy.spatial.transform import Rotation as Rot 9 | from scipy.spatial.transform import Slerp 10 | 11 | 12 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 13 | def load_K_Rt_from_P(filename, P=None): 14 | if P is None: 15 | lines = open(filename).read().splitlines() 16 | if len(lines) == 4: 17 | lines = lines[1:] 18 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 19 | P = np.asarray(lines).astype(np.float32).squeeze() 20 | 21 | out = cv.decomposeProjectionMatrix(P) 22 | K = out[0] 23 | R = out[1] 24 | t = out[2] 25 | 26 | K = K / K[2, 2] 27 | intrinsics = np.eye(4) 28 | intrinsics[:3, :3] = K 29 | 30 | pose = np.eye(4, dtype=np.float32) 31 | pose[:3, :3] = R.transpose() 32 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 33 | 34 | return intrinsics, pose 35 | 36 | 37 | class Dataset: 38 | def __init__(self, conf): 39 | super(Dataset, self).__init__() 40 | print('Load data: Begin') 41 | self.device = torch.device('cuda') 42 | self.conf = conf 43 | 44 | self.data_dir = conf.get_string('data_dir') 45 | self.render_cameras_name = conf.get_string('render_cameras_name') 46 | self.object_cameras_name = conf.get_string('object_cameras_name') 47 | 48 | self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True) 49 | self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1) 50 | 51 | camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name)) 52 | self.camera_dict = camera_dict 53 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png'))) 54 | self.n_images = len(self.images_lis) 55 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0 56 | self.masks_lis = sorted(glob(os.path.join(self.data_dir, 'mask/*.png'))) 57 | self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 256.0 58 | 59 | # world_mat is a projection matrix from world to image 60 | self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 61 | 62 | self.scale_mats_np = [] 63 | 64 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin. 65 | self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 66 | 67 | self.intrinsics_all = [] 68 | self.pose_all = [] 69 | 70 | for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np): 71 | P = world_mat @ scale_mat 72 | P = P[:3, :4] 73 | intrinsics, pose = load_K_Rt_from_P(None, P) 74 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 75 | self.pose_all.append(torch.from_numpy(pose).float()) 76 | 77 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3] 78 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3] 79 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4] 80 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4] 81 | self.focal = self.intrinsics_all[0][0, 0] 82 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4] 83 | self.H, self.W = self.images.shape[1], self.images.shape[2] 84 | self.image_pixels = self.H * self.W 85 | 86 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0]) 87 | object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0]) 88 | # Object scale mat: region of interest to **extract mesh** 89 | object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0'] 90 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None] 91 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None] 92 | self.object_bbox_min = object_bbox_min[:3, 0] 93 | self.object_bbox_max = object_bbox_max[:3, 0] 94 | 95 | print('Load data: End') 96 | 97 | def gen_rays_at(self, img_idx, resolution_level=1): 98 | """ 99 | Generate rays at world space from one camera. 100 | """ 101 | l = resolution_level 102 | tx = torch.linspace(0, self.W - 1, self.W // l) 103 | ty = torch.linspace(0, self.H - 1, self.H // l) 104 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 105 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 106 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 107 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 108 | rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 109 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3 110 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 111 | 112 | def gen_random_rays_at(self, img_idx, batch_size): 113 | """ 114 | Generate random rays at world space from one camera. 115 | """ 116 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) 117 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) 118 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3 119 | mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3 120 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3 121 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3 122 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 123 | rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3 124 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3 125 | return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10 126 | 127 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 128 | """ 129 | Interpolate pose between two cameras. 130 | """ 131 | l = resolution_level 132 | tx = torch.linspace(0, self.W - 1, self.W // l) 133 | ty = torch.linspace(0, self.H - 1, self.H // l) 134 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 135 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3 136 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 137 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 138 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio 139 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy() 140 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy() 141 | pose_0 = np.linalg.inv(pose_0) 142 | pose_1 = np.linalg.inv(pose_1) 143 | rot_0 = pose_0[:3, :3] 144 | rot_1 = pose_1[:3, :3] 145 | rots = Rot.from_matrix(np.stack([rot_0, rot_1])) 146 | key_times = [0, 1] 147 | slerp = Slerp(key_times, rots) 148 | rot = slerp(ratio) 149 | pose = np.diag([1.0, 1.0, 1.0, 1.0]) 150 | pose = pose.astype(np.float32) 151 | pose[:3, :3] = rot.as_matrix() 152 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] 153 | pose = np.linalg.inv(pose) 154 | rot = torch.from_numpy(pose[:3, :3]).cuda() 155 | trans = torch.from_numpy(pose[:3, 3]).cuda() 156 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 157 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 158 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1) 159 | 160 | def near_far_from_sphere(self, rays_o, rays_d): 161 | a = torch.sum(rays_d**2, dim=-1, keepdim=True) 162 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 163 | mid = 0.5 * (-b) / a 164 | near = mid - 1.0 165 | far = mid + 1.0 166 | return near, far 167 | 168 | def image_at(self, idx, resolution_level): 169 | img = cv.imread(self.images_lis[idx]) 170 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255) 171 | 172 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | 8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 9 | class SDFNetwork(nn.Module): 10 | def __init__(self, 11 | d_in, 12 | d_out, 13 | d_hidden, 14 | n_layers, 15 | skip_in=(4,), 16 | multires=0, 17 | bias=0.5, 18 | scale=1, 19 | geometric_init=True, 20 | weight_norm=True, 21 | inside_outside=False): 22 | super(SDFNetwork, self).__init__() 23 | 24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 25 | 26 | self.embed_fn_fine = None 27 | 28 | if multires > 0: 29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 30 | self.embed_fn_fine = embed_fn 31 | dims[0] = input_ch 32 | 33 | self.num_layers = len(dims) 34 | self.skip_in = skip_in 35 | self.scale = scale 36 | 37 | for l in range(0, self.num_layers - 1): 38 | if l + 1 in self.skip_in: 39 | out_dim = dims[l + 1] - dims[0] 40 | else: 41 | out_dim = dims[l + 1] 42 | 43 | lin = nn.Linear(dims[l], out_dim) 44 | 45 | if geometric_init: 46 | if l == self.num_layers - 2: 47 | if not inside_outside: 48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 49 | torch.nn.init.constant_(lin.bias, -bias) 50 | else: 51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 52 | torch.nn.init.constant_(lin.bias, bias) 53 | elif multires > 0 and l == 0: 54 | torch.nn.init.constant_(lin.bias, 0.0) 55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 57 | elif multires > 0 and l in self.skip_in: 58 | torch.nn.init.constant_(lin.bias, 0.0) 59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 61 | else: 62 | torch.nn.init.constant_(lin.bias, 0.0) 63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 64 | 65 | if weight_norm: 66 | lin = nn.utils.weight_norm(lin) 67 | 68 | setattr(self, "lin" + str(l), lin) 69 | 70 | self.activation = nn.Softplus(beta=100) 71 | 72 | def forward(self, inputs): 73 | inputs = inputs * self.scale 74 | if self.embed_fn_fine is not None: 75 | inputs = self.embed_fn_fine(inputs) 76 | 77 | x = inputs 78 | for l in range(0, self.num_layers - 1): 79 | lin = getattr(self, "lin" + str(l)) 80 | 81 | if l in self.skip_in: 82 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 83 | 84 | x = lin(x) 85 | 86 | if l < self.num_layers - 2: 87 | x = self.activation(x) 88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) 89 | 90 | def sdf(self, x): 91 | return self.forward(x)[:, :1] 92 | 93 | def sdf_hidden_appearance(self, x): 94 | return self.forward(x) 95 | 96 | def gradient(self, x): # gradients assumed to be normalized because of eikonal_loss 97 | x.requires_grad_(True) 98 | y = self.sdf(x) 99 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 100 | gradients = torch.autograd.grad( 101 | outputs=y, 102 | inputs=x, 103 | grad_outputs=d_output, 104 | create_graph=True, 105 | retain_graph=True, 106 | only_inputs=True)[0] 107 | return gradients.unsqueeze(1) 108 | 109 | 110 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr 111 | class RenderingNetwork(nn.Module): 112 | def __init__(self, 113 | d_feature, 114 | mode, 115 | d_in, 116 | d_out, 117 | d_hidden, 118 | n_layers, 119 | weight_norm=True, 120 | multires_view=0, 121 | squeeze_out=True): 122 | super().__init__() 123 | 124 | self.mode = mode 125 | self.squeeze_out = squeeze_out 126 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 127 | 128 | self.embedview_fn = None 129 | if multires_view > 0: 130 | embedview_fn, input_ch = get_embedder(multires_view) 131 | self.embedview_fn = embedview_fn 132 | dims[0] += (input_ch - 3) 133 | 134 | self.num_layers = len(dims) 135 | 136 | for l in range(0, self.num_layers - 1): 137 | out_dim = dims[l + 1] 138 | lin = nn.Linear(dims[l], out_dim) 139 | 140 | if weight_norm: 141 | lin = nn.utils.weight_norm(lin) 142 | 143 | setattr(self, "lin" + str(l), lin) 144 | 145 | self.relu = nn.ReLU() 146 | 147 | def forward(self, points, normals, view_dirs, feature_vectors): 148 | if self.embedview_fn is not None: 149 | view_dirs = self.embedview_fn(view_dirs) 150 | 151 | rendering_input = None 152 | 153 | if self.mode == 'idr': 154 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 155 | elif self.mode == 'no_view_dir': 156 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) 157 | elif self.mode == 'no_normal': 158 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 159 | 160 | x = rendering_input 161 | 162 | for l in range(0, self.num_layers - 1): 163 | lin = getattr(self, "lin" + str(l)) 164 | 165 | x = lin(x) 166 | 167 | if l < self.num_layers - 2: 168 | x = self.relu(x) 169 | 170 | if self.squeeze_out: 171 | x = torch.sigmoid(x) 172 | return x 173 | 174 | 175 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch 176 | class NeRF(nn.Module): 177 | def __init__(self, 178 | D=8, 179 | W=256, 180 | d_in=3, 181 | d_in_view=3, 182 | multires=0, 183 | multires_view=0, 184 | output_ch=4, 185 | skips=[4], 186 | use_viewdirs=False): 187 | super(NeRF, self).__init__() 188 | self.D = D 189 | self.W = W 190 | self.d_in = d_in 191 | self.d_in_view = d_in_view 192 | self.input_ch = 3 193 | self.input_ch_view = 3 194 | self.embed_fn = None 195 | self.embed_fn_view = None 196 | 197 | if multires > 0: 198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 199 | self.embed_fn = embed_fn 200 | self.input_ch = input_ch 201 | 202 | if multires_view > 0: 203 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view) 204 | self.embed_fn_view = embed_fn_view 205 | self.input_ch_view = input_ch_view 206 | 207 | self.skips = skips 208 | self.use_viewdirs = use_viewdirs 209 | 210 | self.pts_linears = nn.ModuleList( 211 | [nn.Linear(self.input_ch, W)] + 212 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)]) 213 | 214 | ### Implementation according to the official code release 215 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 216 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) 217 | 218 | ### Implementation according to the paper 219 | # self.views_linears = nn.ModuleList( 220 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 221 | 222 | if use_viewdirs: 223 | self.feature_linear = nn.Linear(W, W) 224 | self.alpha_linear = nn.Linear(W, 1) 225 | self.rgb_linear = nn.Linear(W // 2, 3) 226 | else: 227 | self.output_linear = nn.Linear(W, output_ch) 228 | 229 | def forward(self, input_pts, input_views): 230 | if self.embed_fn is not None: 231 | input_pts = self.embed_fn(input_pts) 232 | if self.embed_fn_view is not None: 233 | input_views = self.embed_fn_view(input_views) 234 | 235 | h = input_pts 236 | for i, l in enumerate(self.pts_linears): 237 | h = self.pts_linears[i](h) 238 | h = F.relu(h) 239 | if i in self.skips: 240 | h = torch.cat([input_pts, h], -1) 241 | 242 | if self.use_viewdirs: 243 | alpha = self.alpha_linear(h) 244 | feature = self.feature_linear(h) 245 | h = torch.cat([feature, input_views], -1) 246 | 247 | for i, l in enumerate(self.views_linears): 248 | h = self.views_linears[i](h) 249 | h = F.relu(h) 250 | 251 | rgb = self.rgb_linear(h) 252 | return alpha, rgb 253 | else: 254 | assert False 255 | 256 | 257 | class SingleVarianceNetwork(nn.Module): 258 | def __init__(self, init_val): 259 | super(SingleVarianceNetwork, self).__init__() 260 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) 261 | 262 | def forward(self, x): 263 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0) 264 | -------------------------------------------------------------------------------- /models/nerf2neus.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import mcubes 3 | import torch 4 | from models.tensoRF import TensorVMSplit 5 | import numpy as np 6 | import skimage 7 | import plyfile 8 | import os 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class PriorNeRF(nn.Module): 14 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): 15 | """ 16 | """ 17 | super(PriorNeRF, self).__init__() 18 | self.D = D 19 | self.W = W 20 | self.input_ch = input_ch 21 | self.input_ch_views = input_ch_views 22 | self.skips = skips 23 | self.use_viewdirs = use_viewdirs 24 | 25 | self.pts_linears = nn.ModuleList( 26 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in 27 | range(D - 1)]) 28 | 29 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) 30 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)]) 31 | 32 | ### Implementation according to the paper 33 | # self.views_linears = nn.ModuleList( 34 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) 35 | 36 | if use_viewdirs: 37 | self.feature_linear = nn.Linear(W, W) 38 | self.alpha_linear = nn.Linear(W, 1) 39 | self.rgb_linear = nn.Linear(W // 2, 3) 40 | else: 41 | self.output_linear = nn.Linear(W, output_ch) 42 | 43 | def forward(self, x): 44 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1) 45 | h = input_pts 46 | for i, l in enumerate(self.pts_linears): 47 | h = self.pts_linears[i](h) 48 | h = F.relu(h) 49 | if i in self.skips: 50 | h = torch.cat([input_pts, h], -1) 51 | 52 | if self.use_viewdirs: 53 | alpha = self.alpha_linear(h) 54 | feature = self.feature_linear(h) 55 | h = torch.cat([feature, input_views], -1) 56 | 57 | for i, l in enumerate(self.views_linears): 58 | h = self.views_linears[i](h) 59 | h = F.relu(h) 60 | 61 | rgb = self.rgb_linear(h) 62 | outputs = torch.cat([rgb, alpha], -1) 63 | else: 64 | outputs = self.output_linear(h) 65 | 66 | return outputs 67 | 68 | def load_weights_from_keras(self, weights): 69 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False" 70 | 71 | # Load pts_linears 72 | for i in range(self.D): 73 | idx_pts_linears = 2 * i 74 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears])) 75 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears + 1])) 76 | 77 | # Load feature_linear 78 | idx_feature_linear = 2 * self.D 79 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear])) 80 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear + 1])) 81 | 82 | # Load views_linears 83 | idx_views_linears = 2 * self.D + 2 84 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears])) 85 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears + 1])) 86 | 87 | # Load rgb_linear 88 | idx_rbg_linear = 2 * self.D + 4 89 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear])) 90 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear + 1])) 91 | 92 | # Load alpha_linear 93 | idx_alpha_linear = 2 * self.D + 6 94 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear])) 95 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear + 1])) 96 | 97 | 98 | class Embedder: 99 | def __init__(self, **kwargs): 100 | self.kwargs = kwargs 101 | self.create_embedding_fn() 102 | 103 | def create_embedding_fn(self): 104 | embed_fns = [] 105 | d = self.kwargs['input_dims'] 106 | out_dim = 0 107 | if self.kwargs['include_input']: 108 | embed_fns.append(lambda x: x) 109 | out_dim += d 110 | 111 | max_freq = self.kwargs['max_freq_log2'] 112 | N_freqs = self.kwargs['num_freqs'] 113 | 114 | if self.kwargs['log_sampling']: 115 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 116 | else: 117 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 118 | 119 | for freq in freq_bands: 120 | for p_fn in self.kwargs['periodic_fns']: 121 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 122 | out_dim += d 123 | 124 | self.embed_fns = embed_fns 125 | self.out_dim = out_dim 126 | 127 | def embed(self, inputs): 128 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 129 | 130 | 131 | def get_embedder(multires, i=0): 132 | if i == -1: 133 | return nn.Identity(), 3 134 | 135 | embed_kwargs = { 136 | 'include_input': True, 137 | 'input_dims': 3, 138 | 'max_freq_log2': multires - 1, 139 | 'num_freqs': multires, 140 | 'log_sampling': True, 141 | 'periodic_fns': [torch.sin, torch.cos], 142 | } 143 | 144 | embedder_obj = Embedder(**embed_kwargs) 145 | embed = lambda x, eo=embedder_obj: eo.embed(x) 146 | return embed, embedder_obj.out_dim 147 | 148 | 149 | 150 | def query_alpha_color_nerf(nerf: PriorNeRF, _xyz_sampled, viewdirs=None, z_vals=None): 151 | # xyz_sampled: [N, 3] 152 | # viewdirs: [N_rays, 3] 153 | xyz_sampled = _xyz_sampled * 2 154 | batch_size, n_samples, _ = xyz_sampled.shape 155 | inputs_flat = torch.reshape(xyz_sampled, [-1, 3]) 156 | embedded = nerf.embed_fn(inputs_flat) 157 | 158 | if viewdirs is None: 159 | viewdirs = torch.zeros(batch_size,3,dtype=torch.float32).cuda() 160 | input_dirs = viewdirs[:, None].expand(xyz_sampled.shape) 161 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 162 | embedded_dirs = nerf.embeddirs_fn(input_dirs_flat) 163 | embedded = torch.cat([embedded, embedded_dirs], -1) 164 | 165 | outputs_flat = nerf(embedded) 166 | 167 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists) 168 | 169 | dists = z_vals[..., 1:] - z_vals[..., :-1] 170 | dists = torch.cat([dists, dists[0][0].expand(dists[..., :1].shape)], -1) 171 | 172 | # dists = dists * torch.norm(rays_d[..., None, :], dim=-1) 173 | outputs_flat = outputs_flat.reshape(batch_size, n_samples, -1) 174 | alpha = raw2alpha(outputs_flat[..., 3], dists) # [N_rays, N_samples] 175 | 176 | return alpha, None 177 | 178 | def query_density_nerf(nerf, _xyz_sampled): 179 | # xyz_sampled: [N, 3] 180 | # viewdirs: [N_rays, 3] 181 | xyz_sampled = _xyz_sampled * 2 182 | batch_size, n_samples, _ = xyz_sampled.shape 183 | inputs_flat = torch.reshape(xyz_sampled, [-1, 3]) 184 | embedded = nerf.embed_fn(inputs_flat) 185 | 186 | viewdirs = torch.zeros(batch_size, 3, dtype=torch.float32).cuda() 187 | input_dirs = viewdirs[:, None].expand(xyz_sampled.shape) 188 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) 189 | embedded_dirs = nerf.embeddirs_fn(input_dirs_flat) 190 | embedded = torch.cat([embedded, embedded_dirs], -1) 191 | 192 | outputs_flat = nerf(embedded) 193 | 194 | return outputs_flat[..., 3] 195 | 196 | 197 | def load_nerf(ckpt_path): 198 | embed_fn, input_ch = get_embedder(10, 0) 199 | embeddirs_fn, input_ch_views = get_embedder(4, 0) 200 | model_fine = PriorNeRF(D=8, W=256, input_ch=input_ch, output_ch=5, skips=[4], 201 | input_ch_views=input_ch_views, use_viewdirs=True).cuda() 202 | print('Reloading from', ckpt_path) 203 | ckpt = torch.load(ckpt_path) 204 | model_fine.load_state_dict(ckpt['network_fine_state_dict']) 205 | model_fine.embed_fn = embed_fn 206 | model_fine.embeddirs_fn = embeddirs_fn 207 | return model_fine 208 | 209 | 210 | def validate_mesh(nerf): 211 | object_bbox_min = np.array([-1.01, -1.01, -1.01]) # only used in extract 212 | object_bbox_max = np.array([1.01, 1.01, 1.01]) 213 | bound_min = torch.tensor(object_bbox_min, dtype=torch.float32) 214 | bound_max = torch.tensor(object_bbox_max, dtype=torch.float32) 215 | 216 | resolution = 256 217 | N = 64 218 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 219 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 220 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 221 | 222 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 223 | with torch.no_grad(): 224 | for xi, xs in enumerate(X): 225 | for yi, ys in enumerate(Y): 226 | for zi, zs in enumerate(Z): 227 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 228 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) 229 | val = query_density_nerf(nerf,pts.reshape(1,-1,3)).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 230 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val 231 | 232 | vertices, triangles = mcubes.marching_cubes(u, 30) 233 | b_max_np = bound_max.detach().cpu().numpy() 234 | b_min_np = bound_min.detach().cpu().numpy() 235 | 236 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 237 | 238 | os.makedirs(os.path.join('debug'), exist_ok=True) 239 | mesh = trimesh.Trimesh(vertices, triangles) 240 | mesh.export(os.path.join('debug/nerf.ply')) -------------------------------------------------------------------------------- /models/ray_utils.py: -------------------------------------------------------------------------------- 1 | import torch, re 2 | import numpy as np 3 | from torch import searchsorted 4 | from kornia import create_meshgrid 5 | 6 | 7 | # from utils import index_point_feature 8 | 9 | def depth2dist(z_vals, cos_angle): 10 | # z_vals: [N_ray N_sample] 11 | device = z_vals.device 12 | dists = z_vals[..., 1:] - z_vals[..., :-1] 13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples] 14 | dists = dists * cos_angle.unsqueeze(-1) 15 | return dists 16 | 17 | 18 | def ndc2dist(ndc_pts, cos_angle): 19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1) 20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples] 21 | return dists 22 | 23 | 24 | def get_ray_directions(H, W, focal, center=None): 25 | """ 26 | Get ray directions for all pixels in camera coordinate. 27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 28 | ray-tracing-generating-camera-rays/standard-coordinate-systems 29 | Inputs: 30 | H, W, focal: image height, width and focal length 31 | Outputs: 32 | directions: (H, W, 3), the direction of the rays in camera coordinate 33 | """ 34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 35 | 36 | i, j = grid.unbind(-1) 37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 38 | # see https://github.com/bmild/nerf/issues/24 39 | cent = center if center is not None else [W / 2, H / 2] 40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) 41 | 42 | return directions 43 | 44 | 45 | def get_ray_directions_blender(H, W, focal, center=None): 46 | """ 47 | Get ray directions for all pixels in camera coordinate. 48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 49 | ray-tracing-generating-camera-rays/standard-coordinate-systems 50 | Inputs: 51 | H, W, focal: image height, width and focal length 52 | Outputs: 53 | directions: (H, W, 3), the direction of the rays in camera coordinate 54 | """ 55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5 56 | i, j = grid.unbind(-1) 57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate 58 | # see https://github.com/bmild/nerf/issues/24 59 | cent = center if center is not None else [W / 2, H / 2] 60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)], 61 | -1) # (H, W, 3) 62 | 63 | return directions 64 | 65 | 66 | def get_rays(directions, c2w): 67 | """ 68 | Get ray origin and normalized directions in world coordinate for all pixels in one image. 69 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ 70 | ray-tracing-generating-camera-rays/standard-coordinate-systems 71 | Inputs: 72 | directions: (H, W, 3) precomputed ray directions in camera coordinate 73 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate 74 | Outputs: 75 | rays_o: (H*W, 3), the origin of the rays in world coordinate 76 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate 77 | """ 78 | # Rotate ray directions from camera coordinate to the world coordinate 79 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3) 80 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True) 81 | # The origin of all rays is the camera origin in world coordinate 82 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3) 83 | 84 | rays_d = rays_d.view(-1, 3) 85 | rays_o = rays_o.view(-1, 3) 86 | 87 | return rays_o, rays_d 88 | 89 | 90 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d): 91 | # Shift ray origins to near plane 92 | t = -(near + rays_o[..., 2]) / rays_d[..., 2] 93 | rays_o = rays_o + t[..., None] * rays_d 94 | 95 | # Projection 96 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 97 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 98 | o2 = 1. + 2. * near / rays_o[..., 2] 99 | 100 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 101 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 102 | d2 = -2. * near / rays_o[..., 2] 103 | 104 | rays_o = torch.stack([o0, o1, o2], -1) 105 | rays_d = torch.stack([d0, d1, d2], -1) 106 | 107 | return rays_o, rays_d 108 | 109 | def ndc_rays(H, W, focal, near, rays_o, rays_d): 110 | # Shift ray origins to near plane 111 | t = (near - rays_o[..., 2]) / rays_d[..., 2] 112 | rays_o = rays_o + t[..., None] * rays_d 113 | 114 | # Projection 115 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2] 116 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2] 117 | o2 = 1. - 2. * near / rays_o[..., 2] 118 | 119 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2]) 120 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2]) 121 | d2 = 2. * near / rays_o[..., 2] 122 | 123 | rays_o = torch.stack([o0, o1, o2], -1) 124 | rays_d = torch.stack([d0, d1, d2], -1) 125 | 126 | return rays_o, rays_d 127 | 128 | # Hierarchical sampling (section 5.2) 129 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False): 130 | device = weights.device 131 | # Get pdf 132 | weights = weights + 1e-5 # prevent nans 133 | pdf = weights / torch.sum(weights, -1, keepdim=True) 134 | cdf = torch.cumsum(pdf, -1) 135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 136 | 137 | # Take uniform samples 138 | if det: 139 | u = torch.linspace(0., 1., steps=N_samples, device=device) 140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples]) 141 | else: 142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device) 143 | 144 | # Pytest, overwrite u with numpy's fixed random numbers 145 | if pytest: 146 | np.random.seed(0) 147 | new_shape = list(cdf.shape[:-1]) + [N_samples] 148 | if det: 149 | u = np.linspace(0., 1., N_samples) 150 | u = np.broadcast_to(u, new_shape) 151 | else: 152 | u = np.random.rand(*new_shape) 153 | u = torch.Tensor(u) 154 | 155 | # Invert CDF 156 | u = u.contiguous() 157 | inds = searchsorted(cdf.detach(), u, right=True) 158 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 159 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 160 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 161 | 162 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 163 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 164 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 165 | 166 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 167 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 168 | t = (u - cdf_g[..., 0]) / denom 169 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 170 | 171 | return samples 172 | 173 | 174 | def dda(rays_o, rays_d, bbox_3D): 175 | inv_ray_d = 1.0 / (rays_d + 1e-6) 176 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3 177 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d 178 | t = torch.stack((t_min, t_max)) # 2 N_rays 3 179 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0] 180 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0] 181 | return t_min, t_max 182 | 183 | 184 | def ray_marcher(rays, 185 | N_samples=64, 186 | lindisp=False, 187 | perturb=0, 188 | bbox_3D=None): 189 | """ 190 | sample points along the rays 191 | Inputs: 192 | rays: () 193 | 194 | Returns: 195 | 196 | """ 197 | 198 | # Decompose the inputs 199 | N_rays = rays.shape[0] 200 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) 201 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1) 202 | 203 | if bbox_3D is not None: 204 | # cal aabb boundles 205 | near, far = dda(rays_o, rays_d, bbox_3D) 206 | 207 | # Sample depth points 208 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples) 209 | if not lindisp: # use linear sampling in depth space 210 | z_vals = near * (1 - z_steps) + far * z_steps 211 | else: # use linear sampling in disparity space 212 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps) 213 | 214 | z_vals = z_vals.expand(N_rays, N_samples) 215 | 216 | if perturb > 0: # perturb sampling depths (z_vals) 217 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points 218 | # get intervals between samples 219 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1) 220 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1) 221 | 222 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device) 223 | z_vals = lower + (upper - lower) * perturb_rand 224 | 225 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \ 226 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3) 227 | 228 | return xyz_coarse_sampled, rays_o, rays_d, z_vals 229 | 230 | 231 | def read_pfm(filename): 232 | file = open(filename, 'rb') 233 | color = None 234 | width = None 235 | height = None 236 | scale = None 237 | endian = None 238 | 239 | header = file.readline().decode('utf-8').rstrip() 240 | if header == 'PF': 241 | color = True 242 | elif header == 'Pf': 243 | color = False 244 | else: 245 | raise Exception('Not a PFM file.') 246 | 247 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 248 | if dim_match: 249 | width, height = map(int, dim_match.groups()) 250 | else: 251 | raise Exception('Malformed PFM header.') 252 | 253 | scale = float(file.readline().rstrip()) 254 | if scale < 0: # little-endian 255 | endian = '<' 256 | scale = -scale 257 | else: 258 | endian = '>' # big-endian 259 | 260 | data = np.fromfile(file, endian + 'f') 261 | shape = (height, width, 3) if color else (height, width) 262 | 263 | data = np.reshape(data, shape) 264 | data = np.flipud(data) 265 | file.close() 266 | return data, scale 267 | 268 | 269 | def ndc_bbox(all_rays): 270 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0] 271 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0] 272 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0] 273 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0] 274 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}') 275 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max))) -------------------------------------------------------------------------------- /cull_mesh.py: -------------------------------------------------------------------------------- 1 | # import argparse 2 | 3 | import numpy as np 4 | import torch 5 | import trimesh 6 | from pyhocon import ConfigFactory 7 | from tqdm import tqdm 8 | import os 9 | from pytorch3d.structures import Meshes 10 | from pytorch3d.renderer.mesh import rasterizer 11 | from pytorch3d.renderer.cameras import PerspectiveCameras 12 | 13 | 14 | H = 680 15 | W = 1200 16 | 17 | 18 | 19 | translation_dict = { 20 | # BlendSwap 21 | "breakfast_room": [0.0, -1.42, 0.0], 22 | "kitchen": [0.3, -3.42, -0.12], 23 | "green_room": [-1, -0.38, 0.3], 24 | "complete_kitchen": [1.5, -2.25, 0.0], 25 | "grey_white_room": [-0.12, -1.94, -0.69], 26 | "morning_apartment": [-0.22, -0.43, 0.0], 27 | "staircase": [0.0, -2.42, 0.0], 28 | "whiteroom": [0.3, -3.42, -0.12], 29 | # Replica 30 | 'office0': [-0.1944, 0.6488, -0.3271], 31 | 'office1': [-0.585, -0.4703, -0.3507], 32 | 'office2': [0.1909, -1.2262, -0.1574], 33 | 'office3': [0.7893, 1.3371, -0.3305], 34 | 'office4': [-2.0684, -0.9268, -0.1993], 35 | 'room0': [-3.00, -1.1631, 0.1235], 36 | 'room1': [2.0795, 0.1747, 0.0314], 37 | 'room2': [-2.5681, 0.7727, 1.1110], 38 | } 39 | 40 | scale_dict = { 41 | # BlendSwap 42 | "breakfast_room": 0.4, 43 | "kitchen": 0.25, 44 | "green_room": 0.25, 45 | "complete_kitchen": 0.20, 46 | "grey_white_room": 0.25, 47 | "morning_apartment": 0.5, 48 | "staircase": 0.25, 49 | "whiteroom": 0.25, 50 | # Replica 51 | 'office0': 0.4, 52 | 'office1': 0.41, 53 | 'office2': 0.24, 54 | 'office3': 0.21, 55 | 'office4': 0.30, 56 | 'room0': 0.25, 57 | 'room1': 0.30, 58 | 'room2': 0.29, 59 | } 60 | 61 | scene_bounds_dict = { 62 | # BlendSwap 63 | "whiteroom": np.array([[-2.46, -0.1, 0.36], 64 | [3.06, 3.3, 8.2]]), 65 | "kitchen": np.array([[-3.12, -0.1, -3.18], 66 | [3.75, 3.3, 5.45]]), 67 | "breakfast_room": np.array([[-2.23, -0.5, -1.7], 68 | [1.85, 2.77, 3.0]]), 69 | "staircase":np.array([[-4.14, -0.1, -5.25], 70 | [2.52, 3.43, 1.08]]), 71 | "complete_kitchen":np.array([[-5.55, 0.0, -6.45], 72 | [3.65, 3.1, 3.5]]), 73 | "green_room":np.array([[-2.5, -0.1, 0.4], 74 | [5.4, 2.8, 5.0]]), 75 | "grey_white_room":np.array([[-0.55, -0.1, -3.75], 76 | [5.3, 3.0, 0.65]]), 77 | "morning_apartment":np.array([[-1.38, -0.1, -2.2], 78 | [2.1, 2.1, 1.75]]), 79 | "thin_geometry":np.array([[-2.15, 0.0, 0.0], 80 | [0.77, 0.75, 3.53]]), 81 | # Replica 82 | 'office0': np.array([[-2.0056, -3.1537, -1.1689], 83 | [2.3944, 1.8561, 1.8230]]), 84 | 'office1': np.array([[-1.8204, -1.5824, -1.0477], 85 | [2.9904, 2.5231, 1.7491]]), 86 | 'office2': np.array([[-3.4272, -2.8455, -1.2265], 87 | [3.0453, 5.2980, 1.5414]]), 88 | 'office3': np.array([[-5.1116, -5.9395, -1.2207], 89 | [3.5329, 3.2652, 1.8816]]), 90 | 'office4': np.array([[-1.2047, -2.3258, -1.2093], 91 | [5.3415, 4.1794, 1.6078]]), 92 | 'room0': np.array([[-0.8794, -1.1860, -1.5274], 93 | [6.8852, 3.5123, 1.2804]]), 94 | 'room1': np.array([[-5.4027, -3.0385, -1.4080], 95 | [1.2436, 2.6891, 1.3452]]), 96 | 'room2': np.array([[-0.8171, -3.2454, -2.9081], 97 | [5.9533, 1.7000, 0.6861]]), 98 | } 99 | 100 | 101 | 102 | def clean_invisible_vertices(mesh, train_dataset): 103 | 104 | poses = train_dataset.poses 105 | n_imgs = train_dataset.__len__() 106 | pc = mesh.vertices 107 | faces = mesh.faces 108 | xyz = torch.Tensor(pc) 109 | xyz = xyz.reshape(1, -1, 3) 110 | xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1) 111 | 112 | # delete mesh vertices that are not inside any camera's viewing frustum 113 | whole_mask = np.ones(pc.shape[0]).astype(np.bool) 114 | for i in tqdm(range(0, n_imgs, 1), desc='clean_vertices'): 115 | intrinsics = train_dataset.intrinsics 116 | pose = poses[i] 117 | # adjusted for blender 118 | camera_pos = torch.einsum('abj,ij->abi', xyz_h, pose.inverse()) 119 | projections = torch.einsum('ij, abj->abi', intrinsics, camera_pos[..., :3]) # [W, H, 3] 120 | pixel_locations = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8) - 0.5 121 | pixel_locations = pixel_locations[:, :, [1, 0]] 122 | pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6) 123 | uv = pixel_locations.reshape(-1, 2) 124 | z = pixel_locations[..., -1:] + 1e-5 125 | z = z.reshape(-1) 126 | edge = 0 127 | mask = (0 <= z) & (uv[:, 0] < H - edge) & (uv[:, 0] > edge) & (uv[:, 1] < W-edge) & (uv[:, 1] > edge) 128 | whole_mask &= ~mask.cpu().numpy() 129 | 130 | pc = mesh.vertices 131 | faces = mesh.faces 132 | face_mask = whole_mask[mesh.faces].all(axis=1) 133 | mesh.update_faces(~face_mask) 134 | 135 | return mesh 136 | 137 | # correction from pytorch3d (v0.5.0) 138 | def corrected_cameras_from_opencv_projection( R, tvec, camera_matrix, image_size): 139 | focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) 140 | principal_point = camera_matrix[:, :2, 2] 141 | 142 | # Retype the image_size correctly and flip to width, height. 143 | image_size_wh = image_size.to(R).flip(dims=(1,)) 144 | 145 | # Get the PyTorch3D focal length and principal point. 146 | s = (image_size_wh).min(dim=1).values 147 | 148 | focal_pytorch3d = focal_length / (0.5 * s) 149 | p0_pytorch3d = -(principal_point - image_size_wh / 2) * 2 / s 150 | 151 | # For R, T we flip x, y axes (opencv screen space has an opposite 152 | # orientation of screen axes). 153 | # We also transpose R (opencv multiplies points from the opposite=left side). 154 | R_pytorch3d = R.clone().permute(0, 2, 1) 155 | # R_pytorch3d = R.clone() 156 | T_pytorch3d = tvec.clone() 157 | R_pytorch3d[:, :, :2] *= -1 158 | T_pytorch3d[:, :2] *= -1 159 | # T_pytorch3d[:, 0] *= -1 160 | 161 | return PerspectiveCameras( 162 | R=R_pytorch3d, 163 | T=T_pytorch3d, 164 | focal_length=focal_pytorch3d, 165 | principal_point=p0_pytorch3d, 166 | ) 167 | 168 | 169 | def clean_triangle_faces(mesh, train_dataset): 170 | # returns a mask of triangles that reprojects on at least nb_visible images 171 | num_view = train_dataset.__len__() 172 | K = train_dataset.intrinsics[:3, :3].unsqueeze(0).repeat([num_view, 1, 1]) 173 | R = train_dataset.poses[:, :3, :3].transpose(2, 1) 174 | t = - train_dataset.poses[:, :3, :3].transpose(2, 1) @ train_dataset.poses[:, :3, 3:] 175 | sizes = torch.Tensor([[train_dataset.w, train_dataset.h]]).repeat([num_view, 1]) 176 | cams = [K, R, t, sizes] 177 | num_faces = len(mesh.faces) 178 | nb_visible = 1 179 | count = torch.zeros(num_faces, device="cuda") 180 | K, R, t, sizes = cams[:4] 181 | 182 | n = len(K) 183 | with torch.no_grad(): 184 | for i in tqdm(range(n), desc="clean_faces"): 185 | intr = torch.zeros(1, 4, 4).cuda() # 186 | intr[:, :3, :3] = K[i:i + 1] 187 | intr[:, 3, 3] = 1 188 | vertices = torch.from_numpy(mesh.vertices).cuda().float() # 189 | faces = torch.from_numpy(mesh.faces).cuda().long() # 190 | meshes = Meshes(verts=[vertices], 191 | faces=[faces]) 192 | 193 | cam = corrected_cameras_from_opencv_projection(camera_matrix=intr, R=R[i:i + 1].cuda(), # 194 | tvec=t[i:i + 1].squeeze(2).cuda(), # 195 | image_size=sizes[i:i + 1, [1, 0]].cuda()) # 196 | cam = cam.cuda() # 197 | raster_settings = rasterizer.RasterizationSettings(image_size=tuple(sizes[i, [1, 0]].long().tolist()), 198 | faces_per_pixel=1) 199 | meshRasterizer = rasterizer.MeshRasterizer(cam, raster_settings) 200 | 201 | with torch.no_grad(): 202 | ret = meshRasterizer(meshes) 203 | pix_to_face = ret.pix_to_face 204 | # pix_to_face, zbuf, bar, pixd = 205 | 206 | visible_faces = pix_to_face.view(-1).unique() 207 | count[visible_faces[visible_faces > -1]] += 1 208 | 209 | pred_visible_mask = (count >= nb_visible).cpu() 210 | 211 | mesh.update_faces(pred_visible_mask) 212 | return mesh 213 | 214 | def cull_by_bounds(points, scene_bounds): 215 | eps = 0.02 216 | inside_mask = np.all(points >= (scene_bounds[0] - eps), axis=1) & np.all(points <= (scene_bounds[1] + eps), axis=1) 217 | return inside_mask 218 | 219 | 220 | 221 | def crop_mesh(scene, mesh, subdivide=True, max_edge=0.015): 222 | vertices = mesh.vertices 223 | triangles = mesh.faces 224 | 225 | if subdivide: 226 | vertices, triangles = trimesh.remesh.subdivide_to_size(vertices, triangles, max_edge=max_edge, max_iter=10) 227 | 228 | # Cull with the bounding box first 229 | inside_mask = None 230 | scene_bounds = scene_bounds_dict[scene] 231 | if scene_bounds is not None: 232 | inside_mask = cull_by_bounds(vertices, scene_bounds) 233 | 234 | inside_mask = inside_mask[triangles[:, 0]] | inside_mask[triangles[:, 1]] | inside_mask[triangles[:, 2]] 235 | triangles = triangles[inside_mask, :] 236 | print("Processed culling by bound") 237 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 238 | # we don't need subdivided mesh to render depth 239 | mesh = trimesh.Trimesh(vertices, triangles, process=False) 240 | mesh.remove_unreferenced_vertices() 241 | return mesh 242 | 243 | 244 | def transform_mesh(scene, mesh): 245 | pc = mesh.vertices 246 | faces = mesh.faces 247 | pc = (pc / scale_dict[scene]) - np.array([translation_dict[scene]]) 248 | mesh = trimesh.Trimesh(pc, faces, process=False) 249 | return mesh 250 | 251 | 252 | def detransform_mesh(scene, mesh): 253 | pc = mesh.vertices 254 | faces = mesh.faces 255 | pc = (pc + np.array([translation_dict[scene]])) * scale_dict[scene] 256 | mesh = trimesh.Trimesh(pc, faces, process=False) 257 | return mesh 258 | 259 | if __name__ == '__main__': 260 | from models.blender_swap import BlendSwapDataset 261 | 262 | out_dir_pat = 'out_meshes/%s' 263 | scene = 'room0' 264 | 265 | f = open('confs/room0.conf') 266 | conf_text = f.read() 267 | conf_text = conf_text.replace('CASE_NAME', '') 268 | f.close() 269 | conf = ConfigFactory.parse_string(conf_text) 270 | conf['dataset.data_dir'] = conf['dataset.data_dir'].replace('CASE_NAME', '') 271 | train_dataset = BlendSwapDataset(conf['dataset']) 272 | 273 | exp_name = 'room0' 274 | mesh_name = 'neus' 275 | dir_pth = out_dir_pat % (exp_name) 276 | mesh_pth = os.path.join(dir_pth, mesh_name+'.ply') 277 | print(dir_pth, mesh_pth) 278 | mesh = trimesh.load_mesh(mesh_pth) 279 | 280 | mesh.vertices = mesh.vertices[:, [0,2,1]] * np.array([[1, -1, 1]]) 281 | mesh = transform_mesh(scene, mesh) 282 | mesh = crop_mesh(scene, mesh) 283 | # mesh.export(os.path.join(dir_pth, '%s_cropped.ply' % mesh_name)) 284 | mesh = detransform_mesh(scene, mesh) 285 | mesh.vertices = (mesh.vertices * np.array([[1, -1, 1]]))[:, [0, 2, 1]] 286 | mesh = clean_invisible_vertices(mesh, train_dataset) 287 | mesh = clean_triangle_faces(mesh, train_dataset) 288 | # mesh.export(os.path.join(dir_pth, '%s_cropped_culled.ply' % mesh_name)) 289 | mesh.vertices = mesh.vertices[:, [0, 2, 1]] * np.array([[1, -1, 1]]) 290 | mesh = transform_mesh(scene, mesh) 291 | mesh.export(os.path.join(dir_pth, '%s_cropped_culled_transformed.ply' % mesh_name)) 292 | -------------------------------------------------------------------------------- /models/tensorf2neus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.tensoRF import TensorVMSplit 3 | import numpy as np 4 | import skimage 5 | import plyfile 6 | import os 7 | # from test_ours import * 8 | 9 | 10 | def cal_n_samples(reso, step_ratio=0.5): 11 | return int(np.linalg.norm(reso)/step_ratio) 12 | 13 | 14 | def query_alpha_color(tensorf: TensorVMSplit, xyz_sampled, viewdirs=None, prior_mode='cat', requires_grad=False): 15 | # xyz_sampled: [N, 3] 16 | # viewdirs: [N_rays, 3] 17 | 18 | dim = 3 19 | if len(xyz_sampled.shape) == 3: 20 | N_rays, N_samples, _ = xyz_sampled.shape 21 | xyz_sampled = xyz_sampled.reshape(-1, 3) 22 | # if len(xyz_sampled.shape) == 2: 23 | # dim = 2 24 | # xyz_sampled = xyz_sampled.unsqueeze(0) 25 | 26 | # neighboring 8 vertices 27 | # after: xyz_samples: [N, 9, 3] or [N, 1, 3] 28 | if prior_mode.startswith('local'): 29 | xyz_sampled = query_grid_vertices(tensorf, xyz_sampled) 30 | else: 31 | xyz_sampled = xyz_sampled.unsqueeze(1) 32 | 33 | mask_outbbox = ((tensorf.aabb[0] > xyz_sampled) | (xyz_sampled > tensorf.aabb[1])).any(dim=-1) 34 | ray_valid = ~mask_outbbox 35 | 36 | if tensorf.alphaMask is not None: 37 | alphas = tensorf.alphaMask.sample_alpha(xyz_sampled[ray_valid]) 38 | alpha_mask = alphas > 0 39 | ray_invalid = ~ray_valid 40 | ray_invalid[ray_valid] |= (~alpha_mask) 41 | ray_valid = ~ray_invalid 42 | 43 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) 44 | 45 | if ray_valid.any(): 46 | xyz_sampled = tensorf.normalize_coord(xyz_sampled) 47 | sigma_feature = tensorf.compute_densityfeature(xyz_sampled[ray_valid], requires_grad=requires_grad) 48 | 49 | validsigma = tensorf.feature2density(sigma_feature) 50 | sigma[ray_valid] = validsigma 51 | 52 | # alpha, weight, bg_weight = raw2alpha(sigma, tensorf.stepSize * tensorf.distance_scale) 53 | # app_mask = weight > tensorf.rayMarch_weight_thres 54 | alpha = 1. - torch.exp(-sigma * tensorf.stepSize * tensorf.distance_scale) 55 | # alpha = 1-alpha 56 | 57 | if prior_mode == 'local_mean': 58 | alpha = torch.mean(alpha, dim=1, keepdim=True) 59 | 60 | if viewdirs is None: 61 | # alpha: [N, 1] or [N, 9] 62 | return alpha, None 63 | 64 | 65 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) 66 | viewdirs = viewdirs.reshape(-1, 1, 3).expand(xyz_sampled.shape) 67 | app_mask = ray_valid 68 | if app_mask.any(): 69 | app_features = tensorf.compute_appfeature(xyz_sampled[app_mask]) 70 | valid_rgbs = tensorf.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features) 71 | rgb[app_mask] = valid_rgbs 72 | 73 | # alpha: [N, 1] or [N, 9] 74 | # rgb: [N, 3] or [N, 9*3] 75 | if prior_mode == 'local_mean': 76 | rgb = torch.mean(rgb, dim=1, keepdim=False) 77 | elif prior_mode == 'local_cat': 78 | rgb = rgb.reshape(rgb.shape[0], -1) 79 | else: 80 | rgb = rgb.reshape(-1, 3) 81 | return alpha, rgb 82 | 83 | 84 | def load_tensorf(ckpt, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")): 85 | ckpt = torch.load(ckpt, map_location=device) 86 | kwargs = ckpt['kwargs'] 87 | kwargs.update({'device': device}) 88 | tensorf = TensorVMSplit(**kwargs) 89 | tensorf.load(ckpt) 90 | 91 | # 把密集体素网格的坐标点预先加载好 92 | gridSize = tensorf.gridSize 93 | samples = torch.stack(torch.meshgrid( 94 | torch.linspace(0, 1, gridSize[0]), 95 | torch.linspace(0, 1, gridSize[1]), 96 | torch.linspace(0, 1, gridSize[2]), 97 | ), -1).to(tensorf.device) 98 | dense_xyz = tensorf.aabb[0] * (1 - samples) + tensorf.aabb[1] * samples 99 | tensorf.dense_xyz = dense_xyz # [X, Y, Z, 3] 100 | 101 | return tensorf 102 | 103 | 104 | def construct_alpha_color_grid(tensorf: TensorVMSplit): 105 | alpha, dense_xyz = tensorf.getDenseAlpha() 106 | tensorf.alpha_grid = alpha 107 | 108 | rgb = alpha = torch.zeros(dense_xyz.shape[:3]+[3]) 109 | 110 | 111 | def query_grid_vertices(tensorf: TensorVMSplit, pts): 112 | # pts: [N, 3] 113 | # return: [N, 9, 3] 114 | bbox = tensorf.aabb 115 | pts_normalized = (pts-bbox[0]) / (bbox[1] - bbox[0]) # \in [0, 1] 116 | pts_gridded = pts_normalized * (tensorf.gridSize-1) # \in [0, 256] 117 | pts_lower = torch.floor(pts_gridded) # [N, 3] 118 | verts = pts_lower.unsqueeze(1).repeat([1,8,1]) 119 | # x-, y-, z- 120 | 121 | # x+, y-, z- 122 | verts[:, 1, 0] += 1 123 | # x+, y+, z- 124 | verts[:, 2, :2] += 1 125 | # x-, y+, z- 126 | verts[:, 3, 1] += 1 127 | # x-, y-, z+ 128 | verts[:, 4, 2] += 1 129 | # x+, y-, z+ 130 | verts[:, 5, 0] += 1 131 | verts[:, 5, 2] += 1 132 | # x+, y+, z+ 133 | verts[:, 6, :] += 1 134 | # x-, y+, z+ 135 | verts[:, 7, 1:] += 1 136 | 137 | mask_outbbox = ((pts_gridded <= 0) | (pts_gridded >= (tensorf.gridSize-1))).any(dim=-1) # [N,] 138 | mask_outbbox = mask_outbbox.unsqueeze(-1).repeat([1, 8]) # [N, 9] 139 | mask_inbbox = ~mask_outbbox 140 | 141 | # [N,8,3] 142 | grid_vertices = pts.unsqueeze(1).repeat([1,8,1]) 143 | verts = verts.long()[mask_inbbox] 144 | grid_vertices[mask_inbbox] = tensorf.dense_xyz[verts[:,0], verts[:,1], verts[:,2]] 145 | # grid_vertices = tensorf.dense_xyz[verts[:,:,0], verts[:,:,1], verts[:,:,2]] 146 | all_vertices = torch.cat([pts.unsqueeze(1), grid_vertices], 1) 147 | 148 | 149 | return all_vertices 150 | 151 | 152 | def query_occupancy_confidence(model, pts): 153 | # model: rendering.network 154 | # pts: [N, 3] 155 | N, _ = pts.shape 156 | grid_points = query_grid_vertices(model.tensorf, pts) # [N, 9, 3] 157 | grid_points = grid_points.reshape(N*9, 3) 158 | 159 | g = [] 160 | occs = [] 161 | chunk = N 162 | for i in range(0, N*9, chunk): 163 | with torch.enable_grad(): 164 | p = grid_points[i:i+N] 165 | p.requires_grad_(True) 166 | queried_occ, _ = query_alpha_color(model.tensorf, p, None, prior_mode='cat', requires_grad=True) 167 | y = queried_occ 168 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 169 | gradients = torch.autograd.grad( 170 | outputs=y, 171 | inputs=p, 172 | grad_outputs=d_output, 173 | create_graph=True, 174 | retain_graph=True, 175 | only_inputs=True, allow_unused=True)[0] 176 | _g = gradients.unsqueeze(1) 177 | grid_points.requires_grad_(False) 178 | g.append(_g.detach()) 179 | occs.append(queried_occ.squeeze()) 180 | del _g 181 | 182 | g = torch.cat(g, 0).squeeze().reshape(N, 9, 3) # [N, 9, 3] 183 | occs = torch.cat(occs, 0).reshape(N, 9) # only for test 184 | normals_ = g[:, :, :] / (g[:, :, :].norm(2, dim=2).unsqueeze(-1) + 10 ** (-5)) 185 | normals_ = normals_.reshape(N, 9, 3) 186 | 187 | 188 | confidence = torch.var(normals_, 1).sum(1) 189 | length = g[:, :, :].norm(2, dim=2) 190 | confidence = torch.var(length, 1) 191 | confidence = torch.var(occs, 1) 192 | ref_pts = normals_[:, 0:1, :].repeat([1,8,1]) 193 | grid_pts = normals_[:, 1:, :] 194 | cosine = torch.mul(ref_pts, grid_pts).sum(-1) 195 | confidence = torch.var(cosine, 1) 196 | 197 | 198 | return confidence, normals_[:, 0, :] 199 | 200 | 201 | def query_entropy(occ, entropy_num=5): 202 | # pts: [N_rays, N_samples, 3], occ: [N_rays, N_samples] 203 | # pts_4_ = pts[:, :-4, :] 204 | # pts_3_ = pts[:, 1:-3, :] 205 | expand = (entropy_num-1)//2 206 | occs = [] 207 | for i in range(expand): 208 | occ_ = occ[:, i:(-expand-(expand-i))] 209 | occs.append(occ_) 210 | occs.append(occ[:, expand:-expand]) 211 | for i in range(expand): 212 | if i != expand -1: 213 | occ_ = occ[:, expand+i+1:(-expand+i+1)] 214 | else: 215 | occ_ = occ[:, expand+i+1:] 216 | occs.append(occ_) 217 | # occ_2_ = occ[:, 0:-4] 218 | # occ_1_ = occ[:, 1:-3] 219 | # occ_ = occ[:, 2:-2] 220 | # occ_1 = occ[:, 3:-1] 221 | # occ_2 = occ[:, 4:] 222 | # occ_5around = torch.stack([occ_2_, occ_1_, occ_, occ_1, occ_2], 2) # [N_rays, N_samples, 5] 223 | occ_around = torch.stack(occs, 2) 224 | occ_around = torch.softmax(occ_around, 2) 225 | log_occ_around = torch.log(occ_around+1e-8) 226 | entropy = (-occ_around * log_occ_around).sum(2) # [N_rays, N_samples] 227 | 228 | return entropy 229 | 230 | 231 | def tensorf_volume_rendering(xyz_sampled, z_vals, camera_world, viewdirs, tensorf): 232 | # tensorf volume rendering 233 | # xyz_sampled: [N_rays, N_samples, 3] 234 | # z_vals: [N_rays, N_samples] 235 | # viewdirs: [N_rays, 3] 236 | 237 | # from test_ours import visual_ray_points 238 | # visual_ray_points(xyz_sampled) 239 | 240 | # xyz_sampled, z_vals, ray_valid = tensorf.sample_ray(camera_world, viewdirs, is_train=False, N_samples=tensorf.nSamples) 241 | 242 | mask_outbbox = ((tensorf.aabb[0] > xyz_sampled) | (xyz_sampled > tensorf.aabb[1])).any(dim=-1) 243 | ray_valid = ~mask_outbbox 244 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 245 | 246 | 247 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape) 248 | 249 | if tensorf.alphaMask is not None: 250 | alphas = tensorf.alphaMask.sample_alpha(xyz_sampled[ray_valid]) 251 | alpha_mask = alphas > 0 252 | ray_invalid = ~ray_valid 253 | ray_invalid[ray_valid] |= (~alpha_mask) 254 | ray_valid = ~ray_invalid 255 | 256 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) 257 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) 258 | 259 | if ray_valid.any(): 260 | xyz_sampled = tensorf.normalize_coord(xyz_sampled) 261 | sigma_feature = tensorf.compute_densityfeature(xyz_sampled[ray_valid]) 262 | 263 | validsigma = tensorf.feature2density(sigma_feature) 264 | sigma[ray_valid] = validsigma 265 | 266 | alpha, weight, bg_weight = raw2alpha(sigma, dists * tensorf.distance_scale) 267 | 268 | app_mask = weight > tensorf.rayMarch_weight_thres 269 | 270 | if app_mask.any(): 271 | app_features = tensorf.compute_appfeature(xyz_sampled[app_mask]) 272 | valid_rgbs = tensorf.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features) 273 | rgb[app_mask] = valid_rgbs 274 | 275 | acc_map = torch.sum(weight, -1) 276 | rgb_map = torch.sum(weight[..., None] * rgb, -2) 277 | 278 | return rgb_map 279 | 280 | 281 | ### deprecated 282 | def raw2alpha(sigma, dist): 283 | # sigma, dist [N_rays, N_samples] 284 | alpha = 1. - torch.exp(-sigma*dist) 285 | 286 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) 287 | 288 | weights = alpha * T[:, :-1] # [N_rays, N_samples] 289 | return alpha, weights, T[:,-1:] 290 | 291 | def convert_sdf_samples_to_ply( 292 | pytorch_3d_sdf_tensor, 293 | ply_filename_out, 294 | bbox, 295 | level=0.5, 296 | offset=None, 297 | scale=None, 298 | ): 299 | """ 300 | Convert sdf samples to .ply 301 | 302 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n) 303 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid 304 | :voxel_size: float, the size of the voxels 305 | :ply_filename_out: string, path of the filename to save to 306 | 307 | This function adapted from: https://github.com/RobotLocomotion/spartan 308 | """ 309 | 310 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy() 311 | voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape)) 312 | 313 | verts, faces, normals, values = skimage.measure.marching_cubes( 314 | numpy_3d_sdf_tensor, level=level, spacing=voxel_size 315 | ) 316 | faces = faces[...,::-1] # inverse face orientation 317 | 318 | # transform from voxel coordinates to camera coordinates 319 | # note x and y are flipped in the output of marching_cubes 320 | mesh_points = np.zeros_like(verts) 321 | mesh_points[:, 0] = bbox[0,0] + verts[:, 0] 322 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1] 323 | mesh_points[:, 2] = bbox[0,2] + verts[:, 2] 324 | 325 | # apply additional offset and scale 326 | if scale is not None: 327 | mesh_points = mesh_points / scale 328 | if offset is not None: 329 | mesh_points = mesh_points - offset 330 | 331 | # try writing to the ply file 332 | 333 | num_verts = verts.shape[0] 334 | num_faces = faces.shape[0] 335 | 336 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")]) 337 | 338 | for i in range(0, num_verts): 339 | verts_tuple[i] = tuple(mesh_points[i, :]) 340 | 341 | faces_building = [] 342 | for i in range(0, num_faces): 343 | faces_building.append(((faces[i, :].tolist(),))) 344 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))]) 345 | 346 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex") 347 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face") 348 | 349 | ply_data = plyfile.PlyData([el_verts, el_faces]) 350 | print("saving mesh to %s" % (ply_filename_out)) 351 | ply_data.write(ply_filename_out) -------------------------------------------------------------------------------- /models/blender_swap.py: -------------------------------------------------------------------------------- 1 | import torch, cv2 2 | from torch.utils.data import Dataset 3 | import json 4 | from tqdm import tqdm 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision import transforms as T 9 | 10 | from .ray_utils import * 11 | 12 | 13 | class BlendSwapDataset(Dataset): 14 | def __init__(self, conf, split='train', N_vis=-1): 15 | self.device = torch.device('cuda') 16 | self.N_vis = N_vis 17 | self.root_dir = conf.get_string('data_dir') 18 | scene = conf.get_string('scene') 19 | self.root_dir = os.path.join(self.root_dir, scene) 20 | 21 | self.split = split 22 | self.is_stack = False 23 | self.downsample = 1.0 24 | self.define_transforms() 25 | 26 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 27 | self.read_meta() 28 | # self.define_proj_mat() 29 | 30 | self.white_bg = True 31 | 32 | def read_meta(self): 33 | 34 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f: 35 | self.meta = json.load(f) 36 | 37 | w, h = int(self.meta['w'] / self.downsample), int(self.meta['h'] / self.downsample) 38 | self.img_wh = [w, h] 39 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length 40 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length 41 | self.cx, self.cy = self.meta['cx'], self.meta['cy'] 42 | 43 | # ray directions for all pixels, same for all images (same H, W, focal) 44 | self.directions = get_ray_directions(h, w, [self.focal_x, self.focal_y], center=[self.cx, self.cy]) # (h, w, 3) 45 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) 46 | self.intrinsics = torch.tensor([[self.focal_x, 0, self.cx], [0, self.focal_y, self.cy], [0, 0, 1]]).float() 47 | 48 | self.image_paths = [] 49 | self.poses = [] 50 | self.all_rays = [] 51 | self.all_rgbs = [] 52 | self.all_pretrained_rgbs = [] 53 | self.all_depth = [] 54 | self.all_rgb_std = [] 55 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis 56 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval)) 57 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:# 58 | frame = self.meta['frames'][i] 59 | pose = np.array(frame['transform_matrix']) 60 | pose = pose @ self.blender2opencv 61 | c2w = torch.FloatTensor(pose) 62 | self.poses.append(c2w) 63 | 64 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}") 65 | self.image_paths += [image_path] 66 | img = Image.open(image_path) 67 | 68 | img = self.transform(img) # (4, h, w) 69 | 70 | rgb_std = self.cal_rgb_std(img.permute(1, 2, 0)) 71 | rgb_std = np.where(rgb_std > 10 / 255.0, 1.0, 0.0) 72 | self.all_rgb_std.append(torch.Tensor(rgb_std).cpu()) 73 | 74 | img = img.view(-1, w * h).permute(1, 0) # (h*w, 4) RGBA 75 | if img.shape[-1] == 4: 76 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB 77 | self.all_rgbs.append(img) 78 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3) 79 | self.all_rays.append(torch.cat([rays_o, rays_d], 1)) # (h*w, 6) 80 | 81 | self.poses = torch.stack(self.poses) #(N, 4, 4) 82 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 6) 83 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3) 84 | 85 | self.h, self.w = h, w 86 | # [w*h, 9] 87 | self.all_neighbor_idx = self.query_neighbor_idx(torch.arange(w*h)) 88 | 89 | # for Neus exp_runner 90 | self.n_images = len(self.image_paths) 91 | self.pose_all = self.poses 92 | self.object_bbox_min = np.array([-1.01, -1.01, -1.01]) # only used in extract 93 | self.object_bbox_max = np.array([ 1.01, 1.01, 1.01]) 94 | self.scale_mats_np = [np.eye(4)] 95 | 96 | def define_transforms(self): 97 | self.transform = T.ToTensor() 98 | 99 | # def define_proj_mat(self): 100 | # self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3] 101 | 102 | def __len__(self): 103 | return len(self.all_rgbs) 104 | 105 | def cal_rgb_std(self, img): 106 | """ 107 | :param img: [h, w, 3], rgb\in [0, 255] 108 | :return: [h, w] \in [0, 1] 109 | """ 110 | img = np.array(img, np.float64) 111 | kernel = (3, 3) 112 | kernel = (9, 9) 113 | E_square = cv2.blur(img ** 2, kernel) 114 | square_E = cv2.blur(img, kernel) ** 2 115 | sharp_img = np.sqrt(np.abs(E_square - square_E)) 116 | 117 | gray_img = sharp_img.max(2) 118 | gray_min = np.min(gray_img) 119 | gray_max = np.max(gray_img) 120 | gray_img = np.clip(gray_img, gray_min, gray_max) 121 | gray_img = (gray_img - gray_min) / (gray_max - gray_min + 1e-6) 122 | return gray_img 123 | 124 | def query_neighbor_idx(self, idx): 125 | # query idx's neighbors. idx: [N,] 126 | # 0 1 2 127 | # 3 * 4 128 | # 5 6 7 129 | # return: [N, 9] 130 | pad = 4 131 | row, col = idx // self.w, idx % self.w 132 | r0 = r1 = r2 = torch.maximum(row-pad, torch.zeros_like(row)) 133 | r3 = r4 = row 134 | r5 = r6 = r7 = torch.minimum(row+pad, torch.ones_like(row)*(self.h-1)) 135 | c0 = c3 = c5 = torch.maximum(col-pad, torch.zeros_like(col)) 136 | c1 = c6 = col 137 | c2 = c4 = c7 = torch.minimum(col+pad, torch.ones_like(col)*(self.w-1)) 138 | 139 | idx0 = r0 * self.w + c0 140 | idx1 = r1 * self.w + c1 141 | idx2 = r2 * self.w + c2 142 | idx3 = r3 * self.w + c3 143 | idx4 = r4 * self.w + c4 144 | idx5 = r5 * self.w + c5 145 | idx6 = r6 * self.w + c6 146 | idx7 = r7 * self.w + c7 147 | neighbor_idx = torch.stack([idx0, idx1, idx2, idx3, idx, idx4, idx5, idx6, idx7], 1) # [N, 9] 148 | border_mask = (idx % self.w == 0) | (idx % self.w == (self.w - 1)) | (idx // self.w == 0) | ( 149 | idx // self.w == (self.h - 1)) 150 | neighbor_idx[border_mask] = idx[border_mask].unsqueeze(1).repeat([1, 9]) 151 | return neighbor_idx 152 | 153 | 154 | 155 | # def gen_random_rays_at(self, img_idx, batch_size): 156 | # pixels_x = torch.randint(low=0, high=self.w, size=[batch_size]) 157 | # pixels_y = torch.randint(low=0, high=self.h, size=[batch_size]) 158 | # color = self.all_rgbs[img_idx] # [h, w, 3] 159 | # color = color[(pixels_y, pixels_x)] # [batch_size, 3] 160 | # mask = torch.ones_like(color, dtype=torch.float) 161 | # all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6] 162 | # rand_rays = all_rays[(pixels_y, pixels_x)] # [batch_size, 6] 163 | # return torch.cat([rand_rays, color, mask[:, :1]], dim=-1).to(self.device) 164 | 165 | def gen_random_rays_at(self, img_idx, batch_size, mode): 166 | if mode == 'batch': 167 | pixels_x = torch.randint(low=0, high=self.w, size=[batch_size]) 168 | pixels_y = torch.randint(low=0, high=self.h, size=[batch_size]) 169 | 170 | color = self.all_rgbs[img_idx] # [h, w, 3] 171 | color = color[(pixels_y, pixels_x)] # [batch_size, 3] 172 | mask = torch.ones_like(color, dtype=torch.float) 173 | all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6] 174 | rand_rays = all_rays[(pixels_y, pixels_x)] # [batch_size, 6] 175 | return torch.cat([rand_rays, color, mask[:, :1]], dim=-1).to(self.device) 176 | elif mode == 'patch': 177 | pixels_x = torch.randint(low=0, high=self.w, size=[batch_size // 9]) 178 | pixels_y = torch.randint(low=0, high=self.h, size=[batch_size // 9]) 179 | pixel_idx = pixels_y * self.h + pixels_x 180 | rand_idx = self.all_neighbor_idx[pixel_idx].reshape(-1) 181 | pixels_x = rand_idx % self.w 182 | pixels_y = rand_idx // self.w 183 | patch_rgb_std = self.all_rgb_std[img_idx].reshape(-1, 1)[rand_idx] 184 | color = self.all_rgbs[img_idx] # [h, w, 3] 185 | color = color[(pixels_y, pixels_x)] # [batch_size, 3] 186 | mask = torch.ones_like(color, dtype=torch.float) 187 | all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6] 188 | rand_rays = all_rays[(pixels_y, pixels_x)] # [batch_size, 6] 189 | return torch.cat([rand_rays, color, mask[:, :1], patch_rgb_std], dim=-1).to(self.device) 190 | 191 | 192 | def near_far_from_sphere(self, rays_o, rays_d): 193 | # copied from dataset.py 194 | # a = torch.sum(rays_d**2, dim=-1, keepdim=True) 195 | # b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True) 196 | # mid = 0.5 * (-b) / a 197 | # near = mid - 1.0 198 | # far = mid + 1.0 199 | 200 | near = torch.zeros(rays_o.shape[0], 1).cuda() 201 | far = torch.ones(rays_o.shape[0], 1).cuda() * 3 202 | return near, far 203 | 204 | def gen_rays_at(self, img_idx, resolution_level=1): 205 | all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6] 206 | rays_o = all_rays[:, :, :3].to(self.device) 207 | rays_d = all_rays[:, :, 3:].to(self.device) 208 | return rays_o, rays_d 209 | 210 | def image_at(self, idx, resolution_level): 211 | img = cv2.imread(self.image_paths[idx]) 212 | return (cv2.resize(img, (self.w // resolution_level, self.h // resolution_level))).clip(0, 255) 213 | 214 | 215 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1): 216 | # only used in novel view synthesis 217 | raise NotImplementedError() 218 | 219 | def __getitem__(self, idx): 220 | sample = {} 221 | img_raynum = self.h*self.w 222 | if self.split == 'train': 223 | imgid = self.__len__() - 1 - idx 224 | # imgid = 100 225 | all_rays = self.all_rays[imgid] 226 | rand_idx = np.random.choice(img_raynum, self.batch_size, replace=False) 227 | # rand_idx = np.arange(10740, 10840) 228 | 229 | # batch_rand_idx = np.random.choice(img_raynum, self.batch_size // 9, replace=False) 230 | # # batch_rand_idx = np.array([64500, 64510, 64520, 64530, 64540, 64550, 64560, 64570, 64580, 64590, 64560]) # debug 231 | # rand_idx = self.all_neighbor_idx[batch_rand_idx].reshape(-1) # [N*9] 232 | # sample.update({'patch_rgb_std': self.all_rgb_std[imgid].reshape(-1)[batch_rand_idx]}) 233 | sample.update({'patch_rgb_std': self.all_rgb_std[imgid].reshape(-1)[rand_idx]}) 234 | 235 | sample.update({'rays_o': all_rays[rand_idx, :3], 236 | 'rays_d': all_rays[rand_idx, 3:6], 237 | 'rgb': self.all_rgbs[imgid].reshape(-1, 3)[rand_idx], 238 | # 'all_trg_rays_o': all_rays[:, :3], 239 | # 'all_trg_rays_d': all_rays[:, 3:6], 240 | # 'trg_rgb': self.all_rgbs[imgid], 241 | # 'trg_extrinsics': self.poses[imgid], 242 | 'idx': idx, 243 | 'imgid': imgid, 244 | 'h': self.h, 245 | 'w': self.w, 246 | 'intrinsics': self.intrinsics, 247 | 'pixel_idx': rand_idx, 248 | }) 249 | 250 | # debug 251 | # rays_d = sample['rays_d'].unsqueeze(0) 252 | # rays_o = sample['rays_o'].unsqueeze(0) 253 | # c2w = self.poses[imgid] 254 | # world_pos = rays_d * 0.5 + rays_o # [W, H, 3] 255 | # world_pos = torch.cat([world_pos, torch.ones([rays_d.shape[0], rays_d.shape[1], 1])], 2) # [h*w, 4] 256 | # # K * R^-1 * xyz 257 | # camera_pos = torch.einsum('ij,abj->abi', c2w.inverse(), world_pos) # [W, H, 4] 258 | # camera_pos = torch.einsum('abj,ij->abi', world_pos, c2w.inverse()) 259 | # # camera_pos[..., 1] *= -1 260 | # # camera_pos[..., 2] *= -1 261 | # uv1 = torch.einsum('ij, abj->abi', self.intrinsics, camera_pos[..., :3]) # [W, H, 3] 262 | # uv1[..., :] /= uv1[..., 2:] 263 | # uv1 = uv1[..., :2] - 0.5 264 | # uv1 = uv1[:,:,[1,0]] 265 | 266 | 267 | # newview 268 | src_view_num = 3 269 | src_offsets = np.random.choice(np.concatenate((np.arange(-60, -9), np.arange(10, 61))), src_view_num, replace=False) 270 | src_idx = [] 271 | for offset in src_offsets: 272 | _idx = imgid + offset 273 | if _idx >= len(self.all_rgbs) or _idx < 0: 274 | _idx = imgid - offset 275 | src_idx.append(_idx) 276 | src_imgid = src_idx 277 | src_all_rays = self.all_rays[src_imgid] 278 | # sample.update({'src_rgb': self.all_rgbs[src_imgid], 279 | # 'src_extrinsics': self.poses[src_imgid], 280 | # 'src_rays_o': src_all_rays[..., :3].reshape(src_view_num, self.h, self.w, 3), 281 | # 'src_rays_d': src_all_rays[..., 3:6].reshape(src_view_num, self.h, self.w, 3), 282 | # 'offsets': src_offsets,}) 283 | 284 | return sample 285 | elif self.split == 'test': 286 | idx = np.random.randint(self.__len__()) 287 | # idx = 901 288 | all_rays = self.all_rays[idx] 289 | sample = {'rays_o': all_rays[:, :3], 290 | 'rays_d': all_rays[:, 3:6], 291 | 'rgb': self.all_rgbs[idx].reshape(-1,3), 292 | 'idx': idx} 293 | return sample 294 | -------------------------------------------------------------------------------- /models/tensorBase.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | from models.sh import eval_sh_bases 5 | import numpy as np 6 | import time 7 | 8 | 9 | def positional_encoding(positions, freqs): 10 | freq_bands = (2 ** torch.arange(freqs).float()).to(positions.device) # (F,) 11 | pts = (positions[..., None] * freq_bands).reshape( 12 | positions.shape[:-1] + (freqs * positions.shape[-1],)) # (..., DF) 13 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) 14 | return pts 15 | 16 | 17 | def raw2alpha(sigma, dist): 18 | # sigma, dist [N_rays, N_samples] 19 | alpha = 1. - torch.exp(-sigma * dist) 20 | 21 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) 22 | 23 | weights = alpha * T[:, :-1] # [N_rays, N_samples] 24 | return alpha, weights, T[:, -1:] 25 | 26 | 27 | def SHRender(xyz_sampled, viewdirs, features): 28 | sh_mult = eval_sh_bases(2, viewdirs)[:, None] 29 | rgb_sh = features.view(-1, 3, sh_mult.shape[-1]) 30 | rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) 31 | return rgb 32 | 33 | 34 | def RGBRender(xyz_sampled, viewdirs, features): 35 | rgb = features 36 | return rgb 37 | 38 | 39 | class AlphaGridMask(torch.nn.Module): 40 | def __init__(self, device, aabb, alpha_volume): 41 | super(AlphaGridMask, self).__init__() 42 | self.device = device 43 | 44 | self.aabb = aabb.to(self.device) 45 | self.aabbSize = self.aabb[1] - self.aabb[0] 46 | self.invgridSize = 1.0 / self.aabbSize * 2 47 | self.alpha_volume = alpha_volume.view(1, 1, *alpha_volume.shape[-3:]) 48 | self.gridSize = torch.LongTensor([alpha_volume.shape[-1], alpha_volume.shape[-2], alpha_volume.shape[-3]]).to( 49 | self.device) 50 | 51 | def sample_alpha(self, xyz_sampled): 52 | xyz_sampled = self.normalize_coord(xyz_sampled) 53 | alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1) 54 | 55 | return alpha_vals 56 | 57 | def normalize_coord(self, xyz_sampled): 58 | return (xyz_sampled - self.aabb[0]) * self.invgridSize - 1 59 | 60 | 61 | class MLPRender_Fea(torch.nn.Module): 62 | def __init__(self, inChanel, viewpe=6, feape=6, featureC=128): 63 | super(MLPRender_Fea, self).__init__() 64 | 65 | self.in_mlpC = 2 * viewpe * 3 + 2 * feape * inChanel + 3 + inChanel 66 | self.viewpe = viewpe 67 | self.feape = feape 68 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 69 | layer2 = torch.nn.Linear(featureC, featureC) 70 | layer3 = torch.nn.Linear(featureC, 3) 71 | 72 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 73 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 74 | 75 | def forward(self, pts, viewdirs, features): 76 | indata = [features, viewdirs] 77 | if self.feape > 0: 78 | indata += [positional_encoding(features, self.feape)] 79 | if self.viewpe > 0: 80 | indata += [positional_encoding(viewdirs, self.viewpe)] 81 | mlp_in = torch.cat(indata, dim=-1) 82 | rgb = self.mlp(mlp_in) 83 | rgb = torch.sigmoid(rgb) 84 | 85 | return rgb 86 | 87 | 88 | class MLPRender_PE(torch.nn.Module): 89 | def __init__(self, inChanel, viewpe=6, pospe=6, featureC=128): 90 | super(MLPRender_PE, self).__init__() 91 | 92 | self.in_mlpC = (3 + 2 * viewpe * 3) + (3 + 2 * pospe * 3) + inChanel # 93 | self.viewpe = viewpe 94 | self.pospe = pospe 95 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 96 | layer2 = torch.nn.Linear(featureC, featureC) 97 | layer3 = torch.nn.Linear(featureC, 3) 98 | 99 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 100 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 101 | 102 | def forward(self, pts, viewdirs, features): 103 | indata = [features, viewdirs] 104 | if self.pospe > 0: 105 | indata += [positional_encoding(pts, self.pospe)] 106 | if self.viewpe > 0: 107 | indata += [positional_encoding(viewdirs, self.viewpe)] 108 | mlp_in = torch.cat(indata, dim=-1) 109 | rgb = self.mlp(mlp_in) 110 | rgb = torch.sigmoid(rgb) 111 | 112 | return rgb 113 | 114 | 115 | class MLPRender(torch.nn.Module): 116 | def __init__(self, inChanel, viewpe=6, featureC=128): 117 | super(MLPRender, self).__init__() 118 | 119 | self.in_mlpC = (3 + 2 * viewpe * 3) + inChanel 120 | self.viewpe = viewpe 121 | 122 | layer1 = torch.nn.Linear(self.in_mlpC, featureC) 123 | layer2 = torch.nn.Linear(featureC, featureC) 124 | layer3 = torch.nn.Linear(featureC, 3) 125 | 126 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) 127 | torch.nn.init.constant_(self.mlp[-1].bias, 0) 128 | 129 | def forward(self, pts, viewdirs, features): 130 | indata = [features, viewdirs] 131 | if self.viewpe > 0: 132 | indata += [positional_encoding(viewdirs, self.viewpe)] 133 | mlp_in = torch.cat(indata, dim=-1) 134 | rgb = self.mlp(mlp_in) 135 | rgb = torch.sigmoid(rgb) 136 | 137 | return rgb 138 | 139 | 140 | class TensorBase(torch.nn.Module): 141 | def __init__(self, aabb, gridSize, device, density_n_comp=8, appearance_n_comp=24, app_dim=27, 142 | shadingMode='MLP_PE', alphaMask=None, near_far=[2.0, 6.0], 143 | density_shift=-10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001, 144 | pos_pe=6, view_pe=6, fea_pe=6, featureC=128, step_ratio=2.0, 145 | fea2denseAct='softplus'): 146 | super(TensorBase, self).__init__() 147 | 148 | self.density_n_comp = density_n_comp 149 | self.app_n_comp = appearance_n_comp 150 | self.app_dim = app_dim 151 | self.aabb = aabb 152 | self.alphaMask = alphaMask 153 | self.device = device 154 | 155 | self.density_shift = density_shift 156 | self.alphaMask_thres = alphaMask_thres 157 | self.distance_scale = distance_scale 158 | self.rayMarch_weight_thres = rayMarch_weight_thres 159 | self.fea2denseAct = fea2denseAct 160 | 161 | self.near_far = near_far 162 | self.step_ratio = step_ratio 163 | 164 | self.update_stepSize(gridSize) 165 | 166 | self.matMode = [[0, 1], [0, 2], [1, 2]] 167 | self.vecMode = [2, 1, 0] 168 | self.comp_w = [1, 1, 1] 169 | 170 | self.init_svd_volume(gridSize[0], device) 171 | 172 | self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC 173 | self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device) 174 | 175 | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device): 176 | if shadingMode == 'MLP_PE': 177 | self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device) 178 | elif shadingMode == 'MLP_Fea': 179 | self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device) 180 | elif shadingMode == 'MLP': 181 | self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device) 182 | elif shadingMode == 'SH': 183 | self.renderModule = SHRender 184 | elif shadingMode == 'RGB': 185 | assert self.app_dim == 3 186 | self.renderModule = RGBRender 187 | else: 188 | print("Unrecognized shading module") 189 | exit() 190 | print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe) 191 | print(self.renderModule) 192 | 193 | def update_stepSize(self, gridSize): 194 | print("aabb", self.aabb.view(-1)) 195 | print("grid size", gridSize) 196 | self.aabbSize = self.aabb[1] - self.aabb[0] 197 | self.invaabbSize = 2.0 / self.aabbSize 198 | self.gridSize = torch.LongTensor(gridSize).to(self.device) 199 | self.units = self.aabbSize / (self.gridSize - 1) 200 | self.stepSize = torch.mean(self.units) * self.step_ratio 201 | self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) 202 | self.nSamples = int((self.aabbDiag / self.stepSize).item()) + 1 203 | print("sampling step size: ", self.stepSize) 204 | print("sampling number: ", self.nSamples) 205 | 206 | def init_svd_volume(self, res, device): 207 | pass 208 | 209 | def compute_features(self, xyz_sampled): 210 | pass 211 | 212 | def compute_densityfeature(self, xyz_sampled): 213 | pass 214 | 215 | def compute_appfeature(self, xyz_sampled): 216 | pass 217 | 218 | def normalize_coord(self, xyz_sampled): 219 | return (xyz_sampled - self.aabb[0]) * self.invaabbSize - 1 220 | 221 | def get_optparam_groups(self, lr_init_spatial=0.02, lr_init_network=0.001): 222 | pass 223 | 224 | def get_kwargs(self): 225 | return { 226 | 'aabb': self.aabb, 227 | 'gridSize': self.gridSize.tolist(), 228 | 'density_n_comp': self.density_n_comp, 229 | 'appearance_n_comp': self.app_n_comp, 230 | 'app_dim': self.app_dim, 231 | 232 | 'density_shift': self.density_shift, 233 | 'alphaMask_thres': self.alphaMask_thres, 234 | 'distance_scale': self.distance_scale, 235 | 'rayMarch_weight_thres': self.rayMarch_weight_thres, 236 | 'fea2denseAct': self.fea2denseAct, 237 | 238 | 'near_far': self.near_far, 239 | 'step_ratio': self.step_ratio, 240 | 241 | 'shadingMode': self.shadingMode, 242 | 'pos_pe': self.pos_pe, 243 | 'view_pe': self.view_pe, 244 | 'fea_pe': self.fea_pe, 245 | 'featureC': self.featureC 246 | } 247 | 248 | def save(self, path): 249 | kwargs = self.get_kwargs() 250 | ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()} 251 | if self.alphaMask is not None: 252 | alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy() 253 | ckpt.update({'alphaMask.shape': alpha_volume.shape}) 254 | ckpt.update({'alphaMask.mask': np.packbits(alpha_volume.reshape(-1))}) 255 | ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()}) 256 | torch.save(ckpt, path) 257 | 258 | def load(self, ckpt): 259 | if 'alphaMask.aabb' in ckpt.keys(): 260 | length = np.prod(ckpt['alphaMask.shape']) 261 | alpha_volume = torch.from_numpy( 262 | np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape'])) 263 | self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), 264 | alpha_volume.float().to(self.device)) 265 | self.load_state_dict(ckpt['state_dict']) 266 | 267 | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): 268 | N_samples = N_samples if N_samples > 0 else self.nSamples 269 | near, far = self.near_far 270 | interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) 271 | if is_train: 272 | interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) 273 | 274 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] 275 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) 276 | return rays_pts, interpx, ~mask_outbbox 277 | 278 | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): 279 | N_samples = N_samples if N_samples > 0 else self.nSamples 280 | stepsize = self.stepSize 281 | near, far = self.near_far 282 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d) 283 | rate_a = (self.aabb[1] - rays_o) / vec 284 | rate_b = (self.aabb[0] - rays_o) / vec 285 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) 286 | 287 | rng = torch.arange(N_samples)[None].float() 288 | if is_train: 289 | rng = rng.repeat(rays_d.shape[-2], 1) 290 | rng += torch.rand_like(rng[:, [0]]) 291 | step = stepsize * rng.to(rays_o.device) 292 | interpx = (t_min[..., None] + step) 293 | 294 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] 295 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) 296 | 297 | return rays_pts, interpx, ~mask_outbbox 298 | 299 | def shrink(self, new_aabb, voxel_size): 300 | pass 301 | 302 | @torch.no_grad() 303 | def getDenseAlpha(self, gridSize=None): 304 | gridSize = self.gridSize if gridSize is None else gridSize 305 | 306 | samples = torch.stack(torch.meshgrid( 307 | torch.linspace(0, 1, gridSize[0]), 308 | torch.linspace(0, 1, gridSize[1]), 309 | torch.linspace(0, 1, gridSize[2]), 310 | ), -1).to(self.device) 311 | dense_xyz = self.aabb[0] * (1 - samples) + self.aabb[1] * samples 312 | 313 | # dense_xyz = dense_xyz 314 | # print(self.stepSize, self.distance_scale*self.aabbDiag) 315 | alpha = torch.zeros_like(dense_xyz[..., 0]) 316 | for i in range(gridSize[0]): 317 | alpha[i] = self.compute_alpha(dense_xyz[i].view(-1, 3), self.stepSize).view((gridSize[1], gridSize[2])) 318 | return alpha, dense_xyz 319 | 320 | @torch.no_grad() 321 | def updateAlphaMask(self, gridSize=(200, 200, 200)): 322 | 323 | alpha, dense_xyz = self.getDenseAlpha(gridSize) 324 | dense_xyz = dense_xyz.transpose(0, 2).contiguous() 325 | alpha = alpha.clamp(0, 1).transpose(0, 2).contiguous()[None, None] 326 | total_voxels = gridSize[0] * gridSize[1] * gridSize[2] 327 | 328 | ks = 3 329 | alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1]) 330 | alpha[alpha >= self.alphaMask_thres] = 1 331 | alpha[alpha < self.alphaMask_thres] = 0 332 | 333 | self.alphaMask = AlphaGridMask(self.device, self.aabb, alpha) 334 | 335 | valid_xyz = dense_xyz[alpha > 0.5] 336 | 337 | xyz_min = valid_xyz.amin(0) 338 | xyz_max = valid_xyz.amax(0) 339 | 340 | new_aabb = torch.stack((xyz_min, xyz_max)) 341 | 342 | total = torch.sum(alpha) 343 | print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f" % (total / total_voxels * 100)) 344 | return new_aabb 345 | 346 | @torch.no_grad() 347 | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240 * 5, bbox_only=False): 348 | print('========> filtering rays ...') 349 | tt = time.time() 350 | 351 | N = torch.tensor(all_rays.shape[:-1]).prod() 352 | 353 | mask_filtered = [] 354 | idx_chunks = torch.split(torch.arange(N), chunk) 355 | for idx_chunk in idx_chunks: 356 | rays_chunk = all_rays[idx_chunk].to(self.device) 357 | 358 | rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6] 359 | if bbox_only: 360 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d) 361 | rate_a = (self.aabb[1] - rays_o) / vec 362 | rate_b = (self.aabb[0] - rays_o) / vec 363 | t_min = torch.minimum(rate_a, rate_b).amax(-1) # .clamp(min=near, max=far) 364 | t_max = torch.maximum(rate_a, rate_b).amin(-1) # .clamp(min=near, max=far) 365 | mask_inbbox = t_max > t_min 366 | 367 | else: 368 | xyz_sampled, _, _ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False) 369 | mask_inbbox = (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1) 370 | 371 | mask_filtered.append(mask_inbbox.cpu()) 372 | 373 | mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1]) 374 | 375 | print(f'Ray filtering done! takes {time.time() - tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}') 376 | return all_rays[mask_filtered], all_rgbs[mask_filtered] 377 | 378 | def feature2density(self, density_features): 379 | if self.fea2denseAct == "softplus": 380 | return F.softplus(density_features + self.density_shift) 381 | elif self.fea2denseAct == "relu": 382 | return F.relu(density_features) 383 | 384 | def compute_alpha(self, xyz_locs, length=1): 385 | 386 | if self.alphaMask is not None: 387 | alphas = self.alphaMask.sample_alpha(xyz_locs) 388 | alpha_mask = alphas > 0 389 | else: 390 | alpha_mask = torch.ones_like(xyz_locs[:, 0], dtype=bool) 391 | 392 | sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) 393 | 394 | if alpha_mask.any(): 395 | xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) 396 | sigma_feature = self.compute_densityfeature(xyz_sampled) 397 | validsigma = self.feature2density(sigma_feature) 398 | sigma[alpha_mask] = validsigma 399 | 400 | alpha = 1 - torch.exp(-sigma * length).view(xyz_locs.shape[:-1]) 401 | 402 | return alpha 403 | 404 | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): 405 | 406 | # sample points 407 | viewdirs = rays_chunk[:, 3:6] 408 | if ndc_ray: 409 | xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train, 410 | N_samples=N_samples) 411 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 412 | rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) 413 | dists = dists * rays_norm 414 | viewdirs = viewdirs / rays_norm 415 | else: 416 | xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train, 417 | N_samples=N_samples) 418 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) 419 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape) 420 | 421 | if self.alphaMask is not None: 422 | alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) 423 | alpha_mask = alphas > 0 424 | ray_invalid = ~ray_valid 425 | ray_invalid[ray_valid] |= (~alpha_mask) 426 | ray_valid = ~ray_invalid 427 | 428 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) 429 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) 430 | 431 | if ray_valid.any(): 432 | xyz_sampled = self.normalize_coord(xyz_sampled) 433 | sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) 434 | 435 | validsigma = self.feature2density(sigma_feature) 436 | sigma[ray_valid] = validsigma 437 | 438 | alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) 439 | 440 | app_mask = weight > self.rayMarch_weight_thres 441 | 442 | if app_mask.any(): 443 | app_features = self.compute_appfeature(xyz_sampled[app_mask]) 444 | valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features) 445 | rgb[app_mask] = valid_rgbs 446 | 447 | acc_map = torch.sum(weight, -1) 448 | rgb_map = torch.sum(weight[..., None] * rgb, -2) 449 | 450 | if white_bg or (is_train and torch.rand((1,)) < 0.5): 451 | rgb_map = rgb_map + (1. - acc_map[..., None]) 452 | 453 | rgb_map = rgb_map.clamp(0, 1) 454 | 455 | with torch.no_grad(): 456 | depth_map = torch.sum(weight * z_vals, -1) 457 | depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1] 458 | 459 | return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight 460 | 461 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import logging 6 | import mcubes 7 | from models.tensorf2neus import * 8 | from models.nerf2neus import * 9 | 10 | 11 | def extract_fields(bound_min, bound_max, resolution, query_func): 12 | N = 64 13 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 14 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 15 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 16 | 17 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 18 | with torch.no_grad(): 19 | for xi, xs in enumerate(X): 20 | for yi, ys in enumerate(Y): 21 | for zi, zs in enumerate(Z): 22 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 23 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) 24 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 25 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val 26 | return u 27 | 28 | 29 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): 30 | print('threshold: {}'.format(threshold)) 31 | u = extract_fields(bound_min, bound_max, resolution, query_func) 32 | vertices, triangles = mcubes.marching_cubes(u, threshold) 33 | b_max_np = bound_max.detach().cpu().numpy() 34 | b_min_np = bound_min.detach().cpu().numpy() 35 | 36 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] 37 | return vertices, triangles 38 | 39 | 40 | def sample_pdf(bins, weights, n_samples, det=False): 41 | # This implementation is from NeRF 42 | # Get pdf 43 | weights = weights + 1e-5 # prevent nans 44 | pdf = weights / torch.sum(weights, -1, keepdim=True) 45 | cdf = torch.cumsum(pdf, -1) 46 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 47 | # Take uniform samples 48 | if det: 49 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples) 50 | u = u.expand(list(cdf.shape[:-1]) + [n_samples]) 51 | else: 52 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples]) 53 | 54 | # Invert CDF 55 | u = u.contiguous() 56 | inds = torch.searchsorted(cdf, u, right=True) 57 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 58 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 59 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 60 | 61 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 62 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 63 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 64 | 65 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 66 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 67 | t = (u - cdf_g[..., 0]) / denom 68 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 69 | 70 | return samples 71 | 72 | 73 | class NeuSRenderer: 74 | def __init__(self, 75 | nerf, 76 | sdf_network, 77 | deviation_network, 78 | color_network, 79 | n_samples, 80 | n_importance, 81 | n_outside, 82 | up_sample_steps, 83 | perturb): 84 | self.nerf = nerf 85 | self.sdf_network = sdf_network 86 | self.deviation_network = deviation_network 87 | self.color_network = color_network 88 | self.n_samples = n_samples 89 | self.n_importance = n_importance 90 | self.n_outside = n_outside 91 | self.up_sample_steps = up_sample_steps 92 | self.perturb = perturb 93 | 94 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None): 95 | """ 96 | Render background 97 | """ 98 | batch_size, n_samples = z_vals.shape 99 | 100 | # Section length 101 | dists = z_vals[..., 1:] - z_vals[..., :-1] 102 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 103 | mid_z_vals = z_vals + dists * 0.5 104 | 105 | # Section midpoints 106 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3 107 | 108 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10) 109 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4 110 | 111 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3) 112 | 113 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0)) 114 | dirs = dirs.reshape(-1, 3) 115 | 116 | density, sampled_color = nerf(pts, dirs) 117 | sampled_color = torch.sigmoid(sampled_color) 118 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists) 119 | alpha = alpha.reshape(batch_size, n_samples) 120 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 121 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3) 122 | color = (weights[:, :, None] * sampled_color).sum(dim=1) 123 | if background_rgb is not None: 124 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True)) 125 | 126 | return { 127 | 'color': color, 128 | 'sampled_color': sampled_color, 129 | 'alpha': alpha, 130 | 'weights': weights, 131 | } 132 | 133 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s): 134 | """ 135 | Up sampling give a fixed inv_s 136 | """ 137 | batch_size, n_samples = z_vals.shape 138 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 139 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False) 140 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0) 141 | sdf = sdf.reshape(batch_size, n_samples) 142 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] 143 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] 144 | mid_sdf = (prev_sdf + next_sdf) * 0.5 145 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) 146 | 147 | # ---------------------------------------------------------------------------------------------------------- 148 | # Use min value of [ cos, prev_cos ] 149 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more 150 | # robust when meeting situations like below: 151 | # 152 | # SDF 153 | # ^ 154 | # |\ -----x----... 155 | # | \ / 156 | # | x x 157 | # |---\----/-------------> 0 level 158 | # | \ / 159 | # | \/ 160 | # | 161 | # ---------------------------------------------------------------------------------------------------------- 162 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1) 163 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1) 164 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False) 165 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere 166 | 167 | dist = (next_z_vals - prev_z_vals) 168 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5 169 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5 170 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s) 171 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s) 172 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) 173 | weights = alpha * torch.cumprod( 174 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 175 | 176 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() 177 | return z_samples 178 | 179 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False): 180 | batch_size, n_samples = z_vals.shape 181 | _, n_importance = new_z_vals.shape 182 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] 183 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1) 184 | z_vals, index = torch.sort(z_vals, dim=-1) 185 | 186 | if not last: 187 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance) 188 | sdf = torch.cat([sdf, new_sdf], dim=-1) 189 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) 190 | index = index.reshape(-1) 191 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) 192 | 193 | return z_vals, sdf 194 | 195 | def render_core(self, 196 | rays_o, 197 | rays_d, 198 | z_vals, 199 | sample_dist, 200 | sdf_network, 201 | deviation_network, 202 | color_network, 203 | background_alpha=None, 204 | background_sampled_color=None, 205 | background_rgb=None, 206 | cos_anneal_ratio=0.0): # Render part in bounding sphere (core), using color net 207 | batch_size, n_samples = z_vals.shape 208 | 209 | # Section length 210 | dists = z_vals[..., 1:] - z_vals[..., :-1] 211 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1) 212 | mid_z_vals = z_vals + dists * 0.5 213 | 214 | # Section midpoints 215 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 216 | dirs = rays_d[:, None, :].expand(pts.shape) 217 | 218 | pts = pts.reshape(-1, 3) 219 | dirs = dirs.reshape(-1, 3) 220 | 221 | sdf_nn_output = sdf_network(pts) 222 | sdf = sdf_nn_output[:, :1] 223 | feature_vector = sdf_nn_output[:, 1:] 224 | 225 | gradients = sdf_network.gradient(pts).squeeze() # gradients assumed to be normalized because of eikonal_loss 226 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3) 227 | 228 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter 229 | inv_s = inv_s.expand(batch_size * n_samples, 1) 230 | 231 | true_cos = (dirs * gradients).sum(-1, keepdim=True) 232 | 233 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes 234 | # the cos value "not dead" at the beginning training iterations, for better convergence. 235 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) + 236 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive 237 | 238 | # Estimate signed distances at section points 239 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5 240 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5 241 | 242 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) 243 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) 244 | 245 | p = prev_cdf - next_cdf 246 | c = prev_cdf 247 | 248 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0) 249 | 250 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples) 251 | inside_sphere = (pts_norm < 1.0).float().detach() 252 | # relax_inside_sphere = (pts_norm < 1.2).float().detach() 253 | relax_inside_sphere = (pts_norm < 3.0).float().detach() 254 | 255 | # Render with background 256 | if background_alpha is not None: 257 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere) 258 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1) 259 | sampled_color = sampled_color * inside_sphere[:, :, None] +\ 260 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None] 261 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1) 262 | 263 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1] 264 | weights_sum = weights.sum(dim=-1, keepdim=True) 265 | 266 | color = (sampled_color * weights[:, :, None]).sum(dim=1) 267 | if background_rgb is not None: # Fixed background, usually black 268 | color = color + background_rgb * (1.0 - weights_sum) 269 | 270 | # Eikonal loss 271 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2, 272 | dim=-1) - 1.0) ** 2 273 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5) 274 | 275 | if not self.validate: 276 | # query from NeRF Prior 277 | # TensoRF 278 | thres1 = 0.9 # for replica 279 | thres1 = 0.5 # for scannet 280 | thres2 = 1e-8 281 | with torch.no_grad(): 282 | queried_occ, _ = query_alpha_color(self.tensorf, pts, None, prior_mode='cat') 283 | queried_occ = queried_occ.reshape(alpha.shape) 284 | 285 | # NeRF 286 | # thres1 = 0.5 287 | # thres2 = 1e-8 288 | # with torch.no_grad(): 289 | # queried_occ, _ = query_alpha_color_nerf(self.nerf, _xyz_sampled=pts.reshape(batch_size, n_samples, 3), viewdirs=rays_d, z_vals=z_vals) 290 | # queried_occ = queried_occ.reshape(alpha.shape) 291 | # queried_occ = torch.where(queried_occ > thres1, torch.Tensor([1.]).cuda(), queried_occ) 292 | # queried_occ = torch.where(queried_occ < thres2, torch.Tensor([0.]).cuda(), queried_occ) 293 | 294 | # visual_red_green_points(pts[queried_occ.reshape(-1)>thres1], pts[queried_occ.reshape(-1) thres1, torch.Tensor([1]).cuda(), supervise_mask) # scannet 1, replica 2 298 | # supervise_mask = torch.where(queried_occ < 1e-8, torch.Tensor([1]).cuda(), supervise_mask) 299 | supervise_mask = torch.where(queried_occ < thres2, torch.Tensor([0.5]).cuda(), supervise_mask) 300 | alpha_error = torch.abs(alpha - queried_occ) * supervise_mask 301 | alpha_error = alpha_error.sum() / ((queried_occ > thres1).sum() + (queried_occ < thres2).sum()) 302 | else: 303 | alpha_error = 0 304 | 305 | 306 | if weights.shape[0] % 9 == 0: 307 | pred_depth = (weights * mid_z_vals).sum(-1) # [N] 308 | pred_normal = (weights.unsqueeze(-1) * gradients.reshape(batch_size, n_samples, 3)).sum(1) # [N, 3] 309 | cos = (dirs.reshape(batch_size, n_samples, 3)[:,0,:] * pred_normal).sum(1) 310 | proj_depth = pred_depth * cos # [N] 311 | proj_depth = proj_depth.reshape(-1, 9) 312 | mean_depth = proj_depth.mean(1, keepdim=True) 313 | patch_depth_std = ((proj_depth - mean_depth) ** 2).mean(1).sqrt() 314 | else: 315 | patch_depth_std = torch.zeros(weights.shape[0]) 316 | 317 | 318 | return { 319 | 'color': color, 320 | 'sdf': sdf, 321 | 'dists': dists, 322 | 'gradients': gradients.reshape(batch_size, n_samples, 3), 323 | 's_val': 1.0 / inv_s, 324 | 'mid_z_vals': mid_z_vals, 325 | 'weights': weights, 326 | 'cdf': c.reshape(batch_size, n_samples), 327 | 'gradient_error': gradient_error, 328 | 'inside_sphere': inside_sphere, 329 | 'alpha_error': alpha_error, 330 | 'depth_std': patch_depth_std, 331 | 'depth': (weights * mid_z_vals).sum(-1), 332 | } 333 | 334 | def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0): 335 | batch_size = len(rays_o) 336 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere 337 | z_vals = torch.linspace(0.0, 1.0, self.n_samples) 338 | z_vals = near + (far - near) * z_vals[None, :] 339 | 340 | z_vals_outside = None 341 | if self.n_outside > 0: 342 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside) 343 | 344 | n_samples = self.n_samples 345 | perturb = self.perturb 346 | 347 | if perturb_overwrite >= 0: 348 | perturb = perturb_overwrite 349 | if perturb > 0: 350 | t_rand = (torch.rand([batch_size, 1]) - 0.5) 351 | z_vals = z_vals + t_rand * 2.0 / self.n_samples 352 | 353 | if self.n_outside > 0: 354 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1]) 355 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1) 356 | lower = torch.cat([z_vals_outside[..., :1], mids], -1) 357 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]]) 358 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand 359 | 360 | if self.n_outside > 0: 361 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples 362 | 363 | background_alpha = None 364 | background_sampled_color = None 365 | 366 | # Up sample 367 | if self.n_importance > 0: 368 | with torch.no_grad(): 369 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] 370 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples) 371 | 372 | for i in range(self.up_sample_steps): 373 | new_z_vals = self.up_sample(rays_o, 374 | rays_d, 375 | z_vals, 376 | sdf, 377 | self.n_importance // self.up_sample_steps, 378 | 64 * 2**i) 379 | z_vals, sdf = self.cat_z_vals(rays_o, 380 | rays_d, 381 | z_vals, 382 | new_z_vals, 383 | sdf, 384 | last=(i + 1 == self.up_sample_steps)) 385 | 386 | n_samples = self.n_samples + self.n_importance 387 | 388 | # Background model 389 | if self.n_outside > 0: 390 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1) 391 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1) 392 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf) 393 | 394 | background_sampled_color = ret_outside['sampled_color'] 395 | background_alpha = ret_outside['alpha'] 396 | 397 | # Render core 398 | ret_fine = self.render_core(rays_o, 399 | rays_d, 400 | z_vals, 401 | sample_dist, 402 | self.sdf_network, 403 | self.deviation_network, 404 | self.color_network, 405 | background_rgb=background_rgb, 406 | background_alpha=background_alpha, 407 | background_sampled_color=background_sampled_color, 408 | cos_anneal_ratio=cos_anneal_ratio) 409 | 410 | color_fine = ret_fine['color'] 411 | weights = ret_fine['weights'] 412 | weights_sum = weights.sum(dim=-1, keepdim=True) 413 | gradients = ret_fine['gradients'] 414 | s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True) 415 | 416 | return { 417 | 'color_fine': color_fine, 418 | 's_val': s_val, 419 | 'cdf_fine': ret_fine['cdf'], 420 | 'weight_sum': weights_sum, 421 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0], 422 | 'gradients': gradients, 423 | 'weights': weights, 424 | 'gradient_error': ret_fine['gradient_error'], 425 | 'inside_sphere': ret_fine['inside_sphere'], 426 | 'alpha_error': ret_fine['alpha_error'], 427 | 'depth_std': ret_fine['depth_std'], 428 | 'depth': ret_fine['depth'] 429 | } 430 | 431 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0): 432 | return extract_geometry(bound_min, 433 | bound_max, 434 | resolution=resolution, 435 | threshold=threshold, 436 | query_func=lambda pts: -self.sdf_network.sdf(pts)) 437 | -------------------------------------------------------------------------------- /exp_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os 3 | import time 4 | import logging 5 | import argparse 6 | import numpy as np 7 | import cv2 as cv 8 | import trimesh 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from shutil import copyfile 13 | # from icecream import ic 14 | from tqdm import tqdm 15 | from pyhocon import ConfigFactory 16 | from models.dataset import Dataset 17 | from models.blender_swap import BlendSwapDataset 18 | from models.scannet_blender import ScanNetDataset 19 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF 20 | from models.renderer import NeuSRenderer 21 | from models.tensorf2neus import * 22 | from models.nerf2neus import * 23 | 24 | 25 | class Runner: 26 | def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False): 27 | self.device = torch.device('cuda') 28 | 29 | # Configuration 30 | self.conf_path = conf_path 31 | f = open(self.conf_path) 32 | conf_text = f.read() 33 | conf_text = conf_text.replace('CASE_NAME', case) 34 | f.close() 35 | 36 | self.conf = ConfigFactory.parse_string(conf_text) 37 | self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case) 38 | self.base_exp_dir = self.conf['general.base_exp_dir'] 39 | os.makedirs(self.base_exp_dir, exist_ok=True) 40 | 41 | # if self.conf['dataset.type'] == "Neus": 42 | # self.dataset = Dataset(self.conf['dataset']) 43 | # elif self.conf['dataset.type'] == "Blender": 44 | # self.dataset = BlendSwapDataset(self.conf['dataset']) 45 | self.dataset = ScanNetDataset(self.conf['dataset']) 46 | 47 | self.iter_step = 0 48 | 49 | # Training parameters 50 | self.end_iter = self.conf.get_int('train.end_iter') 51 | self.save_freq = self.conf.get_int('train.save_freq') 52 | self.report_freq = self.conf.get_int('train.report_freq') 53 | self.val_freq = self.conf.get_int('train.val_freq') 54 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') 55 | self.batch_size = self.conf.get_int('train.batch_size') 56 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') 57 | self.learning_rate = self.conf.get_float('train.learning_rate') 58 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha') 59 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') 60 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) 61 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) 62 | 63 | # Weights 64 | self.igr_weight = self.conf.get_float('train.igr_weight') 65 | self.mask_weight = self.conf.get_float('train.mask_weight') 66 | self.is_continue = is_continue 67 | self.mode = mode 68 | self.model_list = [] 69 | self.writer = None 70 | 71 | # Networks 72 | params_to_train = [] 73 | self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device) 74 | self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device) # sdf net 75 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) # train for parameter s in paper 76 | self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device) # color net, same as nerf 77 | params_to_train += list(self.nerf_outside.parameters()) 78 | params_to_train += list(self.sdf_network.parameters()) 79 | params_to_train += list(self.deviation_network.parameters()) 80 | params_to_train += list(self.color_network.parameters()) 81 | 82 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate) 83 | 84 | self.renderer = NeuSRenderer(self.nerf_outside, 85 | self.sdf_network, 86 | self.deviation_network, 87 | self.color_network, 88 | **self.conf['model.neus_renderer']) 89 | 90 | # Load checkpoint 91 | latest_model_name = None 92 | if is_continue: 93 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) 94 | model_list = [] 95 | for model_name in model_list_raw: 96 | if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter: 97 | model_list.append(model_name) 98 | model_list.sort() 99 | latest_model_name = model_list[-1] 100 | # latest_model_name = 'ckpt_050000.pth' 101 | 102 | if latest_model_name is not None: 103 | logging.info('Find checkpoint: {}'.format(latest_model_name)) 104 | self.load_checkpoint(latest_model_name) 105 | 106 | # Backup codes and configs for debug 107 | if self.mode[:5] == 'train': 108 | self.file_backup() 109 | 110 | # TODO: 111 | # tensorf 112 | tensorf = load_tensorf('data/tensorf/scan0050_00.th', torch.device('cuda:0')) 113 | reso_cur = list(tensorf.gridSize.cpu().detach().numpy()) 114 | tensorf.nSamples = cal_n_samples(reso_cur, 0.5) 115 | self.renderer.tensorf = tensorf 116 | 117 | # nerf 118 | # nerf = load_nerf('data/nerf/130000.tar') 119 | # self.renderer.nerf = nerf 120 | # # validate_mesh(nerf) 121 | 122 | def train(self): 123 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) 124 | self.update_learning_rate() 125 | res_step = self.end_iter - self.iter_step 126 | image_perm = self.get_image_perm() 127 | # self.validate_image(resolution_level=4) 128 | self.validate_mesh() 129 | self.renderer.validate = False 130 | 131 | for iter_i in tqdm(range(res_step)): 132 | # if self.iter_step > 150000: 133 | # random_mode = 'patch' 134 | # else: 135 | # random_mode = 'batch' 136 | # data = self.dataset.gen_random_rays_at(image_perm[self.iter_step % len(image_perm)], self.batch_size, mode=random_mode) 137 | data = self.dataset.gen_random_rays_at(image_perm[self.iter_step % len(image_perm)], self.batch_size) 138 | 139 | rays_o, rays_d, true_rgb, mask = data[:, :3], data[:, 3: 6], data[:, 6: 9], data[:, 9: 10] 140 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d) 141 | 142 | background_rgb = None 143 | if self.use_white_bkgd: 144 | background_rgb = torch.ones([1, 3]) 145 | 146 | if self.mask_weight > 0.0: 147 | mask = (mask > 0.5).float() 148 | else: 149 | mask = torch.ones_like(mask) 150 | 151 | mask_sum = mask.sum() + 1e-5 152 | render_out = self.renderer.render(rays_o, rays_d, near, far, 153 | background_rgb=background_rgb, 154 | cos_anneal_ratio=self.get_cos_anneal_ratio()) 155 | 156 | color_fine = render_out['color_fine'] 157 | s_val = render_out['s_val'] 158 | cdf_fine = render_out['cdf_fine'] 159 | gradient_error = render_out['gradient_error'] 160 | weight_max = render_out['weight_max'] 161 | weight_sum = render_out['weight_sum'] 162 | alpha_error = render_out['alpha_error'] 163 | 164 | # Loss 165 | color_error = (color_fine - true_rgb) * mask 166 | color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum 167 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt()) 168 | 169 | eikonal_loss = gradient_error 170 | 171 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask) 172 | 173 | # TODO 174 | # if random_mode == 'patch': 175 | # patch_rgb_std = data[:, -1].reshape(-1, 9)[:, 4] 176 | # rgb_mask = patch_rgb_std == 0 177 | # depth_loss = torch.abs(render_out['depth_std'])[rgb_mask] 178 | # depth_loss = depth_loss.mean() 179 | # elif random_mode == 'batch': 180 | # depth_loss = 0 181 | depth_loss = 0 182 | 183 | alpha_error_weight = 0.5 * (1-min(1, self.iter_step/100000)) 184 | loss = color_fine_loss +\ 185 | eikonal_loss * self.igr_weight +\ 186 | mask_loss * self.mask_weight + \ 187 | alpha_error * alpha_error_weight +\ 188 | depth_loss * 0.5 # TODO 189 | 190 | self.optimizer.zero_grad() 191 | loss.backward() 192 | self.optimizer.step() 193 | 194 | self.iter_step += 1 195 | 196 | self.writer.add_scalar('Loss/loss', loss, self.iter_step) 197 | self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step) 198 | self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step) 199 | self.writer.add_scalar('Loss/alpha_error', render_out['alpha_error'], self.iter_step) 200 | self.writer.add_scalar('Loss/depth_std', depth_loss, self.iter_step) 201 | self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step) 202 | self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step) 203 | self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step) 204 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step) 205 | 206 | if self.iter_step % self.report_freq == 0: 207 | print(self.base_exp_dir) 208 | print('iter:{:8>d} loss={:.3f} alpha={:.3f} alpha_weight={:.3f}' 209 | .format(self.iter_step, loss, alpha_error, alpha_error_weight)) 210 | 211 | if self.iter_step % self.save_freq == 0: 212 | self.save_checkpoint() 213 | 214 | if self.iter_step % self.val_freq == 0: 215 | self.validate_image(resolution_level=4) 216 | 217 | if self.iter_step % self.val_mesh_freq == 0: 218 | self.validate_mesh() 219 | 220 | self.update_learning_rate() 221 | 222 | if self.iter_step % len(image_perm) == 0: 223 | image_perm = self.get_image_perm() 224 | 225 | def get_image_perm(self): 226 | return torch.randperm(self.dataset.n_images) 227 | 228 | def get_cos_anneal_ratio(self): 229 | if self.anneal_end == 0.0: 230 | return 1.0 231 | else: 232 | return np.min([1.0, self.iter_step / self.anneal_end]) 233 | 234 | def update_learning_rate(self): 235 | if self.iter_step < self.warm_up_end: 236 | learning_factor = self.iter_step / self.warm_up_end 237 | else: 238 | alpha = self.learning_rate_alpha 239 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end) 240 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha 241 | 242 | for g in self.optimizer.param_groups: 243 | g['lr'] = self.learning_rate * learning_factor 244 | 245 | def file_backup(self): 246 | dir_lis = self.conf['general.recording'] 247 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) 248 | for dir_name in dir_lis: 249 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) 250 | os.makedirs(cur_dir, exist_ok=True) 251 | files = os.listdir(dir_name) 252 | for f_name in files: 253 | if f_name[-3:] == '.py': 254 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 255 | 256 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 257 | 258 | def load_checkpoint(self, checkpoint_name): 259 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) 260 | self.nerf_outside.load_state_dict(checkpoint['nerf']) 261 | self.sdf_network.load_state_dict(checkpoint['sdf_network_fine']) 262 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine']) 263 | self.color_network.load_state_dict(checkpoint['color_network_fine']) 264 | self.optimizer.load_state_dict(checkpoint['optimizer']) 265 | self.iter_step = checkpoint['iter_step'] 266 | 267 | logging.info('End') 268 | 269 | def save_checkpoint(self): 270 | checkpoint = { 271 | 'nerf': self.nerf_outside.state_dict(), 272 | 'sdf_network_fine': self.sdf_network.state_dict(), 273 | 'variance_network_fine': self.deviation_network.state_dict(), 274 | 'color_network_fine': self.color_network.state_dict(), 275 | 'optimizer': self.optimizer.state_dict(), 276 | 'iter_step': self.iter_step, 277 | } 278 | 279 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) 280 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 281 | 282 | def validate_image(self, idx=-1, resolution_level=-1): 283 | self.renderer.validate = True 284 | if idx < 0: 285 | idx = np.random.randint(self.dataset.n_images) 286 | 287 | print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx)) 288 | 289 | if resolution_level < 0: 290 | resolution_level = self.validate_resolution_level 291 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level) 292 | H, W, _ = rays_o.shape 293 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 294 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 295 | 296 | out_rgb_fine = [] 297 | out_normal_fine = [] 298 | 299 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 300 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 301 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 302 | 303 | render_out = self.renderer.render(rays_o_batch, 304 | rays_d_batch, 305 | near, 306 | far, 307 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 308 | background_rgb=background_rgb) 309 | 310 | def feasible(key): 311 | return (key in render_out) and (render_out[key] is not None) 312 | 313 | if feasible('color_fine'): 314 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 315 | if feasible('gradients') and feasible('weights'): 316 | n_samples = self.renderer.n_samples + self.renderer.n_importance 317 | normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None] 318 | if feasible('inside_sphere'): 319 | normals = normals * render_out['inside_sphere'][..., None] 320 | normals = normals.sum(dim=1).detach().cpu().numpy() 321 | out_normal_fine.append(normals) 322 | del render_out 323 | 324 | img_fine = None 325 | if len(out_rgb_fine) > 0: 326 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255) 327 | 328 | normal_img = None 329 | if len(out_normal_fine) > 0: 330 | normal_img = np.concatenate(out_normal_fine, axis=0) 331 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy()) 332 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None]) 333 | .reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255) 334 | 335 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True) 336 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True) 337 | 338 | for i in range(img_fine.shape[-1]): 339 | if len(out_rgb_fine) > 0: 340 | cv.imwrite(os.path.join(self.base_exp_dir, 341 | 'validations_fine', 342 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 343 | np.concatenate([img_fine[..., i], 344 | self.dataset.image_at(idx, resolution_level=resolution_level)])) 345 | if len(out_normal_fine) > 0: 346 | cv.imwrite(os.path.join(self.base_exp_dir, 347 | 'normals', 348 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)), 349 | normal_img[..., i]) 350 | 351 | self.renderer.validate = False 352 | 353 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level): 354 | """ 355 | Interpolate view between two cameras. 356 | """ 357 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level) 358 | H, W, _ = rays_o.shape 359 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size) 360 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size) 361 | 362 | out_rgb_fine = [] 363 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): 364 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch) 365 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None 366 | 367 | render_out = self.renderer.render(rays_o_batch, 368 | rays_d_batch, 369 | near, 370 | far, 371 | cos_anneal_ratio=self.get_cos_anneal_ratio(), 372 | background_rgb=background_rgb) 373 | 374 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) 375 | 376 | del render_out 377 | 378 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8) 379 | return img_fine 380 | 381 | def validate_mesh(self, world_space=False, resolution=64, threshold=0.0): 382 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) 383 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) 384 | 385 | vertices, triangles =\ 386 | self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold) 387 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True) 388 | 389 | if world_space: 390 | vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None] 391 | 392 | mesh = trimesh.Trimesh(vertices, triangles) 393 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step))) 394 | 395 | logging.info('End') 396 | 397 | def interpolate_view(self, img_idx_0, img_idx_1): 398 | images = [] 399 | n_frames = 60 400 | for i in range(n_frames): 401 | print(i) 402 | images.append(self.render_novel_image(img_idx_0, 403 | img_idx_1, 404 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5, 405 | resolution_level=4)) 406 | for i in range(n_frames): 407 | images.append(images[n_frames - i - 1]) 408 | 409 | fourcc = cv.VideoWriter_fourcc(*'mp4v') 410 | video_dir = os.path.join(self.base_exp_dir, 'render') 411 | os.makedirs(video_dir, exist_ok=True) 412 | h, w, _ = images[0].shape 413 | writer = cv.VideoWriter(os.path.join(video_dir, 414 | '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)), 415 | fourcc, 30, (w, h)) 416 | 417 | for image in images: 418 | writer.write(image) 419 | 420 | writer.release() 421 | 422 | 423 | if __name__ == '__main__': 424 | print('Hello Wooden') 425 | 426 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 427 | 428 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 429 | logging.basicConfig(level=logging.DEBUG, format=FORMAT) 430 | 431 | parser = argparse.ArgumentParser() 432 | parser.add_argument('--conf', type=str, default='./confs/base.conf') 433 | parser.add_argument('--mode', type=str, default='train') 434 | parser.add_argument('--mcube_threshold', type=float, default=0.0) 435 | parser.add_argument('--is_continue', default=False, action="store_true") 436 | parser.add_argument('--gpu', type=int, default=0) 437 | parser.add_argument('--case', type=str, default='') 438 | 439 | args = parser.parse_args() 440 | 441 | torch.cuda.set_device(args.gpu) 442 | runner = Runner(args.conf, args.mode, args.case, args.is_continue) 443 | 444 | if args.mode == 'train': 445 | runner.train() 446 | elif args.mode == 'validate_mesh': 447 | runner.validate_mesh(world_space=True, resolution=512, threshold=args.mcube_threshold) 448 | elif args.mode.startswith('interpolate'): # Interpolate views given two image indices 449 | _, img_idx_0, img_idx_1 = args.mode.split('_') 450 | img_idx_0 = int(img_idx_0) 451 | img_idx_1 = int(img_idx_1) 452 | runner.interpolate_view(img_idx_0, img_idx_1) 453 | -------------------------------------------------------------------------------- /models/tensoRF.py: -------------------------------------------------------------------------------- 1 | from models.tensorBase import * 2 | 3 | 4 | class TensorVM(TensorBase): 5 | def __init__(self, aabb, gridSize, device, **kargs): 6 | super(TensorVM, self).__init__(aabb, gridSize, device, **kargs) 7 | 8 | def init_svd_volume(self, res, device): 9 | self.plane_coef = torch.nn.Parameter( 10 | 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, res), device=device)) 11 | self.line_coef = torch.nn.Parameter( 12 | 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, 1), device=device)) 13 | self.basis_mat = torch.nn.Linear(self.app_n_comp * 3, self.app_dim, bias=False, device=device) 14 | 15 | def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001): 16 | grad_vars = [{'params': self.line_coef, 'lr': lr_init_spatialxyz}, 17 | {'params': self.plane_coef, 'lr': lr_init_spatialxyz}, 18 | {'params': self.basis_mat.parameters(), 'lr': lr_init_network}] 19 | if isinstance(self.renderModule, torch.nn.Module): 20 | grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}] 21 | return grad_vars 22 | 23 | def compute_features(self, xyz_sampled): 24 | 25 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], 26 | xyz_sampled[..., self.matMode[2]])).detach() 27 | coordinate_line = torch.stack( 28 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 29 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach() 30 | 31 | plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, 32 | align_corners=True).view( 33 | -1, *xyz_sampled.shape[:1]) 34 | line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view( 35 | -1, *xyz_sampled.shape[:1]) 36 | 37 | sigma_feature = torch.sum(plane_feats * line_feats, dim=0) 38 | 39 | plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view( 40 | 3 * self.app_n_comp, -1) 41 | line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view( 42 | 3 * self.app_n_comp, -1) 43 | 44 | app_features = self.basis_mat((plane_feats * line_feats).T) 45 | 46 | return sigma_feature, app_features 47 | 48 | def compute_densityfeature(self, xyz_sampled): 49 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], 50 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) 51 | coordinate_line = torch.stack( 52 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 53 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 54 | 1, 2) 55 | 56 | plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, 57 | align_corners=True).view( 58 | -1, *xyz_sampled.shape[:1]) 59 | line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view( 60 | -1, *xyz_sampled.shape[:1]) 61 | 62 | sigma_feature = torch.sum(plane_feats * line_feats, dim=0) 63 | 64 | return sigma_feature 65 | 66 | def compute_appfeature(self, xyz_sampled): 67 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], 68 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) 69 | coordinate_line = torch.stack( 70 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 71 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 72 | 1, 2) 73 | 74 | plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view( 75 | 3 * self.app_n_comp, -1) 76 | line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view( 77 | 3 * self.app_n_comp, -1) 78 | 79 | app_features = self.basis_mat((plane_feats * line_feats).T) 80 | 81 | return app_features 82 | 83 | def vectorDiffs(self, vector_comps): 84 | total = 0 85 | 86 | for idx in range(len(vector_comps)): 87 | # print(self.line_coef.shape, vector_comps[idx].shape) 88 | n_comp, n_size = vector_comps[idx].shape[:-1] 89 | 90 | dotp = torch.matmul(vector_comps[idx].view(n_comp, n_size), 91 | vector_comps[idx].view(n_comp, n_size).transpose(-1, -2)) 92 | # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape) 93 | non_diagonal = dotp.view(-1)[1:].view(n_comp - 1, n_comp + 1)[..., :-1] 94 | # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape) 95 | total = total + torch.mean(torch.abs(non_diagonal)) 96 | return total 97 | 98 | def vector_comp_diffs(self): 99 | 100 | return self.vectorDiffs(self.line_coef[:, -self.density_n_comp:]) + self.vectorDiffs( 101 | self.line_coef[:, :self.app_n_comp]) 102 | 103 | @torch.no_grad() 104 | def up_sampling_VM(self, plane_coef, line_coef, res_target): 105 | 106 | for i in range(len(self.vecMode)): 107 | vec_id = self.vecMode[i] 108 | mat_id_0, mat_id_1 = self.matMode[i] 109 | 110 | plane_coef[i] = torch.nn.Parameter( 111 | F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', 112 | align_corners=True)) 113 | line_coef[i] = torch.nn.Parameter( 114 | F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) 115 | 116 | # plane_coef[0] = torch.nn.Parameter( 117 | # F.interpolate(plane_coef[0].data, size=(res_target[1], res_target[0]), mode='bilinear', 118 | # align_corners=True)) 119 | # line_coef[0] = torch.nn.Parameter( 120 | # F.interpolate(line_coef[0].data, size=(res_target[2], 1), mode='bilinear', align_corners=True)) 121 | # plane_coef[1] = torch.nn.Parameter( 122 | # F.interpolate(plane_coef[1].data, size=(res_target[2], res_target[0]), mode='bilinear', 123 | # align_corners=True)) 124 | # line_coef[1] = torch.nn.Parameter( 125 | # F.interpolate(line_coef[1].data, size=(res_target[1], 1), mode='bilinear', align_corners=True)) 126 | # plane_coef[2] = torch.nn.Parameter( 127 | # F.interpolate(plane_coef[2].data, size=(res_target[2], res_target[1]), mode='bilinear', 128 | # align_corners=True)) 129 | # line_coef[2] = torch.nn.Parameter( 130 | # F.interpolate(line_coef[2].data, size=(res_target[0], 1), mode='bilinear', align_corners=True)) 131 | 132 | return plane_coef, line_coef 133 | 134 | @torch.no_grad() 135 | def upsample_volume_grid(self, res_target): 136 | # self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target) 137 | # self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target) 138 | 139 | scale = res_target[0] / self.line_coef.shape[2] # assuming xyz have the same scale 140 | plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear', 141 | align_corners=True) 142 | line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0], 1), mode='bilinear', 143 | align_corners=True) 144 | self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef) 145 | self.compute_stepSize(res_target) 146 | print(f'upsamping to {res_target}') 147 | 148 | 149 | class TensorVMSplit(TensorBase): 150 | def __init__(self, aabb, gridSize, device, **kargs): 151 | super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs) 152 | 153 | def init_svd_volume(self, res, device): 154 | self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device) 155 | self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device) 156 | self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device) 157 | 158 | def init_one_svd(self, n_component, gridSize, scale, device): 159 | plane_coef, line_coef = [], [] 160 | for i in range(len(self.vecMode)): 161 | vec_id = self.vecMode[i] 162 | mat_id_0, mat_id_1 = self.matMode[i] 163 | plane_coef.append(torch.nn.Parameter( 164 | scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) # 165 | line_coef.append( 166 | torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) 167 | 168 | return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device) 169 | 170 | def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001): 171 | grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, 172 | {'params': self.density_plane, 'lr': lr_init_spatialxyz}, 173 | {'params': self.app_line, 'lr': lr_init_spatialxyz}, 174 | {'params': self.app_plane, 'lr': lr_init_spatialxyz}, 175 | {'params': self.basis_mat.parameters(), 'lr': lr_init_network}] 176 | if isinstance(self.renderModule, torch.nn.Module): 177 | grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}] 178 | return grad_vars 179 | 180 | def vectorDiffs(self, vector_comps): 181 | total = 0 182 | 183 | for idx in range(len(vector_comps)): 184 | n_comp, n_size = vector_comps[idx].shape[1:-1] 185 | 186 | dotp = torch.matmul(vector_comps[idx].view(n_comp, n_size), 187 | vector_comps[idx].view(n_comp, n_size).transpose(-1, -2)) 188 | non_diagonal = dotp.view(-1)[1:].view(n_comp - 1, n_comp + 1)[..., :-1] 189 | total = total + torch.mean(torch.abs(non_diagonal)) 190 | return total 191 | 192 | def vector_comp_diffs(self): 193 | return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line) 194 | 195 | def density_L1(self): 196 | total = 0 197 | for idx in range(len(self.density_plane)): 198 | total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[ 199 | idx])) # + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx])) 200 | return total 201 | 202 | def TV_loss_density(self, reg): 203 | total = 0 204 | for idx in range(len(self.density_plane)): 205 | total = total + reg(self.density_plane[idx]) * 1e-2 # + reg(self.density_line[idx]) * 1e-3 206 | return total 207 | 208 | def TV_loss_app(self, reg): 209 | total = 0 210 | for idx in range(len(self.app_plane)): 211 | total = total + reg(self.app_plane[idx]) * 1e-2 # + reg(self.app_line[idx]) * 1e-3 212 | return total 213 | 214 | def compute_densityfeature(self, xyz_sampled, requires_grad=False): 215 | 216 | # plane + line basis 217 | if requires_grad: 218 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], 219 | xyz_sampled[..., self.matMode[2]])).view(3, -1, 1, 2) 220 | else: 221 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], 222 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) 223 | coordinate_line = torch.stack( 224 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 225 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 226 | 1, 2) 227 | 228 | sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device) 229 | for idx_plane in range(len(self.density_plane)): 230 | plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]], 231 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 232 | line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]], 233 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 234 | sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0) 235 | 236 | return sigma_feature 237 | 238 | def compute_appfeature(self, xyz_sampled): 239 | 240 | # plane + line basis 241 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], 242 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) 243 | coordinate_line = torch.stack( 244 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 245 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 246 | 1, 2) 247 | 248 | plane_coef_point, line_coef_point = [], [] 249 | for idx_plane in range(len(self.app_plane)): 250 | plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]], 251 | align_corners=True).view(-1, *xyz_sampled.shape[:1])) 252 | line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]], 253 | align_corners=True).view(-1, *xyz_sampled.shape[:1])) 254 | plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) 255 | 256 | return self.basis_mat((plane_coef_point * line_coef_point).T) 257 | 258 | @torch.no_grad() 259 | def up_sampling_VM(self, plane_coef, line_coef, res_target): 260 | 261 | for i in range(len(self.vecMode)): 262 | vec_id = self.vecMode[i] 263 | mat_id_0, mat_id_1 = self.matMode[i] 264 | plane_coef[i] = torch.nn.Parameter( 265 | F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', 266 | align_corners=True)) 267 | line_coef[i] = torch.nn.Parameter( 268 | F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) 269 | 270 | return plane_coef, line_coef 271 | 272 | @torch.no_grad() 273 | def upsample_volume_grid(self, res_target): 274 | self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target) 275 | self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target) 276 | 277 | self.update_stepSize(res_target) 278 | print(f'upsamping to {res_target}') 279 | 280 | @torch.no_grad() 281 | def shrink(self, new_aabb): 282 | print("====> shrinking ...") 283 | xyz_min, xyz_max = new_aabb 284 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units 285 | # print(new_aabb, self.aabb) 286 | # print(t_l, b_r,self.alphaMask.alpha_volume.shape) 287 | t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1 288 | b_r = torch.stack([b_r, self.gridSize]).amin(0) 289 | 290 | for i in range(len(self.vecMode)): 291 | mode0 = self.vecMode[i] 292 | self.density_line[i] = torch.nn.Parameter( 293 | self.density_line[i].data[..., t_l[mode0]:b_r[mode0], :] 294 | ) 295 | self.app_line[i] = torch.nn.Parameter( 296 | self.app_line[i].data[..., t_l[mode0]:b_r[mode0], :] 297 | ) 298 | mode0, mode1 = self.matMode[i] 299 | self.density_plane[i] = torch.nn.Parameter( 300 | self.density_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]] 301 | ) 302 | self.app_plane[i] = torch.nn.Parameter( 303 | self.app_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]] 304 | ) 305 | 306 | if not torch.all(self.alphaMask.gridSize == self.gridSize): 307 | t_l_r, b_r_r = t_l / (self.gridSize - 1), (b_r - 1) / (self.gridSize - 1) 308 | correct_aabb = torch.zeros_like(new_aabb) 309 | correct_aabb[0] = (1 - t_l_r) * self.aabb[0] + t_l_r * self.aabb[1] 310 | correct_aabb[1] = (1 - b_r_r) * self.aabb[0] + b_r_r * self.aabb[1] 311 | print("aabb", new_aabb, "\ncorrect aabb", correct_aabb) 312 | new_aabb = correct_aabb 313 | 314 | newSize = b_r - t_l 315 | self.aabb = new_aabb 316 | self.update_stepSize((newSize[0], newSize[1], newSize[2])) 317 | 318 | 319 | class TensorCP(TensorBase): 320 | def __init__(self, aabb, gridSize, device, **kargs): 321 | super(TensorCP, self).__init__(aabb, gridSize, device, **kargs) 322 | 323 | def init_svd_volume(self, res, device): 324 | self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device) 325 | self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device) 326 | self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device) 327 | 328 | def init_one_svd(self, n_component, gridSize, scale, device): 329 | line_coef = [] 330 | for i in range(len(self.vecMode)): 331 | vec_id = self.vecMode[i] 332 | line_coef.append( 333 | torch.nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1)))) 334 | return torch.nn.ParameterList(line_coef).to(device) 335 | 336 | def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001): 337 | grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, 338 | {'params': self.app_line, 'lr': lr_init_spatialxyz}, 339 | {'params': self.basis_mat.parameters(), 'lr': lr_init_network}] 340 | if isinstance(self.renderModule, torch.nn.Module): 341 | grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}] 342 | return grad_vars 343 | 344 | def compute_densityfeature(self, xyz_sampled): 345 | 346 | coordinate_line = torch.stack( 347 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 348 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 349 | 1, 2) 350 | 351 | line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]], 352 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 353 | line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]], 354 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 355 | line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]], 356 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 357 | sigma_feature = torch.sum(line_coef_point, dim=0) 358 | 359 | return sigma_feature 360 | 361 | def compute_appfeature(self, xyz_sampled): 362 | 363 | coordinate_line = torch.stack( 364 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) 365 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 366 | 1, 2) 367 | 368 | line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]], 369 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 370 | line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]], 371 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 372 | line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]], 373 | align_corners=True).view(-1, *xyz_sampled.shape[:1]) 374 | 375 | return self.basis_mat(line_coef_point.T) 376 | 377 | @torch.no_grad() 378 | def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target): 379 | 380 | for i in range(len(self.vecMode)): 381 | vec_id = self.vecMode[i] 382 | density_line_coef[i] = torch.nn.Parameter( 383 | F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', 384 | align_corners=True)) 385 | app_line_coef[i] = torch.nn.Parameter( 386 | F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) 387 | 388 | return density_line_coef, app_line_coef 389 | 390 | @torch.no_grad() 391 | def upsample_volume_grid(self, res_target): 392 | self.density_line, self.app_line = self.up_sampling_Vector(self.density_line, self.app_line, res_target) 393 | 394 | self.update_stepSize(res_target) 395 | print(f'upsamping to {res_target}') 396 | 397 | @torch.no_grad() 398 | def shrink(self, new_aabb): 399 | print("====> shrinking ...") 400 | xyz_min, xyz_max = new_aabb 401 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units 402 | 403 | t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1 404 | b_r = torch.stack([b_r, self.gridSize]).amin(0) 405 | 406 | for i in range(len(self.vecMode)): 407 | mode0 = self.vecMode[i] 408 | self.density_line[i] = torch.nn.Parameter( 409 | self.density_line[i].data[..., t_l[mode0]:b_r[mode0], :] 410 | ) 411 | self.app_line[i] = torch.nn.Parameter( 412 | self.app_line[i].data[..., t_l[mode0]:b_r[mode0], :] 413 | ) 414 | 415 | if not torch.all(self.alphaMask.gridSize == self.gridSize): 416 | t_l_r, b_r_r = t_l / (self.gridSize - 1), (b_r - 1) / (self.gridSize - 1) 417 | correct_aabb = torch.zeros_like(new_aabb) 418 | correct_aabb[0] = (1 - t_l_r) * self.aabb[0] + t_l_r * self.aabb[1] 419 | correct_aabb[1] = (1 - b_r_r) * self.aabb[0] + b_r_r * self.aabb[1] 420 | print("aabb", new_aabb, "\ncorrect aabb", correct_aabb) 421 | new_aabb = correct_aabb 422 | 423 | newSize = b_r - t_l 424 | self.aabb = new_aabb 425 | self.update_stepSize((newSize[0], newSize[1], newSize[2])) 426 | 427 | def density_L1(self): 428 | total = 0 429 | for idx in range(len(self.density_line)): 430 | total = total + torch.mean(torch.abs(self.density_line[idx])) 431 | return total 432 | 433 | def TV_loss_density(self, reg): 434 | total = 0 435 | for idx in range(len(self.density_line)): 436 | total = total + reg(self.density_line[idx]) * 1e-3 437 | return total 438 | 439 | def TV_loss_app(self, reg): 440 | total = 0 441 | for idx in range(len(self.app_line)): 442 | total = total + reg(self.app_line[idx]) * 1e-3 443 | return total --------------------------------------------------------------------------------