├── media
└── overview.png
├── requirements.txt
├── models
├── __pycache__
│ ├── dataset.cpython-37.pyc
│ ├── fields.cpython-310.pyc
│ ├── fields.cpython-37.pyc
│ ├── dataset.cpython-310.pyc
│ ├── embedder.cpython-310.pyc
│ ├── embedder.cpython-37.pyc
│ ├── ray_utils.cpython-310.pyc
│ ├── renderer.cpython-310.pyc
│ ├── renderer.cpython-37.pyc
│ └── blender_swap.cpython-310.pyc
├── embedder.py
├── sh.py
├── scannet_blender.py
├── dataset.py
├── fields.py
├── nerf2neus.py
├── ray_utils.py
├── tensorf2neus.py
├── blender_swap.py
├── tensorBase.py
├── renderer.py
└── tensoRF.py
├── confs
├── replica.conf
└── blendswap.conf
├── README.md
├── mesh_metrics.py
├── cull_mesh.py
└── exp_runner.py
/media/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/media/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | open3d
2 | pytorch3d
3 | trimesh
4 | matplotlib
5 | opencv-python
6 | tqdm
7 | scikit-image
--------------------------------------------------------------------------------
/models/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fields.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/fields.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/fields.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/fields.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/embedder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/embedder.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/embedder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/embedder.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ray_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/ray_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/renderer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/renderer.cpython-310.pyc
--------------------------------------------------------------------------------
/models/__pycache__/renderer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/renderer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/blender_swap.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wen-yuan-zhang/NeRFPrior/HEAD/models/__pycache__/blender_swap.cpython-310.pyc
--------------------------------------------------------------------------------
/models/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
6 | class Embedder:
7 | def __init__(self, **kwargs):
8 | self.kwargs = kwargs
9 | self.create_embedding_fn()
10 |
11 | def create_embedding_fn(self):
12 | embed_fns = []
13 | d = self.kwargs['input_dims']
14 | out_dim = 0
15 | if self.kwargs['include_input']:
16 | embed_fns.append(lambda x: x)
17 | out_dim += d
18 |
19 | max_freq = self.kwargs['max_freq_log2']
20 | N_freqs = self.kwargs['num_freqs']
21 |
22 | if self.kwargs['log_sampling']:
23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
24 | else:
25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
26 |
27 | for freq in freq_bands:
28 | for p_fn in self.kwargs['periodic_fns']:
29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 |
39 | def get_embedder(multires, input_dims=3):
40 | embed_kwargs = {
41 | 'include_input': True,
42 | 'input_dims': input_dims,
43 | 'max_freq_log2': multires-1,
44 | 'num_freqs': multires,
45 | 'log_sampling': True,
46 | 'periodic_fns': [torch.sin, torch.cos],
47 | }
48 |
49 | embedder_obj = Embedder(**embed_kwargs)
50 | def embed(x, eo=embedder_obj): return eo.embed(x)
51 | return embed, embedder_obj.out_dim
52 |
--------------------------------------------------------------------------------
/confs/replica.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = log/room0_v2
3 | recording = [
4 | ./,
5 | ./models
6 | ]
7 | }
8 |
9 | dataset {
10 | type = Blender
11 | data_dir = data/Replica
12 | scene = room0
13 | }
14 |
15 | train {
16 | learning_rate = 5e-4
17 | learning_rate_alpha = 0.05
18 | end_iter = 400000
19 |
20 | batch_size = 512
21 | validate_resolution_level = 1
22 | warm_up_end = 5000
23 | anneal_end = 50000
24 | use_white_bkgd = False
25 |
26 | save_freq = 10000
27 | val_freq = 2500
28 | val_mesh_freq = 5000
29 | report_freq = 100
30 |
31 | igr_weight = 0.1
32 | mask_weight = 0.0
33 | }
34 |
35 | model {
36 | nerf {
37 | D = 8,
38 | d_in = 4,
39 | d_in_view = 3,
40 | W = 256,
41 | multires = 10,
42 | multires_view = 4,
43 | output_ch = 4,
44 | skips=[4],
45 | use_viewdirs=True
46 | }
47 |
48 | sdf_network {
49 | d_out = 257
50 | d_in = 3
51 | d_hidden = 256
52 | n_layers = 8
53 | skip_in = [4]
54 | multires = 6
55 | bias = 0.5
56 | scale = 3.0
57 | geometric_init = True
58 | weight_norm = True
59 | }
60 |
61 | variance_network {
62 | init_val = 0.3
63 | }
64 |
65 | rendering_network {
66 | d_feature = 256
67 | mode = idr
68 | d_in = 9
69 | d_out = 3
70 | d_hidden = 256
71 | n_layers = 4
72 | weight_norm = True
73 | multires_view = 4
74 | squeeze_out = True
75 | }
76 |
77 | neus_renderer {
78 | n_samples = 64
79 | n_importance = 64
80 | n_outside = 0
81 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling
82 | perturb = 1.0
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/confs/blendswap.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = log/breakfast_room
3 | recording = [
4 | ./,
5 | ./models
6 | ]
7 | }
8 |
9 | dataset {
10 | type = Blender
11 | data_dir = /data/blendswap
12 | scene = breakfast_room
13 | }
14 |
15 | train {
16 | learning_rate = 5e-4
17 | learning_rate_alpha = 0.05
18 | end_iter = 300000
19 |
20 | batch_size = 512
21 | validate_resolution_level = 1
22 | warm_up_end = 5000
23 | anneal_end = 50000
24 | use_white_bkgd = False
25 |
26 | save_freq = 10000
27 | val_freq = 2500
28 | val_mesh_freq = 5000
29 | report_freq = 100
30 |
31 | igr_weight = 0.1
32 | mask_weight = 0.0
33 | }
34 |
35 | model {
36 | nerf {
37 | D = 8,
38 | d_in = 4,
39 | d_in_view = 3,
40 | W = 256,
41 | multires = 10,
42 | multires_view = 4,
43 | output_ch = 4,
44 | skips=[4],
45 | use_viewdirs=True
46 | }
47 |
48 | sdf_network {
49 | d_out = 257
50 | d_in = 3
51 | d_hidden = 256
52 | n_layers = 8
53 | skip_in = [4]
54 | multires = 6
55 | bias = 0.5
56 | scale = 3.0
57 | geometric_init = True
58 | weight_norm = True
59 | }
60 |
61 | variance_network {
62 | init_val = 0.3
63 | }
64 |
65 | rendering_network {
66 | d_feature = 256
67 | mode = idr
68 | d_in = 9
69 | d_out = 3
70 | d_hidden = 256
71 | n_layers = 4
72 | weight_norm = True
73 | multires_view = 4
74 | squeeze_out = True
75 | }
76 |
77 | neus_renderer {
78 | n_samples = 64
79 | n_importance = 64
80 | n_outside = 0
81 | up_sample_steps = 4 # 1 for simple coarse-to-fine sampling
82 | perturb = 1.0
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | NeRFPrior: Learning Neural Radiance Field as a Prior for Indoor Scene Reconstruction
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | In this paper, we introduce NeRFPrior. Given multi-view images of a scene as input, we first train a grid-based NeRF to obtain the density field and color field as priors. We then learn a signed distance function by imposing a multi-view consistency constraint using volume rendering. For each sampled point on the ray, we query the prior density and prior color as additional supervision of the predicted density and color, respectively. To improve the smoothness and completeness of textureless areas in the scene, we propose a depth consistency loss, which forces surface points in the same textureless plane to have similar depths.
13 |
14 |
15 | # Preprocessed Datasets
16 |
17 | Our preprocessed ScanNet and Replica datasets are provided in [This link](https://huggingface.co/datasets/zParquet/MonoInstance/tree/main).
18 |
19 |
20 | # Setup
21 |
22 | ## Installation
23 |
24 | Clone the repository and create an anaconda environment using
25 | ```shell
26 | git clone https://github.com/wen-yuan-zhang/NeRFPrior.git
27 | cd NeRFPrior
28 |
29 | conda create -n nerfprior python=3.10
30 | conda activate nerfprior
31 |
32 | conda install pytorch=1.13.0 torchvision=0.14.0 cudatoolkit=11.7 -c pytorch
33 | conda install cudatoolkit-dev=11.7 -c conda-forge
34 |
35 | pip install -r requirements.txt
36 | ```
37 |
38 | You should also clone TensoRF to obtain NeRF prior.
39 | ```shell
40 | git clone https://github.com/apchenstu/TensoRF.git
41 | ```
42 |
43 |
44 | # Training
45 |
46 | Firstly, train TensoRF on the given scene to obtain the ```.th``` checkpoint for further neural implicit surface reconstruction.
47 |
48 | To train BlendSwap dataset, use
49 | ```shell
50 | CUDA_VISIBLE_DEVICES=1 python exp_runner.py --conf confs/blendswap.conf
51 | ```
52 | To train Replica dataset, use
53 | ```shell
54 | CUDA_VISIBLE_DEVICES=1 python exp_runner.py --conf confs/replica.conf
55 | ```
56 |
57 |
58 |
59 |
60 | # Evaluation
61 |
62 | To evaluate the reconstructed meshes first use ```cull_mesh.py``` to cull the meshes according to view frustums. Then use ```mesh_metrics.py``` to specify the path to meshes and evaluate the metrics.
63 | ```shell
64 | python cull_mesh.py
65 | python mesh_metrics.py
66 | ```
67 |
68 |
69 |
70 |
71 | # Acknowledgements
72 |
73 | This project is built upon [NeuS](https://lingjie0206.github.io/papers/NeuS/) and [TensoRF](https://apchenstu.github.io/TensoRF/). We thank all the authors for their great repos.
74 |
75 |
76 | # Citation
77 |
78 | If you find our code or paper useful, please consider citing
79 | ```bibtex
80 | @inproceedings{zhang2025nerfprior,
81 | title={{NeRFPrior}: Learning neural radiance field as a prior for indoor scene reconstruction},
82 | author={Zhang, Wenyuan and Jia, Emily Yue-ting and Zhou, Junsheng and Ma, Baorui and Shi, Kanle and Liu, Yu-Shen and Han, Zhizhong},
83 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
84 | pages={11317--11327},
85 | year={2025}
86 | }
87 | ```
88 |
89 |
--------------------------------------------------------------------------------
/models/sh.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | ################## sh function ##################
4 | C0 = 0.28209479177387814
5 | C1 = 0.4886025119029199
6 | C2 = [
7 | 1.0925484305920792,
8 | -1.0925484305920792,
9 | 0.31539156525252005,
10 | -1.0925484305920792,
11 | 0.5462742152960396
12 | ]
13 | C3 = [
14 | -0.5900435899266435,
15 | 2.890611442640554,
16 | -0.4570457994644658,
17 | 0.3731763325901154,
18 | -0.4570457994644658,
19 | 1.445305721320277,
20 | -0.5900435899266435
21 | ]
22 | C4 = [
23 | 2.5033429417967046,
24 | -1.7701307697799304,
25 | 0.9461746957575601,
26 | -0.6690465435572892,
27 | 0.10578554691520431,
28 | -0.6690465435572892,
29 | 0.47308734787878004,
30 | -1.7701307697799304,
31 | 0.6258357354491761,
32 | ]
33 |
34 | def eval_sh(deg, sh, dirs):
35 | """
36 | Evaluate spherical harmonics at unit directions
37 | using hardcoded SH polynomials.
38 | Works with torch/np/jnp.
39 | ... Can be 0 or more batch dimensions.
40 | :param deg: int SH max degree. Currently, 0-4 supported
41 | :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
42 | :param dirs: torch.Tensor unit directions (..., 3)
43 | :return: (..., C)
44 | """
45 | assert deg <= 4 and deg >= 0
46 | assert (deg + 1) ** 2 == sh.shape[-1]
47 | C = sh.shape[-2]
48 |
49 | result = C0 * sh[..., 0]
50 | if deg > 0:
51 | x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
52 | result = (result -
53 | C1 * y * sh[..., 1] +
54 | C1 * z * sh[..., 2] -
55 | C1 * x * sh[..., 3])
56 | if deg > 1:
57 | xx, yy, zz = x * x, y * y, z * z
58 | xy, yz, xz = x * y, y * z, x * z
59 | result = (result +
60 | C2[0] * xy * sh[..., 4] +
61 | C2[1] * yz * sh[..., 5] +
62 | C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
63 | C2[3] * xz * sh[..., 7] +
64 | C2[4] * (xx - yy) * sh[..., 8])
65 |
66 | if deg > 2:
67 | result = (result +
68 | C3[0] * y * (3 * xx - yy) * sh[..., 9] +
69 | C3[1] * xy * z * sh[..., 10] +
70 | C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
71 | C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
72 | C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
73 | C3[5] * z * (xx - yy) * sh[..., 14] +
74 | C3[6] * x * (xx - 3 * yy) * sh[..., 15])
75 | if deg > 3:
76 | result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
77 | C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
78 | C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
79 | C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
80 | C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
81 | C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
82 | C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
83 | C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
84 | C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
85 | return result
86 |
87 | def eval_sh_bases(deg, dirs):
88 | """
89 | Evaluate spherical harmonics bases at unit directions,
90 | without taking linear combination.
91 | At each point, the final result may the be
92 | obtained through simple multiplication.
93 | :param deg: int SH max degree. Currently, 0-4 supported
94 | :param dirs: torch.Tensor (..., 3) unit directions
95 | :return: torch.Tensor (..., (deg+1) ** 2)
96 | """
97 | assert deg <= 4 and deg >= 0
98 | result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
99 | result[..., 0] = C0
100 | if deg > 0:
101 | x, y, z = dirs.unbind(-1)
102 | result[..., 1] = -C1 * y;
103 | result[..., 2] = C1 * z;
104 | result[..., 3] = -C1 * x;
105 | if deg > 1:
106 | xx, yy, zz = x * x, y * y, z * z
107 | xy, yz, xz = x * y, y * z, x * z
108 | result[..., 4] = C2[0] * xy;
109 | result[..., 5] = C2[1] * yz;
110 | result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
111 | result[..., 7] = C2[3] * xz;
112 | result[..., 8] = C2[4] * (xx - yy);
113 |
114 | if deg > 2:
115 | result[..., 9] = C3[0] * y * (3 * xx - yy);
116 | result[..., 10] = C3[1] * xy * z;
117 | result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
118 | result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
119 | result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
120 | result[..., 14] = C3[5] * z * (xx - yy);
121 | result[..., 15] = C3[6] * x * (xx - 3 * yy);
122 |
123 | if deg > 3:
124 | result[..., 16] = C4[0] * xy * (xx - yy);
125 | result[..., 17] = C4[1] * yz * (3 * xx - yy);
126 | result[..., 18] = C4[2] * xy * (7 * zz - 1);
127 | result[..., 19] = C4[3] * yz * (7 * zz - 3);
128 | result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
129 | result[..., 21] = C4[5] * xz * (7 * zz - 3);
130 | result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
131 | result[..., 23] = C4[7] * xz * (xx - 3 * yy);
132 | result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
133 | return result
134 |
--------------------------------------------------------------------------------
/mesh_metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import trimesh
3 | from scipy.spatial import cKDTree
4 | import numpy as np
5 | import open3d as o3d
6 | # from matplotlib import pyplot as plt
7 |
8 |
9 | def compute_iou(mesh_pred, mesh_target):
10 | res = 0.05
11 | v_pred = mesh_pred.voxelized(pitch=res)
12 | v_target = mesh_target.voxelized(pitch=res)
13 | v_target_mesh = v_target.as_boxes()
14 | v_pred_mesh = v_pred.as_boxes()
15 |
16 | v_pred_filled = set(tuple(np.round(x, 4)) for x in v_pred.points)
17 | v_target_filled = set(tuple(np.round(x, 4)) for x in v_target.points)
18 | inter = v_pred_filled.intersection(v_target_filled)
19 | union = v_pred_filled.union(v_target_filled)
20 | iou = len(inter) / len(union)
21 | return iou, v_target_mesh, v_pred_mesh
22 |
23 |
24 | def get_colored_pcd(pcd, metric):
25 | cmap = plt.cm.get_cmap("jet")
26 | color = cmap(metric / 0.10)[..., :3]
27 | pcd_o3d = o3d.geometry.PointCloud()
28 | pcd_o3d.points = o3d.utility.Vector3dVector(pcd)
29 | pcd_o3d.colors = o3d.utility.Vector3dVector(color)
30 | return pcd_o3d
31 |
32 |
33 | def compute_metrics(mesh_pred, mesh_target):
34 | # mesh_pred = trimesh.load_mesh(path_pred)
35 | # mesh_target = trimesh.load_mesh(path_target)
36 | area_pred = int(mesh_pred.area * 1e4)
37 | area_tgt = int(mesh_target.area * 1e4)
38 | print("pred: {}, target: {}".format(area_pred, area_tgt))
39 |
40 | iou, v_gt, v_pred = compute_iou(mesh_pred, mesh_target)
41 |
42 | pointcloud_pred, idx = mesh_pred.sample(area_pred, return_index=True)
43 | pointcloud_pred = pointcloud_pred.astype(np.float32)
44 | normals_pred = mesh_pred.face_normals[idx]
45 |
46 | pointcloud_tgt, idx = mesh_target.sample(area_tgt, return_index=True)
47 | pointcloud_tgt = pointcloud_tgt.astype(np.float32)
48 | normals_tgt = mesh_target.face_normals[idx]
49 |
50 | thresholds = np.array([0.05])
51 |
52 | # for every point in gt compute the min distance to points in pred
53 | completeness, completeness_normals = distance_p2p(
54 | pointcloud_tgt, normals_tgt, pointcloud_pred, normals_pred
55 | )
56 | recall = get_threshold_percentage(completeness, thresholds)
57 | completeness2 = completeness ** 2
58 |
59 | # color gt_point_cloud using completion
60 | # com_mesh = get_colored_pcd(pointcloud_tgt, completeness)
61 |
62 | completeness = completeness.mean()
63 | completeness2 = completeness2.mean()
64 | completeness_normals = completeness_normals.mean()
65 |
66 | # Accuracy: how far are th points of the predicted pointcloud
67 | # from the target pointcloud
68 | accuracy, accuracy_normals = distance_p2p(
69 | pointcloud_pred, normals_pred, pointcloud_tgt, normals_tgt
70 | )
71 | precision = get_threshold_percentage(accuracy, thresholds)
72 | accuracy2 = accuracy ** 2
73 |
74 | # color pred_point_cloud using completion
75 | # acc_mesh = get_colored_pcd(pointcloud_pred, accuracy)
76 |
77 | accuracy = accuracy.mean()
78 | accuracy2 = accuracy2.mean()
79 | accuracy_normals = accuracy_normals.mean()
80 |
81 | # Chamfer distance
82 | chamferL2 = 0.5 * (completeness2 + accuracy2)
83 | normals_correctness = (
84 | 0.5 * completeness_normals + 0.5 * accuracy_normals
85 | )
86 | chamferL1 = 0.5 * (completeness + accuracy)
87 |
88 | # F-Score
89 | F = [
90 | 2 * precision[i] * recall[i] / (precision[i] + recall[i])
91 | for i in range(len(precision))
92 | ]
93 | rst = {
94 | "IoU": iou,
95 | # "Acc": accuracy,
96 | # "Comp": completeness,
97 | "C-L1": chamferL1,
98 | "NC": normals_correctness,
99 | 'precision': precision[0],
100 | 'recall': recall[0],
101 | "F-score": F[0]
102 | }
103 |
104 | return rst
105 |
106 |
107 | def distance_p2p(points_src, normals_src, points_tgt, normals_tgt):
108 | """ Computes minimal distances of each point in points_src to points_tgt.
109 | Args:
110 | points_src (numpy array): source points
111 | normals_src (numpy array): source normals
112 | points_tgt (numpy array): target points
113 | normals_tgt (numpy array): target normals
114 | """
115 | kdtree = cKDTree(points_tgt)
116 | dist, idx = kdtree.query(points_src)
117 |
118 | if normals_src is not None and normals_tgt is not None:
119 | normals_src = \
120 | normals_src / np.linalg.norm(normals_src, axis=-1, keepdims=True)
121 | normals_tgt = \
122 | normals_tgt / np.linalg.norm(normals_tgt, axis=-1, keepdims=True)
123 |
124 | normals_dot_product = (normals_tgt[idx] * normals_src).sum(axis=-1)
125 | # Handle normals that point into wrong direction gracefully
126 | # (mostly due to mehtod not caring about this in generation)
127 | normals_dot_product = np.abs(normals_dot_product)
128 | else:
129 | normals_dot_product = np.array(
130 | [np.nan] * points_src.shape[0], dtype=np.float32)
131 | return dist, normals_dot_product
132 |
133 |
134 | def get_threshold_percentage(dist, thresholds):
135 | """ Evaluates a point cloud.
136 | Args:
137 | dist (numpy array): calculated distance
138 | thresholds (numpy array): threshold values for the F-score calculation
139 | """
140 | in_threshold = [
141 | (dist <= t).astype(np.float32).mean() for t in thresholds
142 | ]
143 | return in_threshold
144 |
145 |
146 | def save_meshes(meshes, mesh_dir, save_name):
147 | for key in meshes:
148 | mesh = meshes[key]
149 | if isinstance(mesh, o3d.geometry.PointCloud):
150 | o3d.io.write_point_cloud(os.path.join(mesh_dir, "{}_{}.ply".format(key, save_name)), mesh)
151 | elif isinstance(mesh, trimesh.Trimesh):
152 | mesh.export(os.path.join(mesh_dir, "{}_{}.ply".format(key, save_name)))
153 |
154 |
155 | if __name__ == '__main__':
156 | gt_mesh_name = 'out_meshes/room0/gt_cropped_culled_transformed.ply'
157 | ours_mesh_name = 'out_meshes/room0/neus_cropped_culled_transformed.ply'
158 |
159 | gt_mesh = trimesh.load_mesh(gt_mesh_name)
160 | # neuralrgbd_mesh = trimesh.load_mesh(neuralrgbd_mesh_name)
161 | ours_mesh = trimesh.load(ours_mesh_name)
162 |
163 | # iou1, target_mesh, pred_mesh = compute_iou(neuralrgbd_mesh, gt_mesh)
164 | # iou2, target_mesh, pred_mesh = compute_iou(ours_mesh, gt_mesh)
165 | # print(iou1)
166 | # print(iou2)
167 |
168 | # ret = compute_metrics(neuralrgbd_mesh, gt_mesh)
169 | # print(ret)
170 | ret = compute_metrics(ours_mesh, gt_mesh)
171 | print(ret)
172 |
--------------------------------------------------------------------------------
/models/scannet_blender.py:
--------------------------------------------------------------------------------
1 | from cv2 import cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision import transforms as T
9 |
10 | import torch
11 | from kornia import create_meshgrid
12 |
13 | os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
14 |
15 | def get_ray_directions(H, W, focal, center=None):
16 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
17 |
18 | i, j = grid.unbind(-1)
19 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
20 | # see https://github.com/bmild/nerf/issues/24
21 | cent = center if center is not None else [W / 2, H / 2]
22 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
23 |
24 | return directions
25 |
26 | def get_rays(directions, c2w):
27 | # Rotate ray directions from camera coordinate to the world coordinate
28 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
29 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
30 | # The origin of all rays is the camera origin in world coordinate
31 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
32 |
33 | rays_d = rays_d.view(-1, 3)
34 | rays_o = rays_o.view(-1, 3)
35 |
36 | return rays_o, rays_d
37 |
38 | class ScanNetDataset(Dataset):
39 | def __init__(self, conf, split='train', N_vis=-1):
40 | self.device = torch.device('cuda')
41 | self.N_vis = N_vis
42 | self.root_dir = conf.get_string('data_dir')
43 | scene = conf.get_string('scene')
44 |
45 | self.split = split
46 | self.is_stack = False
47 | self.downsample = 1.0
48 | self.transform = T.ToTensor()
49 |
50 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
51 |
52 | self.poses = []
53 | self.all_rays = []
54 | self.all_rgbs = []
55 | self.all_depth = []
56 | self.directions = []
57 | self.read_meta(scene)
58 | self.n_images = len(self.all_rgbs)
59 | self.white_bg = True
60 |
61 | # for Neus exp_runner
62 | self.pose_all = self.poses
63 | self.object_bbox_min = np.array([-1.01, -1.01, -1.01]) # only used in extract
64 | self.object_bbox_max = np.array([ 1.01, 1.01, 1.01])
65 | self.object_bbox_max = np.array([0.6, 0.6, 0.6])
66 | self.object_bbox_min = np.array([-0.6, -0.6, -0.6])
67 | self.scale_mats_np = [np.eye(4)]
68 |
69 | def read_meta(self, scene):
70 | root_dir = os.path.join(self.root_dir, scene)
71 | with open(os.path.join(root_dir, f"transforms_{self.split}.json"), 'r') as f:
72 | self.meta = json.load(f)
73 |
74 | w, h = int(self.meta['w'] / self.downsample), int(self.meta['h'] / self.downsample)
75 | self.img_wh = [w, h]
76 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
77 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
78 | self.cx, self.cy = self.meta['cx'], self.meta['cy']
79 |
80 | # ray directions for all pixels, same for all images (same H, W, focal)
81 | direction = get_ray_directions(h, w, [self.focal_x, self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
82 | self.directions.append(direction)
83 | # self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True) # TODO
84 | self.intrinsics = torch.tensor([[self.focal_x, 0, self.cx], [0, self.focal_y, self.cy], [0, 0, 1]]).float()
85 |
86 | idxs = list(range(0, len(self.meta['frames'])))
87 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
88 | frame = self.meta['frames'][i]
89 | pose = np.array(frame['transform_matrix'])
90 | pose = pose @ self.blender2opencv
91 | c2w = torch.FloatTensor(pose)
92 | self.poses.append(c2w)
93 | image_path = os.path.join(root_dir, f"{frame['file_path']}")
94 | image_id = int(os.path.basename(image_path)[:-4])
95 | img = Image.open(image_path)
96 | img = self.transform(img) # (4, h, w)
97 | img = img.permute(1, 2, 0) # (h*w, 4) RGBA
98 | if img.shape[-1] == 4:
99 | img = img[..., :3] * img[..., -1:] + (1 - img[..., -1:]) # blend A to RGB
100 | self.all_rgbs.append(img)
101 |
102 |
103 | self.poses = torch.stack(self.poses) #(N, 4, 4)
104 |
105 | rays_o = self.poses[:, :3, 3]
106 |
107 | # self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 6)
108 | # self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
109 |
110 | self.h, self.w = h, w
111 |
112 | # def define_proj_mat(self):
113 | # self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
114 |
115 | def __len__(self):
116 | return len(self.all_rgbs)
117 |
118 | # 多个场景版本的get_random_rays
119 | def gen_random_rays_at(self, img_idx, batch_size):
120 | scene_id = 0
121 | pixels_x = torch.randint(low=0, high=self.w, size=[batch_size])
122 | pixels_y = torch.randint(low=0, high=self.h, size=[batch_size])
123 | color = self.all_rgbs[img_idx] # [h, w, 3]
124 | color = color[(pixels_y, pixels_x)].cpu() # [batch_size, 3]
125 | # depth = self.all_depth[img_idx].cuda()[(pixels_y, pixels_x)] # [batch_size,]
126 | # depth = depth.unsqueeze(-1).cpu()
127 | depth = torch.ones(batch_size, 1).cpu()
128 | mask = torch.ones_like(color, dtype=torch.float)
129 | directions = self.directions[scene_id]
130 | batch_direction = directions[(pixels_y, pixels_x)]
131 | rays_o, rays_d = get_rays(batch_direction, self.poses[img_idx]) # both (batch_size, 3)
132 | rand_rays = torch.cat([rays_o, rays_d], -1)
133 | return torch.cat([rand_rays, color, mask[:, :1]], dim=-1).to(self.device)
134 |
135 | def near_far_from_sphere(self, rays_o, rays_d):
136 | near = torch.zeros(rays_o.shape[0], 1).cuda()
137 | far = torch.ones(rays_o.shape[0], 1).cuda() * 2
138 | return near, far
139 |
140 | def gen_rays_at(self, img_idx, resolution_level=1):
141 | tx = torch.linspace(0, self.w-resolution_level, self.w // resolution_level)
142 | ty = torch.linspace(0, self.h-resolution_level, self.h // resolution_level)
143 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
144 | directions = self.directions[0]
145 | batch_direction = directions[(pixels_y.long(), pixels_x.long())]
146 | rays_o, rays_d = get_rays(batch_direction, self.poses[img_idx]) # both (batch_size, 3)
147 | rays_o = rays_o.reshape(self.w//resolution_level, self.h//resolution_level, 3).to(self.device).transpose(0,1)
148 | rays_d = rays_d.reshape(self.w//resolution_level, self.h//resolution_level, 3).to(self.device).transpose(0,1)
149 | return rays_o, rays_d
150 |
151 | def image_at(self, idx, resolution_level):
152 | img = self.all_rgbs[idx].cpu().numpy() * 255
153 | img = cv2.resize(img, (self.w//resolution_level, self.h//resolution_level))
154 | return img[:,:,[2,1,0]]
155 |
156 | def get_scene_id(self, img_idx):
157 | return 0
158 |
159 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
160 | # only used in novel view synthesis
161 | raise NotImplementedError()
162 |
163 |
164 |
--------------------------------------------------------------------------------
/models/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import cv2 as cv
4 | import numpy as np
5 | import os
6 | from glob import glob
7 | # from icecream import ic
8 | from scipy.spatial.transform import Rotation as Rot
9 | from scipy.spatial.transform import Slerp
10 |
11 |
12 | # This function is borrowed from IDR: https://github.com/lioryariv/idr
13 | def load_K_Rt_from_P(filename, P=None):
14 | if P is None:
15 | lines = open(filename).read().splitlines()
16 | if len(lines) == 4:
17 | lines = lines[1:]
18 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
19 | P = np.asarray(lines).astype(np.float32).squeeze()
20 |
21 | out = cv.decomposeProjectionMatrix(P)
22 | K = out[0]
23 | R = out[1]
24 | t = out[2]
25 |
26 | K = K / K[2, 2]
27 | intrinsics = np.eye(4)
28 | intrinsics[:3, :3] = K
29 |
30 | pose = np.eye(4, dtype=np.float32)
31 | pose[:3, :3] = R.transpose()
32 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
33 |
34 | return intrinsics, pose
35 |
36 |
37 | class Dataset:
38 | def __init__(self, conf):
39 | super(Dataset, self).__init__()
40 | print('Load data: Begin')
41 | self.device = torch.device('cuda')
42 | self.conf = conf
43 |
44 | self.data_dir = conf.get_string('data_dir')
45 | self.render_cameras_name = conf.get_string('render_cameras_name')
46 | self.object_cameras_name = conf.get_string('object_cameras_name')
47 |
48 | self.camera_outside_sphere = conf.get_bool('camera_outside_sphere', default=True)
49 | self.scale_mat_scale = conf.get_float('scale_mat_scale', default=1.1)
50 |
51 | camera_dict = np.load(os.path.join(self.data_dir, self.render_cameras_name))
52 | self.camera_dict = camera_dict
53 | self.images_lis = sorted(glob(os.path.join(self.data_dir, 'image/*.png')))
54 | self.n_images = len(self.images_lis)
55 | self.images_np = np.stack([cv.imread(im_name) for im_name in self.images_lis]) / 256.0
56 | self.masks_lis = sorted(glob(os.path.join(self.data_dir, 'mask/*.png')))
57 | self.masks_np = np.stack([cv.imread(im_name) for im_name in self.masks_lis]) / 256.0
58 |
59 | # world_mat is a projection matrix from world to image
60 | self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
61 |
62 | self.scale_mats_np = []
63 |
64 | # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
65 | self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
66 |
67 | self.intrinsics_all = []
68 | self.pose_all = []
69 |
70 | for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np):
71 | P = world_mat @ scale_mat
72 | P = P[:3, :4]
73 | intrinsics, pose = load_K_Rt_from_P(None, P)
74 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
75 | self.pose_all.append(torch.from_numpy(pose).float())
76 |
77 | self.images = torch.from_numpy(self.images_np.astype(np.float32)).cpu() # [n_images, H, W, 3]
78 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32)).cpu() # [n_images, H, W, 3]
79 | self.intrinsics_all = torch.stack(self.intrinsics_all).to(self.device) # [n_images, 4, 4]
80 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) # [n_images, 4, 4]
81 | self.focal = self.intrinsics_all[0][0, 0]
82 | self.pose_all = torch.stack(self.pose_all).to(self.device) # [n_images, 4, 4]
83 | self.H, self.W = self.images.shape[1], self.images.shape[2]
84 | self.image_pixels = self.H * self.W
85 |
86 | object_bbox_min = np.array([-1.01, -1.01, -1.01, 1.0])
87 | object_bbox_max = np.array([ 1.01, 1.01, 1.01, 1.0])
88 | # Object scale mat: region of interest to **extract mesh**
89 | object_scale_mat = np.load(os.path.join(self.data_dir, self.object_cameras_name))['scale_mat_0']
90 | object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
91 | object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
92 | self.object_bbox_min = object_bbox_min[:3, 0]
93 | self.object_bbox_max = object_bbox_max[:3, 0]
94 |
95 | print('Load data: End')
96 |
97 | def gen_rays_at(self, img_idx, resolution_level=1):
98 | """
99 | Generate rays at world space from one camera.
100 | """
101 | l = resolution_level
102 | tx = torch.linspace(0, self.W - 1, self.W // l)
103 | ty = torch.linspace(0, self.H - 1, self.H // l)
104 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
105 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
106 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
107 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
108 | rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
109 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape) # W, H, 3
110 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
111 |
112 | def gen_random_rays_at(self, img_idx, batch_size):
113 | """
114 | Generate random rays at world space from one camera.
115 | """
116 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size])
117 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size])
118 | color = self.images[img_idx][(pixels_y, pixels_x)] # batch_size, 3
119 | mask = self.masks[img_idx][(pixels_y, pixels_x)] # batch_size, 3
120 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float() # batch_size, 3
121 | p = torch.matmul(self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]).squeeze() # batch_size, 3
122 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
123 | rays_v = torch.matmul(self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]).squeeze() # batch_size, 3
124 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(rays_v.shape) # batch_size, 3
125 | return torch.cat([rays_o.cpu(), rays_v.cpu(), color, mask[:, :1]], dim=-1).cuda() # batch_size, 10
126 |
127 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
128 | """
129 | Interpolate pose between two cameras.
130 | """
131 | l = resolution_level
132 | tx = torch.linspace(0, self.W - 1, self.W // l)
133 | ty = torch.linspace(0, self.H - 1, self.H // l)
134 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
135 | p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
136 | p = torch.matmul(self.intrinsics_all_inv[0, None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
137 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
138 | trans = self.pose_all[idx_0, :3, 3] * (1.0 - ratio) + self.pose_all[idx_1, :3, 3] * ratio
139 | pose_0 = self.pose_all[idx_0].detach().cpu().numpy()
140 | pose_1 = self.pose_all[idx_1].detach().cpu().numpy()
141 | pose_0 = np.linalg.inv(pose_0)
142 | pose_1 = np.linalg.inv(pose_1)
143 | rot_0 = pose_0[:3, :3]
144 | rot_1 = pose_1[:3, :3]
145 | rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
146 | key_times = [0, 1]
147 | slerp = Slerp(key_times, rots)
148 | rot = slerp(ratio)
149 | pose = np.diag([1.0, 1.0, 1.0, 1.0])
150 | pose = pose.astype(np.float32)
151 | pose[:3, :3] = rot.as_matrix()
152 | pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
153 | pose = np.linalg.inv(pose)
154 | rot = torch.from_numpy(pose[:3, :3]).cuda()
155 | trans = torch.from_numpy(pose[:3, 3]).cuda()
156 | rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
157 | rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3
158 | return rays_o.transpose(0, 1), rays_v.transpose(0, 1)
159 |
160 | def near_far_from_sphere(self, rays_o, rays_d):
161 | a = torch.sum(rays_d**2, dim=-1, keepdim=True)
162 | b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
163 | mid = 0.5 * (-b) / a
164 | near = mid - 1.0
165 | far = mid + 1.0
166 | return near, far
167 |
168 | def image_at(self, idx, resolution_level):
169 | img = cv.imread(self.images_lis[idx])
170 | return (cv.resize(img, (self.W // resolution_level, self.H // resolution_level))).clip(0, 255)
171 |
172 |
--------------------------------------------------------------------------------
/models/fields.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from models.embedder import get_embedder
6 |
7 |
8 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
9 | class SDFNetwork(nn.Module):
10 | def __init__(self,
11 | d_in,
12 | d_out,
13 | d_hidden,
14 | n_layers,
15 | skip_in=(4,),
16 | multires=0,
17 | bias=0.5,
18 | scale=1,
19 | geometric_init=True,
20 | weight_norm=True,
21 | inside_outside=False):
22 | super(SDFNetwork, self).__init__()
23 |
24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
25 |
26 | self.embed_fn_fine = None
27 |
28 | if multires > 0:
29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
30 | self.embed_fn_fine = embed_fn
31 | dims[0] = input_ch
32 |
33 | self.num_layers = len(dims)
34 | self.skip_in = skip_in
35 | self.scale = scale
36 |
37 | for l in range(0, self.num_layers - 1):
38 | if l + 1 in self.skip_in:
39 | out_dim = dims[l + 1] - dims[0]
40 | else:
41 | out_dim = dims[l + 1]
42 |
43 | lin = nn.Linear(dims[l], out_dim)
44 |
45 | if geometric_init:
46 | if l == self.num_layers - 2:
47 | if not inside_outside:
48 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
49 | torch.nn.init.constant_(lin.bias, -bias)
50 | else:
51 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
52 | torch.nn.init.constant_(lin.bias, bias)
53 | elif multires > 0 and l == 0:
54 | torch.nn.init.constant_(lin.bias, 0.0)
55 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
56 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
57 | elif multires > 0 and l in self.skip_in:
58 | torch.nn.init.constant_(lin.bias, 0.0)
59 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
60 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
61 | else:
62 | torch.nn.init.constant_(lin.bias, 0.0)
63 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
64 |
65 | if weight_norm:
66 | lin = nn.utils.weight_norm(lin)
67 |
68 | setattr(self, "lin" + str(l), lin)
69 |
70 | self.activation = nn.Softplus(beta=100)
71 |
72 | def forward(self, inputs):
73 | inputs = inputs * self.scale
74 | if self.embed_fn_fine is not None:
75 | inputs = self.embed_fn_fine(inputs)
76 |
77 | x = inputs
78 | for l in range(0, self.num_layers - 1):
79 | lin = getattr(self, "lin" + str(l))
80 |
81 | if l in self.skip_in:
82 | x = torch.cat([x, inputs], 1) / np.sqrt(2)
83 |
84 | x = lin(x)
85 |
86 | if l < self.num_layers - 2:
87 | x = self.activation(x)
88 | return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1)
89 |
90 | def sdf(self, x):
91 | return self.forward(x)[:, :1]
92 |
93 | def sdf_hidden_appearance(self, x):
94 | return self.forward(x)
95 |
96 | def gradient(self, x): # gradients assumed to be normalized because of eikonal_loss
97 | x.requires_grad_(True)
98 | y = self.sdf(x)
99 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
100 | gradients = torch.autograd.grad(
101 | outputs=y,
102 | inputs=x,
103 | grad_outputs=d_output,
104 | create_graph=True,
105 | retain_graph=True,
106 | only_inputs=True)[0]
107 | return gradients.unsqueeze(1)
108 |
109 |
110 | # This implementation is borrowed from IDR: https://github.com/lioryariv/idr
111 | class RenderingNetwork(nn.Module):
112 | def __init__(self,
113 | d_feature,
114 | mode,
115 | d_in,
116 | d_out,
117 | d_hidden,
118 | n_layers,
119 | weight_norm=True,
120 | multires_view=0,
121 | squeeze_out=True):
122 | super().__init__()
123 |
124 | self.mode = mode
125 | self.squeeze_out = squeeze_out
126 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
127 |
128 | self.embedview_fn = None
129 | if multires_view > 0:
130 | embedview_fn, input_ch = get_embedder(multires_view)
131 | self.embedview_fn = embedview_fn
132 | dims[0] += (input_ch - 3)
133 |
134 | self.num_layers = len(dims)
135 |
136 | for l in range(0, self.num_layers - 1):
137 | out_dim = dims[l + 1]
138 | lin = nn.Linear(dims[l], out_dim)
139 |
140 | if weight_norm:
141 | lin = nn.utils.weight_norm(lin)
142 |
143 | setattr(self, "lin" + str(l), lin)
144 |
145 | self.relu = nn.ReLU()
146 |
147 | def forward(self, points, normals, view_dirs, feature_vectors):
148 | if self.embedview_fn is not None:
149 | view_dirs = self.embedview_fn(view_dirs)
150 |
151 | rendering_input = None
152 |
153 | if self.mode == 'idr':
154 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
155 | elif self.mode == 'no_view_dir':
156 | rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
157 | elif self.mode == 'no_normal':
158 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
159 |
160 | x = rendering_input
161 |
162 | for l in range(0, self.num_layers - 1):
163 | lin = getattr(self, "lin" + str(l))
164 |
165 | x = lin(x)
166 |
167 | if l < self.num_layers - 2:
168 | x = self.relu(x)
169 |
170 | if self.squeeze_out:
171 | x = torch.sigmoid(x)
172 | return x
173 |
174 |
175 | # This implementation is borrowed from nerf-pytorch: https://github.com/yenchenlin/nerf-pytorch
176 | class NeRF(nn.Module):
177 | def __init__(self,
178 | D=8,
179 | W=256,
180 | d_in=3,
181 | d_in_view=3,
182 | multires=0,
183 | multires_view=0,
184 | output_ch=4,
185 | skips=[4],
186 | use_viewdirs=False):
187 | super(NeRF, self).__init__()
188 | self.D = D
189 | self.W = W
190 | self.d_in = d_in
191 | self.d_in_view = d_in_view
192 | self.input_ch = 3
193 | self.input_ch_view = 3
194 | self.embed_fn = None
195 | self.embed_fn_view = None
196 |
197 | if multires > 0:
198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
199 | self.embed_fn = embed_fn
200 | self.input_ch = input_ch
201 |
202 | if multires_view > 0:
203 | embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view)
204 | self.embed_fn_view = embed_fn_view
205 | self.input_ch_view = input_ch_view
206 |
207 | self.skips = skips
208 | self.use_viewdirs = use_viewdirs
209 |
210 | self.pts_linears = nn.ModuleList(
211 | [nn.Linear(self.input_ch, W)] +
212 | [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) for i in range(D - 1)])
213 |
214 | ### Implementation according to the official code release
215 | ### (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
216 | self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)])
217 |
218 | ### Implementation according to the paper
219 | # self.views_linears = nn.ModuleList(
220 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
221 |
222 | if use_viewdirs:
223 | self.feature_linear = nn.Linear(W, W)
224 | self.alpha_linear = nn.Linear(W, 1)
225 | self.rgb_linear = nn.Linear(W // 2, 3)
226 | else:
227 | self.output_linear = nn.Linear(W, output_ch)
228 |
229 | def forward(self, input_pts, input_views):
230 | if self.embed_fn is not None:
231 | input_pts = self.embed_fn(input_pts)
232 | if self.embed_fn_view is not None:
233 | input_views = self.embed_fn_view(input_views)
234 |
235 | h = input_pts
236 | for i, l in enumerate(self.pts_linears):
237 | h = self.pts_linears[i](h)
238 | h = F.relu(h)
239 | if i in self.skips:
240 | h = torch.cat([input_pts, h], -1)
241 |
242 | if self.use_viewdirs:
243 | alpha = self.alpha_linear(h)
244 | feature = self.feature_linear(h)
245 | h = torch.cat([feature, input_views], -1)
246 |
247 | for i, l in enumerate(self.views_linears):
248 | h = self.views_linears[i](h)
249 | h = F.relu(h)
250 |
251 | rgb = self.rgb_linear(h)
252 | return alpha, rgb
253 | else:
254 | assert False
255 |
256 |
257 | class SingleVarianceNetwork(nn.Module):
258 | def __init__(self, init_val):
259 | super(SingleVarianceNetwork, self).__init__()
260 | self.register_parameter('variance', nn.Parameter(torch.tensor(init_val)))
261 |
262 | def forward(self, x):
263 | return torch.ones([len(x), 1]) * torch.exp(self.variance * 10.0)
264 |
--------------------------------------------------------------------------------
/models/nerf2neus.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import mcubes
3 | import torch
4 | from models.tensoRF import TensorVMSplit
5 | import numpy as np
6 | import skimage
7 | import plyfile
8 | import os
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class PriorNeRF(nn.Module):
14 | def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
15 | """
16 | """
17 | super(PriorNeRF, self).__init__()
18 | self.D = D
19 | self.W = W
20 | self.input_ch = input_ch
21 | self.input_ch_views = input_ch_views
22 | self.skips = skips
23 | self.use_viewdirs = use_viewdirs
24 |
25 | self.pts_linears = nn.ModuleList(
26 | [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in
27 | range(D - 1)])
28 |
29 | ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
30 | self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W // 2)])
31 |
32 | ### Implementation according to the paper
33 | # self.views_linears = nn.ModuleList(
34 | # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
35 |
36 | if use_viewdirs:
37 | self.feature_linear = nn.Linear(W, W)
38 | self.alpha_linear = nn.Linear(W, 1)
39 | self.rgb_linear = nn.Linear(W // 2, 3)
40 | else:
41 | self.output_linear = nn.Linear(W, output_ch)
42 |
43 | def forward(self, x):
44 | input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)
45 | h = input_pts
46 | for i, l in enumerate(self.pts_linears):
47 | h = self.pts_linears[i](h)
48 | h = F.relu(h)
49 | if i in self.skips:
50 | h = torch.cat([input_pts, h], -1)
51 |
52 | if self.use_viewdirs:
53 | alpha = self.alpha_linear(h)
54 | feature = self.feature_linear(h)
55 | h = torch.cat([feature, input_views], -1)
56 |
57 | for i, l in enumerate(self.views_linears):
58 | h = self.views_linears[i](h)
59 | h = F.relu(h)
60 |
61 | rgb = self.rgb_linear(h)
62 | outputs = torch.cat([rgb, alpha], -1)
63 | else:
64 | outputs = self.output_linear(h)
65 |
66 | return outputs
67 |
68 | def load_weights_from_keras(self, weights):
69 | assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
70 |
71 | # Load pts_linears
72 | for i in range(self.D):
73 | idx_pts_linears = 2 * i
74 | self.pts_linears[i].weight.data = torch.from_numpy(np.transpose(weights[idx_pts_linears]))
75 | self.pts_linears[i].bias.data = torch.from_numpy(np.transpose(weights[idx_pts_linears + 1]))
76 |
77 | # Load feature_linear
78 | idx_feature_linear = 2 * self.D
79 | self.feature_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_feature_linear]))
80 | self.feature_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_feature_linear + 1]))
81 |
82 | # Load views_linears
83 | idx_views_linears = 2 * self.D + 2
84 | self.views_linears[0].weight.data = torch.from_numpy(np.transpose(weights[idx_views_linears]))
85 | self.views_linears[0].bias.data = torch.from_numpy(np.transpose(weights[idx_views_linears + 1]))
86 |
87 | # Load rgb_linear
88 | idx_rbg_linear = 2 * self.D + 4
89 | self.rgb_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear]))
90 | self.rgb_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_rbg_linear + 1]))
91 |
92 | # Load alpha_linear
93 | idx_alpha_linear = 2 * self.D + 6
94 | self.alpha_linear.weight.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear]))
95 | self.alpha_linear.bias.data = torch.from_numpy(np.transpose(weights[idx_alpha_linear + 1]))
96 |
97 |
98 | class Embedder:
99 | def __init__(self, **kwargs):
100 | self.kwargs = kwargs
101 | self.create_embedding_fn()
102 |
103 | def create_embedding_fn(self):
104 | embed_fns = []
105 | d = self.kwargs['input_dims']
106 | out_dim = 0
107 | if self.kwargs['include_input']:
108 | embed_fns.append(lambda x: x)
109 | out_dim += d
110 |
111 | max_freq = self.kwargs['max_freq_log2']
112 | N_freqs = self.kwargs['num_freqs']
113 |
114 | if self.kwargs['log_sampling']:
115 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
116 | else:
117 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
118 |
119 | for freq in freq_bands:
120 | for p_fn in self.kwargs['periodic_fns']:
121 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
122 | out_dim += d
123 |
124 | self.embed_fns = embed_fns
125 | self.out_dim = out_dim
126 |
127 | def embed(self, inputs):
128 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
129 |
130 |
131 | def get_embedder(multires, i=0):
132 | if i == -1:
133 | return nn.Identity(), 3
134 |
135 | embed_kwargs = {
136 | 'include_input': True,
137 | 'input_dims': 3,
138 | 'max_freq_log2': multires - 1,
139 | 'num_freqs': multires,
140 | 'log_sampling': True,
141 | 'periodic_fns': [torch.sin, torch.cos],
142 | }
143 |
144 | embedder_obj = Embedder(**embed_kwargs)
145 | embed = lambda x, eo=embedder_obj: eo.embed(x)
146 | return embed, embedder_obj.out_dim
147 |
148 |
149 |
150 | def query_alpha_color_nerf(nerf: PriorNeRF, _xyz_sampled, viewdirs=None, z_vals=None):
151 | # xyz_sampled: [N, 3]
152 | # viewdirs: [N_rays, 3]
153 | xyz_sampled = _xyz_sampled * 2
154 | batch_size, n_samples, _ = xyz_sampled.shape
155 | inputs_flat = torch.reshape(xyz_sampled, [-1, 3])
156 | embedded = nerf.embed_fn(inputs_flat)
157 |
158 | if viewdirs is None:
159 | viewdirs = torch.zeros(batch_size,3,dtype=torch.float32).cuda()
160 | input_dirs = viewdirs[:, None].expand(xyz_sampled.shape)
161 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
162 | embedded_dirs = nerf.embeddirs_fn(input_dirs_flat)
163 | embedded = torch.cat([embedded, embedded_dirs], -1)
164 |
165 | outputs_flat = nerf(embedded)
166 |
167 | raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists)
168 |
169 | dists = z_vals[..., 1:] - z_vals[..., :-1]
170 | dists = torch.cat([dists, dists[0][0].expand(dists[..., :1].shape)], -1)
171 |
172 | # dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
173 | outputs_flat = outputs_flat.reshape(batch_size, n_samples, -1)
174 | alpha = raw2alpha(outputs_flat[..., 3], dists) # [N_rays, N_samples]
175 |
176 | return alpha, None
177 |
178 | def query_density_nerf(nerf, _xyz_sampled):
179 | # xyz_sampled: [N, 3]
180 | # viewdirs: [N_rays, 3]
181 | xyz_sampled = _xyz_sampled * 2
182 | batch_size, n_samples, _ = xyz_sampled.shape
183 | inputs_flat = torch.reshape(xyz_sampled, [-1, 3])
184 | embedded = nerf.embed_fn(inputs_flat)
185 |
186 | viewdirs = torch.zeros(batch_size, 3, dtype=torch.float32).cuda()
187 | input_dirs = viewdirs[:, None].expand(xyz_sampled.shape)
188 | input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
189 | embedded_dirs = nerf.embeddirs_fn(input_dirs_flat)
190 | embedded = torch.cat([embedded, embedded_dirs], -1)
191 |
192 | outputs_flat = nerf(embedded)
193 |
194 | return outputs_flat[..., 3]
195 |
196 |
197 | def load_nerf(ckpt_path):
198 | embed_fn, input_ch = get_embedder(10, 0)
199 | embeddirs_fn, input_ch_views = get_embedder(4, 0)
200 | model_fine = PriorNeRF(D=8, W=256, input_ch=input_ch, output_ch=5, skips=[4],
201 | input_ch_views=input_ch_views, use_viewdirs=True).cuda()
202 | print('Reloading from', ckpt_path)
203 | ckpt = torch.load(ckpt_path)
204 | model_fine.load_state_dict(ckpt['network_fine_state_dict'])
205 | model_fine.embed_fn = embed_fn
206 | model_fine.embeddirs_fn = embeddirs_fn
207 | return model_fine
208 |
209 |
210 | def validate_mesh(nerf):
211 | object_bbox_min = np.array([-1.01, -1.01, -1.01]) # only used in extract
212 | object_bbox_max = np.array([1.01, 1.01, 1.01])
213 | bound_min = torch.tensor(object_bbox_min, dtype=torch.float32)
214 | bound_max = torch.tensor(object_bbox_max, dtype=torch.float32)
215 |
216 | resolution = 256
217 | N = 64
218 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
219 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
220 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
221 |
222 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
223 | with torch.no_grad():
224 | for xi, xs in enumerate(X):
225 | for yi, ys in enumerate(Y):
226 | for zi, zs in enumerate(Z):
227 | xx, yy, zz = torch.meshgrid(xs, ys, zs)
228 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
229 | val = query_density_nerf(nerf,pts.reshape(1,-1,3)).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
230 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
231 |
232 | vertices, triangles = mcubes.marching_cubes(u, 30)
233 | b_max_np = bound_max.detach().cpu().numpy()
234 | b_min_np = bound_min.detach().cpu().numpy()
235 |
236 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
237 |
238 | os.makedirs(os.path.join('debug'), exist_ok=True)
239 | mesh = trimesh.Trimesh(vertices, triangles)
240 | mesh.export(os.path.join('debug/nerf.ply'))
--------------------------------------------------------------------------------
/models/ray_utils.py:
--------------------------------------------------------------------------------
1 | import torch, re
2 | import numpy as np
3 | from torch import searchsorted
4 | from kornia import create_meshgrid
5 |
6 |
7 | # from utils import index_point_feature
8 |
9 | def depth2dist(z_vals, cos_angle):
10 | # z_vals: [N_ray N_sample]
11 | device = z_vals.device
12 | dists = z_vals[..., 1:] - z_vals[..., :-1]
13 | dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
14 | dists = dists * cos_angle.unsqueeze(-1)
15 | return dists
16 |
17 |
18 | def ndc2dist(ndc_pts, cos_angle):
19 | dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
20 | dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples]
21 | return dists
22 |
23 |
24 | def get_ray_directions(H, W, focal, center=None):
25 | """
26 | Get ray directions for all pixels in camera coordinate.
27 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
28 | ray-tracing-generating-camera-rays/standard-coordinate-systems
29 | Inputs:
30 | H, W, focal: image height, width and focal length
31 | Outputs:
32 | directions: (H, W, 3), the direction of the rays in camera coordinate
33 | """
34 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
35 |
36 | i, j = grid.unbind(-1)
37 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
38 | # see https://github.com/bmild/nerf/issues/24
39 | cent = center if center is not None else [W / 2, H / 2]
40 | directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
41 |
42 | return directions
43 |
44 |
45 | def get_ray_directions_blender(H, W, focal, center=None):
46 | """
47 | Get ray directions for all pixels in camera coordinate.
48 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
49 | ray-tracing-generating-camera-rays/standard-coordinate-systems
50 | Inputs:
51 | H, W, focal: image height, width and focal length
52 | Outputs:
53 | directions: (H, W, 3), the direction of the rays in camera coordinate
54 | """
55 | grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
56 | i, j = grid.unbind(-1)
57 | # the direction here is without +0.5 pixel centering as calibration is not so accurate
58 | # see https://github.com/bmild/nerf/issues/24
59 | cent = center if center is not None else [W / 2, H / 2]
60 | directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
61 | -1) # (H, W, 3)
62 |
63 | return directions
64 |
65 |
66 | def get_rays(directions, c2w):
67 | """
68 | Get ray origin and normalized directions in world coordinate for all pixels in one image.
69 | Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
70 | ray-tracing-generating-camera-rays/standard-coordinate-systems
71 | Inputs:
72 | directions: (H, W, 3) precomputed ray directions in camera coordinate
73 | c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
74 | Outputs:
75 | rays_o: (H*W, 3), the origin of the rays in world coordinate
76 | rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
77 | """
78 | # Rotate ray directions from camera coordinate to the world coordinate
79 | rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
80 | # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
81 | # The origin of all rays is the camera origin in world coordinate
82 | rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
83 |
84 | rays_d = rays_d.view(-1, 3)
85 | rays_o = rays_o.view(-1, 3)
86 |
87 | return rays_o, rays_d
88 |
89 |
90 | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
91 | # Shift ray origins to near plane
92 | t = -(near + rays_o[..., 2]) / rays_d[..., 2]
93 | rays_o = rays_o + t[..., None] * rays_d
94 |
95 | # Projection
96 | o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
97 | o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
98 | o2 = 1. + 2. * near / rays_o[..., 2]
99 |
100 | d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
101 | d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
102 | d2 = -2. * near / rays_o[..., 2]
103 |
104 | rays_o = torch.stack([o0, o1, o2], -1)
105 | rays_d = torch.stack([d0, d1, d2], -1)
106 |
107 | return rays_o, rays_d
108 |
109 | def ndc_rays(H, W, focal, near, rays_o, rays_d):
110 | # Shift ray origins to near plane
111 | t = (near - rays_o[..., 2]) / rays_d[..., 2]
112 | rays_o = rays_o + t[..., None] * rays_d
113 |
114 | # Projection
115 | o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
116 | o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
117 | o2 = 1. - 2. * near / rays_o[..., 2]
118 |
119 | d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
120 | d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
121 | d2 = 2. * near / rays_o[..., 2]
122 |
123 | rays_o = torch.stack([o0, o1, o2], -1)
124 | rays_d = torch.stack([d0, d1, d2], -1)
125 |
126 | return rays_o, rays_d
127 |
128 | # Hierarchical sampling (section 5.2)
129 | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
130 | device = weights.device
131 | # Get pdf
132 | weights = weights + 1e-5 # prevent nans
133 | pdf = weights / torch.sum(weights, -1, keepdim=True)
134 | cdf = torch.cumsum(pdf, -1)
135 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
136 |
137 | # Take uniform samples
138 | if det:
139 | u = torch.linspace(0., 1., steps=N_samples, device=device)
140 | u = u.expand(list(cdf.shape[:-1]) + [N_samples])
141 | else:
142 | u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)
143 |
144 | # Pytest, overwrite u with numpy's fixed random numbers
145 | if pytest:
146 | np.random.seed(0)
147 | new_shape = list(cdf.shape[:-1]) + [N_samples]
148 | if det:
149 | u = np.linspace(0., 1., N_samples)
150 | u = np.broadcast_to(u, new_shape)
151 | else:
152 | u = np.random.rand(*new_shape)
153 | u = torch.Tensor(u)
154 |
155 | # Invert CDF
156 | u = u.contiguous()
157 | inds = searchsorted(cdf.detach(), u, right=True)
158 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
159 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
160 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
161 |
162 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
163 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
164 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
165 |
166 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
167 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
168 | t = (u - cdf_g[..., 0]) / denom
169 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
170 |
171 | return samples
172 |
173 |
174 | def dda(rays_o, rays_d, bbox_3D):
175 | inv_ray_d = 1.0 / (rays_d + 1e-6)
176 | t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3
177 | t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
178 | t = torch.stack((t_min, t_max)) # 2 N_rays 3
179 | t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
180 | t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
181 | return t_min, t_max
182 |
183 |
184 | def ray_marcher(rays,
185 | N_samples=64,
186 | lindisp=False,
187 | perturb=0,
188 | bbox_3D=None):
189 | """
190 | sample points along the rays
191 | Inputs:
192 | rays: ()
193 |
194 | Returns:
195 |
196 | """
197 |
198 | # Decompose the inputs
199 | N_rays = rays.shape[0]
200 | rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
201 | near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
202 |
203 | if bbox_3D is not None:
204 | # cal aabb boundles
205 | near, far = dda(rays_o, rays_d, bbox_3D)
206 |
207 | # Sample depth points
208 | z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
209 | if not lindisp: # use linear sampling in depth space
210 | z_vals = near * (1 - z_steps) + far * z_steps
211 | else: # use linear sampling in disparity space
212 | z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
213 |
214 | z_vals = z_vals.expand(N_rays, N_samples)
215 |
216 | if perturb > 0: # perturb sampling depths (z_vals)
217 | z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points
218 | # get intervals between samples
219 | upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
220 | lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
221 |
222 | perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
223 | z_vals = lower + (upper - lower) * perturb_rand
224 |
225 | xyz_coarse_sampled = rays_o.unsqueeze(1) + \
226 | rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
227 |
228 | return xyz_coarse_sampled, rays_o, rays_d, z_vals
229 |
230 |
231 | def read_pfm(filename):
232 | file = open(filename, 'rb')
233 | color = None
234 | width = None
235 | height = None
236 | scale = None
237 | endian = None
238 |
239 | header = file.readline().decode('utf-8').rstrip()
240 | if header == 'PF':
241 | color = True
242 | elif header == 'Pf':
243 | color = False
244 | else:
245 | raise Exception('Not a PFM file.')
246 |
247 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
248 | if dim_match:
249 | width, height = map(int, dim_match.groups())
250 | else:
251 | raise Exception('Malformed PFM header.')
252 |
253 | scale = float(file.readline().rstrip())
254 | if scale < 0: # little-endian
255 | endian = '<'
256 | scale = -scale
257 | else:
258 | endian = '>' # big-endian
259 |
260 | data = np.fromfile(file, endian + 'f')
261 | shape = (height, width, 3) if color else (height, width)
262 |
263 | data = np.reshape(data, shape)
264 | data = np.flipud(data)
265 | file.close()
266 | return data, scale
267 |
268 |
269 | def ndc_bbox(all_rays):
270 | near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
271 | near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
272 | far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
273 | far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
274 | print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
275 | return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))
--------------------------------------------------------------------------------
/cull_mesh.py:
--------------------------------------------------------------------------------
1 | # import argparse
2 |
3 | import numpy as np
4 | import torch
5 | import trimesh
6 | from pyhocon import ConfigFactory
7 | from tqdm import tqdm
8 | import os
9 | from pytorch3d.structures import Meshes
10 | from pytorch3d.renderer.mesh import rasterizer
11 | from pytorch3d.renderer.cameras import PerspectiveCameras
12 |
13 |
14 | H = 680
15 | W = 1200
16 |
17 |
18 |
19 | translation_dict = {
20 | # BlendSwap
21 | "breakfast_room": [0.0, -1.42, 0.0],
22 | "kitchen": [0.3, -3.42, -0.12],
23 | "green_room": [-1, -0.38, 0.3],
24 | "complete_kitchen": [1.5, -2.25, 0.0],
25 | "grey_white_room": [-0.12, -1.94, -0.69],
26 | "morning_apartment": [-0.22, -0.43, 0.0],
27 | "staircase": [0.0, -2.42, 0.0],
28 | "whiteroom": [0.3, -3.42, -0.12],
29 | # Replica
30 | 'office0': [-0.1944, 0.6488, -0.3271],
31 | 'office1': [-0.585, -0.4703, -0.3507],
32 | 'office2': [0.1909, -1.2262, -0.1574],
33 | 'office3': [0.7893, 1.3371, -0.3305],
34 | 'office4': [-2.0684, -0.9268, -0.1993],
35 | 'room0': [-3.00, -1.1631, 0.1235],
36 | 'room1': [2.0795, 0.1747, 0.0314],
37 | 'room2': [-2.5681, 0.7727, 1.1110],
38 | }
39 |
40 | scale_dict = {
41 | # BlendSwap
42 | "breakfast_room": 0.4,
43 | "kitchen": 0.25,
44 | "green_room": 0.25,
45 | "complete_kitchen": 0.20,
46 | "grey_white_room": 0.25,
47 | "morning_apartment": 0.5,
48 | "staircase": 0.25,
49 | "whiteroom": 0.25,
50 | # Replica
51 | 'office0': 0.4,
52 | 'office1': 0.41,
53 | 'office2': 0.24,
54 | 'office3': 0.21,
55 | 'office4': 0.30,
56 | 'room0': 0.25,
57 | 'room1': 0.30,
58 | 'room2': 0.29,
59 | }
60 |
61 | scene_bounds_dict = {
62 | # BlendSwap
63 | "whiteroom": np.array([[-2.46, -0.1, 0.36],
64 | [3.06, 3.3, 8.2]]),
65 | "kitchen": np.array([[-3.12, -0.1, -3.18],
66 | [3.75, 3.3, 5.45]]),
67 | "breakfast_room": np.array([[-2.23, -0.5, -1.7],
68 | [1.85, 2.77, 3.0]]),
69 | "staircase":np.array([[-4.14, -0.1, -5.25],
70 | [2.52, 3.43, 1.08]]),
71 | "complete_kitchen":np.array([[-5.55, 0.0, -6.45],
72 | [3.65, 3.1, 3.5]]),
73 | "green_room":np.array([[-2.5, -0.1, 0.4],
74 | [5.4, 2.8, 5.0]]),
75 | "grey_white_room":np.array([[-0.55, -0.1, -3.75],
76 | [5.3, 3.0, 0.65]]),
77 | "morning_apartment":np.array([[-1.38, -0.1, -2.2],
78 | [2.1, 2.1, 1.75]]),
79 | "thin_geometry":np.array([[-2.15, 0.0, 0.0],
80 | [0.77, 0.75, 3.53]]),
81 | # Replica
82 | 'office0': np.array([[-2.0056, -3.1537, -1.1689],
83 | [2.3944, 1.8561, 1.8230]]),
84 | 'office1': np.array([[-1.8204, -1.5824, -1.0477],
85 | [2.9904, 2.5231, 1.7491]]),
86 | 'office2': np.array([[-3.4272, -2.8455, -1.2265],
87 | [3.0453, 5.2980, 1.5414]]),
88 | 'office3': np.array([[-5.1116, -5.9395, -1.2207],
89 | [3.5329, 3.2652, 1.8816]]),
90 | 'office4': np.array([[-1.2047, -2.3258, -1.2093],
91 | [5.3415, 4.1794, 1.6078]]),
92 | 'room0': np.array([[-0.8794, -1.1860, -1.5274],
93 | [6.8852, 3.5123, 1.2804]]),
94 | 'room1': np.array([[-5.4027, -3.0385, -1.4080],
95 | [1.2436, 2.6891, 1.3452]]),
96 | 'room2': np.array([[-0.8171, -3.2454, -2.9081],
97 | [5.9533, 1.7000, 0.6861]]),
98 | }
99 |
100 |
101 |
102 | def clean_invisible_vertices(mesh, train_dataset):
103 |
104 | poses = train_dataset.poses
105 | n_imgs = train_dataset.__len__()
106 | pc = mesh.vertices
107 | faces = mesh.faces
108 | xyz = torch.Tensor(pc)
109 | xyz = xyz.reshape(1, -1, 3)
110 | xyz_h = torch.cat([xyz, torch.ones_like(xyz[..., :1])], dim=-1)
111 |
112 | # delete mesh vertices that are not inside any camera's viewing frustum
113 | whole_mask = np.ones(pc.shape[0]).astype(np.bool)
114 | for i in tqdm(range(0, n_imgs, 1), desc='clean_vertices'):
115 | intrinsics = train_dataset.intrinsics
116 | pose = poses[i]
117 | # adjusted for blender
118 | camera_pos = torch.einsum('abj,ij->abi', xyz_h, pose.inverse())
119 | projections = torch.einsum('ij, abj->abi', intrinsics, camera_pos[..., :3]) # [W, H, 3]
120 | pixel_locations = projections[..., :2] / torch.clamp(projections[..., 2:3], min=1e-8) - 0.5
121 | pixel_locations = pixel_locations[:, :, [1, 0]]
122 | pixel_locations = torch.clamp(pixel_locations, min=-1e6, max=1e6)
123 | uv = pixel_locations.reshape(-1, 2)
124 | z = pixel_locations[..., -1:] + 1e-5
125 | z = z.reshape(-1)
126 | edge = 0
127 | mask = (0 <= z) & (uv[:, 0] < H - edge) & (uv[:, 0] > edge) & (uv[:, 1] < W-edge) & (uv[:, 1] > edge)
128 | whole_mask &= ~mask.cpu().numpy()
129 |
130 | pc = mesh.vertices
131 | faces = mesh.faces
132 | face_mask = whole_mask[mesh.faces].all(axis=1)
133 | mesh.update_faces(~face_mask)
134 |
135 | return mesh
136 |
137 | # correction from pytorch3d (v0.5.0)
138 | def corrected_cameras_from_opencv_projection( R, tvec, camera_matrix, image_size):
139 | focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
140 | principal_point = camera_matrix[:, :2, 2]
141 |
142 | # Retype the image_size correctly and flip to width, height.
143 | image_size_wh = image_size.to(R).flip(dims=(1,))
144 |
145 | # Get the PyTorch3D focal length and principal point.
146 | s = (image_size_wh).min(dim=1).values
147 |
148 | focal_pytorch3d = focal_length / (0.5 * s)
149 | p0_pytorch3d = -(principal_point - image_size_wh / 2) * 2 / s
150 |
151 | # For R, T we flip x, y axes (opencv screen space has an opposite
152 | # orientation of screen axes).
153 | # We also transpose R (opencv multiplies points from the opposite=left side).
154 | R_pytorch3d = R.clone().permute(0, 2, 1)
155 | # R_pytorch3d = R.clone()
156 | T_pytorch3d = tvec.clone()
157 | R_pytorch3d[:, :, :2] *= -1
158 | T_pytorch3d[:, :2] *= -1
159 | # T_pytorch3d[:, 0] *= -1
160 |
161 | return PerspectiveCameras(
162 | R=R_pytorch3d,
163 | T=T_pytorch3d,
164 | focal_length=focal_pytorch3d,
165 | principal_point=p0_pytorch3d,
166 | )
167 |
168 |
169 | def clean_triangle_faces(mesh, train_dataset):
170 | # returns a mask of triangles that reprojects on at least nb_visible images
171 | num_view = train_dataset.__len__()
172 | K = train_dataset.intrinsics[:3, :3].unsqueeze(0).repeat([num_view, 1, 1])
173 | R = train_dataset.poses[:, :3, :3].transpose(2, 1)
174 | t = - train_dataset.poses[:, :3, :3].transpose(2, 1) @ train_dataset.poses[:, :3, 3:]
175 | sizes = torch.Tensor([[train_dataset.w, train_dataset.h]]).repeat([num_view, 1])
176 | cams = [K, R, t, sizes]
177 | num_faces = len(mesh.faces)
178 | nb_visible = 1
179 | count = torch.zeros(num_faces, device="cuda")
180 | K, R, t, sizes = cams[:4]
181 |
182 | n = len(K)
183 | with torch.no_grad():
184 | for i in tqdm(range(n), desc="clean_faces"):
185 | intr = torch.zeros(1, 4, 4).cuda() #
186 | intr[:, :3, :3] = K[i:i + 1]
187 | intr[:, 3, 3] = 1
188 | vertices = torch.from_numpy(mesh.vertices).cuda().float() #
189 | faces = torch.from_numpy(mesh.faces).cuda().long() #
190 | meshes = Meshes(verts=[vertices],
191 | faces=[faces])
192 |
193 | cam = corrected_cameras_from_opencv_projection(camera_matrix=intr, R=R[i:i + 1].cuda(), #
194 | tvec=t[i:i + 1].squeeze(2).cuda(), #
195 | image_size=sizes[i:i + 1, [1, 0]].cuda()) #
196 | cam = cam.cuda() #
197 | raster_settings = rasterizer.RasterizationSettings(image_size=tuple(sizes[i, [1, 0]].long().tolist()),
198 | faces_per_pixel=1)
199 | meshRasterizer = rasterizer.MeshRasterizer(cam, raster_settings)
200 |
201 | with torch.no_grad():
202 | ret = meshRasterizer(meshes)
203 | pix_to_face = ret.pix_to_face
204 | # pix_to_face, zbuf, bar, pixd =
205 |
206 | visible_faces = pix_to_face.view(-1).unique()
207 | count[visible_faces[visible_faces > -1]] += 1
208 |
209 | pred_visible_mask = (count >= nb_visible).cpu()
210 |
211 | mesh.update_faces(pred_visible_mask)
212 | return mesh
213 |
214 | def cull_by_bounds(points, scene_bounds):
215 | eps = 0.02
216 | inside_mask = np.all(points >= (scene_bounds[0] - eps), axis=1) & np.all(points <= (scene_bounds[1] + eps), axis=1)
217 | return inside_mask
218 |
219 |
220 |
221 | def crop_mesh(scene, mesh, subdivide=True, max_edge=0.015):
222 | vertices = mesh.vertices
223 | triangles = mesh.faces
224 |
225 | if subdivide:
226 | vertices, triangles = trimesh.remesh.subdivide_to_size(vertices, triangles, max_edge=max_edge, max_iter=10)
227 |
228 | # Cull with the bounding box first
229 | inside_mask = None
230 | scene_bounds = scene_bounds_dict[scene]
231 | if scene_bounds is not None:
232 | inside_mask = cull_by_bounds(vertices, scene_bounds)
233 |
234 | inside_mask = inside_mask[triangles[:, 0]] | inside_mask[triangles[:, 1]] | inside_mask[triangles[:, 2]]
235 | triangles = triangles[inside_mask, :]
236 | print("Processed culling by bound")
237 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
238 | # we don't need subdivided mesh to render depth
239 | mesh = trimesh.Trimesh(vertices, triangles, process=False)
240 | mesh.remove_unreferenced_vertices()
241 | return mesh
242 |
243 |
244 | def transform_mesh(scene, mesh):
245 | pc = mesh.vertices
246 | faces = mesh.faces
247 | pc = (pc / scale_dict[scene]) - np.array([translation_dict[scene]])
248 | mesh = trimesh.Trimesh(pc, faces, process=False)
249 | return mesh
250 |
251 |
252 | def detransform_mesh(scene, mesh):
253 | pc = mesh.vertices
254 | faces = mesh.faces
255 | pc = (pc + np.array([translation_dict[scene]])) * scale_dict[scene]
256 | mesh = trimesh.Trimesh(pc, faces, process=False)
257 | return mesh
258 |
259 | if __name__ == '__main__':
260 | from models.blender_swap import BlendSwapDataset
261 |
262 | out_dir_pat = 'out_meshes/%s'
263 | scene = 'room0'
264 |
265 | f = open('confs/room0.conf')
266 | conf_text = f.read()
267 | conf_text = conf_text.replace('CASE_NAME', '')
268 | f.close()
269 | conf = ConfigFactory.parse_string(conf_text)
270 | conf['dataset.data_dir'] = conf['dataset.data_dir'].replace('CASE_NAME', '')
271 | train_dataset = BlendSwapDataset(conf['dataset'])
272 |
273 | exp_name = 'room0'
274 | mesh_name = 'neus'
275 | dir_pth = out_dir_pat % (exp_name)
276 | mesh_pth = os.path.join(dir_pth, mesh_name+'.ply')
277 | print(dir_pth, mesh_pth)
278 | mesh = trimesh.load_mesh(mesh_pth)
279 |
280 | mesh.vertices = mesh.vertices[:, [0,2,1]] * np.array([[1, -1, 1]])
281 | mesh = transform_mesh(scene, mesh)
282 | mesh = crop_mesh(scene, mesh)
283 | # mesh.export(os.path.join(dir_pth, '%s_cropped.ply' % mesh_name))
284 | mesh = detransform_mesh(scene, mesh)
285 | mesh.vertices = (mesh.vertices * np.array([[1, -1, 1]]))[:, [0, 2, 1]]
286 | mesh = clean_invisible_vertices(mesh, train_dataset)
287 | mesh = clean_triangle_faces(mesh, train_dataset)
288 | # mesh.export(os.path.join(dir_pth, '%s_cropped_culled.ply' % mesh_name))
289 | mesh.vertices = mesh.vertices[:, [0, 2, 1]] * np.array([[1, -1, 1]])
290 | mesh = transform_mesh(scene, mesh)
291 | mesh.export(os.path.join(dir_pth, '%s_cropped_culled_transformed.ply' % mesh_name))
292 |
--------------------------------------------------------------------------------
/models/tensorf2neus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.tensoRF import TensorVMSplit
3 | import numpy as np
4 | import skimage
5 | import plyfile
6 | import os
7 | # from test_ours import *
8 |
9 |
10 | def cal_n_samples(reso, step_ratio=0.5):
11 | return int(np.linalg.norm(reso)/step_ratio)
12 |
13 |
14 | def query_alpha_color(tensorf: TensorVMSplit, xyz_sampled, viewdirs=None, prior_mode='cat', requires_grad=False):
15 | # xyz_sampled: [N, 3]
16 | # viewdirs: [N_rays, 3]
17 |
18 | dim = 3
19 | if len(xyz_sampled.shape) == 3:
20 | N_rays, N_samples, _ = xyz_sampled.shape
21 | xyz_sampled = xyz_sampled.reshape(-1, 3)
22 | # if len(xyz_sampled.shape) == 2:
23 | # dim = 2
24 | # xyz_sampled = xyz_sampled.unsqueeze(0)
25 |
26 | # neighboring 8 vertices
27 | # after: xyz_samples: [N, 9, 3] or [N, 1, 3]
28 | if prior_mode.startswith('local'):
29 | xyz_sampled = query_grid_vertices(tensorf, xyz_sampled)
30 | else:
31 | xyz_sampled = xyz_sampled.unsqueeze(1)
32 |
33 | mask_outbbox = ((tensorf.aabb[0] > xyz_sampled) | (xyz_sampled > tensorf.aabb[1])).any(dim=-1)
34 | ray_valid = ~mask_outbbox
35 |
36 | if tensorf.alphaMask is not None:
37 | alphas = tensorf.alphaMask.sample_alpha(xyz_sampled[ray_valid])
38 | alpha_mask = alphas > 0
39 | ray_invalid = ~ray_valid
40 | ray_invalid[ray_valid] |= (~alpha_mask)
41 | ray_valid = ~ray_invalid
42 |
43 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
44 |
45 | if ray_valid.any():
46 | xyz_sampled = tensorf.normalize_coord(xyz_sampled)
47 | sigma_feature = tensorf.compute_densityfeature(xyz_sampled[ray_valid], requires_grad=requires_grad)
48 |
49 | validsigma = tensorf.feature2density(sigma_feature)
50 | sigma[ray_valid] = validsigma
51 |
52 | # alpha, weight, bg_weight = raw2alpha(sigma, tensorf.stepSize * tensorf.distance_scale)
53 | # app_mask = weight > tensorf.rayMarch_weight_thres
54 | alpha = 1. - torch.exp(-sigma * tensorf.stepSize * tensorf.distance_scale)
55 | # alpha = 1-alpha
56 |
57 | if prior_mode == 'local_mean':
58 | alpha = torch.mean(alpha, dim=1, keepdim=True)
59 |
60 | if viewdirs is None:
61 | # alpha: [N, 1] or [N, 9]
62 | return alpha, None
63 |
64 |
65 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
66 | viewdirs = viewdirs.reshape(-1, 1, 3).expand(xyz_sampled.shape)
67 | app_mask = ray_valid
68 | if app_mask.any():
69 | app_features = tensorf.compute_appfeature(xyz_sampled[app_mask])
70 | valid_rgbs = tensorf.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
71 | rgb[app_mask] = valid_rgbs
72 |
73 | # alpha: [N, 1] or [N, 9]
74 | # rgb: [N, 3] or [N, 9*3]
75 | if prior_mode == 'local_mean':
76 | rgb = torch.mean(rgb, dim=1, keepdim=False)
77 | elif prior_mode == 'local_cat':
78 | rgb = rgb.reshape(rgb.shape[0], -1)
79 | else:
80 | rgb = rgb.reshape(-1, 3)
81 | return alpha, rgb
82 |
83 |
84 | def load_tensorf(ckpt, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
85 | ckpt = torch.load(ckpt, map_location=device)
86 | kwargs = ckpt['kwargs']
87 | kwargs.update({'device': device})
88 | tensorf = TensorVMSplit(**kwargs)
89 | tensorf.load(ckpt)
90 |
91 | # 把密集体素网格的坐标点预先加载好
92 | gridSize = tensorf.gridSize
93 | samples = torch.stack(torch.meshgrid(
94 | torch.linspace(0, 1, gridSize[0]),
95 | torch.linspace(0, 1, gridSize[1]),
96 | torch.linspace(0, 1, gridSize[2]),
97 | ), -1).to(tensorf.device)
98 | dense_xyz = tensorf.aabb[0] * (1 - samples) + tensorf.aabb[1] * samples
99 | tensorf.dense_xyz = dense_xyz # [X, Y, Z, 3]
100 |
101 | return tensorf
102 |
103 |
104 | def construct_alpha_color_grid(tensorf: TensorVMSplit):
105 | alpha, dense_xyz = tensorf.getDenseAlpha()
106 | tensorf.alpha_grid = alpha
107 |
108 | rgb = alpha = torch.zeros(dense_xyz.shape[:3]+[3])
109 |
110 |
111 | def query_grid_vertices(tensorf: TensorVMSplit, pts):
112 | # pts: [N, 3]
113 | # return: [N, 9, 3]
114 | bbox = tensorf.aabb
115 | pts_normalized = (pts-bbox[0]) / (bbox[1] - bbox[0]) # \in [0, 1]
116 | pts_gridded = pts_normalized * (tensorf.gridSize-1) # \in [0, 256]
117 | pts_lower = torch.floor(pts_gridded) # [N, 3]
118 | verts = pts_lower.unsqueeze(1).repeat([1,8,1])
119 | # x-, y-, z-
120 |
121 | # x+, y-, z-
122 | verts[:, 1, 0] += 1
123 | # x+, y+, z-
124 | verts[:, 2, :2] += 1
125 | # x-, y+, z-
126 | verts[:, 3, 1] += 1
127 | # x-, y-, z+
128 | verts[:, 4, 2] += 1
129 | # x+, y-, z+
130 | verts[:, 5, 0] += 1
131 | verts[:, 5, 2] += 1
132 | # x+, y+, z+
133 | verts[:, 6, :] += 1
134 | # x-, y+, z+
135 | verts[:, 7, 1:] += 1
136 |
137 | mask_outbbox = ((pts_gridded <= 0) | (pts_gridded >= (tensorf.gridSize-1))).any(dim=-1) # [N,]
138 | mask_outbbox = mask_outbbox.unsqueeze(-1).repeat([1, 8]) # [N, 9]
139 | mask_inbbox = ~mask_outbbox
140 |
141 | # [N,8,3]
142 | grid_vertices = pts.unsqueeze(1).repeat([1,8,1])
143 | verts = verts.long()[mask_inbbox]
144 | grid_vertices[mask_inbbox] = tensorf.dense_xyz[verts[:,0], verts[:,1], verts[:,2]]
145 | # grid_vertices = tensorf.dense_xyz[verts[:,:,0], verts[:,:,1], verts[:,:,2]]
146 | all_vertices = torch.cat([pts.unsqueeze(1), grid_vertices], 1)
147 |
148 |
149 | return all_vertices
150 |
151 |
152 | def query_occupancy_confidence(model, pts):
153 | # model: rendering.network
154 | # pts: [N, 3]
155 | N, _ = pts.shape
156 | grid_points = query_grid_vertices(model.tensorf, pts) # [N, 9, 3]
157 | grid_points = grid_points.reshape(N*9, 3)
158 |
159 | g = []
160 | occs = []
161 | chunk = N
162 | for i in range(0, N*9, chunk):
163 | with torch.enable_grad():
164 | p = grid_points[i:i+N]
165 | p.requires_grad_(True)
166 | queried_occ, _ = query_alpha_color(model.tensorf, p, None, prior_mode='cat', requires_grad=True)
167 | y = queried_occ
168 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
169 | gradients = torch.autograd.grad(
170 | outputs=y,
171 | inputs=p,
172 | grad_outputs=d_output,
173 | create_graph=True,
174 | retain_graph=True,
175 | only_inputs=True, allow_unused=True)[0]
176 | _g = gradients.unsqueeze(1)
177 | grid_points.requires_grad_(False)
178 | g.append(_g.detach())
179 | occs.append(queried_occ.squeeze())
180 | del _g
181 |
182 | g = torch.cat(g, 0).squeeze().reshape(N, 9, 3) # [N, 9, 3]
183 | occs = torch.cat(occs, 0).reshape(N, 9) # only for test
184 | normals_ = g[:, :, :] / (g[:, :, :].norm(2, dim=2).unsqueeze(-1) + 10 ** (-5))
185 | normals_ = normals_.reshape(N, 9, 3)
186 |
187 |
188 | confidence = torch.var(normals_, 1).sum(1)
189 | length = g[:, :, :].norm(2, dim=2)
190 | confidence = torch.var(length, 1)
191 | confidence = torch.var(occs, 1)
192 | ref_pts = normals_[:, 0:1, :].repeat([1,8,1])
193 | grid_pts = normals_[:, 1:, :]
194 | cosine = torch.mul(ref_pts, grid_pts).sum(-1)
195 | confidence = torch.var(cosine, 1)
196 |
197 |
198 | return confidence, normals_[:, 0, :]
199 |
200 |
201 | def query_entropy(occ, entropy_num=5):
202 | # pts: [N_rays, N_samples, 3], occ: [N_rays, N_samples]
203 | # pts_4_ = pts[:, :-4, :]
204 | # pts_3_ = pts[:, 1:-3, :]
205 | expand = (entropy_num-1)//2
206 | occs = []
207 | for i in range(expand):
208 | occ_ = occ[:, i:(-expand-(expand-i))]
209 | occs.append(occ_)
210 | occs.append(occ[:, expand:-expand])
211 | for i in range(expand):
212 | if i != expand -1:
213 | occ_ = occ[:, expand+i+1:(-expand+i+1)]
214 | else:
215 | occ_ = occ[:, expand+i+1:]
216 | occs.append(occ_)
217 | # occ_2_ = occ[:, 0:-4]
218 | # occ_1_ = occ[:, 1:-3]
219 | # occ_ = occ[:, 2:-2]
220 | # occ_1 = occ[:, 3:-1]
221 | # occ_2 = occ[:, 4:]
222 | # occ_5around = torch.stack([occ_2_, occ_1_, occ_, occ_1, occ_2], 2) # [N_rays, N_samples, 5]
223 | occ_around = torch.stack(occs, 2)
224 | occ_around = torch.softmax(occ_around, 2)
225 | log_occ_around = torch.log(occ_around+1e-8)
226 | entropy = (-occ_around * log_occ_around).sum(2) # [N_rays, N_samples]
227 |
228 | return entropy
229 |
230 |
231 | def tensorf_volume_rendering(xyz_sampled, z_vals, camera_world, viewdirs, tensorf):
232 | # tensorf volume rendering
233 | # xyz_sampled: [N_rays, N_samples, 3]
234 | # z_vals: [N_rays, N_samples]
235 | # viewdirs: [N_rays, 3]
236 |
237 | # from test_ours import visual_ray_points
238 | # visual_ray_points(xyz_sampled)
239 |
240 | # xyz_sampled, z_vals, ray_valid = tensorf.sample_ray(camera_world, viewdirs, is_train=False, N_samples=tensorf.nSamples)
241 |
242 | mask_outbbox = ((tensorf.aabb[0] > xyz_sampled) | (xyz_sampled > tensorf.aabb[1])).any(dim=-1)
243 | ray_valid = ~mask_outbbox
244 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
245 |
246 |
247 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
248 |
249 | if tensorf.alphaMask is not None:
250 | alphas = tensorf.alphaMask.sample_alpha(xyz_sampled[ray_valid])
251 | alpha_mask = alphas > 0
252 | ray_invalid = ~ray_valid
253 | ray_invalid[ray_valid] |= (~alpha_mask)
254 | ray_valid = ~ray_invalid
255 |
256 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
257 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
258 |
259 | if ray_valid.any():
260 | xyz_sampled = tensorf.normalize_coord(xyz_sampled)
261 | sigma_feature = tensorf.compute_densityfeature(xyz_sampled[ray_valid])
262 |
263 | validsigma = tensorf.feature2density(sigma_feature)
264 | sigma[ray_valid] = validsigma
265 |
266 | alpha, weight, bg_weight = raw2alpha(sigma, dists * tensorf.distance_scale)
267 |
268 | app_mask = weight > tensorf.rayMarch_weight_thres
269 |
270 | if app_mask.any():
271 | app_features = tensorf.compute_appfeature(xyz_sampled[app_mask])
272 | valid_rgbs = tensorf.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
273 | rgb[app_mask] = valid_rgbs
274 |
275 | acc_map = torch.sum(weight, -1)
276 | rgb_map = torch.sum(weight[..., None] * rgb, -2)
277 |
278 | return rgb_map
279 |
280 |
281 | ### deprecated
282 | def raw2alpha(sigma, dist):
283 | # sigma, dist [N_rays, N_samples]
284 | alpha = 1. - torch.exp(-sigma*dist)
285 |
286 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
287 |
288 | weights = alpha * T[:, :-1] # [N_rays, N_samples]
289 | return alpha, weights, T[:,-1:]
290 |
291 | def convert_sdf_samples_to_ply(
292 | pytorch_3d_sdf_tensor,
293 | ply_filename_out,
294 | bbox,
295 | level=0.5,
296 | offset=None,
297 | scale=None,
298 | ):
299 | """
300 | Convert sdf samples to .ply
301 |
302 | :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
303 | :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
304 | :voxel_size: float, the size of the voxels
305 | :ply_filename_out: string, path of the filename to save to
306 |
307 | This function adapted from: https://github.com/RobotLocomotion/spartan
308 | """
309 |
310 | numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
311 | voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))
312 |
313 | verts, faces, normals, values = skimage.measure.marching_cubes(
314 | numpy_3d_sdf_tensor, level=level, spacing=voxel_size
315 | )
316 | faces = faces[...,::-1] # inverse face orientation
317 |
318 | # transform from voxel coordinates to camera coordinates
319 | # note x and y are flipped in the output of marching_cubes
320 | mesh_points = np.zeros_like(verts)
321 | mesh_points[:, 0] = bbox[0,0] + verts[:, 0]
322 | mesh_points[:, 1] = bbox[0,1] + verts[:, 1]
323 | mesh_points[:, 2] = bbox[0,2] + verts[:, 2]
324 |
325 | # apply additional offset and scale
326 | if scale is not None:
327 | mesh_points = mesh_points / scale
328 | if offset is not None:
329 | mesh_points = mesh_points - offset
330 |
331 | # try writing to the ply file
332 |
333 | num_verts = verts.shape[0]
334 | num_faces = faces.shape[0]
335 |
336 | verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
337 |
338 | for i in range(0, num_verts):
339 | verts_tuple[i] = tuple(mesh_points[i, :])
340 |
341 | faces_building = []
342 | for i in range(0, num_faces):
343 | faces_building.append(((faces[i, :].tolist(),)))
344 | faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
345 |
346 | el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
347 | el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
348 |
349 | ply_data = plyfile.PlyData([el_verts, el_faces])
350 | print("saving mesh to %s" % (ply_filename_out))
351 | ply_data.write(ply_filename_out)
--------------------------------------------------------------------------------
/models/blender_swap.py:
--------------------------------------------------------------------------------
1 | import torch, cv2
2 | from torch.utils.data import Dataset
3 | import json
4 | from tqdm import tqdm
5 | import os
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision import transforms as T
9 |
10 | from .ray_utils import *
11 |
12 |
13 | class BlendSwapDataset(Dataset):
14 | def __init__(self, conf, split='train', N_vis=-1):
15 | self.device = torch.device('cuda')
16 | self.N_vis = N_vis
17 | self.root_dir = conf.get_string('data_dir')
18 | scene = conf.get_string('scene')
19 | self.root_dir = os.path.join(self.root_dir, scene)
20 |
21 | self.split = split
22 | self.is_stack = False
23 | self.downsample = 1.0
24 | self.define_transforms()
25 |
26 | self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
27 | self.read_meta()
28 | # self.define_proj_mat()
29 |
30 | self.white_bg = True
31 |
32 | def read_meta(self):
33 |
34 | with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
35 | self.meta = json.load(f)
36 |
37 | w, h = int(self.meta['w'] / self.downsample), int(self.meta['h'] / self.downsample)
38 | self.img_wh = [w, h]
39 | self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
40 | self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
41 | self.cx, self.cy = self.meta['cx'], self.meta['cy']
42 |
43 | # ray directions for all pixels, same for all images (same H, W, focal)
44 | self.directions = get_ray_directions(h, w, [self.focal_x, self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
45 | self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
46 | self.intrinsics = torch.tensor([[self.focal_x, 0, self.cx], [0, self.focal_y, self.cy], [0, 0, 1]]).float()
47 |
48 | self.image_paths = []
49 | self.poses = []
50 | self.all_rays = []
51 | self.all_rgbs = []
52 | self.all_pretrained_rgbs = []
53 | self.all_depth = []
54 | self.all_rgb_std = []
55 | img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
56 | idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
57 | for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
58 | frame = self.meta['frames'][i]
59 | pose = np.array(frame['transform_matrix'])
60 | pose = pose @ self.blender2opencv
61 | c2w = torch.FloatTensor(pose)
62 | self.poses.append(c2w)
63 |
64 | image_path = os.path.join(self.root_dir, f"{frame['file_path']}")
65 | self.image_paths += [image_path]
66 | img = Image.open(image_path)
67 |
68 | img = self.transform(img) # (4, h, w)
69 |
70 | rgb_std = self.cal_rgb_std(img.permute(1, 2, 0))
71 | rgb_std = np.where(rgb_std > 10 / 255.0, 1.0, 0.0)
72 | self.all_rgb_std.append(torch.Tensor(rgb_std).cpu())
73 |
74 | img = img.view(-1, w * h).permute(1, 0) # (h*w, 4) RGBA
75 | if img.shape[-1] == 4:
76 | img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
77 | self.all_rgbs.append(img)
78 | rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
79 | self.all_rays.append(torch.cat([rays_o, rays_d], 1)) # (h*w, 6)
80 |
81 | self.poses = torch.stack(self.poses) #(N, 4, 4)
82 | self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 6)
83 | self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
84 |
85 | self.h, self.w = h, w
86 | # [w*h, 9]
87 | self.all_neighbor_idx = self.query_neighbor_idx(torch.arange(w*h))
88 |
89 | # for Neus exp_runner
90 | self.n_images = len(self.image_paths)
91 | self.pose_all = self.poses
92 | self.object_bbox_min = np.array([-1.01, -1.01, -1.01]) # only used in extract
93 | self.object_bbox_max = np.array([ 1.01, 1.01, 1.01])
94 | self.scale_mats_np = [np.eye(4)]
95 |
96 | def define_transforms(self):
97 | self.transform = T.ToTensor()
98 |
99 | # def define_proj_mat(self):
100 | # self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
101 |
102 | def __len__(self):
103 | return len(self.all_rgbs)
104 |
105 | def cal_rgb_std(self, img):
106 | """
107 | :param img: [h, w, 3], rgb\in [0, 255]
108 | :return: [h, w] \in [0, 1]
109 | """
110 | img = np.array(img, np.float64)
111 | kernel = (3, 3)
112 | kernel = (9, 9)
113 | E_square = cv2.blur(img ** 2, kernel)
114 | square_E = cv2.blur(img, kernel) ** 2
115 | sharp_img = np.sqrt(np.abs(E_square - square_E))
116 |
117 | gray_img = sharp_img.max(2)
118 | gray_min = np.min(gray_img)
119 | gray_max = np.max(gray_img)
120 | gray_img = np.clip(gray_img, gray_min, gray_max)
121 | gray_img = (gray_img - gray_min) / (gray_max - gray_min + 1e-6)
122 | return gray_img
123 |
124 | def query_neighbor_idx(self, idx):
125 | # query idx's neighbors. idx: [N,]
126 | # 0 1 2
127 | # 3 * 4
128 | # 5 6 7
129 | # return: [N, 9]
130 | pad = 4
131 | row, col = idx // self.w, idx % self.w
132 | r0 = r1 = r2 = torch.maximum(row-pad, torch.zeros_like(row))
133 | r3 = r4 = row
134 | r5 = r6 = r7 = torch.minimum(row+pad, torch.ones_like(row)*(self.h-1))
135 | c0 = c3 = c5 = torch.maximum(col-pad, torch.zeros_like(col))
136 | c1 = c6 = col
137 | c2 = c4 = c7 = torch.minimum(col+pad, torch.ones_like(col)*(self.w-1))
138 |
139 | idx0 = r0 * self.w + c0
140 | idx1 = r1 * self.w + c1
141 | idx2 = r2 * self.w + c2
142 | idx3 = r3 * self.w + c3
143 | idx4 = r4 * self.w + c4
144 | idx5 = r5 * self.w + c5
145 | idx6 = r6 * self.w + c6
146 | idx7 = r7 * self.w + c7
147 | neighbor_idx = torch.stack([idx0, idx1, idx2, idx3, idx, idx4, idx5, idx6, idx7], 1) # [N, 9]
148 | border_mask = (idx % self.w == 0) | (idx % self.w == (self.w - 1)) | (idx // self.w == 0) | (
149 | idx // self.w == (self.h - 1))
150 | neighbor_idx[border_mask] = idx[border_mask].unsqueeze(1).repeat([1, 9])
151 | return neighbor_idx
152 |
153 |
154 |
155 | # def gen_random_rays_at(self, img_idx, batch_size):
156 | # pixels_x = torch.randint(low=0, high=self.w, size=[batch_size])
157 | # pixels_y = torch.randint(low=0, high=self.h, size=[batch_size])
158 | # color = self.all_rgbs[img_idx] # [h, w, 3]
159 | # color = color[(pixels_y, pixels_x)] # [batch_size, 3]
160 | # mask = torch.ones_like(color, dtype=torch.float)
161 | # all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6]
162 | # rand_rays = all_rays[(pixels_y, pixels_x)] # [batch_size, 6]
163 | # return torch.cat([rand_rays, color, mask[:, :1]], dim=-1).to(self.device)
164 |
165 | def gen_random_rays_at(self, img_idx, batch_size, mode):
166 | if mode == 'batch':
167 | pixels_x = torch.randint(low=0, high=self.w, size=[batch_size])
168 | pixels_y = torch.randint(low=0, high=self.h, size=[batch_size])
169 |
170 | color = self.all_rgbs[img_idx] # [h, w, 3]
171 | color = color[(pixels_y, pixels_x)] # [batch_size, 3]
172 | mask = torch.ones_like(color, dtype=torch.float)
173 | all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6]
174 | rand_rays = all_rays[(pixels_y, pixels_x)] # [batch_size, 6]
175 | return torch.cat([rand_rays, color, mask[:, :1]], dim=-1).to(self.device)
176 | elif mode == 'patch':
177 | pixels_x = torch.randint(low=0, high=self.w, size=[batch_size // 9])
178 | pixels_y = torch.randint(low=0, high=self.h, size=[batch_size // 9])
179 | pixel_idx = pixels_y * self.h + pixels_x
180 | rand_idx = self.all_neighbor_idx[pixel_idx].reshape(-1)
181 | pixels_x = rand_idx % self.w
182 | pixels_y = rand_idx // self.w
183 | patch_rgb_std = self.all_rgb_std[img_idx].reshape(-1, 1)[rand_idx]
184 | color = self.all_rgbs[img_idx] # [h, w, 3]
185 | color = color[(pixels_y, pixels_x)] # [batch_size, 3]
186 | mask = torch.ones_like(color, dtype=torch.float)
187 | all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6]
188 | rand_rays = all_rays[(pixels_y, pixels_x)] # [batch_size, 6]
189 | return torch.cat([rand_rays, color, mask[:, :1], patch_rgb_std], dim=-1).to(self.device)
190 |
191 |
192 | def near_far_from_sphere(self, rays_o, rays_d):
193 | # copied from dataset.py
194 | # a = torch.sum(rays_d**2, dim=-1, keepdim=True)
195 | # b = 2.0 * torch.sum(rays_o * rays_d, dim=-1, keepdim=True)
196 | # mid = 0.5 * (-b) / a
197 | # near = mid - 1.0
198 | # far = mid + 1.0
199 |
200 | near = torch.zeros(rays_o.shape[0], 1).cuda()
201 | far = torch.ones(rays_o.shape[0], 1).cuda() * 3
202 | return near, far
203 |
204 | def gen_rays_at(self, img_idx, resolution_level=1):
205 | all_rays = self.all_rays[img_idx].reshape(self.h, self.w, 6) # [h, w, 6]
206 | rays_o = all_rays[:, :, :3].to(self.device)
207 | rays_d = all_rays[:, :, 3:].to(self.device)
208 | return rays_o, rays_d
209 |
210 | def image_at(self, idx, resolution_level):
211 | img = cv2.imread(self.image_paths[idx])
212 | return (cv2.resize(img, (self.w // resolution_level, self.h // resolution_level))).clip(0, 255)
213 |
214 |
215 | def gen_rays_between(self, idx_0, idx_1, ratio, resolution_level=1):
216 | # only used in novel view synthesis
217 | raise NotImplementedError()
218 |
219 | def __getitem__(self, idx):
220 | sample = {}
221 | img_raynum = self.h*self.w
222 | if self.split == 'train':
223 | imgid = self.__len__() - 1 - idx
224 | # imgid = 100
225 | all_rays = self.all_rays[imgid]
226 | rand_idx = np.random.choice(img_raynum, self.batch_size, replace=False)
227 | # rand_idx = np.arange(10740, 10840)
228 |
229 | # batch_rand_idx = np.random.choice(img_raynum, self.batch_size // 9, replace=False)
230 | # # batch_rand_idx = np.array([64500, 64510, 64520, 64530, 64540, 64550, 64560, 64570, 64580, 64590, 64560]) # debug
231 | # rand_idx = self.all_neighbor_idx[batch_rand_idx].reshape(-1) # [N*9]
232 | # sample.update({'patch_rgb_std': self.all_rgb_std[imgid].reshape(-1)[batch_rand_idx]})
233 | sample.update({'patch_rgb_std': self.all_rgb_std[imgid].reshape(-1)[rand_idx]})
234 |
235 | sample.update({'rays_o': all_rays[rand_idx, :3],
236 | 'rays_d': all_rays[rand_idx, 3:6],
237 | 'rgb': self.all_rgbs[imgid].reshape(-1, 3)[rand_idx],
238 | # 'all_trg_rays_o': all_rays[:, :3],
239 | # 'all_trg_rays_d': all_rays[:, 3:6],
240 | # 'trg_rgb': self.all_rgbs[imgid],
241 | # 'trg_extrinsics': self.poses[imgid],
242 | 'idx': idx,
243 | 'imgid': imgid,
244 | 'h': self.h,
245 | 'w': self.w,
246 | 'intrinsics': self.intrinsics,
247 | 'pixel_idx': rand_idx,
248 | })
249 |
250 | # debug
251 | # rays_d = sample['rays_d'].unsqueeze(0)
252 | # rays_o = sample['rays_o'].unsqueeze(0)
253 | # c2w = self.poses[imgid]
254 | # world_pos = rays_d * 0.5 + rays_o # [W, H, 3]
255 | # world_pos = torch.cat([world_pos, torch.ones([rays_d.shape[0], rays_d.shape[1], 1])], 2) # [h*w, 4]
256 | # # K * R^-1 * xyz
257 | # camera_pos = torch.einsum('ij,abj->abi', c2w.inverse(), world_pos) # [W, H, 4]
258 | # camera_pos = torch.einsum('abj,ij->abi', world_pos, c2w.inverse())
259 | # # camera_pos[..., 1] *= -1
260 | # # camera_pos[..., 2] *= -1
261 | # uv1 = torch.einsum('ij, abj->abi', self.intrinsics, camera_pos[..., :3]) # [W, H, 3]
262 | # uv1[..., :] /= uv1[..., 2:]
263 | # uv1 = uv1[..., :2] - 0.5
264 | # uv1 = uv1[:,:,[1,0]]
265 |
266 |
267 | # newview
268 | src_view_num = 3
269 | src_offsets = np.random.choice(np.concatenate((np.arange(-60, -9), np.arange(10, 61))), src_view_num, replace=False)
270 | src_idx = []
271 | for offset in src_offsets:
272 | _idx = imgid + offset
273 | if _idx >= len(self.all_rgbs) or _idx < 0:
274 | _idx = imgid - offset
275 | src_idx.append(_idx)
276 | src_imgid = src_idx
277 | src_all_rays = self.all_rays[src_imgid]
278 | # sample.update({'src_rgb': self.all_rgbs[src_imgid],
279 | # 'src_extrinsics': self.poses[src_imgid],
280 | # 'src_rays_o': src_all_rays[..., :3].reshape(src_view_num, self.h, self.w, 3),
281 | # 'src_rays_d': src_all_rays[..., 3:6].reshape(src_view_num, self.h, self.w, 3),
282 | # 'offsets': src_offsets,})
283 |
284 | return sample
285 | elif self.split == 'test':
286 | idx = np.random.randint(self.__len__())
287 | # idx = 901
288 | all_rays = self.all_rays[idx]
289 | sample = {'rays_o': all_rays[:, :3],
290 | 'rays_d': all_rays[:, 3:6],
291 | 'rgb': self.all_rgbs[idx].reshape(-1,3),
292 | 'idx': idx}
293 | return sample
294 |
--------------------------------------------------------------------------------
/models/tensorBase.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn
3 | import torch.nn.functional as F
4 | from models.sh import eval_sh_bases
5 | import numpy as np
6 | import time
7 |
8 |
9 | def positional_encoding(positions, freqs):
10 | freq_bands = (2 ** torch.arange(freqs).float()).to(positions.device) # (F,)
11 | pts = (positions[..., None] * freq_bands).reshape(
12 | positions.shape[:-1] + (freqs * positions.shape[-1],)) # (..., DF)
13 | pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
14 | return pts
15 |
16 |
17 | def raw2alpha(sigma, dist):
18 | # sigma, dist [N_rays, N_samples]
19 | alpha = 1. - torch.exp(-sigma * dist)
20 |
21 | T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
22 |
23 | weights = alpha * T[:, :-1] # [N_rays, N_samples]
24 | return alpha, weights, T[:, -1:]
25 |
26 |
27 | def SHRender(xyz_sampled, viewdirs, features):
28 | sh_mult = eval_sh_bases(2, viewdirs)[:, None]
29 | rgb_sh = features.view(-1, 3, sh_mult.shape[-1])
30 | rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)
31 | return rgb
32 |
33 |
34 | def RGBRender(xyz_sampled, viewdirs, features):
35 | rgb = features
36 | return rgb
37 |
38 |
39 | class AlphaGridMask(torch.nn.Module):
40 | def __init__(self, device, aabb, alpha_volume):
41 | super(AlphaGridMask, self).__init__()
42 | self.device = device
43 |
44 | self.aabb = aabb.to(self.device)
45 | self.aabbSize = self.aabb[1] - self.aabb[0]
46 | self.invgridSize = 1.0 / self.aabbSize * 2
47 | self.alpha_volume = alpha_volume.view(1, 1, *alpha_volume.shape[-3:])
48 | self.gridSize = torch.LongTensor([alpha_volume.shape[-1], alpha_volume.shape[-2], alpha_volume.shape[-3]]).to(
49 | self.device)
50 |
51 | def sample_alpha(self, xyz_sampled):
52 | xyz_sampled = self.normalize_coord(xyz_sampled)
53 | alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1)
54 |
55 | return alpha_vals
56 |
57 | def normalize_coord(self, xyz_sampled):
58 | return (xyz_sampled - self.aabb[0]) * self.invgridSize - 1
59 |
60 |
61 | class MLPRender_Fea(torch.nn.Module):
62 | def __init__(self, inChanel, viewpe=6, feape=6, featureC=128):
63 | super(MLPRender_Fea, self).__init__()
64 |
65 | self.in_mlpC = 2 * viewpe * 3 + 2 * feape * inChanel + 3 + inChanel
66 | self.viewpe = viewpe
67 | self.feape = feape
68 | layer1 = torch.nn.Linear(self.in_mlpC, featureC)
69 | layer2 = torch.nn.Linear(featureC, featureC)
70 | layer3 = torch.nn.Linear(featureC, 3)
71 |
72 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
73 | torch.nn.init.constant_(self.mlp[-1].bias, 0)
74 |
75 | def forward(self, pts, viewdirs, features):
76 | indata = [features, viewdirs]
77 | if self.feape > 0:
78 | indata += [positional_encoding(features, self.feape)]
79 | if self.viewpe > 0:
80 | indata += [positional_encoding(viewdirs, self.viewpe)]
81 | mlp_in = torch.cat(indata, dim=-1)
82 | rgb = self.mlp(mlp_in)
83 | rgb = torch.sigmoid(rgb)
84 |
85 | return rgb
86 |
87 |
88 | class MLPRender_PE(torch.nn.Module):
89 | def __init__(self, inChanel, viewpe=6, pospe=6, featureC=128):
90 | super(MLPRender_PE, self).__init__()
91 |
92 | self.in_mlpC = (3 + 2 * viewpe * 3) + (3 + 2 * pospe * 3) + inChanel #
93 | self.viewpe = viewpe
94 | self.pospe = pospe
95 | layer1 = torch.nn.Linear(self.in_mlpC, featureC)
96 | layer2 = torch.nn.Linear(featureC, featureC)
97 | layer3 = torch.nn.Linear(featureC, 3)
98 |
99 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
100 | torch.nn.init.constant_(self.mlp[-1].bias, 0)
101 |
102 | def forward(self, pts, viewdirs, features):
103 | indata = [features, viewdirs]
104 | if self.pospe > 0:
105 | indata += [positional_encoding(pts, self.pospe)]
106 | if self.viewpe > 0:
107 | indata += [positional_encoding(viewdirs, self.viewpe)]
108 | mlp_in = torch.cat(indata, dim=-1)
109 | rgb = self.mlp(mlp_in)
110 | rgb = torch.sigmoid(rgb)
111 |
112 | return rgb
113 |
114 |
115 | class MLPRender(torch.nn.Module):
116 | def __init__(self, inChanel, viewpe=6, featureC=128):
117 | super(MLPRender, self).__init__()
118 |
119 | self.in_mlpC = (3 + 2 * viewpe * 3) + inChanel
120 | self.viewpe = viewpe
121 |
122 | layer1 = torch.nn.Linear(self.in_mlpC, featureC)
123 | layer2 = torch.nn.Linear(featureC, featureC)
124 | layer3 = torch.nn.Linear(featureC, 3)
125 |
126 | self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
127 | torch.nn.init.constant_(self.mlp[-1].bias, 0)
128 |
129 | def forward(self, pts, viewdirs, features):
130 | indata = [features, viewdirs]
131 | if self.viewpe > 0:
132 | indata += [positional_encoding(viewdirs, self.viewpe)]
133 | mlp_in = torch.cat(indata, dim=-1)
134 | rgb = self.mlp(mlp_in)
135 | rgb = torch.sigmoid(rgb)
136 |
137 | return rgb
138 |
139 |
140 | class TensorBase(torch.nn.Module):
141 | def __init__(self, aabb, gridSize, device, density_n_comp=8, appearance_n_comp=24, app_dim=27,
142 | shadingMode='MLP_PE', alphaMask=None, near_far=[2.0, 6.0],
143 | density_shift=-10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,
144 | pos_pe=6, view_pe=6, fea_pe=6, featureC=128, step_ratio=2.0,
145 | fea2denseAct='softplus'):
146 | super(TensorBase, self).__init__()
147 |
148 | self.density_n_comp = density_n_comp
149 | self.app_n_comp = appearance_n_comp
150 | self.app_dim = app_dim
151 | self.aabb = aabb
152 | self.alphaMask = alphaMask
153 | self.device = device
154 |
155 | self.density_shift = density_shift
156 | self.alphaMask_thres = alphaMask_thres
157 | self.distance_scale = distance_scale
158 | self.rayMarch_weight_thres = rayMarch_weight_thres
159 | self.fea2denseAct = fea2denseAct
160 |
161 | self.near_far = near_far
162 | self.step_ratio = step_ratio
163 |
164 | self.update_stepSize(gridSize)
165 |
166 | self.matMode = [[0, 1], [0, 2], [1, 2]]
167 | self.vecMode = [2, 1, 0]
168 | self.comp_w = [1, 1, 1]
169 |
170 | self.init_svd_volume(gridSize[0], device)
171 |
172 | self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC
173 | self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device)
174 |
175 | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device):
176 | if shadingMode == 'MLP_PE':
177 | self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device)
178 | elif shadingMode == 'MLP_Fea':
179 | self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device)
180 | elif shadingMode == 'MLP':
181 | self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device)
182 | elif shadingMode == 'SH':
183 | self.renderModule = SHRender
184 | elif shadingMode == 'RGB':
185 | assert self.app_dim == 3
186 | self.renderModule = RGBRender
187 | else:
188 | print("Unrecognized shading module")
189 | exit()
190 | print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe)
191 | print(self.renderModule)
192 |
193 | def update_stepSize(self, gridSize):
194 | print("aabb", self.aabb.view(-1))
195 | print("grid size", gridSize)
196 | self.aabbSize = self.aabb[1] - self.aabb[0]
197 | self.invaabbSize = 2.0 / self.aabbSize
198 | self.gridSize = torch.LongTensor(gridSize).to(self.device)
199 | self.units = self.aabbSize / (self.gridSize - 1)
200 | self.stepSize = torch.mean(self.units) * self.step_ratio
201 | self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
202 | self.nSamples = int((self.aabbDiag / self.stepSize).item()) + 1
203 | print("sampling step size: ", self.stepSize)
204 | print("sampling number: ", self.nSamples)
205 |
206 | def init_svd_volume(self, res, device):
207 | pass
208 |
209 | def compute_features(self, xyz_sampled):
210 | pass
211 |
212 | def compute_densityfeature(self, xyz_sampled):
213 | pass
214 |
215 | def compute_appfeature(self, xyz_sampled):
216 | pass
217 |
218 | def normalize_coord(self, xyz_sampled):
219 | return (xyz_sampled - self.aabb[0]) * self.invaabbSize - 1
220 |
221 | def get_optparam_groups(self, lr_init_spatial=0.02, lr_init_network=0.001):
222 | pass
223 |
224 | def get_kwargs(self):
225 | return {
226 | 'aabb': self.aabb,
227 | 'gridSize': self.gridSize.tolist(),
228 | 'density_n_comp': self.density_n_comp,
229 | 'appearance_n_comp': self.app_n_comp,
230 | 'app_dim': self.app_dim,
231 |
232 | 'density_shift': self.density_shift,
233 | 'alphaMask_thres': self.alphaMask_thres,
234 | 'distance_scale': self.distance_scale,
235 | 'rayMarch_weight_thres': self.rayMarch_weight_thres,
236 | 'fea2denseAct': self.fea2denseAct,
237 |
238 | 'near_far': self.near_far,
239 | 'step_ratio': self.step_ratio,
240 |
241 | 'shadingMode': self.shadingMode,
242 | 'pos_pe': self.pos_pe,
243 | 'view_pe': self.view_pe,
244 | 'fea_pe': self.fea_pe,
245 | 'featureC': self.featureC
246 | }
247 |
248 | def save(self, path):
249 | kwargs = self.get_kwargs()
250 | ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
251 | if self.alphaMask is not None:
252 | alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
253 | ckpt.update({'alphaMask.shape': alpha_volume.shape})
254 | ckpt.update({'alphaMask.mask': np.packbits(alpha_volume.reshape(-1))})
255 | ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
256 | torch.save(ckpt, path)
257 |
258 | def load(self, ckpt):
259 | if 'alphaMask.aabb' in ckpt.keys():
260 | length = np.prod(ckpt['alphaMask.shape'])
261 | alpha_volume = torch.from_numpy(
262 | np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
263 | self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device),
264 | alpha_volume.float().to(self.device))
265 | self.load_state_dict(ckpt['state_dict'])
266 |
267 | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
268 | N_samples = N_samples if N_samples > 0 else self.nSamples
269 | near, far = self.near_far
270 | interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
271 | if is_train:
272 | interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)
273 |
274 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
275 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
276 | return rays_pts, interpx, ~mask_outbbox
277 |
278 | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
279 | N_samples = N_samples if N_samples > 0 else self.nSamples
280 | stepsize = self.stepSize
281 | near, far = self.near_far
282 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
283 | rate_a = (self.aabb[1] - rays_o) / vec
284 | rate_b = (self.aabb[0] - rays_o) / vec
285 | t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
286 |
287 | rng = torch.arange(N_samples)[None].float()
288 | if is_train:
289 | rng = rng.repeat(rays_d.shape[-2], 1)
290 | rng += torch.rand_like(rng[:, [0]])
291 | step = stepsize * rng.to(rays_o.device)
292 | interpx = (t_min[..., None] + step)
293 |
294 | rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
295 | mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
296 |
297 | return rays_pts, interpx, ~mask_outbbox
298 |
299 | def shrink(self, new_aabb, voxel_size):
300 | pass
301 |
302 | @torch.no_grad()
303 | def getDenseAlpha(self, gridSize=None):
304 | gridSize = self.gridSize if gridSize is None else gridSize
305 |
306 | samples = torch.stack(torch.meshgrid(
307 | torch.linspace(0, 1, gridSize[0]),
308 | torch.linspace(0, 1, gridSize[1]),
309 | torch.linspace(0, 1, gridSize[2]),
310 | ), -1).to(self.device)
311 | dense_xyz = self.aabb[0] * (1 - samples) + self.aabb[1] * samples
312 |
313 | # dense_xyz = dense_xyz
314 | # print(self.stepSize, self.distance_scale*self.aabbDiag)
315 | alpha = torch.zeros_like(dense_xyz[..., 0])
316 | for i in range(gridSize[0]):
317 | alpha[i] = self.compute_alpha(dense_xyz[i].view(-1, 3), self.stepSize).view((gridSize[1], gridSize[2]))
318 | return alpha, dense_xyz
319 |
320 | @torch.no_grad()
321 | def updateAlphaMask(self, gridSize=(200, 200, 200)):
322 |
323 | alpha, dense_xyz = self.getDenseAlpha(gridSize)
324 | dense_xyz = dense_xyz.transpose(0, 2).contiguous()
325 | alpha = alpha.clamp(0, 1).transpose(0, 2).contiguous()[None, None]
326 | total_voxels = gridSize[0] * gridSize[1] * gridSize[2]
327 |
328 | ks = 3
329 | alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
330 | alpha[alpha >= self.alphaMask_thres] = 1
331 | alpha[alpha < self.alphaMask_thres] = 0
332 |
333 | self.alphaMask = AlphaGridMask(self.device, self.aabb, alpha)
334 |
335 | valid_xyz = dense_xyz[alpha > 0.5]
336 |
337 | xyz_min = valid_xyz.amin(0)
338 | xyz_max = valid_xyz.amax(0)
339 |
340 | new_aabb = torch.stack((xyz_min, xyz_max))
341 |
342 | total = torch.sum(alpha)
343 | print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f" % (total / total_voxels * 100))
344 | return new_aabb
345 |
346 | @torch.no_grad()
347 | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240 * 5, bbox_only=False):
348 | print('========> filtering rays ...')
349 | tt = time.time()
350 |
351 | N = torch.tensor(all_rays.shape[:-1]).prod()
352 |
353 | mask_filtered = []
354 | idx_chunks = torch.split(torch.arange(N), chunk)
355 | for idx_chunk in idx_chunks:
356 | rays_chunk = all_rays[idx_chunk].to(self.device)
357 |
358 | rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
359 | if bbox_only:
360 | vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
361 | rate_a = (self.aabb[1] - rays_o) / vec
362 | rate_b = (self.aabb[0] - rays_o) / vec
363 | t_min = torch.minimum(rate_a, rate_b).amax(-1) # .clamp(min=near, max=far)
364 | t_max = torch.maximum(rate_a, rate_b).amin(-1) # .clamp(min=near, max=far)
365 | mask_inbbox = t_max > t_min
366 |
367 | else:
368 | xyz_sampled, _, _ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
369 | mask_inbbox = (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
370 |
371 | mask_filtered.append(mask_inbbox.cpu())
372 |
373 | mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])
374 |
375 | print(f'Ray filtering done! takes {time.time() - tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
376 | return all_rays[mask_filtered], all_rgbs[mask_filtered]
377 |
378 | def feature2density(self, density_features):
379 | if self.fea2denseAct == "softplus":
380 | return F.softplus(density_features + self.density_shift)
381 | elif self.fea2denseAct == "relu":
382 | return F.relu(density_features)
383 |
384 | def compute_alpha(self, xyz_locs, length=1):
385 |
386 | if self.alphaMask is not None:
387 | alphas = self.alphaMask.sample_alpha(xyz_locs)
388 | alpha_mask = alphas > 0
389 | else:
390 | alpha_mask = torch.ones_like(xyz_locs[:, 0], dtype=bool)
391 |
392 | sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
393 |
394 | if alpha_mask.any():
395 | xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
396 | sigma_feature = self.compute_densityfeature(xyz_sampled)
397 | validsigma = self.feature2density(sigma_feature)
398 | sigma[alpha_mask] = validsigma
399 |
400 | alpha = 1 - torch.exp(-sigma * length).view(xyz_locs.shape[:-1])
401 |
402 | return alpha
403 |
404 | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
405 |
406 | # sample points
407 | viewdirs = rays_chunk[:, 3:6]
408 | if ndc_ray:
409 | xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,
410 | N_samples=N_samples)
411 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
412 | rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
413 | dists = dists * rays_norm
414 | viewdirs = viewdirs / rays_norm
415 | else:
416 | xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,
417 | N_samples=N_samples)
418 | dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
419 | viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
420 |
421 | if self.alphaMask is not None:
422 | alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
423 | alpha_mask = alphas > 0
424 | ray_invalid = ~ray_valid
425 | ray_invalid[ray_valid] |= (~alpha_mask)
426 | ray_valid = ~ray_invalid
427 |
428 | sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
429 | rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
430 |
431 | if ray_valid.any():
432 | xyz_sampled = self.normalize_coord(xyz_sampled)
433 | sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
434 |
435 | validsigma = self.feature2density(sigma_feature)
436 | sigma[ray_valid] = validsigma
437 |
438 | alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
439 |
440 | app_mask = weight > self.rayMarch_weight_thres
441 |
442 | if app_mask.any():
443 | app_features = self.compute_appfeature(xyz_sampled[app_mask])
444 | valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
445 | rgb[app_mask] = valid_rgbs
446 |
447 | acc_map = torch.sum(weight, -1)
448 | rgb_map = torch.sum(weight[..., None] * rgb, -2)
449 |
450 | if white_bg or (is_train and torch.rand((1,)) < 0.5):
451 | rgb_map = rgb_map + (1. - acc_map[..., None])
452 |
453 | rgb_map = rgb_map.clamp(0, 1)
454 |
455 | with torch.no_grad():
456 | depth_map = torch.sum(weight * z_vals, -1)
457 | depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
458 |
459 | return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight
460 |
461 |
--------------------------------------------------------------------------------
/models/renderer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import logging
6 | import mcubes
7 | from models.tensorf2neus import *
8 | from models.nerf2neus import *
9 |
10 |
11 | def extract_fields(bound_min, bound_max, resolution, query_func):
12 | N = 64
13 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)
14 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)
15 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)
16 |
17 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
18 | with torch.no_grad():
19 | for xi, xs in enumerate(X):
20 | for yi, ys in enumerate(Y):
21 | for zi, zs in enumerate(Z):
22 | xx, yy, zz = torch.meshgrid(xs, ys, zs)
23 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)
24 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
25 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val
26 | return u
27 |
28 |
29 | def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
30 | print('threshold: {}'.format(threshold))
31 | u = extract_fields(bound_min, bound_max, resolution, query_func)
32 | vertices, triangles = mcubes.marching_cubes(u, threshold)
33 | b_max_np = bound_max.detach().cpu().numpy()
34 | b_min_np = bound_min.detach().cpu().numpy()
35 |
36 | vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]
37 | return vertices, triangles
38 |
39 |
40 | def sample_pdf(bins, weights, n_samples, det=False):
41 | # This implementation is from NeRF
42 | # Get pdf
43 | weights = weights + 1e-5 # prevent nans
44 | pdf = weights / torch.sum(weights, -1, keepdim=True)
45 | cdf = torch.cumsum(pdf, -1)
46 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
47 | # Take uniform samples
48 | if det:
49 | u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples)
50 | u = u.expand(list(cdf.shape[:-1]) + [n_samples])
51 | else:
52 | u = torch.rand(list(cdf.shape[:-1]) + [n_samples])
53 |
54 | # Invert CDF
55 | u = u.contiguous()
56 | inds = torch.searchsorted(cdf, u, right=True)
57 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
58 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
59 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
60 |
61 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
62 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
63 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
64 |
65 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
66 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
67 | t = (u - cdf_g[..., 0]) / denom
68 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
69 |
70 | return samples
71 |
72 |
73 | class NeuSRenderer:
74 | def __init__(self,
75 | nerf,
76 | sdf_network,
77 | deviation_network,
78 | color_network,
79 | n_samples,
80 | n_importance,
81 | n_outside,
82 | up_sample_steps,
83 | perturb):
84 | self.nerf = nerf
85 | self.sdf_network = sdf_network
86 | self.deviation_network = deviation_network
87 | self.color_network = color_network
88 | self.n_samples = n_samples
89 | self.n_importance = n_importance
90 | self.n_outside = n_outside
91 | self.up_sample_steps = up_sample_steps
92 | self.perturb = perturb
93 |
94 | def render_core_outside(self, rays_o, rays_d, z_vals, sample_dist, nerf, background_rgb=None):
95 | """
96 | Render background
97 | """
98 | batch_size, n_samples = z_vals.shape
99 |
100 | # Section length
101 | dists = z_vals[..., 1:] - z_vals[..., :-1]
102 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
103 | mid_z_vals = z_vals + dists * 0.5
104 |
105 | # Section midpoints
106 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # batch_size, n_samples, 3
107 |
108 | dis_to_center = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).clip(1.0, 1e10)
109 | pts = torch.cat([pts / dis_to_center, 1.0 / dis_to_center], dim=-1) # batch_size, n_samples, 4
110 |
111 | dirs = rays_d[:, None, :].expand(batch_size, n_samples, 3)
112 |
113 | pts = pts.reshape(-1, 3 + int(self.n_outside > 0))
114 | dirs = dirs.reshape(-1, 3)
115 |
116 | density, sampled_color = nerf(pts, dirs)
117 | sampled_color = torch.sigmoid(sampled_color)
118 | alpha = 1.0 - torch.exp(-F.softplus(density.reshape(batch_size, n_samples)) * dists)
119 | alpha = alpha.reshape(batch_size, n_samples)
120 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
121 | sampled_color = sampled_color.reshape(batch_size, n_samples, 3)
122 | color = (weights[:, :, None] * sampled_color).sum(dim=1)
123 | if background_rgb is not None:
124 | color = color + background_rgb * (1.0 - weights.sum(dim=-1, keepdim=True))
125 |
126 | return {
127 | 'color': color,
128 | 'sampled_color': sampled_color,
129 | 'alpha': alpha,
130 | 'weights': weights,
131 | }
132 |
133 | def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_s):
134 | """
135 | Up sampling give a fixed inv_s
136 | """
137 | batch_size, n_samples = z_vals.shape
138 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3
139 | radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=False)
140 | inside_sphere = (radius[:, :-1] < 1.0) | (radius[:, 1:] < 1.0)
141 | sdf = sdf.reshape(batch_size, n_samples)
142 | prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:]
143 | prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:]
144 | mid_sdf = (prev_sdf + next_sdf) * 0.5
145 | cos_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
146 |
147 | # ----------------------------------------------------------------------------------------------------------
148 | # Use min value of [ cos, prev_cos ]
149 | # Though it makes the sampling (not rendering) a little bit biased, this strategy can make the sampling more
150 | # robust when meeting situations like below:
151 | #
152 | # SDF
153 | # ^
154 | # |\ -----x----...
155 | # | \ /
156 | # | x x
157 | # |---\----/-------------> 0 level
158 | # | \ /
159 | # | \/
160 | # |
161 | # ----------------------------------------------------------------------------------------------------------
162 | prev_cos_val = torch.cat([torch.zeros([batch_size, 1]), cos_val[:, :-1]], dim=-1)
163 | cos_val = torch.stack([prev_cos_val, cos_val], dim=-1)
164 | cos_val, _ = torch.min(cos_val, dim=-1, keepdim=False)
165 | cos_val = cos_val.clip(-1e3, 0.0) * inside_sphere
166 |
167 | dist = (next_z_vals - prev_z_vals)
168 | prev_esti_sdf = mid_sdf - cos_val * dist * 0.5
169 | next_esti_sdf = mid_sdf + cos_val * dist * 0.5
170 | prev_cdf = torch.sigmoid(prev_esti_sdf * inv_s)
171 | next_cdf = torch.sigmoid(next_esti_sdf * inv_s)
172 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
173 | weights = alpha * torch.cumprod(
174 | torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
175 |
176 | z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach()
177 | return z_samples
178 |
179 | def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, last=False):
180 | batch_size, n_samples = z_vals.shape
181 | _, n_importance = new_z_vals.shape
182 | pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None]
183 | z_vals = torch.cat([z_vals, new_z_vals], dim=-1)
184 | z_vals, index = torch.sort(z_vals, dim=-1)
185 |
186 | if not last:
187 | new_sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, n_importance)
188 | sdf = torch.cat([sdf, new_sdf], dim=-1)
189 | xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1)
190 | index = index.reshape(-1)
191 | sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance)
192 |
193 | return z_vals, sdf
194 |
195 | def render_core(self,
196 | rays_o,
197 | rays_d,
198 | z_vals,
199 | sample_dist,
200 | sdf_network,
201 | deviation_network,
202 | color_network,
203 | background_alpha=None,
204 | background_sampled_color=None,
205 | background_rgb=None,
206 | cos_anneal_ratio=0.0): # Render part in bounding sphere (core), using color net
207 | batch_size, n_samples = z_vals.shape
208 |
209 | # Section length
210 | dists = z_vals[..., 1:] - z_vals[..., :-1]
211 | dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape)], -1)
212 | mid_z_vals = z_vals + dists * 0.5
213 |
214 | # Section midpoints
215 | pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3
216 | dirs = rays_d[:, None, :].expand(pts.shape)
217 |
218 | pts = pts.reshape(-1, 3)
219 | dirs = dirs.reshape(-1, 3)
220 |
221 | sdf_nn_output = sdf_network(pts)
222 | sdf = sdf_nn_output[:, :1]
223 | feature_vector = sdf_nn_output[:, 1:]
224 |
225 | gradients = sdf_network.gradient(pts).squeeze() # gradients assumed to be normalized because of eikonal_loss
226 | sampled_color = color_network(pts, gradients, dirs, feature_vector).reshape(batch_size, n_samples, 3)
227 |
228 | inv_s = deviation_network(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) # Single parameter
229 | inv_s = inv_s.expand(batch_size * n_samples, 1)
230 |
231 | true_cos = (dirs * gradients).sum(-1, keepdim=True)
232 |
233 | # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
234 | # the cos value "not dead" at the beginning training iterations, for better convergence.
235 | iter_cos = -(F.relu(-true_cos * 0.5 + 0.5) * (1.0 - cos_anneal_ratio) +
236 | F.relu(-true_cos) * cos_anneal_ratio) # always non-positive
237 |
238 | # Estimate signed distances at section points
239 | estimated_next_sdf = sdf + iter_cos * dists.reshape(-1, 1) * 0.5
240 | estimated_prev_sdf = sdf - iter_cos * dists.reshape(-1, 1) * 0.5
241 |
242 | prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s)
243 | next_cdf = torch.sigmoid(estimated_next_sdf * inv_s)
244 |
245 | p = prev_cdf - next_cdf
246 | c = prev_cdf
247 |
248 | alpha = ((p + 1e-5) / (c + 1e-5)).reshape(batch_size, n_samples).clip(0.0, 1.0)
249 |
250 | pts_norm = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(batch_size, n_samples)
251 | inside_sphere = (pts_norm < 1.0).float().detach()
252 | # relax_inside_sphere = (pts_norm < 1.2).float().detach()
253 | relax_inside_sphere = (pts_norm < 3.0).float().detach()
254 |
255 | # Render with background
256 | if background_alpha is not None:
257 | alpha = alpha * inside_sphere + background_alpha[:, :n_samples] * (1.0 - inside_sphere)
258 | alpha = torch.cat([alpha, background_alpha[:, n_samples:]], dim=-1)
259 | sampled_color = sampled_color * inside_sphere[:, :, None] +\
260 | background_sampled_color[:, :n_samples] * (1.0 - inside_sphere)[:, :, None]
261 | sampled_color = torch.cat([sampled_color, background_sampled_color[:, n_samples:]], dim=1)
262 |
263 | weights = alpha * torch.cumprod(torch.cat([torch.ones([batch_size, 1]), 1. - alpha + 1e-7], -1), -1)[:, :-1]
264 | weights_sum = weights.sum(dim=-1, keepdim=True)
265 |
266 | color = (sampled_color * weights[:, :, None]).sum(dim=1)
267 | if background_rgb is not None: # Fixed background, usually black
268 | color = color + background_rgb * (1.0 - weights_sum)
269 |
270 | # Eikonal loss
271 | gradient_error = (torch.linalg.norm(gradients.reshape(batch_size, n_samples, 3), ord=2,
272 | dim=-1) - 1.0) ** 2
273 | gradient_error = (relax_inside_sphere * gradient_error).sum() / (relax_inside_sphere.sum() + 1e-5)
274 |
275 | if not self.validate:
276 | # query from NeRF Prior
277 | # TensoRF
278 | thres1 = 0.9 # for replica
279 | thres1 = 0.5 # for scannet
280 | thres2 = 1e-8
281 | with torch.no_grad():
282 | queried_occ, _ = query_alpha_color(self.tensorf, pts, None, prior_mode='cat')
283 | queried_occ = queried_occ.reshape(alpha.shape)
284 |
285 | # NeRF
286 | # thres1 = 0.5
287 | # thres2 = 1e-8
288 | # with torch.no_grad():
289 | # queried_occ, _ = query_alpha_color_nerf(self.nerf, _xyz_sampled=pts.reshape(batch_size, n_samples, 3), viewdirs=rays_d, z_vals=z_vals)
290 | # queried_occ = queried_occ.reshape(alpha.shape)
291 | # queried_occ = torch.where(queried_occ > thres1, torch.Tensor([1.]).cuda(), queried_occ)
292 | # queried_occ = torch.where(queried_occ < thres2, torch.Tensor([0.]).cuda(), queried_occ)
293 |
294 | # visual_red_green_points(pts[queried_occ.reshape(-1)>thres1], pts[queried_occ.reshape(-1) thres1, torch.Tensor([1]).cuda(), supervise_mask) # scannet 1, replica 2
298 | # supervise_mask = torch.where(queried_occ < 1e-8, torch.Tensor([1]).cuda(), supervise_mask)
299 | supervise_mask = torch.where(queried_occ < thres2, torch.Tensor([0.5]).cuda(), supervise_mask)
300 | alpha_error = torch.abs(alpha - queried_occ) * supervise_mask
301 | alpha_error = alpha_error.sum() / ((queried_occ > thres1).sum() + (queried_occ < thres2).sum())
302 | else:
303 | alpha_error = 0
304 |
305 |
306 | if weights.shape[0] % 9 == 0:
307 | pred_depth = (weights * mid_z_vals).sum(-1) # [N]
308 | pred_normal = (weights.unsqueeze(-1) * gradients.reshape(batch_size, n_samples, 3)).sum(1) # [N, 3]
309 | cos = (dirs.reshape(batch_size, n_samples, 3)[:,0,:] * pred_normal).sum(1)
310 | proj_depth = pred_depth * cos # [N]
311 | proj_depth = proj_depth.reshape(-1, 9)
312 | mean_depth = proj_depth.mean(1, keepdim=True)
313 | patch_depth_std = ((proj_depth - mean_depth) ** 2).mean(1).sqrt()
314 | else:
315 | patch_depth_std = torch.zeros(weights.shape[0])
316 |
317 |
318 | return {
319 | 'color': color,
320 | 'sdf': sdf,
321 | 'dists': dists,
322 | 'gradients': gradients.reshape(batch_size, n_samples, 3),
323 | 's_val': 1.0 / inv_s,
324 | 'mid_z_vals': mid_z_vals,
325 | 'weights': weights,
326 | 'cdf': c.reshape(batch_size, n_samples),
327 | 'gradient_error': gradient_error,
328 | 'inside_sphere': inside_sphere,
329 | 'alpha_error': alpha_error,
330 | 'depth_std': patch_depth_std,
331 | 'depth': (weights * mid_z_vals).sum(-1),
332 | }
333 |
334 | def render(self, rays_o, rays_d, near, far, perturb_overwrite=-1, background_rgb=None, cos_anneal_ratio=0.0):
335 | batch_size = len(rays_o)
336 | sample_dist = 2.0 / self.n_samples # Assuming the region of interest is a unit sphere
337 | z_vals = torch.linspace(0.0, 1.0, self.n_samples)
338 | z_vals = near + (far - near) * z_vals[None, :]
339 |
340 | z_vals_outside = None
341 | if self.n_outside > 0:
342 | z_vals_outside = torch.linspace(1e-3, 1.0 - 1.0 / (self.n_outside + 1.0), self.n_outside)
343 |
344 | n_samples = self.n_samples
345 | perturb = self.perturb
346 |
347 | if perturb_overwrite >= 0:
348 | perturb = perturb_overwrite
349 | if perturb > 0:
350 | t_rand = (torch.rand([batch_size, 1]) - 0.5)
351 | z_vals = z_vals + t_rand * 2.0 / self.n_samples
352 |
353 | if self.n_outside > 0:
354 | mids = .5 * (z_vals_outside[..., 1:] + z_vals_outside[..., :-1])
355 | upper = torch.cat([mids, z_vals_outside[..., -1:]], -1)
356 | lower = torch.cat([z_vals_outside[..., :1], mids], -1)
357 | t_rand = torch.rand([batch_size, z_vals_outside.shape[-1]])
358 | z_vals_outside = lower[None, :] + (upper - lower)[None, :] * t_rand
359 |
360 | if self.n_outside > 0:
361 | z_vals_outside = far / torch.flip(z_vals_outside, dims=[-1]) + 1.0 / self.n_samples
362 |
363 | background_alpha = None
364 | background_sampled_color = None
365 |
366 | # Up sample
367 | if self.n_importance > 0:
368 | with torch.no_grad():
369 | pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None]
370 | sdf = self.sdf_network.sdf(pts.reshape(-1, 3)).reshape(batch_size, self.n_samples)
371 |
372 | for i in range(self.up_sample_steps):
373 | new_z_vals = self.up_sample(rays_o,
374 | rays_d,
375 | z_vals,
376 | sdf,
377 | self.n_importance // self.up_sample_steps,
378 | 64 * 2**i)
379 | z_vals, sdf = self.cat_z_vals(rays_o,
380 | rays_d,
381 | z_vals,
382 | new_z_vals,
383 | sdf,
384 | last=(i + 1 == self.up_sample_steps))
385 |
386 | n_samples = self.n_samples + self.n_importance
387 |
388 | # Background model
389 | if self.n_outside > 0:
390 | z_vals_feed = torch.cat([z_vals, z_vals_outside], dim=-1)
391 | z_vals_feed, _ = torch.sort(z_vals_feed, dim=-1)
392 | ret_outside = self.render_core_outside(rays_o, rays_d, z_vals_feed, sample_dist, self.nerf)
393 |
394 | background_sampled_color = ret_outside['sampled_color']
395 | background_alpha = ret_outside['alpha']
396 |
397 | # Render core
398 | ret_fine = self.render_core(rays_o,
399 | rays_d,
400 | z_vals,
401 | sample_dist,
402 | self.sdf_network,
403 | self.deviation_network,
404 | self.color_network,
405 | background_rgb=background_rgb,
406 | background_alpha=background_alpha,
407 | background_sampled_color=background_sampled_color,
408 | cos_anneal_ratio=cos_anneal_ratio)
409 |
410 | color_fine = ret_fine['color']
411 | weights = ret_fine['weights']
412 | weights_sum = weights.sum(dim=-1, keepdim=True)
413 | gradients = ret_fine['gradients']
414 | s_val = ret_fine['s_val'].reshape(batch_size, n_samples).mean(dim=-1, keepdim=True)
415 |
416 | return {
417 | 'color_fine': color_fine,
418 | 's_val': s_val,
419 | 'cdf_fine': ret_fine['cdf'],
420 | 'weight_sum': weights_sum,
421 | 'weight_max': torch.max(weights, dim=-1, keepdim=True)[0],
422 | 'gradients': gradients,
423 | 'weights': weights,
424 | 'gradient_error': ret_fine['gradient_error'],
425 | 'inside_sphere': ret_fine['inside_sphere'],
426 | 'alpha_error': ret_fine['alpha_error'],
427 | 'depth_std': ret_fine['depth_std'],
428 | 'depth': ret_fine['depth']
429 | }
430 |
431 | def extract_geometry(self, bound_min, bound_max, resolution, threshold=0.0):
432 | return extract_geometry(bound_min,
433 | bound_max,
434 | resolution=resolution,
435 | threshold=threshold,
436 | query_func=lambda pts: -self.sdf_network.sdf(pts))
437 |
--------------------------------------------------------------------------------
/exp_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os
3 | import time
4 | import logging
5 | import argparse
6 | import numpy as np
7 | import cv2 as cv
8 | import trimesh
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from shutil import copyfile
13 | # from icecream import ic
14 | from tqdm import tqdm
15 | from pyhocon import ConfigFactory
16 | from models.dataset import Dataset
17 | from models.blender_swap import BlendSwapDataset
18 | from models.scannet_blender import ScanNetDataset
19 | from models.fields import RenderingNetwork, SDFNetwork, SingleVarianceNetwork, NeRF
20 | from models.renderer import NeuSRenderer
21 | from models.tensorf2neus import *
22 | from models.nerf2neus import *
23 |
24 |
25 | class Runner:
26 | def __init__(self, conf_path, mode='train', case='CASE_NAME', is_continue=False):
27 | self.device = torch.device('cuda')
28 |
29 | # Configuration
30 | self.conf_path = conf_path
31 | f = open(self.conf_path)
32 | conf_text = f.read()
33 | conf_text = conf_text.replace('CASE_NAME', case)
34 | f.close()
35 |
36 | self.conf = ConfigFactory.parse_string(conf_text)
37 | self.conf['dataset.data_dir'] = self.conf['dataset.data_dir'].replace('CASE_NAME', case)
38 | self.base_exp_dir = self.conf['general.base_exp_dir']
39 | os.makedirs(self.base_exp_dir, exist_ok=True)
40 |
41 | # if self.conf['dataset.type'] == "Neus":
42 | # self.dataset = Dataset(self.conf['dataset'])
43 | # elif self.conf['dataset.type'] == "Blender":
44 | # self.dataset = BlendSwapDataset(self.conf['dataset'])
45 | self.dataset = ScanNetDataset(self.conf['dataset'])
46 |
47 | self.iter_step = 0
48 |
49 | # Training parameters
50 | self.end_iter = self.conf.get_int('train.end_iter')
51 | self.save_freq = self.conf.get_int('train.save_freq')
52 | self.report_freq = self.conf.get_int('train.report_freq')
53 | self.val_freq = self.conf.get_int('train.val_freq')
54 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq')
55 | self.batch_size = self.conf.get_int('train.batch_size')
56 | self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level')
57 | self.learning_rate = self.conf.get_float('train.learning_rate')
58 | self.learning_rate_alpha = self.conf.get_float('train.learning_rate_alpha')
59 | self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd')
60 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0)
61 | self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0)
62 |
63 | # Weights
64 | self.igr_weight = self.conf.get_float('train.igr_weight')
65 | self.mask_weight = self.conf.get_float('train.mask_weight')
66 | self.is_continue = is_continue
67 | self.mode = mode
68 | self.model_list = []
69 | self.writer = None
70 |
71 | # Networks
72 | params_to_train = []
73 | self.nerf_outside = NeRF(**self.conf['model.nerf']).to(self.device)
74 | self.sdf_network = SDFNetwork(**self.conf['model.sdf_network']).to(self.device) # sdf net
75 | self.deviation_network = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) # train for parameter s in paper
76 | self.color_network = RenderingNetwork(**self.conf['model.rendering_network']).to(self.device) # color net, same as nerf
77 | params_to_train += list(self.nerf_outside.parameters())
78 | params_to_train += list(self.sdf_network.parameters())
79 | params_to_train += list(self.deviation_network.parameters())
80 | params_to_train += list(self.color_network.parameters())
81 |
82 | self.optimizer = torch.optim.Adam(params_to_train, lr=self.learning_rate)
83 |
84 | self.renderer = NeuSRenderer(self.nerf_outside,
85 | self.sdf_network,
86 | self.deviation_network,
87 | self.color_network,
88 | **self.conf['model.neus_renderer'])
89 |
90 | # Load checkpoint
91 | latest_model_name = None
92 | if is_continue:
93 | model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints'))
94 | model_list = []
95 | for model_name in model_list_raw:
96 | if model_name[-3:] == 'pth' and int(model_name[5:-4]) <= self.end_iter:
97 | model_list.append(model_name)
98 | model_list.sort()
99 | latest_model_name = model_list[-1]
100 | # latest_model_name = 'ckpt_050000.pth'
101 |
102 | if latest_model_name is not None:
103 | logging.info('Find checkpoint: {}'.format(latest_model_name))
104 | self.load_checkpoint(latest_model_name)
105 |
106 | # Backup codes and configs for debug
107 | if self.mode[:5] == 'train':
108 | self.file_backup()
109 |
110 | # TODO:
111 | # tensorf
112 | tensorf = load_tensorf('data/tensorf/scan0050_00.th', torch.device('cuda:0'))
113 | reso_cur = list(tensorf.gridSize.cpu().detach().numpy())
114 | tensorf.nSamples = cal_n_samples(reso_cur, 0.5)
115 | self.renderer.tensorf = tensorf
116 |
117 | # nerf
118 | # nerf = load_nerf('data/nerf/130000.tar')
119 | # self.renderer.nerf = nerf
120 | # # validate_mesh(nerf)
121 |
122 | def train(self):
123 | self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs'))
124 | self.update_learning_rate()
125 | res_step = self.end_iter - self.iter_step
126 | image_perm = self.get_image_perm()
127 | # self.validate_image(resolution_level=4)
128 | self.validate_mesh()
129 | self.renderer.validate = False
130 |
131 | for iter_i in tqdm(range(res_step)):
132 | # if self.iter_step > 150000:
133 | # random_mode = 'patch'
134 | # else:
135 | # random_mode = 'batch'
136 | # data = self.dataset.gen_random_rays_at(image_perm[self.iter_step % len(image_perm)], self.batch_size, mode=random_mode)
137 | data = self.dataset.gen_random_rays_at(image_perm[self.iter_step % len(image_perm)], self.batch_size)
138 |
139 | rays_o, rays_d, true_rgb, mask = data[:, :3], data[:, 3: 6], data[:, 6: 9], data[:, 9: 10]
140 | near, far = self.dataset.near_far_from_sphere(rays_o, rays_d)
141 |
142 | background_rgb = None
143 | if self.use_white_bkgd:
144 | background_rgb = torch.ones([1, 3])
145 |
146 | if self.mask_weight > 0.0:
147 | mask = (mask > 0.5).float()
148 | else:
149 | mask = torch.ones_like(mask)
150 |
151 | mask_sum = mask.sum() + 1e-5
152 | render_out = self.renderer.render(rays_o, rays_d, near, far,
153 | background_rgb=background_rgb,
154 | cos_anneal_ratio=self.get_cos_anneal_ratio())
155 |
156 | color_fine = render_out['color_fine']
157 | s_val = render_out['s_val']
158 | cdf_fine = render_out['cdf_fine']
159 | gradient_error = render_out['gradient_error']
160 | weight_max = render_out['weight_max']
161 | weight_sum = render_out['weight_sum']
162 | alpha_error = render_out['alpha_error']
163 |
164 | # Loss
165 | color_error = (color_fine - true_rgb) * mask
166 | color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error), reduction='sum') / mask_sum
167 | psnr = 20.0 * torch.log10(1.0 / (((color_fine - true_rgb)**2 * mask).sum() / (mask_sum * 3.0)).sqrt())
168 |
169 | eikonal_loss = gradient_error
170 |
171 | mask_loss = F.binary_cross_entropy(weight_sum.clip(1e-3, 1.0 - 1e-3), mask)
172 |
173 | # TODO
174 | # if random_mode == 'patch':
175 | # patch_rgb_std = data[:, -1].reshape(-1, 9)[:, 4]
176 | # rgb_mask = patch_rgb_std == 0
177 | # depth_loss = torch.abs(render_out['depth_std'])[rgb_mask]
178 | # depth_loss = depth_loss.mean()
179 | # elif random_mode == 'batch':
180 | # depth_loss = 0
181 | depth_loss = 0
182 |
183 | alpha_error_weight = 0.5 * (1-min(1, self.iter_step/100000))
184 | loss = color_fine_loss +\
185 | eikonal_loss * self.igr_weight +\
186 | mask_loss * self.mask_weight + \
187 | alpha_error * alpha_error_weight +\
188 | depth_loss * 0.5 # TODO
189 |
190 | self.optimizer.zero_grad()
191 | loss.backward()
192 | self.optimizer.step()
193 |
194 | self.iter_step += 1
195 |
196 | self.writer.add_scalar('Loss/loss', loss, self.iter_step)
197 | self.writer.add_scalar('Loss/color_loss', color_fine_loss, self.iter_step)
198 | self.writer.add_scalar('Loss/eikonal_loss', eikonal_loss, self.iter_step)
199 | self.writer.add_scalar('Loss/alpha_error', render_out['alpha_error'], self.iter_step)
200 | self.writer.add_scalar('Loss/depth_std', depth_loss, self.iter_step)
201 | self.writer.add_scalar('Statistics/s_val', s_val.mean(), self.iter_step)
202 | self.writer.add_scalar('Statistics/cdf', (cdf_fine[:, :1] * mask).sum() / mask_sum, self.iter_step)
203 | self.writer.add_scalar('Statistics/weight_max', (weight_max * mask).sum() / mask_sum, self.iter_step)
204 | self.writer.add_scalar('Statistics/psnr', psnr, self.iter_step)
205 |
206 | if self.iter_step % self.report_freq == 0:
207 | print(self.base_exp_dir)
208 | print('iter:{:8>d} loss={:.3f} alpha={:.3f} alpha_weight={:.3f}'
209 | .format(self.iter_step, loss, alpha_error, alpha_error_weight))
210 |
211 | if self.iter_step % self.save_freq == 0:
212 | self.save_checkpoint()
213 |
214 | if self.iter_step % self.val_freq == 0:
215 | self.validate_image(resolution_level=4)
216 |
217 | if self.iter_step % self.val_mesh_freq == 0:
218 | self.validate_mesh()
219 |
220 | self.update_learning_rate()
221 |
222 | if self.iter_step % len(image_perm) == 0:
223 | image_perm = self.get_image_perm()
224 |
225 | def get_image_perm(self):
226 | return torch.randperm(self.dataset.n_images)
227 |
228 | def get_cos_anneal_ratio(self):
229 | if self.anneal_end == 0.0:
230 | return 1.0
231 | else:
232 | return np.min([1.0, self.iter_step / self.anneal_end])
233 |
234 | def update_learning_rate(self):
235 | if self.iter_step < self.warm_up_end:
236 | learning_factor = self.iter_step / self.warm_up_end
237 | else:
238 | alpha = self.learning_rate_alpha
239 | progress = (self.iter_step - self.warm_up_end) / (self.end_iter - self.warm_up_end)
240 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (1 - alpha) + alpha
241 |
242 | for g in self.optimizer.param_groups:
243 | g['lr'] = self.learning_rate * learning_factor
244 |
245 | def file_backup(self):
246 | dir_lis = self.conf['general.recording']
247 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True)
248 | for dir_name in dir_lis:
249 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name)
250 | os.makedirs(cur_dir, exist_ok=True)
251 | files = os.listdir(dir_name)
252 | for f_name in files:
253 | if f_name[-3:] == '.py':
254 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name))
255 |
256 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
257 |
258 | def load_checkpoint(self, checkpoint_name):
259 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device)
260 | self.nerf_outside.load_state_dict(checkpoint['nerf'])
261 | self.sdf_network.load_state_dict(checkpoint['sdf_network_fine'])
262 | self.deviation_network.load_state_dict(checkpoint['variance_network_fine'])
263 | self.color_network.load_state_dict(checkpoint['color_network_fine'])
264 | self.optimizer.load_state_dict(checkpoint['optimizer'])
265 | self.iter_step = checkpoint['iter_step']
266 |
267 | logging.info('End')
268 |
269 | def save_checkpoint(self):
270 | checkpoint = {
271 | 'nerf': self.nerf_outside.state_dict(),
272 | 'sdf_network_fine': self.sdf_network.state_dict(),
273 | 'variance_network_fine': self.deviation_network.state_dict(),
274 | 'color_network_fine': self.color_network.state_dict(),
275 | 'optimizer': self.optimizer.state_dict(),
276 | 'iter_step': self.iter_step,
277 | }
278 |
279 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True)
280 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step)))
281 |
282 | def validate_image(self, idx=-1, resolution_level=-1):
283 | self.renderer.validate = True
284 | if idx < 0:
285 | idx = np.random.randint(self.dataset.n_images)
286 |
287 | print('Validate: iter: {}, camera: {}'.format(self.iter_step, idx))
288 |
289 | if resolution_level < 0:
290 | resolution_level = self.validate_resolution_level
291 | rays_o, rays_d = self.dataset.gen_rays_at(idx, resolution_level=resolution_level)
292 | H, W, _ = rays_o.shape
293 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
294 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
295 |
296 | out_rgb_fine = []
297 | out_normal_fine = []
298 |
299 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
300 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
301 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
302 |
303 | render_out = self.renderer.render(rays_o_batch,
304 | rays_d_batch,
305 | near,
306 | far,
307 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
308 | background_rgb=background_rgb)
309 |
310 | def feasible(key):
311 | return (key in render_out) and (render_out[key] is not None)
312 |
313 | if feasible('color_fine'):
314 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
315 | if feasible('gradients') and feasible('weights'):
316 | n_samples = self.renderer.n_samples + self.renderer.n_importance
317 | normals = render_out['gradients'] * render_out['weights'][:, :n_samples, None]
318 | if feasible('inside_sphere'):
319 | normals = normals * render_out['inside_sphere'][..., None]
320 | normals = normals.sum(dim=1).detach().cpu().numpy()
321 | out_normal_fine.append(normals)
322 | del render_out
323 |
324 | img_fine = None
325 | if len(out_rgb_fine) > 0:
326 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3, -1]) * 256).clip(0, 255)
327 |
328 | normal_img = None
329 | if len(out_normal_fine) > 0:
330 | normal_img = np.concatenate(out_normal_fine, axis=0)
331 | rot = np.linalg.inv(self.dataset.pose_all[idx, :3, :3].detach().cpu().numpy())
332 | normal_img = (np.matmul(rot[None, :, :], normal_img[:, :, None])
333 | .reshape([H, W, 3, -1]) * 128 + 128).clip(0, 255)
334 |
335 | os.makedirs(os.path.join(self.base_exp_dir, 'validations_fine'), exist_ok=True)
336 | os.makedirs(os.path.join(self.base_exp_dir, 'normals'), exist_ok=True)
337 |
338 | for i in range(img_fine.shape[-1]):
339 | if len(out_rgb_fine) > 0:
340 | cv.imwrite(os.path.join(self.base_exp_dir,
341 | 'validations_fine',
342 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
343 | np.concatenate([img_fine[..., i],
344 | self.dataset.image_at(idx, resolution_level=resolution_level)]))
345 | if len(out_normal_fine) > 0:
346 | cv.imwrite(os.path.join(self.base_exp_dir,
347 | 'normals',
348 | '{:0>8d}_{}_{}.png'.format(self.iter_step, i, idx)),
349 | normal_img[..., i])
350 |
351 | self.renderer.validate = False
352 |
353 | def render_novel_image(self, idx_0, idx_1, ratio, resolution_level):
354 | """
355 | Interpolate view between two cameras.
356 | """
357 | rays_o, rays_d = self.dataset.gen_rays_between(idx_0, idx_1, ratio, resolution_level=resolution_level)
358 | H, W, _ = rays_o.shape
359 | rays_o = rays_o.reshape(-1, 3).split(self.batch_size)
360 | rays_d = rays_d.reshape(-1, 3).split(self.batch_size)
361 |
362 | out_rgb_fine = []
363 | for rays_o_batch, rays_d_batch in zip(rays_o, rays_d):
364 | near, far = self.dataset.near_far_from_sphere(rays_o_batch, rays_d_batch)
365 | background_rgb = torch.ones([1, 3]) if self.use_white_bkgd else None
366 |
367 | render_out = self.renderer.render(rays_o_batch,
368 | rays_d_batch,
369 | near,
370 | far,
371 | cos_anneal_ratio=self.get_cos_anneal_ratio(),
372 | background_rgb=background_rgb)
373 |
374 | out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy())
375 |
376 | del render_out
377 |
378 | img_fine = (np.concatenate(out_rgb_fine, axis=0).reshape([H, W, 3]) * 256).clip(0, 255).astype(np.uint8)
379 | return img_fine
380 |
381 | def validate_mesh(self, world_space=False, resolution=64, threshold=0.0):
382 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32)
383 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32)
384 |
385 | vertices, triangles =\
386 | self.renderer.extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold)
387 | os.makedirs(os.path.join(self.base_exp_dir, 'meshes'), exist_ok=True)
388 |
389 | if world_space:
390 | vertices = vertices * self.dataset.scale_mats_np[0][0, 0] + self.dataset.scale_mats_np[0][:3, 3][None]
391 |
392 | mesh = trimesh.Trimesh(vertices, triangles)
393 | mesh.export(os.path.join(self.base_exp_dir, 'meshes', '{:0>8d}.ply'.format(self.iter_step)))
394 |
395 | logging.info('End')
396 |
397 | def interpolate_view(self, img_idx_0, img_idx_1):
398 | images = []
399 | n_frames = 60
400 | for i in range(n_frames):
401 | print(i)
402 | images.append(self.render_novel_image(img_idx_0,
403 | img_idx_1,
404 | np.sin(((i / n_frames) - 0.5) * np.pi) * 0.5 + 0.5,
405 | resolution_level=4))
406 | for i in range(n_frames):
407 | images.append(images[n_frames - i - 1])
408 |
409 | fourcc = cv.VideoWriter_fourcc(*'mp4v')
410 | video_dir = os.path.join(self.base_exp_dir, 'render')
411 | os.makedirs(video_dir, exist_ok=True)
412 | h, w, _ = images[0].shape
413 | writer = cv.VideoWriter(os.path.join(video_dir,
414 | '{:0>8d}_{}_{}.mp4'.format(self.iter_step, img_idx_0, img_idx_1)),
415 | fourcc, 30, (w, h))
416 |
417 | for image in images:
418 | writer.write(image)
419 |
420 | writer.release()
421 |
422 |
423 | if __name__ == '__main__':
424 | print('Hello Wooden')
425 |
426 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
427 |
428 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
429 | logging.basicConfig(level=logging.DEBUG, format=FORMAT)
430 |
431 | parser = argparse.ArgumentParser()
432 | parser.add_argument('--conf', type=str, default='./confs/base.conf')
433 | parser.add_argument('--mode', type=str, default='train')
434 | parser.add_argument('--mcube_threshold', type=float, default=0.0)
435 | parser.add_argument('--is_continue', default=False, action="store_true")
436 | parser.add_argument('--gpu', type=int, default=0)
437 | parser.add_argument('--case', type=str, default='')
438 |
439 | args = parser.parse_args()
440 |
441 | torch.cuda.set_device(args.gpu)
442 | runner = Runner(args.conf, args.mode, args.case, args.is_continue)
443 |
444 | if args.mode == 'train':
445 | runner.train()
446 | elif args.mode == 'validate_mesh':
447 | runner.validate_mesh(world_space=True, resolution=512, threshold=args.mcube_threshold)
448 | elif args.mode.startswith('interpolate'): # Interpolate views given two image indices
449 | _, img_idx_0, img_idx_1 = args.mode.split('_')
450 | img_idx_0 = int(img_idx_0)
451 | img_idx_1 = int(img_idx_1)
452 | runner.interpolate_view(img_idx_0, img_idx_1)
453 |
--------------------------------------------------------------------------------
/models/tensoRF.py:
--------------------------------------------------------------------------------
1 | from models.tensorBase import *
2 |
3 |
4 | class TensorVM(TensorBase):
5 | def __init__(self, aabb, gridSize, device, **kargs):
6 | super(TensorVM, self).__init__(aabb, gridSize, device, **kargs)
7 |
8 | def init_svd_volume(self, res, device):
9 | self.plane_coef = torch.nn.Parameter(
10 | 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, res), device=device))
11 | self.line_coef = torch.nn.Parameter(
12 | 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, 1), device=device))
13 | self.basis_mat = torch.nn.Linear(self.app_n_comp * 3, self.app_dim, bias=False, device=device)
14 |
15 | def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001):
16 | grad_vars = [{'params': self.line_coef, 'lr': lr_init_spatialxyz},
17 | {'params': self.plane_coef, 'lr': lr_init_spatialxyz},
18 | {'params': self.basis_mat.parameters(), 'lr': lr_init_network}]
19 | if isinstance(self.renderModule, torch.nn.Module):
20 | grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}]
21 | return grad_vars
22 |
23 | def compute_features(self, xyz_sampled):
24 |
25 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
26 | xyz_sampled[..., self.matMode[2]])).detach()
27 | coordinate_line = torch.stack(
28 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
29 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach()
30 |
31 | plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane,
32 | align_corners=True).view(
33 | -1, *xyz_sampled.shape[:1])
34 | line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view(
35 | -1, *xyz_sampled.shape[:1])
36 |
37 | sigma_feature = torch.sum(plane_feats * line_feats, dim=0)
38 |
39 | plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(
40 | 3 * self.app_n_comp, -1)
41 | line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(
42 | 3 * self.app_n_comp, -1)
43 |
44 | app_features = self.basis_mat((plane_feats * line_feats).T)
45 |
46 | return sigma_feature, app_features
47 |
48 | def compute_densityfeature(self, xyz_sampled):
49 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
50 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
51 | coordinate_line = torch.stack(
52 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
53 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
54 | 1, 2)
55 |
56 | plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane,
57 | align_corners=True).view(
58 | -1, *xyz_sampled.shape[:1])
59 | line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view(
60 | -1, *xyz_sampled.shape[:1])
61 |
62 | sigma_feature = torch.sum(plane_feats * line_feats, dim=0)
63 |
64 | return sigma_feature
65 |
66 | def compute_appfeature(self, xyz_sampled):
67 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
68 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
69 | coordinate_line = torch.stack(
70 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
71 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
72 | 1, 2)
73 |
74 | plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(
75 | 3 * self.app_n_comp, -1)
76 | line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(
77 | 3 * self.app_n_comp, -1)
78 |
79 | app_features = self.basis_mat((plane_feats * line_feats).T)
80 |
81 | return app_features
82 |
83 | def vectorDiffs(self, vector_comps):
84 | total = 0
85 |
86 | for idx in range(len(vector_comps)):
87 | # print(self.line_coef.shape, vector_comps[idx].shape)
88 | n_comp, n_size = vector_comps[idx].shape[:-1]
89 |
90 | dotp = torch.matmul(vector_comps[idx].view(n_comp, n_size),
91 | vector_comps[idx].view(n_comp, n_size).transpose(-1, -2))
92 | # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape)
93 | non_diagonal = dotp.view(-1)[1:].view(n_comp - 1, n_comp + 1)[..., :-1]
94 | # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape)
95 | total = total + torch.mean(torch.abs(non_diagonal))
96 | return total
97 |
98 | def vector_comp_diffs(self):
99 |
100 | return self.vectorDiffs(self.line_coef[:, -self.density_n_comp:]) + self.vectorDiffs(
101 | self.line_coef[:, :self.app_n_comp])
102 |
103 | @torch.no_grad()
104 | def up_sampling_VM(self, plane_coef, line_coef, res_target):
105 |
106 | for i in range(len(self.vecMode)):
107 | vec_id = self.vecMode[i]
108 | mat_id_0, mat_id_1 = self.matMode[i]
109 |
110 | plane_coef[i] = torch.nn.Parameter(
111 | F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
112 | align_corners=True))
113 | line_coef[i] = torch.nn.Parameter(
114 | F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
115 |
116 | # plane_coef[0] = torch.nn.Parameter(
117 | # F.interpolate(plane_coef[0].data, size=(res_target[1], res_target[0]), mode='bilinear',
118 | # align_corners=True))
119 | # line_coef[0] = torch.nn.Parameter(
120 | # F.interpolate(line_coef[0].data, size=(res_target[2], 1), mode='bilinear', align_corners=True))
121 | # plane_coef[1] = torch.nn.Parameter(
122 | # F.interpolate(plane_coef[1].data, size=(res_target[2], res_target[0]), mode='bilinear',
123 | # align_corners=True))
124 | # line_coef[1] = torch.nn.Parameter(
125 | # F.interpolate(line_coef[1].data, size=(res_target[1], 1), mode='bilinear', align_corners=True))
126 | # plane_coef[2] = torch.nn.Parameter(
127 | # F.interpolate(plane_coef[2].data, size=(res_target[2], res_target[1]), mode='bilinear',
128 | # align_corners=True))
129 | # line_coef[2] = torch.nn.Parameter(
130 | # F.interpolate(line_coef[2].data, size=(res_target[0], 1), mode='bilinear', align_corners=True))
131 |
132 | return plane_coef, line_coef
133 |
134 | @torch.no_grad()
135 | def upsample_volume_grid(self, res_target):
136 | # self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
137 | # self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)
138 |
139 | scale = res_target[0] / self.line_coef.shape[2] # assuming xyz have the same scale
140 | plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear',
141 | align_corners=True)
142 | line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0], 1), mode='bilinear',
143 | align_corners=True)
144 | self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef)
145 | self.compute_stepSize(res_target)
146 | print(f'upsamping to {res_target}')
147 |
148 |
149 | class TensorVMSplit(TensorBase):
150 | def __init__(self, aabb, gridSize, device, **kargs):
151 | super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)
152 |
153 | def init_svd_volume(self, res, device):
154 | self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
155 | self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
156 | self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)
157 |
158 | def init_one_svd(self, n_component, gridSize, scale, device):
159 | plane_coef, line_coef = [], []
160 | for i in range(len(self.vecMode)):
161 | vec_id = self.vecMode[i]
162 | mat_id_0, mat_id_1 = self.matMode[i]
163 | plane_coef.append(torch.nn.Parameter(
164 | scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) #
165 | line_coef.append(
166 | torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))
167 |
168 | return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)
169 |
170 | def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001):
171 | grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
172 | {'params': self.density_plane, 'lr': lr_init_spatialxyz},
173 | {'params': self.app_line, 'lr': lr_init_spatialxyz},
174 | {'params': self.app_plane, 'lr': lr_init_spatialxyz},
175 | {'params': self.basis_mat.parameters(), 'lr': lr_init_network}]
176 | if isinstance(self.renderModule, torch.nn.Module):
177 | grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}]
178 | return grad_vars
179 |
180 | def vectorDiffs(self, vector_comps):
181 | total = 0
182 |
183 | for idx in range(len(vector_comps)):
184 | n_comp, n_size = vector_comps[idx].shape[1:-1]
185 |
186 | dotp = torch.matmul(vector_comps[idx].view(n_comp, n_size),
187 | vector_comps[idx].view(n_comp, n_size).transpose(-1, -2))
188 | non_diagonal = dotp.view(-1)[1:].view(n_comp - 1, n_comp + 1)[..., :-1]
189 | total = total + torch.mean(torch.abs(non_diagonal))
190 | return total
191 |
192 | def vector_comp_diffs(self):
193 | return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)
194 |
195 | def density_L1(self):
196 | total = 0
197 | for idx in range(len(self.density_plane)):
198 | total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[
199 | idx])) # + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))
200 | return total
201 |
202 | def TV_loss_density(self, reg):
203 | total = 0
204 | for idx in range(len(self.density_plane)):
205 | total = total + reg(self.density_plane[idx]) * 1e-2 # + reg(self.density_line[idx]) * 1e-3
206 | return total
207 |
208 | def TV_loss_app(self, reg):
209 | total = 0
210 | for idx in range(len(self.app_plane)):
211 | total = total + reg(self.app_plane[idx]) * 1e-2 # + reg(self.app_line[idx]) * 1e-3
212 | return total
213 |
214 | def compute_densityfeature(self, xyz_sampled, requires_grad=False):
215 |
216 | # plane + line basis
217 | if requires_grad:
218 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
219 | xyz_sampled[..., self.matMode[2]])).view(3, -1, 1, 2)
220 | else:
221 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
222 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
223 | coordinate_line = torch.stack(
224 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
225 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
226 | 1, 2)
227 |
228 | sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)
229 | for idx_plane in range(len(self.density_plane)):
230 | plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],
231 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
232 | line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],
233 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
234 | sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)
235 |
236 | return sigma_feature
237 |
238 | def compute_appfeature(self, xyz_sampled):
239 |
240 | # plane + line basis
241 | coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]],
242 | xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
243 | coordinate_line = torch.stack(
244 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
245 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
246 | 1, 2)
247 |
248 | plane_coef_point, line_coef_point = [], []
249 | for idx_plane in range(len(self.app_plane)):
250 | plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],
251 | align_corners=True).view(-1, *xyz_sampled.shape[:1]))
252 | line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],
253 | align_corners=True).view(-1, *xyz_sampled.shape[:1]))
254 | plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)
255 |
256 | return self.basis_mat((plane_coef_point * line_coef_point).T)
257 |
258 | @torch.no_grad()
259 | def up_sampling_VM(self, plane_coef, line_coef, res_target):
260 |
261 | for i in range(len(self.vecMode)):
262 | vec_id = self.vecMode[i]
263 | mat_id_0, mat_id_1 = self.matMode[i]
264 | plane_coef[i] = torch.nn.Parameter(
265 | F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
266 | align_corners=True))
267 | line_coef[i] = torch.nn.Parameter(
268 | F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
269 |
270 | return plane_coef, line_coef
271 |
272 | @torch.no_grad()
273 | def upsample_volume_grid(self, res_target):
274 | self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
275 | self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)
276 |
277 | self.update_stepSize(res_target)
278 | print(f'upsamping to {res_target}')
279 |
280 | @torch.no_grad()
281 | def shrink(self, new_aabb):
282 | print("====> shrinking ...")
283 | xyz_min, xyz_max = new_aabb
284 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
285 | # print(new_aabb, self.aabb)
286 | # print(t_l, b_r,self.alphaMask.alpha_volume.shape)
287 | t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
288 | b_r = torch.stack([b_r, self.gridSize]).amin(0)
289 |
290 | for i in range(len(self.vecMode)):
291 | mode0 = self.vecMode[i]
292 | self.density_line[i] = torch.nn.Parameter(
293 | self.density_line[i].data[..., t_l[mode0]:b_r[mode0], :]
294 | )
295 | self.app_line[i] = torch.nn.Parameter(
296 | self.app_line[i].data[..., t_l[mode0]:b_r[mode0], :]
297 | )
298 | mode0, mode1 = self.matMode[i]
299 | self.density_plane[i] = torch.nn.Parameter(
300 | self.density_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]]
301 | )
302 | self.app_plane[i] = torch.nn.Parameter(
303 | self.app_plane[i].data[..., t_l[mode1]:b_r[mode1], t_l[mode0]:b_r[mode0]]
304 | )
305 |
306 | if not torch.all(self.alphaMask.gridSize == self.gridSize):
307 | t_l_r, b_r_r = t_l / (self.gridSize - 1), (b_r - 1) / (self.gridSize - 1)
308 | correct_aabb = torch.zeros_like(new_aabb)
309 | correct_aabb[0] = (1 - t_l_r) * self.aabb[0] + t_l_r * self.aabb[1]
310 | correct_aabb[1] = (1 - b_r_r) * self.aabb[0] + b_r_r * self.aabb[1]
311 | print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
312 | new_aabb = correct_aabb
313 |
314 | newSize = b_r - t_l
315 | self.aabb = new_aabb
316 | self.update_stepSize((newSize[0], newSize[1], newSize[2]))
317 |
318 |
319 | class TensorCP(TensorBase):
320 | def __init__(self, aabb, gridSize, device, **kargs):
321 | super(TensorCP, self).__init__(aabb, gridSize, device, **kargs)
322 |
323 | def init_svd_volume(self, res, device):
324 | self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device)
325 | self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device)
326 | self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device)
327 |
328 | def init_one_svd(self, n_component, gridSize, scale, device):
329 | line_coef = []
330 | for i in range(len(self.vecMode)):
331 | vec_id = self.vecMode[i]
332 | line_coef.append(
333 | torch.nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1))))
334 | return torch.nn.ParameterList(line_coef).to(device)
335 |
336 | def get_optparam_groups(self, lr_init_spatialxyz=0.02, lr_init_network=0.001):
337 | grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
338 | {'params': self.app_line, 'lr': lr_init_spatialxyz},
339 | {'params': self.basis_mat.parameters(), 'lr': lr_init_network}]
340 | if isinstance(self.renderModule, torch.nn.Module):
341 | grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_init_network}]
342 | return grad_vars
343 |
344 | def compute_densityfeature(self, xyz_sampled):
345 |
346 | coordinate_line = torch.stack(
347 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
348 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
349 | 1, 2)
350 |
351 | line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]],
352 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
353 | line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]],
354 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
355 | line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]],
356 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
357 | sigma_feature = torch.sum(line_coef_point, dim=0)
358 |
359 | return sigma_feature
360 |
361 | def compute_appfeature(self, xyz_sampled):
362 |
363 | coordinate_line = torch.stack(
364 | (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
365 | coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1,
366 | 1, 2)
367 |
368 | line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]],
369 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
370 | line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]],
371 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
372 | line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]],
373 | align_corners=True).view(-1, *xyz_sampled.shape[:1])
374 |
375 | return self.basis_mat(line_coef_point.T)
376 |
377 | @torch.no_grad()
378 | def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target):
379 |
380 | for i in range(len(self.vecMode)):
381 | vec_id = self.vecMode[i]
382 | density_line_coef[i] = torch.nn.Parameter(
383 | F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear',
384 | align_corners=True))
385 | app_line_coef[i] = torch.nn.Parameter(
386 | F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
387 |
388 | return density_line_coef, app_line_coef
389 |
390 | @torch.no_grad()
391 | def upsample_volume_grid(self, res_target):
392 | self.density_line, self.app_line = self.up_sampling_Vector(self.density_line, self.app_line, res_target)
393 |
394 | self.update_stepSize(res_target)
395 | print(f'upsamping to {res_target}')
396 |
397 | @torch.no_grad()
398 | def shrink(self, new_aabb):
399 | print("====> shrinking ...")
400 | xyz_min, xyz_max = new_aabb
401 | t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
402 |
403 | t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
404 | b_r = torch.stack([b_r, self.gridSize]).amin(0)
405 |
406 | for i in range(len(self.vecMode)):
407 | mode0 = self.vecMode[i]
408 | self.density_line[i] = torch.nn.Parameter(
409 | self.density_line[i].data[..., t_l[mode0]:b_r[mode0], :]
410 | )
411 | self.app_line[i] = torch.nn.Parameter(
412 | self.app_line[i].data[..., t_l[mode0]:b_r[mode0], :]
413 | )
414 |
415 | if not torch.all(self.alphaMask.gridSize == self.gridSize):
416 | t_l_r, b_r_r = t_l / (self.gridSize - 1), (b_r - 1) / (self.gridSize - 1)
417 | correct_aabb = torch.zeros_like(new_aabb)
418 | correct_aabb[0] = (1 - t_l_r) * self.aabb[0] + t_l_r * self.aabb[1]
419 | correct_aabb[1] = (1 - b_r_r) * self.aabb[0] + b_r_r * self.aabb[1]
420 | print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
421 | new_aabb = correct_aabb
422 |
423 | newSize = b_r - t_l
424 | self.aabb = new_aabb
425 | self.update_stepSize((newSize[0], newSize[1], newSize[2]))
426 |
427 | def density_L1(self):
428 | total = 0
429 | for idx in range(len(self.density_line)):
430 | total = total + torch.mean(torch.abs(self.density_line[idx]))
431 | return total
432 |
433 | def TV_loss_density(self, reg):
434 | total = 0
435 | for idx in range(len(self.density_line)):
436 | total = total + reg(self.density_line[idx]) * 1e-3
437 | return total
438 |
439 | def TV_loss_app(self, reg):
440 | total = 0
441 | for idx in range(len(self.app_line)):
442 | total = total + reg(self.app_line[idx]) * 1e-3
443 | return total
--------------------------------------------------------------------------------