├── LICENSE
├── README.md
├── code
├── confs
│ ├── replica_objsdfplus.conf
│ ├── scannet_objsdfplus.conf
│ └── scannet_objsdfplus_mlp.conf
├── datasets
│ └── scene_dataset.py
├── hashencoder
│ ├── __init__.py
│ ├── backend.py
│ ├── hashgrid.py
│ └── src
│ │ ├── bindings.cpp
│ │ ├── hashencoder.cu
│ │ └── hashencoder.h
├── model
│ ├── density.py
│ ├── embedder.py
│ ├── loss.py
│ ├── network.py
│ └── ray_sampler.py
├── training
│ ├── exp_runner.py
│ └── objectsdfplus_train.py
└── utils
│ ├── general.py
│ ├── plots.py
│ ├── rend_util.py
│ └── sem_util.py
├── media
└── teaser.gif
├── preprocess
├── README.md
├── extract_monocular_cues.py
├── replica_to_objsdfpp.py
└── scannet_to_objsdfpp.py
├── replica_eval
├── avg_metric.py
├── cull_mesh.py
├── cull_obj_gt.py
├── eval_3D_obj.py
├── eval_recon.py
├── evaluate.py
├── evaluate_single_scene.py
└── metrics.py
├── requirements.txt
├── scannet_eval
└── evaluate.py
└── scripts
├── download_dataset.sh
└── download_meshes.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Qianyi Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
ObjectSDF++: Improved Object-Compositional Neural Implicit Surfaces
4 |
5 | Qianyi Wu
6 | ·
7 | Kaisiyuan Wang
8 | ·
9 | Kejie Li
10 | ·
11 | Jianmin Zheng
12 | ·
13 | Jianfei Cai
14 |
15 |
16 | ICCV 2023
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 | TL; DR: We propose an occlusion-aware opacity rendering formulation to better use the instance mask supervision. Together with an object-distinction regularization term, the proposed ObjectSDF++ produces more accurate surface reconstruction at both scene and object levels.
29 |
30 |
31 |
32 | # Setup
33 |
34 | ## Installation
35 | This code has been tested on Ubuntu 22.02 with torch 2.0 & CUDA 11.7 on a RTX 3090.
36 | Clone the repository and create an anaconda environment named objsdf
37 | ```
38 | git clone https://github.com/QianyiWu/objectsdf_plus.git
39 | cd objectsdf_plus
40 |
41 | conda create -y -n objsdf python=3.9
42 | conda activate object
43 |
44 | pip install -r requirements.txt
45 | ```
46 | The hash encoder will be compiled on the fly when running the code.
47 |
48 | ## Dataset
49 | For downloading the preprocessed data, run the following script. The data for the Replica and ScanNet is adapted from [MonoSDF](https://github.com/autonomousvision/monosdf), [vMAP](https://github.com/kxhit/vMAP).
50 | ```
51 | bash scripts/download_dataset.sh
52 | ```
53 | # Training
54 |
55 | Run the following command to train ObjectSDF++:
56 | ```
57 | cd ./code
58 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 training/exp_runner.py --conf CONFIG --scan_id SCAN_ID
59 | ```
60 | where CONFIG is the config file in `code/confs`, and SCAN_ID is the id of the scene to reconstruct.
61 |
62 | We provide example commands for training Replica dataset as follows:
63 | ```
64 | # Replica scan 1 (room0)
65 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 training/exp_runner.py --conf confs/replica_objsdfplus.conf --scan_id 1
66 |
67 | # ScanNet scan 1 (scene_0050_00)
68 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 training/exp_runner.py --conf confs/scannet_objsdfplus.conf --scan_id 1
69 |
70 | ```
71 | The intermediate results and checkpoints will be saved in ``exps`` folder.
72 |
73 | # Evaluations
74 |
75 | ## Replica
76 | Evaluate one scene (take scan 1 room0 for example)
77 | ```
78 | cd replica_eval
79 | python evaluate_single_scene.py --input_mesh replica_scan1_mesh.ply --scan_id 1 --output_dir replica_scan1
80 | ```
81 |
82 | We also provided scripts for evaluating all Replica scenes and objects:
83 | ```
84 | cd replica_eval
85 | python evaluate.py # scene-level evaluation
86 | python evaluate_3D_obj.py # object-level evaluation
87 | ```
88 | please check the script for more details. For obtaining the object groundtruth, you can refer to [here](https://github.com/kxhit/vMAP#dataset) for more details.
89 |
90 | ## ScanNet
91 | ```
92 | cd scannet_eval
93 | python evaluate.py
94 | ```
95 | please check the script for more details.
96 |
97 | # Acknowledgements
98 | This project is built upon [MonoSDF](https://github.com/autonomousvision/monosdf). The monocular depth and normal images are obtained by [Omnidata](https://omnidata.vision). The evaluation of object reconstruction is inspired by [vMAP](https://github.com/kxhit/vMAP). Cuda implementation of Multi-Resolution hash encoding is heavily based on [torch-ngp](https://github.com/ashawkey/torch-ngp). Kudos to these researchers.
99 |
100 |
101 | # Citation
102 | If you find our code or paper useful, please cite the series of ObjectSDF works.
103 | ```BibTeX
104 | @inproceedings{wu2022object,
105 | title = {Object-compositional neural implicit surfaces},
106 | author = {Wu, Qianyi and Liu, Xian and Chen, Yuedong and Li, Kejie and Zheng, Chuanxia and Cai, Jianfei and Zheng, Jianmin},
107 | booktitle = {European Conference on Computer Vision},
108 | year = {2022},
109 | }
110 |
111 | @inproceedings{wu2023objsdfplus,
112 | author = {Wu, Qianyi and Wang, Kaisiyuan and Li, Kejie and Zheng, Jianmin and Cai, Jianfei},
113 | title = {ObjectSDF++: Improved Object-Compositional Neural Implicit Surfaces},
114 | booktitle = {ICCV},
115 | year = {2023},
116 | }
117 | ```
118 |
119 |
--------------------------------------------------------------------------------
/code/confs/replica_objsdfplus.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = objectsdfplus_replica
3 | dataset_class = datasets.scene_dataset.SceneDatasetDN_segs
4 | model_class = model.network.ObjectSDFPlusNetwork
5 | loss_class = model.loss.ObjectSDFPlusLoss
6 | learning_rate = 5.0e-4
7 | lr_factor_for_grid = 20.0
8 | num_pixels = 1024
9 | checkpoint_freq = 100
10 | plot_freq = 50
11 | split_n_pixels = 1024
12 | add_objectvio_iter = 100000
13 | }
14 | plot{
15 | plot_nimgs = 1
16 | resolution = 256
17 | grid_boundary = [-1.0, 1.0]
18 | }
19 | loss{
20 | rgb_loss = torch.nn.L1Loss
21 | eikonal_weight = 0.1
22 | smooth_weight = 0.005
23 | depth_weight = 0.1
24 | normal_l1_weight = 0.05
25 | normal_cos_weight = 0.05
26 | semantic_loss = torch.nn.MSELoss
27 | use_obj_opacity = True
28 | semantic_weight = 1
29 | reg_vio_weight = 0.5
30 | }
31 | dataset{
32 | data_dir = replica
33 | img_res = [384, 384]
34 | use_mask = True
35 | center_crop_type = center_crop_for_replica
36 | }
37 | model{
38 | feature_vector_size = 256
39 | scene_bounding_sphere = 1.0
40 |
41 | Grid_MLP = True
42 |
43 | implicit_network
44 | {
45 | d_in = 3
46 | d_out = 46
47 | dims = [256, 256]
48 | geometric_init = True
49 | bias = 0.9
50 | skip_in = [4]
51 | weight_norm = True
52 | multires = 6
53 | inside_outside = True
54 | use_grid_feature = True
55 | divide_factor = 1.0
56 | sigmoid = 10
57 | }
58 |
59 | rendering_network
60 | {
61 | mode = idr
62 | d_in = 9
63 | d_out = 3
64 | dims = [256, 256]
65 | weight_norm = True
66 | multires_view = 4
67 | per_image_code = True
68 | }
69 | density
70 | {
71 | params_init{
72 | beta = 0.1
73 | }
74 | beta_min = 0.0001
75 | }
76 | ray_sampler
77 | {
78 | near = 0.0
79 | N_samples = 64
80 | N_samples_eval = 128
81 | N_samples_extra = 32
82 | eps = 0.1
83 | beta_iters = 10
84 | max_total_iters = 5
85 | }
86 | }
87 |
--------------------------------------------------------------------------------
/code/confs/scannet_objsdfplus.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = objectsdfplus_grid_scannet
3 | dataset_class = datasets.scene_dataset.SceneDatasetDN_segs
4 | model_class = model.network.ObjectSDFPlusNetwork
5 | loss_class = model.loss.ObjectSDFPlusLoss
6 | learning_rate = 5.0e-4
7 | lr_factor_for_grid = 20.0
8 | num_pixels = 1024
9 | checkpoint_freq = 100
10 | plot_freq = 50
11 | split_n_pixels = 1024
12 | }
13 | plot{
14 | plot_nimgs = 1
15 | resolution = 256
16 | grid_boundary = [-1.1, 1.1]
17 | }
18 | loss{
19 | rgb_loss = torch.nn.L1Loss
20 | eikonal_weight = 0.05
21 | smooth_weight = 0.005
22 | depth_weight = 0.1
23 | normal_l1_weight = 0.05
24 | normal_cos_weight = 0.05
25 | semantic_loss = torch.nn.MSELoss
26 | use_obj_opacity = True
27 | semantic_weight = 0.5
28 | reg_vio_weight = 0.1
29 | }
30 | dataset{
31 | data_dir = scannet
32 | img_res = [384, 384]
33 | center_crop_type = no_crop
34 | }
35 | model{
36 | feature_vector_size = 256
37 | scene_bounding_sphere = 1.0
38 |
39 | Grid_MLP = True
40 |
41 | implicit_network
42 | {
43 | d_in = 3
44 | d_out = 32
45 | dims = [256, 256]
46 | geometric_init = True
47 | bias = 0.9
48 | skip_in = [4]
49 | weight_norm = True
50 | multires = 6
51 | inside_outside = True
52 | use_grid_feature = True
53 | divide_factor = 1.1
54 | sigmoid = 10
55 | }
56 |
57 | rendering_network
58 | {
59 | mode = idr
60 | d_in = 9
61 | d_out = 3
62 | dims = [256, 256]
63 | weight_norm = True
64 | multires_view = 4
65 | per_image_code = True
66 | }
67 | density
68 | {
69 | params_init{
70 | beta = 0.1
71 | }
72 | beta_min = 0.0001
73 | }
74 | ray_sampler
75 | {
76 | near = 0.0
77 | N_samples = 64
78 | N_samples_eval = 128
79 | N_samples_extra = 32
80 | eps = 0.1
81 | beta_iters = 10
82 | max_total_iters = 5
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/code/confs/scannet_objsdfplus_mlp.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = objectsdfplus_mlp_scannet
3 | dataset_class = datasets.scene_dataset.SceneDatasetDN_segs
4 | model_class = model.network.ObjectSDFPlusNetwork
5 | loss_class = model.loss.ObjectSDFPlusLoss
6 | learning_rate = 5.0e-4
7 | lr_factor_for_grid = 20.0
8 | num_pixels = 1024
9 | checkpoint_freq = 100
10 | plot_freq = 50
11 | split_n_pixels = 1024
12 | }
13 | plot{
14 | plot_nimgs = 1
15 | resolution = 256
16 | grid_boundary = [-1.1, 1.1]
17 | }
18 | loss{
19 | rgb_loss = torch.nn.L1Loss
20 | eikonal_weight = 0.05
21 | smooth_weight = 0.005
22 | depth_weight = 0.1
23 | normal_l1_weight = 0.05
24 | normal_cos_weight = 0.05
25 | semantic_loss = torch.nn.MSELoss
26 | use_obj_opacity = True
27 | semantic_weight = 0.5
28 | reg_vio_weight = 0.1
29 | }
30 | dataset{
31 | data_dir = scannet
32 | img_res = [384, 384]
33 | center_crop_type = no_crop
34 | }
35 | model{
36 | feature_vector_size = 256
37 | scene_bounding_sphere = 1.0
38 |
39 | Grid_MLP = True
40 |
41 | implicit_network
42 | {
43 | d_in = 3
44 | d_out = 32
45 | dims = [256, 256, 256, 256, 256, 256, 256, 256]
46 | geometric_init = True
47 | bias = 0.9
48 | skip_in = [4]
49 | weight_norm = True
50 | multires = 6
51 | inside_outside = True
52 | use_grid_feature = False
53 | divide_factor = 1.1
54 | sigmoid = 10
55 | }
56 |
57 | rendering_network
58 | {
59 | mode = idr
60 | d_in = 9
61 | d_out = 3
62 | dims = [256, 256]
63 | weight_norm = True
64 | multires_view = 4
65 | per_image_code = True
66 | }
67 | density
68 | {
69 | params_init{
70 | beta = 0.1
71 | }
72 | beta_min = 0.0001
73 | }
74 | ray_sampler
75 | {
76 | near = 0.0
77 | N_samples = 64
78 | N_samples_eval = 128
79 | N_samples_extra = 32
80 | eps = 0.1
81 | beta_iters = 10
82 | max_total_iters = 5
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/code/datasets/scene_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | import utils.general as utils
7 | from utils import rend_util
8 | from glob import glob
9 | import cv2
10 | import random
11 |
12 | # Dataset with monocular depth, normal and segmentation mask
13 | class SceneDatasetDN_segs(torch.utils.data.Dataset):
14 |
15 | def __init__(self,
16 | data_dir,
17 | img_res,
18 | scan_id=0,
19 | center_crop_type='xxxx',
20 | use_mask=False,
21 | num_views=-1
22 | ):
23 |
24 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id))
25 | print(self.instance_dir)
26 |
27 | self.total_pixels = img_res[0] * img_res[1]
28 | self.img_res = img_res
29 | self.num_views = num_views
30 | assert num_views in [-1, 3, 6, 9]
31 |
32 | assert os.path.exists(self.instance_dir), "Data directory is empty"
33 |
34 | self.sampling_idx = None
35 |
36 | def glob_data(data_dir):
37 | data_paths = []
38 | data_paths.extend(glob(data_dir))
39 | data_paths = sorted(data_paths)
40 | return data_paths
41 |
42 | image_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_rgb.png"))
43 | depth_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_depth.npy"))
44 | normal_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_normal.npy"))
45 | # semantic_paths
46 | semantic_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "segs", "*_segs.png"))
47 |
48 | # mask is only used in the replica dataset as some monocular depth predictions have very large error and we ignore it
49 | if use_mask:
50 | mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_mask.npy"))
51 | else:
52 | mask_paths = None
53 |
54 | self.n_images = len(image_paths)
55 | print('[INFO]: Dataset Size ', self.n_images)
56 |
57 | # load instance_mapping_dict
58 | self.label_mapping = None
59 | self.instance_mapping_dict= {}
60 | # with open(os.path.join(data_dir, 'label_mapping_instance.txt'), 'r') as f:
61 | # content = f.readlines()
62 | # self.label_mapping = [int(a) for a in content[0].split(',')]
63 |
64 | # using the remapped instance label for training
65 | with open(os.path.join(self.instance_dir, 'instance_mapping.txt'), 'r') as f:
66 | for l in f:
67 | (k, v_sem, v_ins) = l.split(',')
68 | self.instance_mapping_dict[int(k)] = int(v_ins)
69 | # self.label_mapping = [] # get the sorted label mapping list
70 | # for k, v in self.instance_mapping_dict.items():
71 | # if v not in self.label_mapping: # not a duplicate instance
72 | # self.label_mapping.append(v)
73 | self.label_mapping = sorted(set(self.instance_mapping_dict.values())) # get sorted label mapping. The first one is the background
74 | # sorted
75 | # print('Instance Label Mapping: ', self.label_mapping)
76 | self.obj_id = torch.from_numpy(np.array(range(len(self.label_mapping))))
77 | print(self.obj_id)
78 |
79 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir)
80 | camera_dict = np.load(self.cam_file)
81 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
82 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
83 |
84 | self.intrinsics_all = []
85 | self.pose_all = []
86 | for scale_mat, world_mat in zip(scale_mats, world_mats):
87 | P = world_mat @ scale_mat
88 | P = P[:3, :4]
89 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
90 |
91 | # because we do resize and center crop 384x384 when using omnidata model, we need to adjust the camera intrinsic accordingly
92 | if center_crop_type == 'center_crop_for_replica':
93 | scale = 384 / 680
94 | offset = (1200 - 680 ) * 0.5
95 | intrinsics[0, 2] -= offset
96 | intrinsics[:2, :] *= scale
97 | elif center_crop_type == 'center_crop_for_tnt':
98 | scale = 384 / 540
99 | offset = (960 - 540) * 0.5
100 | intrinsics[0, 2] -= offset
101 | intrinsics[:2, :] *= scale
102 | elif center_crop_type == 'center_crop_for_dtu':
103 | scale = 384 / 1200
104 | offset = (1600 - 1200) * 0.5
105 | intrinsics[0, 2] -= offset
106 | intrinsics[:2, :] *= scale
107 | elif center_crop_type == 'padded_for_dtu':
108 | scale = 384 / 1200
109 | offset = 0
110 | intrinsics[0, 2] -= offset
111 | intrinsics[:2, :] *= scale
112 | elif center_crop_type == 'no_crop': # for scannet dataset, we already adjust the camera intrinsic duing preprocessing so nothing to be done here
113 | pass
114 | else:
115 | raise NotImplementedError
116 |
117 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
118 | self.pose_all.append(torch.from_numpy(pose).float())
119 |
120 | self.rgb_images = []
121 | for path in image_paths:
122 | rgb = rend_util.load_rgb(path)
123 | rgb = rgb.reshape(3, -1).transpose(1, 0)
124 | self.rgb_images.append(torch.from_numpy(rgb).float())
125 |
126 | self.depth_images = []
127 | self.normal_images = []
128 |
129 | for dpath, npath in zip(depth_paths, normal_paths):
130 | depth = np.load(dpath)
131 | self.depth_images.append(torch.from_numpy(depth.reshape(-1, 1)).float())
132 |
133 | normal = np.load(npath)
134 | normal = normal.reshape(3, -1).transpose(1, 0)
135 | # important as the output of omnidata is normalized
136 | normal = normal * 2. - 1.
137 | self.normal_images.append(torch.from_numpy(normal).float())
138 |
139 | # load semantic
140 | self.semantic_images = []
141 | for spath in semantic_paths:
142 | semantic_ori = cv2.imread(spath, cv2.IMREAD_UNCHANGED).astype(np.int32)
143 | semantic = np.copy(semantic_ori)
144 | ins_list = np.unique(semantic_ori)
145 | if self.label_mapping is not None:
146 | for i in ins_list:
147 | semantic[semantic_ori ==i] = self.label_mapping.index(self.instance_mapping_dict[i])
148 | self.semantic_images.append(torch.from_numpy(semantic.reshape(-1, 1)).float())
149 |
150 | # load mask
151 | self.mask_images = []
152 | if mask_paths is None:
153 | for depth in self.depth_images:
154 | mask = torch.ones_like(depth)
155 | self.mask_images.append(mask)
156 | else:
157 | for path in mask_paths:
158 | mask = np.load(path)
159 | self.mask_images.append(torch.from_numpy(mask.reshape(-1, 1)).float())
160 |
161 | def __len__(self):
162 | return self.n_images
163 |
164 | def __getitem__(self, idx):
165 | if self.num_views >= 0:
166 | image_ids = [25, 22, 28, 40, 44, 48, 0, 8, 13][:self.num_views]
167 | idx = image_ids[random.randint(0, self.num_views - 1)]
168 |
169 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
170 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
171 | uv = uv.reshape(2, -1).transpose(1, 0)
172 |
173 | sample = {
174 | "uv": uv,
175 | "intrinsics": self.intrinsics_all[idx],
176 | "pose": self.pose_all[idx]
177 | }
178 |
179 | ground_truth = {
180 | "rgb": self.rgb_images[idx],
181 | "depth": self.depth_images[idx],
182 | "mask": self.mask_images[idx],
183 | "normal": self.normal_images[idx],
184 | "segs": self.semantic_images[idx]
185 | }
186 |
187 | if self.sampling_idx is not None:
188 | if (self.random_image_for_path is None) or (idx not in self.random_image_for_path):
189 | # print('sampling_idx:', self.sampling_idx)
190 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :]
191 | ground_truth["full_rgb"] = self.rgb_images[idx]
192 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :]
193 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :]
194 | ground_truth["full_depth"] = self.depth_images[idx]
195 | ground_truth["mask"] = self.mask_images[idx][self.sampling_idx, :]
196 | ground_truth["full_mask"] = self.mask_images[idx]
197 | ground_truth["segs"] = self.semantic_images[idx][self.sampling_idx, :]
198 |
199 | sample["uv"] = uv[self.sampling_idx, :]
200 | sample["is_patch"] = torch.tensor([False])
201 | else:
202 | # sampling a patch from the image, this could be used for training with depth total variational loss
203 | # a fix patch sampling, which require the sampling_size should be a H*H continuous path
204 | patch_size = np.floor(np.sqrt(len(self.sampling_idx))).astype(np.int32)
205 | start = np.random.randint(self.img_res[1]-patch_size +1)*self.img_res[0] + np.random.randint(self.img_res[1]-patch_size +1) # the start coordinate
206 | idx_row = torch.arange(start, start + patch_size)
207 | patch_sampling_idx = torch.cat([idx_row + self.img_res[1]*m for m in range(patch_size)])
208 | ground_truth["rgb"] = self.rgb_images[idx][patch_sampling_idx, :]
209 | ground_truth["full_rgb"] = self.rgb_images[idx]
210 | ground_truth["normal"] = self.normal_images[idx][patch_sampling_idx, :]
211 | ground_truth["depth"] = self.depth_images[idx][patch_sampling_idx, :]
212 | ground_truth["full_depth"] = self.depth_images[idx]
213 | ground_truth["mask"] = self.mask_images[idx][patch_sampling_idx, :]
214 | ground_truth["full_mask"] = self.mask_images[idx]
215 | ground_truth["segs"] = self.semantic_images[idx][patch_sampling_idx, :]
216 |
217 | sample["uv"] = uv[patch_sampling_idx, :]
218 | sample["is_patch"] = torch.tensor([True])
219 |
220 | return idx, sample, ground_truth
221 |
222 |
223 | def collate_fn(self, batch_list):
224 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
225 | batch_list = zip(*batch_list)
226 |
227 | all_parsed = []
228 | for entry in batch_list:
229 | if type(entry[0]) is dict:
230 | # make them all into a new dict
231 | ret = {}
232 | for k in entry[0].keys():
233 | ret[k] = torch.stack([obj[k] for obj in entry])
234 | all_parsed.append(ret)
235 | else:
236 | all_parsed.append(torch.LongTensor(entry))
237 |
238 | return tuple(all_parsed)
239 |
240 | def change_sampling_idx(self, sampling_size, sampling_pattern='random'):
241 | if sampling_size == -1:
242 | self.sampling_idx = None
243 | self.random_image_for_path = None
244 | else:
245 | if sampling_pattern == 'random':
246 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
247 | self.random_image_for_path = None
248 | elif sampling_pattern == 'patch':
249 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
250 | self.random_image_for_path = torch.randperm(self.n_images, )[:int(self.n_images/10)]
251 | else:
252 | raise NotImplementedError('the sampling pattern is not implemented.')
253 |
254 | def get_scale_mat(self):
255 | return np.load(self.cam_file)['scale_mat_0']
--------------------------------------------------------------------------------
/code/hashencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .hashgrid import HashEncoder
--------------------------------------------------------------------------------
/code/hashencoder/backend.py:
--------------------------------------------------------------------------------
1 | from distutils.command.build import build
2 | import os
3 | from torch.utils.cpp_extension import load
4 | from pathlib import Path
5 |
6 | Path('./tmp_build/').mkdir(parents=True, exist_ok=True)
7 |
8 | _src_path = os.path.dirname(os.path.abspath(__file__))
9 |
10 | _backend = load(name='_hash_encoder',
11 | extra_cflags=['-O3', '-std=c++14'],
12 | extra_cuda_cflags=[
13 | '-O3', '-std=c++14', '-allow-unsupported-compiler',
14 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
15 | ],
16 | sources=[os.path.join(_src_path, 'src', f) for f in [
17 | 'hashencoder.cu',
18 | 'bindings.cpp',
19 | ]],
20 | build_directory='./tmp_build/',
21 | verbose=True,
22 | )
23 |
24 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/code/hashencoder/hashgrid.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from math import ceil
3 | from cachetools import cached
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Function
9 | from torch.autograd.function import once_differentiable
10 | from torch.cuda.amp import custom_bwd, custom_fwd
11 |
12 | from .backend import _backend
13 |
14 | class _hash_encode(Function):
15 | @staticmethod
16 | @custom_fwd(cast_inputs=torch.half)
17 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False):
18 | # inputs: [B, D], float in [0, 1]
19 | # embeddings: [sO, C], float
20 | # offsets: [L + 1], int
21 | # RETURN: [B, F], float
22 |
23 | inputs = inputs.contiguous()
24 | embeddings = embeddings.contiguous()
25 | offsets = offsets.contiguous()
26 |
27 | B, D = inputs.shape # batch size, coord dim
28 | L = offsets.shape[0] - 1 # level
29 | C = embeddings.shape[1] # embedding dim for each level
30 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
31 | H = base_resolution # base resolution
32 |
33 | # L first, optimize cache for cuda kernel, but needs an extra permute later
34 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=inputs.dtype)
35 |
36 | if calc_grad_inputs:
37 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=inputs.dtype)
38 | else:
39 | dy_dx = torch.empty(1, device=inputs.device, dtype=inputs.dtype)
40 |
41 | _backend.hash_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx)
42 |
43 | # permute back to [B, L * C]
44 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
45 |
46 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
47 | ctx.dims = [B, D, C, L, S, H]
48 | ctx.calc_grad_inputs = calc_grad_inputs
49 |
50 | return outputs
51 |
52 | @staticmethod
53 | @custom_bwd
54 | def backward(ctx, grad):
55 |
56 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
57 | B, D, C, L, S, H = ctx.dims
58 | calc_grad_inputs = ctx.calc_grad_inputs
59 |
60 | # grad: [B, L * C] --> [L, B, C]
61 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
62 |
63 | grad_inputs, grad_embeddings = _hash_encode_second_backward.apply(grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx)
64 |
65 | if calc_grad_inputs:
66 | return grad_inputs, grad_embeddings, None, None, None, None
67 | else:
68 | return None, grad_embeddings, None, None, None, None
69 |
70 |
71 | class _hash_encode_second_backward(Function):
72 | @staticmethod
73 | def forward(ctx, grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx):
74 |
75 | grad_inputs = torch.zeros_like(inputs)
76 | grad_embeddings = torch.zeros_like(embeddings)
77 |
78 | ctx.save_for_backward(grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings)
79 | ctx.dims = [B, D, C, L, S, H]
80 | ctx.calc_grad_inputs = calc_grad_inputs
81 |
82 | _backend.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs)
83 |
84 | return grad_inputs, grad_embeddings
85 |
86 | @staticmethod
87 | def backward(ctx, grad_grad_inputs, grad_grad_embeddings):
88 |
89 | grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings = ctx.saved_tensors
90 | B, D, C, L, S, H = ctx.dims
91 | calc_grad_inputs = ctx.calc_grad_inputs
92 |
93 | grad_grad = torch.zeros_like(grad)
94 | grad2_embeddings = torch.zeros_like(embeddings)
95 |
96 | _backend.hash_encode_second_backward(grad, inputs, embeddings, offsets,
97 | B, D, C, L, S, H, calc_grad_inputs, dy_dx,
98 | grad_grad_inputs,
99 | grad_grad, grad2_embeddings)
100 |
101 | return grad_grad, None, grad2_embeddings, None, None, None, None, None, None, None, None, None
102 |
103 |
104 | hash_encode = _hash_encode.apply
105 |
106 |
107 | class HashEncoder(nn.Module):
108 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None):
109 | super().__init__()
110 |
111 | # the finest resolution desired at the last level, if provided, overridee per_level_scale
112 | if desired_resolution is not None:
113 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
114 |
115 | self.input_dim = input_dim # coord dims, 2 or 3
116 | self.num_levels = num_levels # num levels, each level multiply resolution by 2
117 | self.level_dim = level_dim # encode channels per level
118 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
119 | self.log2_hashmap_size = log2_hashmap_size
120 | self.base_resolution = base_resolution
121 | self.output_dim = num_levels * level_dim
122 |
123 | if level_dim % 2 != 0:
124 | print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)')
125 |
126 | # allocate parameters
127 | offsets = []
128 | offset = 0
129 | self.max_params = 2 ** log2_hashmap_size
130 | for i in range(num_levels):
131 | resolution = int(np.ceil(base_resolution * per_level_scale ** i))
132 | params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
133 | #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible
134 | offsets.append(offset)
135 | offset += params_in_level
136 | offsets.append(offset)
137 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
138 | self.register_buffer('offsets', offsets)
139 |
140 | self.n_params = offsets[-1] * level_dim
141 |
142 | # parameters
143 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
144 |
145 | self.reset_parameters()
146 |
147 | def reset_parameters(self):
148 | std = 1e-4
149 | self.embeddings.data.uniform_(-std, std)
150 |
151 | def __repr__(self):
152 | return f"HashEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)}"
153 |
154 | def forward(self, inputs, size=1):
155 | # inputs: [..., input_dim], normalized real world positions in [-size, size]
156 | # return: [..., num_levels * level_dim]
157 |
158 | inputs = (inputs + size) / (2 * size) # map to [0, 1]
159 |
160 | prefix_shape = list(inputs.shape[:-1])
161 | inputs = inputs.view(-1, self.input_dim)
162 |
163 | outputs = hash_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad)
164 | outputs = outputs.view(prefix_shape + [self.output_dim])
165 |
166 | return outputs
--------------------------------------------------------------------------------
/code/hashencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "hashencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("hash_encode_forward", &hash_encode_forward, "hash encode forward (CUDA)");
7 | m.def("hash_encode_backward", &hash_encode_backward, "hash encode backward (CUDA)");
8 | m.def("hash_encode_second_backward", &hash_encode_second_backward, "hash encode second backward (CUDA)");
9 | }
--------------------------------------------------------------------------------
/code/hashencoder/src/hashencoder.h:
--------------------------------------------------------------------------------
1 | #ifndef _HASH_ENCODE_H
2 | #define _HASH_ENCODE_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | // inputs: [B, D], float, in [0, 1]
9 | // embeddings: [sO, C], float
10 | // offsets: [L + 1], uint32_t
11 | // outputs: [B, L * C], float
12 | // H: base resolution
13 | void hash_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx);
14 | void hash_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs);
15 | void hash_encode_second_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, const at::Tensor grad_grad_inputs, at::Tensor grad_grad, at::Tensor grad2_embeddings);
16 |
17 | #endif
--------------------------------------------------------------------------------
/code/model/density.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class Density(nn.Module):
6 | def __init__(self, params_init={}):
7 | super().__init__()
8 | for p in params_init:
9 | param = nn.Parameter(torch.tensor(params_init[p]))
10 | setattr(self, p, param)
11 |
12 | def forward(self, sdf, beta=None):
13 | return self.density_func(sdf, beta=beta)
14 |
15 |
16 | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
17 | def __init__(self, params_init={}, beta_min=0.0001):
18 | super().__init__(params_init=params_init)
19 | self.beta_min = torch.tensor(beta_min).cuda()
20 |
21 | def density_func(self, sdf, beta=None):
22 | if beta is None:
23 | beta = self.get_beta()
24 |
25 | alpha = 1 / beta
26 | return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))
27 |
28 | def get_beta(self):
29 | beta = self.beta.abs() + self.beta_min
30 | return beta
31 |
32 |
33 | class AbsDensity(Density): # like NeRF++
34 | def density_func(self, sdf, beta=None):
35 | return torch.abs(sdf)
36 |
37 |
38 | class SimpleDensity(Density): # like NeRF
39 | def __init__(self, params_init={}, noise_std=1.0):
40 | super().__init__(params_init=params_init)
41 | self.noise_std = noise_std
42 |
43 | def density_func(self, sdf, beta=None):
44 | if self.training and self.noise_std > 0.0:
45 | noise = torch.randn(sdf.shape).cuda() * self.noise_std
46 | sdf = sdf + noise
47 | return torch.relu(sdf)
48 |
--------------------------------------------------------------------------------
/code/model/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
4 |
5 | class Embedder:
6 | def __init__(self, **kwargs):
7 | self.kwargs = kwargs
8 | self.create_embedding_fn()
9 |
10 | def create_embedding_fn(self):
11 | embed_fns = []
12 | d = self.kwargs['input_dims']
13 | out_dim = 0
14 | if self.kwargs['include_input']:
15 | embed_fns.append(lambda x: x)
16 | out_dim += d
17 |
18 | max_freq = self.kwargs['max_freq_log2']
19 | N_freqs = self.kwargs['num_freqs']
20 |
21 | if self.kwargs['log_sampling']:
22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
23 | else:
24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
25 |
26 | for freq in freq_bands:
27 | for p_fn in self.kwargs['periodic_fns']:
28 | embed_fns.append(lambda x, p_fn=p_fn,
29 | freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 | def get_embedder(multires, input_dims=3):
39 | embed_kwargs = {
40 | 'include_input': True,
41 | 'input_dims': input_dims,
42 | 'max_freq_log2': multires-1,
43 | 'num_freqs': multires,
44 | 'log_sampling': True,
45 | 'periodic_fns': [torch.sin, torch.cos],
46 | }
47 |
48 | embedder_obj = Embedder(**embed_kwargs)
49 | def embed(x, eo=embedder_obj): return eo.embed(x)
50 | return embed, embedder_obj.out_dim
51 |
--------------------------------------------------------------------------------
/code/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import utils.general as utils
4 | import math
5 | import torch.nn.functional as F
6 |
7 | # copy from MiDaS
8 | def compute_scale_and_shift(prediction, target, mask):
9 | # system matrix: A = [[a_00, a_01], [a_10, a_11]]
10 | a_00 = torch.sum(mask * prediction * prediction, (1, 2))
11 | a_01 = torch.sum(mask * prediction, (1, 2))
12 | a_11 = torch.sum(mask, (1, 2))
13 |
14 | # right hand side: b = [b_0, b_1]
15 | b_0 = torch.sum(mask * prediction * target, (1, 2))
16 | b_1 = torch.sum(mask * target, (1, 2))
17 |
18 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
19 | x_0 = torch.zeros_like(b_0)
20 | x_1 = torch.zeros_like(b_1)
21 |
22 | det = a_00 * a_11 - a_01 * a_01
23 | valid = det.nonzero()
24 |
25 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
26 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
27 |
28 | return x_0, x_1
29 |
30 |
31 | def reduction_batch_based(image_loss, M):
32 | # average of all valid pixels of the batch
33 |
34 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
35 | divisor = torch.sum(M)
36 |
37 | if divisor == 0:
38 | return 0
39 | else:
40 | return torch.sum(image_loss) / divisor
41 |
42 |
43 | def reduction_image_based(image_loss, M):
44 | # mean of average of valid pixels of an image
45 |
46 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
47 | valid = M.nonzero()
48 |
49 | image_loss[valid] = image_loss[valid] / M[valid]
50 |
51 | return torch.mean(image_loss)
52 |
53 |
54 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based):
55 |
56 | M = torch.sum(mask, (1, 2))
57 | res = prediction - target
58 | image_loss = torch.sum(mask * res * res, (1, 2))
59 |
60 | return reduction(image_loss, 2 * M)
61 |
62 |
63 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
64 |
65 | M = torch.sum(mask, (1, 2))
66 |
67 | diff = prediction - target
68 | diff = torch.mul(mask, diff)
69 |
70 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
71 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
72 | grad_x = torch.mul(mask_x, grad_x)
73 |
74 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
75 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
76 | grad_y = torch.mul(mask_y, grad_y)
77 |
78 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
79 |
80 | return reduction(image_loss, M)
81 |
82 |
83 | def convert_number_to_digits(x):
84 | assert isinstance(x, int), 'the input value {} should be int'.format(x)
85 | v = 2**x
86 | # convert to 0-1 digits
87 |
88 |
89 |
90 |
91 | class MSELoss(nn.Module):
92 | def __init__(self, reduction='batch-based'):
93 | super().__init__()
94 |
95 | if reduction == 'batch-based':
96 | self.__reduction = reduction_batch_based
97 | else:
98 | self.__reduction = reduction_image_based
99 |
100 | def forward(self, prediction, target, mask):
101 | return mse_loss(prediction, target, mask, reduction=self.__reduction)
102 |
103 |
104 | class GradientLoss(nn.Module):
105 | def __init__(self, scales=4, reduction='batch-based'):
106 | super().__init__()
107 |
108 | if reduction == 'batch-based':
109 | self.__reduction = reduction_batch_based
110 | else:
111 | self.__reduction = reduction_image_based
112 |
113 | self.__scales = scales
114 |
115 | def forward(self, prediction, target, mask):
116 | total = 0
117 |
118 | for scale in range(self.__scales):
119 | step = pow(2, scale)
120 |
121 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
122 | mask[:, ::step, ::step], reduction=self.__reduction)
123 |
124 | return total
125 |
126 |
127 | class ScaleAndShiftInvariantLoss(nn.Module):
128 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'):
129 | super().__init__()
130 |
131 | self.__data_loss = MSELoss(reduction=reduction)
132 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
133 | self.__alpha = alpha
134 |
135 | self.__prediction_ssi = None
136 |
137 | def forward(self, prediction, target, mask):
138 |
139 | scale, shift = compute_scale_and_shift(prediction, target, mask)
140 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
141 |
142 | total = self.__data_loss(self.__prediction_ssi, target, mask)
143 | if self.__alpha > 0:
144 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)
145 |
146 | return total
147 |
148 | def __get_prediction_ssi(self):
149 | return self.__prediction_ssi
150 |
151 | prediction_ssi = property(__get_prediction_ssi)
152 | # end copy
153 |
154 |
155 | class MonoSDFLoss(nn.Module):
156 | def __init__(self, rgb_loss,
157 | eikonal_weight,
158 | smooth_weight = 0.005,
159 | depth_weight = 0.1,
160 | normal_l1_weight = 0.05,
161 | normal_cos_weight = 0.05,
162 | end_step = -1):
163 | super().__init__()
164 | self.eikonal_weight = eikonal_weight
165 | self.smooth_weight = smooth_weight
166 | self.depth_weight = depth_weight
167 | self.normal_l1_weight = normal_l1_weight
168 | self.normal_cos_weight = normal_cos_weight
169 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean')
170 |
171 | self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1)
172 |
173 | # print(f"using weight for loss RGB_1.0 EK_{self.eikonal_weight} SM_{self.smooth_weight} Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}")
174 |
175 | self.step = 0
176 | self.end_step = end_step
177 |
178 | def get_rgb_loss(self,rgb_values, rgb_gt):
179 | rgb_gt = rgb_gt.reshape(-1, 3)
180 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt)
181 | return rgb_loss
182 |
183 | def get_eikonal_loss(self, grad_theta):
184 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
185 | return eikonal_loss
186 |
187 | def get_smooth_loss(self,model_outputs):
188 | # smoothness loss as unisurf
189 | g1 = model_outputs['grad_theta']
190 | g2 = model_outputs['grad_theta_nei']
191 |
192 | normals_1 = g1 / (g1.norm(2, dim=1).unsqueeze(-1) + 1e-5)
193 | normals_2 = g2 / (g2.norm(2, dim=1).unsqueeze(-1) + 1e-5)
194 | smooth_loss = torch.norm(normals_1 - normals_2, dim=-1).mean()
195 | return smooth_loss
196 |
197 | def get_depth_loss(self, depth_pred, depth_gt, mask):
198 | # TODO remove hard-coded scaling for depth
199 | return self.depth_loss(depth_pred.reshape(1, 32, 32), (depth_gt * 50 + 0.5).reshape(1, 32, 32), mask.reshape(1, 32, 32))
200 |
201 | def get_normal_loss(self, normal_pred, normal_gt):
202 | normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1)
203 | normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1)
204 | l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean()
205 | cos = (1. - torch.sum(normal_pred * normal_gt, dim = -1)).mean()
206 | return l1, cos
207 |
208 | def forward(self, model_outputs, ground_truth):
209 | # import pdb; pdb.set_trace()
210 | rgb_gt = ground_truth['rgb'].cuda()
211 | # monocular depth and normal
212 | depth_gt = ground_truth['depth'].cuda()
213 | normal_gt = ground_truth['normal'].cuda()
214 |
215 | depth_pred = model_outputs['depth_values']
216 | normal_pred = model_outputs['normal_map'][None]
217 |
218 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)
219 |
220 | if 'grad_theta' in model_outputs:
221 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
222 | else:
223 | eikonal_loss = torch.tensor(0.0).cuda().float()
224 |
225 | # only supervised the foreground normal
226 | mask = ((model_outputs['sdf'] > 0.).any(dim=-1) & (model_outputs['sdf'] < 0.).any(dim=-1))[None, :, None]
227 | # combine with GT
228 | mask = (ground_truth['mask'] > 0.5).cuda() & mask
229 |
230 | depth_loss = self.get_depth_loss(depth_pred, depth_gt, mask) if self.depth_weight > 0 else torch.tensor(0.0).cuda().float()
231 | if isinstance(depth_loss, float):
232 | depth_loss = torch.tensor(0.0).cuda().float()
233 |
234 | normal_l1, normal_cos = self.get_normal_loss(normal_pred * mask, normal_gt)
235 |
236 | smooth_loss = self.get_smooth_loss(model_outputs)
237 |
238 | # compute decay weights
239 | if self.end_step > 0:
240 | decay = math.exp(-self.step / self.end_step * 10.)
241 | else:
242 | decay = 1.0
243 |
244 | self.step += 1
245 |
246 | loss = rgb_loss + \
247 | self.eikonal_weight * eikonal_loss +\
248 | self.smooth_weight * smooth_loss +\
249 | decay * self.depth_weight * depth_loss +\
250 | decay * self.normal_l1_weight * normal_l1 +\
251 | decay * self.normal_cos_weight * normal_cos
252 |
253 | output = {
254 | 'loss': loss,
255 | 'rgb_loss': rgb_loss,
256 | 'eikonal_loss': eikonal_loss,
257 | 'smooth_loss': smooth_loss,
258 | 'depth_loss': depth_loss,
259 | 'normal_l1': normal_l1,
260 | 'normal_cos': normal_cos
261 | }
262 |
263 | return output
264 |
265 |
266 | class ObjectSDFPlusLoss(MonoSDFLoss):
267 | def __init__(self, rgb_loss,
268 | eikonal_weight,
269 | semantic_weight = 0.04,
270 | smooth_weight = 0.005,
271 | semantic_loss = torch.nn.CrossEntropyLoss(ignore_index = -1),
272 | depth_weight = 0.1,
273 | normal_l1_weight = 0.05,
274 | normal_cos_weight = 0.05,
275 | reg_vio_weight = 0.1,
276 | use_obj_opacity = True,
277 | bg_reg_weight = 0.1,
278 | end_step = -1):
279 | super().__init__(
280 | rgb_loss = rgb_loss,
281 | eikonal_weight = eikonal_weight,
282 | smooth_weight = smooth_weight,
283 | depth_weight = depth_weight,
284 | normal_l1_weight = normal_l1_weight,
285 | normal_cos_weight = normal_cos_weight,
286 | end_step = end_step)
287 | self.semantic_weight = semantic_weight
288 | self.bg_reg_weight = bg_reg_weight
289 | self.semantic_loss = utils.get_class(semantic_loss)(reduction='mean') if semantic_loss is not torch.nn.CrossEntropyLoss else torch.nn.CrossEntropyLoss(ignore_index = -1)
290 | self.reg_vio_weight = reg_vio_weight
291 | self.use_obj_opacity = use_obj_opacity
292 |
293 | print(f"[INFO]: using weight for loss RGB_1.0 EK_{self.eikonal_weight} SM_{self.smooth_weight} Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}\
294 | Semantic_{self.semantic_weight}, semantic_loss_type_{self.semantic_loss} Use_object_opacity_{self.use_obj_opacity}")
295 |
296 | def get_semantic_loss(self, semantic_value, semantic_gt):
297 | semantic_gt = semantic_gt.squeeze()
298 | semantic_loss = self.semantic_loss(semantic_value, semantic_gt)
299 | return semantic_loss
300 |
301 | # violiation loss
302 | def get_violation_reg_loss(self, sdf_value):
303 | # turn to vector, sdf_value: [#rays, #objects]
304 | min_value, min_indice = torch.min(sdf_value, dim=1, keepdims=True)
305 | input = -sdf_value-min_value.detach() # add the min value for all tensor
306 | res = torch.relu(input).sum(dim=1, keepdims=True) - torch.relu(torch.gather(input, 1, min_indice))
307 | loss = res.sum()
308 | return loss
309 |
310 | def object_distinct_loss(self, sdf_value, min_sdf):
311 | _, min_indice = torch.min(sdf_value.squeeze(), dim=1, keepdims=True)
312 | input = -sdf_value.squeeze() - min_sdf.detach()
313 | res = torch.relu(input).sum(dim=1, keepdims=True) - torch.relu(torch.gather(input, 1, min_indice))
314 | loss = res.mean()
315 | return loss
316 |
317 | def object_opacity_loss(self, predict_opacity, gt_opacity, weight=None):
318 | # normalize predict_opacity
319 | # predict_opacity = torch.nn.functional.normalize(predict_opacity, p=1, dim=-1)
320 | target = torch.nn.functional.one_hot(gt_opacity.squeeze(), num_classes=predict_opacity.shape[1]).float()
321 | if weight is None:
322 | loss = F.binary_cross_entropy(predict_opacity.clamp(1e-4, 1-1e-4), target)
323 | return loss
324 |
325 | # background regularization loss following the desing in Sec 3.2 of RICO (https://arxiv.org/pdf/2303.08605.pdf)
326 | def bg_tv_loss(self, depth_pred, normal_pred, gt_mask):
327 | # the depth_pred and normal_pred should form a patch in image space, depth_pred: [ray, 1], normal_pred: [ray, 3], gt_mask: [1, ray, 1]
328 | size = int(math.sqrt(gt_mask.shape[1]))
329 | mask = gt_mask.reshape(size, size, -1)
330 | depth = depth_pred.reshape(size, size, -1)
331 | normal = torch.nn.functional.normalize(normal_pred, p=2, dim=-1).reshape(size, size, -1)
332 | loss = 0
333 | for stride in [1, 2, 4]:
334 | hd_d = torch.abs(depth[:, :-stride, :] - depth[:, stride:, :])
335 | wd_d = torch.abs(depth[:-stride, :, :] - depth[stride:, :, :])
336 | hd_n = torch.abs(normal[:, :-stride, :] - normal[:, stride:, :])
337 | wd_n = torch.abs(normal[:-stride, :, :] - normal[stride:, :, :])
338 | loss+= torch.mean(hd_d*mask[:, :-stride, :]) + torch.mean(wd_d*mask[:-stride, :, :])
339 | loss+= torch.mean(hd_n*mask[:, :-stride, :]) + torch.mean(wd_n*mask[:-stride, :, :])
340 | return loss
341 |
342 |
343 | def forward(self, model_outputs, ground_truth, call_reg=False, call_bg_reg=False):
344 | output = super().forward(model_outputs, ground_truth)
345 | if 'semantic_values' in model_outputs and not self.use_obj_opacity: # ObjectSDF loss: semantic field + cross entropy
346 | semantic_gt = ground_truth['segs'].cuda().long()
347 | semantic_loss = self.get_semantic_loss(model_outputs['semantic_values'], semantic_gt)
348 | elif "object_opacity" in model_outputs and self.use_obj_opacity: # ObjectSDF++ loss: occlusion-awared object opacity + MSE
349 | semantic_gt = ground_truth['segs'].cuda().long()
350 | semantic_loss = self.object_opacity_loss(model_outputs['object_opacity'], semantic_gt)
351 | else:
352 | semantic_loss = torch.tensor(0.0).cuda().float()
353 |
354 | if "sample_sdf" in model_outputs and call_reg:
355 | sample_sdf_loss = self.object_distinct_loss(model_outputs["sample_sdf"], model_outputs["sample_minsdf"])
356 | else:
357 | sample_sdf_loss = torch.tensor(0.0).cuda().float()
358 |
359 | background_reg_loss = torch.tensor(0.0).cuda().float()
360 | # if call_bg_reg:
361 | # mask = (ground_truth['segs'] != 0).cuda()
362 | # background_reg_loss = self.bg_tv_loss(model_outputs['bg_depth_values'], model_outputs['bg_normal_map'], mask)
363 | # else:
364 | # background_reg_loss = torch.tensor(0.0).cuda().float()
365 | output['semantic_loss'] = semantic_loss
366 | output['collision_reg_loss'] = sample_sdf_loss
367 | output['background_reg_loss'] = background_reg_loss
368 | output['loss'] = output['loss'] + self.semantic_weight * semantic_loss + self.reg_vio_weight* sample_sdf_loss
369 | # + self.bg_reg_weight * background_reg_loss # <- this one is not used in the paper, but is helpful for background regularization
370 | return output
--------------------------------------------------------------------------------
/code/model/network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from utils import rend_util
6 | from model.embedder import *
7 | from model.density import LaplaceDensity
8 | from model.ray_sampler import ErrorBoundSampler
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 | from torch import vmap
13 |
14 | class ImplicitNetwork(nn.Module):
15 | def __init__(
16 | self,
17 | feature_vector_size,
18 | sdf_bounding_sphere,
19 | d_in,
20 | d_out,
21 | dims,
22 | geometric_init=True,
23 | bias=1.0,
24 | skip_in=(),
25 | weight_norm=True,
26 | multires=0,
27 | sphere_scale=1.0,
28 | inside_outside=True,
29 | sigmoid = 10
30 | ):
31 | super().__init__()
32 |
33 | self.sdf_bounding_sphere = sdf_bounding_sphere
34 | self.sphere_scale = sphere_scale
35 | dims = [d_in] + dims + [d_out + feature_vector_size]
36 |
37 | self.embed_fn = None
38 | if multires > 0:
39 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
40 | self.embed_fn = embed_fn
41 | dims[0] = input_ch
42 | print(multires, dims)
43 | self.num_layers = len(dims)
44 | self.skip_in = skip_in
45 | self.d_out = d_out
46 | self.sigmoid = sigmoid
47 |
48 | for l in range(0, self.num_layers - 1):
49 | if l + 1 in self.skip_in:
50 | out_dim = dims[l + 1] - dims[0]
51 | else:
52 | out_dim = dims[l + 1]
53 |
54 | lin = nn.Linear(dims[l], out_dim)
55 |
56 | if geometric_init:
57 | if l == self.num_layers - 2:
58 | # Geometry initalization for compositional scene, bg SDF sign: inside + outside -, fg SDF sign: outside + inside -
59 | # The 0 index is the background SDF, the rest are the object SDFs
60 | # background SDF with postive value inside and nagative value outside
61 | torch.nn.init.normal_(lin.weight[:1, :], mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
62 | torch.nn.init.constant_(lin.bias[:1], bias)
63 | # inner objects with SDF initial with negative value inside and positive value outside, ~0.6 radius of background
64 | torch.nn.init.normal_(lin.weight[1:,:], mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
65 | torch.nn.init.constant_(lin.bias[1:], -0.6*bias)
66 |
67 | elif multires > 0 and l == 0:
68 | torch.nn.init.constant_(lin.bias, 0.0)
69 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
70 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
71 | elif multires > 0 and l in self.skip_in:
72 | torch.nn.init.constant_(lin.bias, 0.0)
73 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
74 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
75 | else:
76 | torch.nn.init.constant_(lin.bias, 0.0)
77 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
78 |
79 | if weight_norm:
80 | lin = nn.utils.weight_norm(lin)
81 |
82 | setattr(self, "lin" + str(l), lin)
83 |
84 | self.softplus = nn.Softplus(beta=100)
85 | self.pool = nn.MaxPool1d(self.d_out)
86 |
87 | def forward(self, input):
88 | if self.embed_fn is not None:
89 | input = self.embed_fn(input)
90 |
91 | x = input
92 |
93 | for l in range(0, self.num_layers - 1):
94 | lin = getattr(self, "lin" + str(l))
95 |
96 | if l in self.skip_in:
97 | x = torch.cat([x, input], 1) / np.sqrt(2)
98 |
99 | x = lin(x)
100 |
101 | if l < self.num_layers - 2:
102 | x = self.softplus(x)
103 |
104 | return x
105 |
106 | def gradient(self, x):
107 | x.requires_grad_(True)
108 | y = self.forward(x)[:,:self.d_out]
109 | d_output = torch.ones_like(y[:, :1], requires_grad=False, device=y.device)
110 | g = []
111 | for idx in range(y.shape[1]):
112 | gradients = torch.autograd.grad(
113 | outputs=y[:, idx:idx+1],
114 | inputs=x,
115 | grad_outputs=d_output,
116 | create_graph=True,
117 | retain_graph=True,
118 | only_inputs=True)[0]
119 | g.append(gradients)
120 | g = torch.cat(g)
121 | # add the gradient of minimum sdf
122 | sdf = -self.pool(-y.unsqueeze(1)).squeeze(-1)
123 | g_min_sdf = torch.autograd.grad(
124 | outputs=sdf,
125 | inputs=x,
126 | grad_outputs=d_output,
127 | create_graph=True,
128 | retain_graph=True,
129 | only_inputs=True)[0]
130 | g = torch.cat([g, g_min_sdf])
131 | return g
132 |
133 | def get_outputs(self, x, beta=None):
134 | x.requires_grad_(True)
135 | output = self.forward(x)
136 | sdf_raw = output[:,:self.d_out]
137 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
138 | if self.sdf_bounding_sphere > 0.0:
139 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
140 | sdf_raw = torch.minimum(sdf_raw, sphere_sdf.expand(sdf_raw.shape))
141 | if beta == None:
142 | semantic = self.sigmoid * torch.sigmoid(-self.sigmoid * sdf_raw)
143 | else:
144 | semantic = 0.5/beta *torch.exp(-sdf_raw.abs()/beta)
145 | sdf = -self.pool(-sdf_raw.unsqueeze(1)).squeeze(-1) # get the minium value of sdf
146 | feature_vectors = output[:, self.d_out:]
147 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
148 | gradients = torch.autograd.grad(
149 | outputs=sdf,
150 | inputs=x,
151 | grad_outputs=d_output,
152 | create_graph=True,
153 | retain_graph=True,
154 | only_inputs=True)[0]
155 |
156 | return sdf, feature_vectors, gradients, semantic, sdf_raw
157 |
158 | def get_sdf_vals(self, x):
159 | sdf = self.forward(x)[:,:self.d_out]
160 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
161 | # sdf = -self.pool(-sdf) # get the minium value of sdf if bound apply in the final
162 | if self.sdf_bounding_sphere > 0.0:
163 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
164 | sdf = torch.minimum(sdf, sphere_sdf.expand(sdf.shape))
165 | sdf = -self.pool(-sdf.unsqueeze(1)).squeeze(-1) # get the minium value of sdf if bound apply before min
166 | return sdf
167 |
168 | def get_sdf_raw(self, x):
169 | return self.forward(x)[:, :self.d_out]
170 |
171 |
172 | from hashencoder.hashgrid import HashEncoder
173 | class ObjectImplicitNetworkGrid(nn.Module):
174 | def __init__(
175 | self,
176 | feature_vector_size,
177 | sdf_bounding_sphere,
178 | d_in,
179 | d_out,
180 | dims,
181 | geometric_init=True,
182 | bias=1.0, # radius of the sphere in geometric initialization
183 | skip_in=(),
184 | weight_norm=True,
185 | multires=0,
186 | sphere_scale=1.0,
187 | inside_outside=False,
188 | base_size = 16,
189 | end_size = 2048,
190 | logmap = 19,
191 | num_levels=16,
192 | level_dim=2,
193 | divide_factor = 1.5, # used to normalize the points range for multi-res grid
194 | use_grid_feature = True, # use hash grid embedding or not, if not, it is a pure MLP with sin/cos embedding
195 | sigmoid = 20
196 | ):
197 | super().__init__()
198 |
199 | self.d_out = d_out
200 | self.sigmoid = sigmoid
201 | self.sdf_bounding_sphere = sdf_bounding_sphere
202 | self.sphere_scale = sphere_scale
203 | dims = [d_in] + dims + [d_out + feature_vector_size]
204 | self.embed_fn = None
205 | self.divide_factor = divide_factor
206 | self.grid_feature_dim = num_levels * level_dim
207 | self.use_grid_feature = use_grid_feature
208 | dims[0] += self.grid_feature_dim
209 |
210 | print(f"[INFO]: using hash encoder with {num_levels} levels, each level with feature dim {level_dim}")
211 | print(f"[INFO]: resolution:{base_size} -> {end_size} with hash map size {logmap}")
212 | self.encoding = HashEncoder(input_dim=3, num_levels=num_levels, level_dim=level_dim,
213 | per_level_scale=2, base_resolution=base_size,
214 | log2_hashmap_size=logmap, desired_resolution=end_size)
215 |
216 | '''
217 | # can also use tcnn for multi-res grid as it now supports eikonal loss
218 | base_size = 16
219 | hash = True
220 | smoothstep = True
221 | self.encoding = tcnn.Encoding(3, {
222 | "otype": "HashGrid" if hash else "DenseGrid",
223 | "n_levels": 16,
224 | "n_features_per_level": 2,
225 | "log2_hashmap_size": 19,
226 | "base_resolution": base_size,
227 | "per_level_scale": 1.34,
228 | "interpolation": "Smoothstep" if smoothstep else "Linear"
229 | })
230 | '''
231 |
232 | if multires > 0:
233 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
234 | self.embed_fn = embed_fn
235 | dims[0] += input_ch - 3
236 | # print("network architecture")
237 | # print(dims)
238 |
239 | self.num_layers = len(dims)
240 | self.skip_in = skip_in
241 | for l in range(0, self.num_layers - 1):
242 | if l + 1 in self.skip_in:
243 | out_dim = dims[l + 1] - dims[0]
244 | else:
245 | out_dim = dims[l + 1]
246 |
247 | lin = nn.Linear(dims[l], out_dim)
248 |
249 | if geometric_init:
250 | if l == self.num_layers - 2:
251 | # Geometry initalization for compositional scene, bg SDF sign: inside + outside -, fg SDF sign: outside + inside -
252 | # The 0 index is the background SDF, the rest are the object SDFs
253 | # background SDF with postive value inside and nagative value outside
254 | torch.nn.init.normal_(lin.weight[:1, :], mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
255 | torch.nn.init.constant_(lin.bias[:1], bias)
256 | # inner objects with SDF initial with negative value inside and positive value outside, ~0.5 radius of background
257 | torch.nn.init.normal_(lin.weight[1:,:], mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
258 | torch.nn.init.constant_(lin.bias[1:], -0.5*bias)
259 |
260 | elif multires > 0 and l == 0:
261 | torch.nn.init.constant_(lin.bias, 0.0)
262 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
263 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
264 | elif multires > 0 and l in self.skip_in:
265 | torch.nn.init.constant_(lin.bias, 0.0)
266 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
267 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
268 | else:
269 | torch.nn.init.constant_(lin.bias, 0.0)
270 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
271 |
272 | if weight_norm:
273 | lin = nn.utils.weight_norm(lin)
274 |
275 | setattr(self, "lin" + str(l), lin)
276 |
277 | self.softplus = nn.Softplus(beta=100)
278 | self.cache_sdf = None
279 |
280 | self.pool = nn.MaxPool1d(d_out)
281 | self.relu = nn.ReLU()
282 |
283 | def forward(self, input):
284 | if self.use_grid_feature:
285 | # normalize point range as encoding assume points are in [-1, 1]
286 | # assert torch.max(input / self.divide_factor)<1 and torch.min(input / self.divide_factor)>-1, 'range out of [-1, 1], max: {}, min: {}'.format(torch.max(input / self.divide_factor), torch.min(input / self.divide_factor))
287 | feature = self.encoding(input / self.divide_factor)
288 | else:
289 | feature = torch.zeros_like(input[:, :1].repeat(1, self.grid_feature_dim))
290 |
291 | if self.embed_fn is not None:
292 | embed = self.embed_fn(input)
293 | input = torch.cat((embed, feature), dim=-1)
294 | else:
295 | input = torch.cat((input, feature), dim=-1)
296 |
297 | x = input
298 |
299 | for l in range(0, self.num_layers - 1):
300 | lin = getattr(self, "lin" + str(l))
301 |
302 | if l in self.skip_in:
303 | x = torch.cat([x, input], 1) / np.sqrt(2)
304 |
305 | x = lin(x)
306 |
307 | if l < self.num_layers - 2:
308 | x = self.softplus(x)
309 |
310 | return x
311 |
312 | def gradient(self, x):
313 | x.requires_grad_(True)
314 | y = self.forward(x)[:,:self.d_out]
315 | d_output = torch.ones_like(y[:, :1], requires_grad=False, device=y.device)
316 | f = lambda v: torch.autograd.grad(outputs=y,
317 | inputs=x,
318 | grad_outputs=v.repeat(y.shape[0], 1),
319 | create_graph=True,
320 | retain_graph=True,
321 | only_inputs=True)[0]
322 |
323 | N = torch.eye(y.shape[1], requires_grad=False).to(y.device)
324 |
325 | # start_time = time.time()
326 | if self.use_grid_feature: # using hashing grid feature, cannot support vmap now
327 | g = torch.cat([torch.autograd.grad(
328 | outputs=y,
329 | inputs=x,
330 | grad_outputs=idx.repeat(y.shape[0], 1),
331 | create_graph=True,
332 | retain_graph=True,
333 | only_inputs=True)[0] for idx in N.unbind()])
334 | # torch.cuda.synchronize()
335 | # print("time for computing gradient by for loop: ", time.time() - start_time, "s")
336 |
337 | # using vmap for batched gradient computation, if not using grid feature (pure MLP)
338 | else:
339 | g = vmap(f, in_dims=1)(N).reshape(-1, 3)
340 |
341 | # add the gradient of scene sdf
342 | sdf = -self.pool(-y.unsqueeze(1)).squeeze(-1)
343 | g_min_sdf = torch.autograd.grad(
344 | outputs=sdf,
345 | inputs=x,
346 | grad_outputs=d_output,
347 | create_graph=True,
348 | retain_graph=True,
349 | only_inputs=True)[0]
350 | g = torch.cat([g, g_min_sdf])
351 | return g
352 |
353 | def get_outputs(self, x, beta=None):
354 | x.requires_grad_(True)
355 | output = self.forward(x)
356 | sdf_raw = output[:,:self.d_out]
357 | # if self.sdf_bounding_sphere > 0.0:
358 | # sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
359 | # sdf_raw = torch.minimum(sdf_raw, sphere_sdf.expand(sdf_raw.shape))
360 |
361 | if beta == None:
362 | semantic = self.sigmoid * torch.sigmoid(-self.sigmoid * sdf_raw)
363 | else:
364 | # change semantic to the gradianct of density
365 | semantic = 1/beta * (0.5 + 0.5 * sdf_raw.sign() * torch.expm1(-sdf_raw.abs() / beta))
366 | sdf = -self.pool(-sdf_raw.unsqueeze(1)).squeeze(-1) # get the minium value of all objects sdf
367 | feature_vectors = output[:, self.d_out:]
368 |
369 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
370 | gradients = torch.autograd.grad(
371 | outputs=sdf,
372 | inputs=x,
373 | grad_outputs=d_output,
374 | create_graph=True,
375 | retain_graph=True,
376 | only_inputs=True)[0]
377 |
378 | return sdf, feature_vectors, gradients, semantic, sdf_raw
379 |
380 | def get_specific_outputs(self, x, idx):
381 | x.requires_grad_(True)
382 | output = self.forward(x)
383 | sdf_raw = output[:,:self.d_out]
384 | semantic = self.sigmoid * torch.sigmoid(-self.sigmoid * sdf_raw)
385 | sdf = -self.pool(-sdf_raw.unsqueeze(1)).squeeze(-1)
386 |
387 | feature_vectors = output[:, self.d_out:]
388 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
389 | gradients = torch.autograd.grad(
390 | outputs=sdf,
391 | inputs=x,
392 | grad_outputs=d_output,
393 | create_graph=True,
394 | retain_graph=True,
395 | only_inputs=True)[0]
396 |
397 | return sdf, feature_vectors, gradients, semantic, output[:,:self.d_out]
398 |
399 |
400 | def get_sdf_vals(self, x):
401 | sdf = -self.pool(-self.forward(x)[:,:self.d_out].unsqueeze(1)).squeeze(-1)
402 | return sdf
403 |
404 | def get_sdf_raw(self, x):
405 | return self.forward(x)[:, :self.d_out]
406 |
407 |
408 | def mlp_parameters(self):
409 | parameters = []
410 | for l in range(0, self.num_layers - 1):
411 | lin = getattr(self, "lin" + str(l))
412 | parameters += list(lin.parameters())
413 | return parameters
414 |
415 | def grid_parameters(self, verbose=False):
416 | if verbose:
417 | print("[INFO]: grid parameters", len(list(self.encoding.parameters())))
418 | for p in self.encoding.parameters():
419 | print(p.shape)
420 | return self.encoding.parameters()
421 |
422 |
423 | class RenderingNetwork(nn.Module):
424 | def __init__(
425 | self,
426 | feature_vector_size,
427 | mode,
428 | d_in,
429 | d_out,
430 | dims,
431 | weight_norm=True,
432 | multires_view=0,
433 | per_image_code = False
434 | ):
435 | super().__init__()
436 |
437 | self.mode = mode
438 | dims = [d_in + feature_vector_size] + dims + [d_out]
439 |
440 | self.embedview_fn = None
441 | if multires_view > 0:
442 | embedview_fn, input_ch = get_embedder(multires_view)
443 | self.embedview_fn = embedview_fn
444 | dims[0] += (input_ch - 3)
445 |
446 | self.per_image_code = per_image_code
447 | if self.per_image_code:
448 | # nerf in the wild parameter
449 | # parameters
450 | # maximum 1024 images
451 | self.embeddings = nn.Parameter(torch.empty(1024, 32))
452 | std = 1e-4
453 | self.embeddings.data.uniform_(-std, std)
454 | dims[0] += 32
455 |
456 | # print("rendering network architecture:")
457 | # print(dims)
458 |
459 | self.num_layers = len(dims)
460 |
461 | for l in range(0, self.num_layers - 1):
462 | out_dim = dims[l + 1]
463 | lin = nn.Linear(dims[l], out_dim)
464 |
465 | if weight_norm:
466 | lin = nn.utils.weight_norm(lin)
467 |
468 | setattr(self, "lin" + str(l), lin)
469 |
470 | self.relu = nn.ReLU()
471 | self.sigmoid = torch.nn.Sigmoid()
472 |
473 | def forward(self, points, normals, view_dirs, feature_vectors, indices):
474 | if self.embedview_fn is not None:
475 | view_dirs = self.embedview_fn(view_dirs)
476 |
477 | if self.mode == 'idr':
478 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
479 | elif self.mode == 'nerf':
480 | rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1)
481 | else:
482 | raise NotImplementedError
483 |
484 | if self.per_image_code:
485 | image_code = self.embeddings[indices].expand(rendering_input.shape[0], -1)
486 | rendering_input = torch.cat([rendering_input, image_code], dim=-1)
487 |
488 | x = rendering_input
489 |
490 | for l in range(0, self.num_layers - 1):
491 | lin = getattr(self, "lin" + str(l))
492 |
493 | x = lin(x)
494 |
495 | if l < self.num_layers - 2:
496 | x = self.relu(x)
497 |
498 | x = self.sigmoid(x)
499 | return x
500 |
501 |
502 | class ObjectSDFPlusNetwork(nn.Module):
503 | def __init__(self, conf):
504 | super().__init__()
505 | self.feature_vector_size = conf.get_int('feature_vector_size')
506 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0)
507 | self.white_bkgd = conf.get_bool('white_bkgd', default=False)
508 | self.bg_color = torch.tensor(conf.get_list("bg_color", default=[1.0, 1.0, 1.0])).float().cuda()
509 |
510 | Grid_MLP = conf.get_bool('Grid_MLP', default=False)
511 | self.Grid_MLP = Grid_MLP
512 | if Grid_MLP:
513 | self.implicit_network = ObjectImplicitNetworkGrid(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network'))
514 | else:
515 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network'))
516 |
517 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network'))
518 |
519 | self.density = LaplaceDensity(**conf.get_config('density'))
520 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, **conf.get_config('ray_sampler'))
521 |
522 | self.num_semantic = conf.get_int('implicit_network.d_out')
523 |
524 | def forward(self, input, indices):
525 | # Parse model input
526 | intrinsics = input["intrinsics"]
527 | uv = input["uv"]
528 | pose = input["pose"]
529 |
530 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics)
531 |
532 | # we should use unnormalized ray direction for depth
533 | ray_dirs_tmp, _ = rend_util.get_camera_params(uv, torch.eye(4).to(pose.device)[None], intrinsics)
534 | depth_scale = ray_dirs_tmp[0, :, 2:]
535 |
536 | batch_size, num_pixels, _ = ray_dirs.shape
537 |
538 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
539 | ray_dirs = ray_dirs.reshape(-1, 3)
540 |
541 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)
542 | N_samples = z_vals.shape[1]
543 |
544 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)
545 | points_flat = points.reshape(-1, 3)
546 |
547 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
548 | dirs_flat = dirs.reshape(-1, 3)
549 |
550 | beta_cur = self.density.get_beta()
551 | sdf, feature_vectors, gradients, semantic, sdf_raw = self.implicit_network.get_outputs(points_flat, beta=None)
552 |
553 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors, indices)
554 | rgb = rgb_flat.reshape(-1, N_samples, 3)
555 |
556 | semantic = semantic.reshape(-1, N_samples, self.num_semantic)
557 | weights, transmittance, dists = self.volume_rendering(z_vals, sdf)
558 |
559 | # rendering the occlusion-awared object opacity
560 | object_opacity = self.occlusion_opacity(z_vals, transmittance, dists, sdf_raw).sum(-1).transpose(0, 1)
561 |
562 |
563 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)
564 | semantic_values = torch.sum(weights.unsqueeze(-1)*semantic, 1)
565 | depth_values = torch.sum(weights * z_vals, 1, keepdims=True) / (weights.sum(dim=1, keepdims=True) +1e-8)
566 | # we should scale rendered distance to depth along z direction
567 | depth_values = depth_scale * depth_values
568 |
569 | # white background assumption
570 | if self.white_bkgd:
571 | acc_map = torch.sum(weights, -1)
572 | rgb_values = rgb_values + (1. - acc_map[..., None]) * self.bg_color.unsqueeze(0)
573 |
574 | output = {
575 | 'rgb':rgb,
576 | 'semantic_values': semantic_values, # here semantic value calculated as in ObjectSDF
577 | 'object_opacity': object_opacity,
578 | 'rgb_values': rgb_values,
579 | 'depth_values': depth_values,
580 | 'z_vals': z_vals,
581 | 'depth_vals': z_vals * depth_scale,
582 | 'sdf': sdf.reshape(z_vals.shape),
583 | 'weights': weights,
584 | }
585 |
586 | if self.training:
587 | # Sample points for the eikonal loss
588 | n_eik_points = batch_size * num_pixels
589 |
590 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda()
591 |
592 | # add some of the near surface points
593 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)
594 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)
595 | # add some neighbour points as unisurf
596 | neighbour_points = eikonal_points + (torch.rand_like(eikonal_points) - 0.5) * 0.01
597 | eikonal_points = torch.cat([eikonal_points, neighbour_points], 0)
598 |
599 | grad_theta = self.implicit_network.gradient(eikonal_points)
600 |
601 | sample_sdf = self.implicit_network.get_sdf_raw(eikonal_points)
602 | sdf_value = self.implicit_network.get_sdf_vals(eikonal_points)
603 | output['sample_sdf'] = sample_sdf
604 | output['sample_minsdf'] = sdf_value
605 |
606 | # split gradient to eikonal points and heighbour ponits
607 | output['grad_theta'] = grad_theta[:grad_theta.shape[0]//2]
608 | output['grad_theta_nei'] = grad_theta[grad_theta.shape[0]//2:]
609 |
610 | # compute normal map
611 | normals = gradients / (gradients.norm(2, -1, keepdim=True) + 1e-6)
612 | normals = normals.reshape(-1, N_samples, 3)
613 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
614 |
615 | # transform to local coordinate system
616 | rot = pose[0, :3, :3].permute(1, 0).contiguous()
617 | normal_map = rot @ normal_map.permute(1, 0)
618 | normal_map = normal_map.permute(1, 0).contiguous()
619 |
620 | output['normal_map'] = normal_map
621 |
622 | return output
623 |
624 | def volume_rendering(self, z_vals, sdf):
625 | density_flat = self.density(sdf)
626 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
627 |
628 | dists = z_vals[:, 1:] - z_vals[:, :-1]
629 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
630 |
631 | # LOG SPACE
632 | free_energy = dists * density
633 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) # shift one step
634 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
635 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
636 | weights = alpha * transmittance # probability of the ray hits something here
637 |
638 | return weights, transmittance, dists
639 |
640 | def occlusion_opacity(self, z_vals, transmittance, dists, sdf_raw):
641 | obj_density = self.density(sdf_raw).transpose(0, 1).reshape(-1, dists.shape[0], dists.shape[1]) # [#object, #ray, #sample points]
642 | free_energy = dists * obj_density
643 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
644 | object_weight = alpha * transmittance
645 | return object_weight
--------------------------------------------------------------------------------
/code/model/ray_sampler.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from tkinter.messagebox import NO
3 | import torch
4 |
5 | from utils import rend_util
6 |
7 | class RaySampler(metaclass=abc.ABCMeta):
8 | def __init__(self,near, far):
9 | self.near = near
10 | self.far = far
11 |
12 | @abc.abstractmethod
13 | def get_z_vals(self, ray_dirs, cam_loc, model):
14 | pass
15 |
16 | class UniformSampler(RaySampler):
17 | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
18 | #super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
19 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75 if far == -1 else far) # default far is 2*R
20 | self.N_samples = N_samples
21 | self.scene_bounding_sphere = scene_bounding_sphere
22 | self.take_sphere_intersection = take_sphere_intersection
23 |
24 | # dtu and bmvs
25 | def get_z_vals_dtu_bmvs(self, ray_dirs, cam_loc, model):
26 | if not self.take_sphere_intersection:
27 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
28 | else:
29 | sphere_intersections = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
30 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
31 | far = sphere_intersections[:,1:]
32 |
33 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
34 | z_vals = near * (1. - t_vals) + far * (t_vals)
35 |
36 | if model.training:
37 | # get intervals between samples
38 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
39 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
40 | lower = torch.cat([z_vals[..., :1], mids], -1)
41 | # stratified samples in those intervals
42 | t_rand = torch.rand(z_vals.shape).cuda()
43 |
44 | z_vals = lower + (upper - lower) * t_rand
45 |
46 | return z_vals, near, far
47 |
48 | def near_far_from_cube(self, rays_o, rays_d, bound):
49 | tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3]
50 | tmax = (bound - rays_o) / (rays_d + 1e-15)
51 | near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0]
52 | far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0]
53 | # if far < near, means no intersection, set both near and far to inf (1e9 here)
54 | mask = far < near
55 | near[mask] = 1e9
56 | far[mask] = 1e9
57 | # restrict near to a minimal value
58 | near = torch.clamp(near, min=self.near)
59 | far = torch.clamp(far, max=self.far)
60 | return near, far
61 |
62 | # currently this is used for replica scannet and T&T
63 | def get_z_vals(self, ray_dirs, cam_loc, model):
64 | if not self.take_sphere_intersection:
65 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
66 | else:
67 | _, far = self.near_far_from_cube(cam_loc, ray_dirs, bound=self.scene_bounding_sphere)
68 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
69 |
70 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
71 | z_vals = near * (1. - t_vals) + far * (t_vals)
72 |
73 | if model.training:
74 | # get intervals between samples
75 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
76 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
77 | lower = torch.cat([z_vals[..., :1], mids], -1)
78 | # stratified samples in those intervals
79 | t_rand = torch.rand(z_vals.shape).cuda()
80 |
81 | z_vals = lower + (upper - lower) * t_rand
82 |
83 | return z_vals, near, far
84 |
85 |
86 | class ErrorBoundSampler(RaySampler):
87 | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
88 | eps, beta_iters, max_total_iters,
89 | inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=1.0e-6):
90 | #super().__init__(near, 2.0 * scene_bounding_sphere)
91 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75)
92 |
93 | self.N_samples = N_samples
94 | self.N_samples_eval = N_samples_eval
95 | self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=True) # replica scannet and T&T courtroom
96 | #self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) # dtu and bmvs
97 |
98 | self.N_samples_extra = N_samples_extra
99 |
100 | self.eps = eps
101 | self.beta_iters = beta_iters
102 | self.max_total_iters = max_total_iters
103 | self.scene_bounding_sphere = scene_bounding_sphere
104 | self.add_tiny = add_tiny
105 |
106 | self.inverse_sphere_bg = inverse_sphere_bg
107 | if inverse_sphere_bg:
108 | self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
109 |
110 | def get_z_vals(self, ray_dirs, cam_loc, model):
111 | beta0 = model.density.get_beta().detach()
112 |
113 | # Start with uniform sampling
114 | z_vals, near, far = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
115 | samples, samples_idx = z_vals, None
116 |
117 | # Get maximum beta from the upper bound (Lemma 2)
118 | dists = z_vals[:, 1:] - z_vals[:, :-1]
119 | bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
120 | beta = torch.sqrt(bound)
121 |
122 | total_iters, not_converge = 0, True
123 |
124 | # Algorithm 1
125 | while not_converge and total_iters < self.max_total_iters:
126 | points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
127 | points_flat = points.reshape(-1, 3)
128 |
129 | # Calculating the SDF only for the new sampled points
130 | with torch.no_grad():
131 | samples_sdf = model.implicit_network.get_sdf_vals(points_flat)
132 | if samples_idx is not None:
133 | sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
134 | samples_sdf.reshape(-1, samples.shape[1])], -1)
135 | sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
136 | else:
137 | sdf = samples_sdf
138 |
139 |
140 | # Calculating the bound d* (Theorem 1)
141 | d = sdf.reshape(z_vals.shape)
142 | dists = z_vals[:, 1:] - z_vals[:, :-1]
143 | a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
144 | first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
145 | second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
146 | d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
147 | d_star[first_cond] = b[first_cond]
148 | d_star[second_cond] = c[second_cond]
149 | s = (a + b + c) / 2.0
150 | area_before_sqrt = s * (s - a) * (s - b) * (s - c)
151 | mask = ~first_cond & ~second_cond & (b + c - a > 0)
152 | d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
153 | d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
154 |
155 |
156 | # Updating beta using line search
157 | curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
158 | beta[curr_error <= self.eps] = beta0
159 | beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
160 | for j in range(self.beta_iters):
161 | beta_mid = (beta_min + beta_max) / 2.
162 | curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
163 | beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
164 | beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
165 | beta = beta_max
166 |
167 | # Upsample more points
168 | density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
169 |
170 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
171 | free_energy = dists * density
172 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
173 | alpha = 1 - torch.exp(-free_energy)
174 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
175 | weights = alpha * transmittance # probability of the ray hits something here
176 |
177 | # Check if we are done and this is the last sampling
178 | total_iters += 1
179 | not_converge = beta.max() > beta0
180 |
181 | if not_converge and total_iters < self.max_total_iters:
182 | ''' Sample more points proportional to the current error bound'''
183 |
184 | N = self.N_samples_eval
185 |
186 | bins = z_vals
187 | error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
188 | error_integral = torch.cumsum(error_per_section, dim=-1)
189 | bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
190 |
191 | pdf = bound_opacity + self.add_tiny
192 | pdf = pdf / torch.sum(pdf, -1, keepdim=True)
193 | cdf = torch.cumsum(pdf, -1)
194 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
195 |
196 | else:
197 | ''' Sample the final sample set to be used in the volume rendering integral '''
198 |
199 | N = self.N_samples
200 |
201 | bins = z_vals
202 | pdf = weights[..., :-1]
203 | pdf = pdf + 1e-5 # prevent nans
204 | pdf = pdf / torch.sum(pdf, -1, keepdim=True)
205 | cdf = torch.cumsum(pdf, -1)
206 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
207 |
208 |
209 | # Invert CDF
210 | if (not_converge and total_iters < self.max_total_iters) or (not model.training):
211 | u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
212 | else:
213 | u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
214 | u = u.contiguous()
215 |
216 | inds = torch.searchsorted(cdf, u, right=True)
217 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
218 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
219 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
220 |
221 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
222 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
223 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
224 |
225 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
226 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
227 | t = (u - cdf_g[..., 0]) / denom
228 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
229 |
230 |
231 | # Adding samples if we not converged
232 | if not_converge and total_iters < self.max_total_iters:
233 | z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
234 |
235 |
236 | z_samples = samples
237 | #TODO Use near and far from intersection
238 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
239 | if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
240 | far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
241 |
242 | if self.N_samples_extra > 0:
243 | if model.training:
244 | sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
245 | else:
246 | sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
247 | z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
248 | else:
249 | z_vals_extra = torch.cat([near, far], -1)
250 |
251 | z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
252 |
253 | # add some of the near surface points
254 | idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
255 | z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
256 |
257 | if self.inverse_sphere_bg:
258 | z_vals_inverse_sphere, _, _ = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
259 | z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
260 | z_vals = (z_vals, z_vals_inverse_sphere)
261 |
262 | return z_vals, z_samples_eik
263 |
264 | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
265 | density = model.density(sdf.reshape(z_vals.shape), beta=beta)
266 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
267 | integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
268 | error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
269 | error_integral = torch.cumsum(error_per_section, dim=-1)
270 | bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
271 |
272 | return bound_opacity.max(-1)[0]
273 |
274 |
--------------------------------------------------------------------------------
/code/training/exp_runner.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../code')
4 | import argparse
5 | import torch
6 |
7 | import os
8 | from training.objectsdfplus_train import ObjectSDFPlusTrainRunner
9 | import datetime
10 |
11 | if __name__ == '__main__':
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
15 | parser.add_argument('--nepoch', type=int, default=2000, help='number of epochs to train for')
16 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf')
17 | parser.add_argument('--expname', type=str, default='')
18 | parser.add_argument("--exps_folder", type=str, default="exps")
19 | #parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]')
20 | parser.add_argument('--is_continue', default=False, action="store_true",
21 | help='If set, indicates continuing from a previous run.')
22 | parser.add_argument('--timestamp', default='latest', type=str,
23 | help='The timestamp of the run to be used in case of continuing from a previous run.')
24 | parser.add_argument('--checkpoint', default='latest', type=str,
25 | help='The checkpoint epoch of the run to be used in case of continuing from a previous run.')
26 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')
27 | parser.add_argument('--cancel_vis', default=False, action="store_true",
28 | help='If set, cancel visualization in intermediate epochs.')
29 | # parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') # this is not required in torch 2.0
30 | parser.add_argument("--ft_folder", type=str, default=None, help='If set, finetune model from the given folder path')
31 |
32 | opt = parser.parse_args()
33 |
34 | '''
35 | # if using GPUtil
36 | if opt.gpu == "auto":
37 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False,
38 | excludeID=[], excludeUUID=[])
39 | gpu = deviceIDs[0]
40 | else:
41 | gpu = opt.gpu
42 | '''
43 | # gpu = opt.local_rank
44 |
45 | # set distributed training
46 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
47 | rank = int(os.environ["RANK"])
48 | world_size = int(os.environ['WORLD_SIZE'])
49 | local_rank = int(os.environ['LOCAL_RANK'])
50 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
51 | else:
52 | rank = -1
53 | world_size = -1
54 | local_rank = -1
55 |
56 | # print(opt.local_rank)
57 | torch.cuda.set_device(local_rank)
58 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank, timeout=datetime.timedelta(1, 1800))
59 | torch.distributed.barrier()
60 |
61 |
62 | trainrunner = ObjectSDFPlusTrainRunner(conf=opt.conf,
63 | batch_size=opt.batch_size,
64 | nepochs=opt.nepoch,
65 | expname=opt.expname,
66 | gpu_index=local_rank,
67 | exps_folder_name=opt.exps_folder,
68 | is_continue=opt.is_continue,
69 | timestamp=opt.timestamp,
70 | checkpoint=opt.checkpoint,
71 | scan_id=opt.scan_id,
72 | do_vis=not opt.cancel_vis,
73 | ft_folder = opt.ft_folder
74 | )
75 |
76 | trainrunner.run()
77 |
--------------------------------------------------------------------------------
/code/training/objectsdfplus_train.py:
--------------------------------------------------------------------------------
1 | import imp
2 | import os
3 | from datetime import datetime
4 | from pyhocon import ConfigFactory
5 | import sys
6 | import torch
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 | import utils.general as utils
11 | import utils.plots as plt
12 | from utils import rend_util
13 | from utils.general import get_time
14 | from torch.utils.tensorboard import SummaryWriter
15 | from model.loss import compute_scale_and_shift
16 | from utils.general import BackprojectDepth
17 |
18 | class ObjectSDFPlusTrainRunner():
19 | def __init__(self,**kwargs):
20 | torch.set_default_dtype(torch.float32)
21 | torch.set_num_threads(1)
22 |
23 | self.conf = ConfigFactory.parse_file(kwargs['conf'])
24 | self.batch_size = kwargs['batch_size']
25 | self.nepochs = kwargs['nepochs']
26 | self.exps_folder_name = kwargs['exps_folder_name']
27 | self.GPU_INDEX = kwargs['gpu_index']
28 |
29 | self.expname = self.conf.get_string('train.expname') + kwargs['expname']
30 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1)
31 | if scan_id != -1:
32 | self.expname = self.expname + '_{0}'.format(scan_id)
33 |
34 | self.finetune_folder = kwargs['ft_folder'] if kwargs['ft_folder'] is not None else None
35 | if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
36 | if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
37 | timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
38 | if (len(timestamps)) == 0:
39 | is_continue = False
40 | timestamp = None
41 | else:
42 | timestamp = sorted(timestamps)[-1]
43 | is_continue = True
44 | else:
45 | is_continue = False
46 | timestamp = None
47 | else:
48 | timestamp = kwargs['timestamp']
49 | is_continue = kwargs['is_continue']
50 |
51 | if self.GPU_INDEX == 0:
52 | utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))
53 | self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
54 | utils.mkdir_ifnotexists(self.expdir)
55 | self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
56 | utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
57 |
58 | self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
59 | utils.mkdir_ifnotexists(self.plots_dir)
60 |
61 | # create checkpoints dirs
62 | self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
63 | utils.mkdir_ifnotexists(self.checkpoints_path)
64 | self.model_params_subdir = "ModelParameters"
65 | self.optimizer_params_subdir = "OptimizerParameters"
66 | self.scheduler_params_subdir = "SchedulerParameters"
67 |
68 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
69 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
70 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir))
71 |
72 | os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf')))
73 |
74 | # if (not self.GPU_INDEX == 'ignore'):
75 | # os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX)
76 |
77 | print('[INFO]: shell command : {0}'.format(' '.join(sys.argv)))
78 |
79 | print('[INFO]: Loading data ...')
80 |
81 | dataset_conf = self.conf.get_config('dataset')
82 | if kwargs['scan_id'] != -1:
83 | dataset_conf['scan_id'] = kwargs['scan_id']
84 |
85 | self.all_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(**dataset_conf)
86 | if hasattr(self.all_dataset, 'i_split'):
87 | # if you would like to split the dataset into train and test, assign 'i_split' attribute to the all_dataset
88 | self.train_dataset = torch.utils.data.Subset(self.all_dataset, self.all_dataset.i_split[0])
89 | self.test_dataset = torch.utils.data.Subset(self.all_dataset, self.all_dataset.i_split[1])
90 | else:
91 | self.train_dataset = torch.utils.data.Subset(self.all_dataset, range(self.all_dataset.n_images))
92 | self.test_dataset = torch.utils.data.Subset(self.all_dataset, range(self.all_dataset.n_images))
93 |
94 | self.max_total_iters = self.conf.get_int('train.max_total_iters', default=200000)
95 | self.ds_len = len(self.train_dataset)
96 | self.nepochs = int(self.max_total_iters / self.ds_len) # update nepochs as iters/len(dataset)
97 | print('[INFO]: Finish loading data. Data-set size: {0}'.format(self.ds_len))
98 |
99 | self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset,
100 | batch_size=self.batch_size,
101 | shuffle=True,
102 | collate_fn=self.all_dataset.collate_fn,
103 | num_workers=8,
104 | pin_memory=True)
105 | self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset,
106 | batch_size=self.conf.get_int('plot.plot_nimgs'),
107 | shuffle=True,
108 | collate_fn=self.all_dataset.collate_fn
109 | )
110 |
111 | conf_model = self.conf.get_config('model')
112 | self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=conf_model)
113 |
114 | self.Grid_MLP = self.model.Grid_MLP
115 | if torch.cuda.is_available():
116 | self.model.cuda()
117 |
118 | self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss'))
119 |
120 | # The MLP and hash grid should have different learning rates
121 | self.lr = self.conf.get_float('train.learning_rate')
122 | self.lr_factor_for_grid = self.conf.get_float('train.lr_factor_for_grid', default=1.0)
123 |
124 | if self.Grid_MLP:
125 | self.optimizer = torch.optim.Adam([
126 | {'name': 'encoding', 'params': list(self.model.implicit_network.grid_parameters()),
127 | 'lr': self.lr * self.lr_factor_for_grid},
128 | {'name': 'net', 'params': list(self.model.implicit_network.mlp_parameters()) +\
129 | list(self.model.rendering_network.parameters()),
130 | 'lr': self.lr},
131 | {'name': 'density', 'params': list(self.model.density.parameters()),
132 | 'lr': self.lr},
133 | ], betas=(0.9, 0.99), eps=1e-15)
134 | else:
135 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
136 |
137 | # Exponential learning rate scheduler
138 | decay_rate = self.conf.get_float('train.sched_decay_rate', default=0.1)
139 | decay_steps = self.nepochs * len(self.train_dataset)
140 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, decay_rate ** (1./decay_steps))
141 |
142 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.GPU_INDEX], broadcast_buffers=False, find_unused_parameters=True)
143 |
144 | self.do_vis = kwargs['do_vis']
145 |
146 | self.start_epoch = 0
147 | # Loading a pretrained model for finetuning, the model path can be provided by self.finetune_folder
148 | if is_continue or self.finetune_folder is not None:
149 | old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') if self.finetune_folder is None\
150 | else os.path.join(self.finetune_folder, 'checkpoints')
151 |
152 | print('[INFO]: Loading pretrained model from {}'.format(old_checkpnts_dir))
153 | saved_model_state = torch.load(
154 | os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
155 | self.model.load_state_dict(saved_model_state["model_state_dict"])
156 | self.start_epoch = saved_model_state['epoch']
157 |
158 | data = torch.load(
159 | os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
160 | self.optimizer.load_state_dict(data["optimizer_state_dict"])
161 |
162 | data = torch.load(
163 | os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth"))
164 | self.scheduler.load_state_dict(data["scheduler_state_dict"])
165 |
166 | self.num_pixels = self.conf.get_int('train.num_pixels')
167 | self.total_pixels = self.all_dataset.total_pixels
168 | self.img_res = self.all_dataset.img_res
169 | self.n_batches = len(self.train_dataloader)
170 | self.plot_freq = self.conf.get_int('train.plot_freq')
171 | self.checkpoint_freq = self.conf.get_int('train.checkpoint_freq', default=100)
172 | self.split_n_pixels = self.conf.get_int('train.split_n_pixels', default=10000)
173 | self.plot_conf = self.conf.get_config('plot')
174 | self.backproject = BackprojectDepth(1, self.img_res[0], self.img_res[1]).cuda()
175 |
176 | self.add_objectvio_iter = self.conf.get_int('train.add_objectvio_iter', default=0)
177 |
178 | def save_checkpoints(self, epoch):
179 | torch.save(
180 | {"epoch": epoch, "model_state_dict": self.model.state_dict()},
181 | os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
182 | torch.save(
183 | {"epoch": epoch, "model_state_dict": self.model.state_dict()},
184 | os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
185 |
186 | torch.save(
187 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
188 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
189 | torch.save(
190 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
191 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
192 |
193 | torch.save(
194 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
195 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth"))
196 | torch.save(
197 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
198 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth"))
199 |
200 | def run(self):
201 | print("training...")
202 | if self.GPU_INDEX == 0 :
203 | self.writer = SummaryWriter(log_dir=os.path.join(self.plots_dir, 'logs'))
204 |
205 | self.iter_step = 0
206 | for epoch in range(self.start_epoch, self.nepochs + 1):
207 |
208 | if self.GPU_INDEX == 0 and epoch % self.checkpoint_freq == 0:
209 | self.save_checkpoints(epoch)
210 |
211 | if self.GPU_INDEX == 0 and self.do_vis and epoch % self.plot_freq == 0:
212 | self.model.eval()
213 |
214 | self.all_dataset.change_sampling_idx(-1)
215 |
216 | indices, model_input, ground_truth = next(iter(self.plot_dataloader))
217 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
218 | model_input["uv"] = model_input["uv"].cuda()
219 | model_input['pose'] = model_input['pose'].cuda()
220 |
221 | split = utils.split_input(model_input, self.total_pixels, n_pixels=self.split_n_pixels)
222 | res = []
223 | for s in tqdm(split):
224 | out = self.model(s, indices)
225 | d = {'rgb_values': out['rgb_values'].detach(),
226 | 'normal_map': out['normal_map'].detach(),
227 | 'depth_values': out['depth_values'].detach()}
228 | if 'rgb_un_values' in out:
229 | d['rgb_un_values'] = out['rgb_un_values'].detach()
230 | if 'semantic_values' in out:
231 | d['semantic_values'] = torch.argmax(out['semantic_values'].detach(),dim=1)
232 | res.append(d)
233 |
234 | batch_size = ground_truth['rgb'].shape[0]
235 | model_outputs = utils.merge_output(res, self.total_pixels, batch_size)
236 | plot_data = self.get_plot_data(model_input, model_outputs, model_input['pose'], ground_truth['rgb'], ground_truth['normal'], ground_truth['depth'], ground_truth['segs'])
237 |
238 | plt.plot(self.model.module.implicit_network,
239 | indices,
240 | plot_data,
241 | self.plots_dir,
242 | epoch,
243 | self.img_res,
244 | **self.plot_conf
245 | )
246 |
247 | self.model.train()
248 | self.all_dataset.change_sampling_idx(self.num_pixels)
249 |
250 | for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):
251 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
252 | model_input["uv"] = model_input["uv"].cuda()
253 | model_input['pose'] = model_input['pose'].cuda()
254 |
255 | self.optimizer.zero_grad()
256 |
257 | model_outputs = self.model(model_input, indices)
258 |
259 | loss_output = self.loss(model_outputs, ground_truth, call_reg=True) if\
260 | self.iter_step >= self.add_objectvio_iter else self.loss(model_outputs, ground_truth, call_reg=False)
261 | # if change the pixel sampling pattern to patch, then you can add a TV loss to enforce some smoothness constraint
262 | loss = loss_output['loss']
263 | loss.backward()
264 | self.optimizer.step()
265 |
266 | psnr = rend_util.get_psnr(model_outputs['rgb_values'],
267 | ground_truth['rgb'].cuda().reshape(-1,3))
268 |
269 | self.iter_step += 1
270 |
271 | if self.GPU_INDEX == 0 and data_index %20 == 0:
272 | print(
273 | '{0}_{1} [{2}] ({3}/{4}): loss = {5}, rgb_loss = {6}, eikonal_loss = {7}, psnr = {8}, bete={9}, alpha={10}, semantic_loss = {11}, reg_loss = {12}'
274 | .format(self.expname, self.timestamp, epoch, data_index, self.n_batches, loss.item(),
275 | loss_output['rgb_loss'].item(),
276 | loss_output['eikonal_loss'].item(),
277 | psnr.item(),
278 | self.model.module.density.get_beta().item(),
279 | 1. / self.model.module.density.get_beta().item(),
280 | loss_output['semantic_loss'].item(),
281 | loss_output['collision_reg_loss'].item()))
282 |
283 | self.writer.add_scalar('Loss/loss', loss.item(), self.iter_step)
284 | self.writer.add_scalar('Loss/color_loss', loss_output['rgb_loss'].item(), self.iter_step)
285 | self.writer.add_scalar('Loss/eikonal_loss', loss_output['eikonal_loss'].item(), self.iter_step)
286 | self.writer.add_scalar('Loss/smooth_loss', loss_output['smooth_loss'].item(), self.iter_step)
287 | self.writer.add_scalar('Loss/depth_loss', loss_output['depth_loss'].item(), self.iter_step)
288 | self.writer.add_scalar('Loss/normal_l1_loss', loss_output['normal_l1'].item(), self.iter_step)
289 | self.writer.add_scalar('Loss/normal_cos_loss', loss_output['normal_cos'].item(), self.iter_step)
290 | if 'semantic_loss' in loss_output:
291 | self.writer.add_scalar('Loss/semantic_loss', loss_output['semantic_loss'].item(), self.iter_step)
292 | if 'collision_reg_loss' in loss_output:
293 | self.writer.add_scalar('Loss/collision_reg_loss', loss_output['collision_reg_loss'].item(), self.iter_step)
294 |
295 | self.writer.add_scalar('Statistics/beta', self.model.module.density.get_beta().item(), self.iter_step)
296 | self.writer.add_scalar('Statistics/alpha', 1. / self.model.module.density.get_beta().item(), self.iter_step)
297 | self.writer.add_scalar('Statistics/psnr', psnr.item(), self.iter_step)
298 |
299 | if self.Grid_MLP:
300 | self.writer.add_scalar('Statistics/lr0', self.optimizer.param_groups[0]['lr'], self.iter_step)
301 | self.writer.add_scalar('Statistics/lr1', self.optimizer.param_groups[1]['lr'], self.iter_step)
302 | self.writer.add_scalar('Statistics/lr2', self.optimizer.param_groups[2]['lr'], self.iter_step)
303 |
304 | self.all_dataset.change_sampling_idx(self.num_pixels)
305 | self.scheduler.step()
306 |
307 | if self.GPU_INDEX == 0:
308 | self.save_checkpoints(epoch)
309 |
310 |
311 | def get_plot_data(self, model_input, model_outputs, pose, rgb_gt, normal_gt, depth_gt, seg_gt):
312 | batch_size, num_samples, _ = rgb_gt.shape
313 |
314 | rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3)
315 | normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3)
316 | normal_map = (normal_map + 1.) / 2.
317 |
318 | depth_map = model_outputs['depth_values'].reshape(batch_size, num_samples)
319 | depth_gt = depth_gt.to(depth_map.device)
320 | scale, shift = compute_scale_and_shift(depth_map[..., None], depth_gt, depth_gt > 0.)
321 | depth_map = depth_map * scale + shift
322 |
323 | seg_map = model_outputs['semantic_values'].reshape(batch_size, num_samples)
324 | seg_gt = seg_gt.to(seg_map.device)
325 |
326 | # save point cloud
327 | depth = depth_map.reshape(1, 1, self.img_res[0], self.img_res[1])
328 | pred_points = self.get_point_cloud(depth, model_input, model_outputs)
329 |
330 | gt_depth = depth_gt.reshape(1, 1, self.img_res[0], self.img_res[1])
331 | gt_points = self.get_point_cloud(gt_depth, model_input, model_outputs)
332 |
333 | plot_data = {
334 | 'rgb_gt': rgb_gt,
335 | 'normal_gt': (normal_gt + 1.)/ 2.,
336 | 'depth_gt': depth_gt,
337 | 'seg_gt': seg_gt,
338 | 'pose': pose,
339 | 'rgb_eval': rgb_eval,
340 | 'normal_map': normal_map,
341 | 'depth_map': depth_map,
342 | 'seg_map': seg_map,
343 | "pred_points": pred_points,
344 | "gt_points": gt_points,
345 | }
346 |
347 | return plot_data
348 |
349 | def get_point_cloud(self, depth, model_input, model_outputs):
350 | color = model_outputs["rgb_values"].reshape(-1, 3)
351 |
352 | K_inv = torch.inverse(model_input["intrinsics"][0])[None]
353 | points = self.backproject(depth, K_inv)[0, :3, :].permute(1, 0)
354 | points = torch.cat([points, color], dim=-1)
355 | return points.detach().cpu().numpy()
356 |
--------------------------------------------------------------------------------
/code/utils/general.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import time
7 | from torchvision import transforms
8 | import numpy as np
9 |
10 | def mkdir_ifnotexists(directory):
11 | if not os.path.exists(directory):
12 | os.mkdir(directory)
13 |
14 | def get_class(kls):
15 | parts = kls.split('.')
16 | module = ".".join(parts[:-1])
17 | m = __import__(module)
18 | for comp in parts[1:]:
19 | m = getattr(m, comp)
20 | return m
21 |
22 | def glob_imgs(path):
23 | imgs = []
24 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']:
25 | imgs.extend(glob(os.path.join(path, ext)))
26 | return imgs
27 |
28 | def split_input(model_input, total_pixels, n_pixels=10000):
29 | '''
30 | Split the input to fit Cuda memory for large resolution.
31 | Can decrease the value of n_pixels in case of cuda out of memory error.
32 | '''
33 | split = []
34 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
35 | data = model_input.copy()
36 | data['uv'] = torch.index_select(model_input['uv'], 1, indx)
37 | if 'object_mask' in data:
38 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
39 | if 'depth' in data:
40 | data['depth'] = torch.index_select(model_input['depth'], 1, indx)
41 | split.append(data)
42 | return split
43 |
44 | def merge_output(res, total_pixels, batch_size):
45 | ''' Merge the split output. '''
46 |
47 | model_outputs = {}
48 | for entry in res[0]:
49 | if res[0][entry] is None:
50 | continue
51 | if len(res[0][entry].shape) == 1:
52 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
53 | 1).reshape(batch_size * total_pixels)
54 | else:
55 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
56 | 1).reshape(batch_size * total_pixels, -1)
57 |
58 | return model_outputs
59 |
60 | def concat_home_dir(path):
61 | return os.path.join(os.environ['HOME'],'data',path)
62 |
63 | def get_time():
64 | torch.cuda.synchronize()
65 | return time.time()
66 |
67 | trans_topil = transforms.ToPILImage()
68 |
69 |
70 | class BackprojectDepth(nn.Module):
71 | """Layer to transform a depth image into a point cloud
72 | """
73 | def __init__(self, batch_size, height, width):
74 | super(BackprojectDepth, self).__init__()
75 |
76 | self.batch_size = batch_size
77 | self.height = height
78 | self.width = width
79 |
80 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
81 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
82 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
83 | requires_grad=False)
84 |
85 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
86 | requires_grad=False)
87 |
88 | self.pix_coords = torch.unsqueeze(torch.stack(
89 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
90 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
91 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
92 | requires_grad=False)
93 |
94 | def forward(self, depth, inv_K):
95 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
96 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points
97 | cam_points = torch.cat([cam_points, self.ones], 1)
98 | return cam_points
99 |
--------------------------------------------------------------------------------
/code/utils/rend_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import skimage
4 | import cv2
5 | import torch
6 | from torch.nn import functional as F
7 |
8 |
9 | def get_psnr(img1, img2, normalize_rgb=False):
10 | if normalize_rgb: # [-1,1] --> [0,1]
11 | img1 = (img1 + 1.) / 2.
12 | img2 = (img2 + 1. ) / 2.
13 |
14 | mse = torch.mean((img1 - img2) ** 2)
15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())
16 |
17 | return psnr
18 |
19 |
20 | def load_rgb(path, normalize_rgb = False):
21 | img = imageio.imread(path)
22 | img = skimage.img_as_float32(img)
23 |
24 | if normalize_rgb: # [-1,1] --> [0,1]
25 | img -= 0.5
26 | img *= 2.
27 | img = img.transpose(2, 0, 1)
28 | return img
29 |
30 |
31 | def load_K_Rt_from_P(filename, P=None):
32 | if P is None:
33 | lines = open(filename).read().splitlines()
34 | if len(lines) == 4:
35 | lines = lines[1:]
36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
37 | P = np.asarray(lines).astype(np.float32).squeeze()
38 |
39 | out = cv2.decomposeProjectionMatrix(P)
40 | K = out[0]
41 | R = out[1]
42 | t = out[2]
43 |
44 | # import pdb; pdb.set_trace()
45 |
46 | K = K/K[2,2]
47 | intrinsics = np.eye(4)
48 | intrinsics[:3, :3] = K
49 |
50 | pose = np.eye(4, dtype=np.float32)
51 | pose[:3, :3] = R.transpose()
52 | pose[:3,3] = (t[:3] / t[3])[:,0]
53 |
54 | return intrinsics, pose
55 |
56 |
57 | def get_camera_params(uv, pose, intrinsics):
58 | if pose.shape[1] == 7: #In case of quaternion vector representation
59 | cam_loc = pose[:, 4:]
60 | R = quat_to_rot(pose[:,:4])
61 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
62 | p[:, :3, :3] = R
63 | p[:, :3, 3] = cam_loc
64 | else: # In case of pose matrix representation
65 | cam_loc = pose[:, :3, 3]
66 | p = pose
67 |
68 | batch_size, num_samples, _ = uv.shape
69 |
70 | depth = torch.ones((batch_size, num_samples)).cuda()
71 | x_cam = uv[:, :, 0].view(batch_size, -1)
72 | y_cam = uv[:, :, 1].view(batch_size, -1)
73 | z_cam = depth.view(batch_size, -1)
74 |
75 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
76 |
77 | # permute for batch matrix product
78 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
79 |
80 | # import pdb; pdb.set_trace();
81 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
82 | ray_dirs = world_coords - cam_loc[:, None, :]
83 | ray_dirs = F.normalize(ray_dirs, dim=2)
84 |
85 | return ray_dirs, cam_loc
86 |
87 |
88 | def get_camera_for_plot(pose):
89 | if pose.shape[1] == 7: #In case of quaternion vector representation
90 | cam_loc = pose[:, 4:].detach()
91 | R = quat_to_rot(pose[:,:4].detach())
92 | else: # In case of pose matrix representation
93 | cam_loc = pose[:, :3, 3]
94 | R = pose[:, :3, :3]
95 | cam_dir = R[:, :3, 2]
96 | return cam_loc, cam_dir
97 |
98 |
99 | def lift(x, y, z, intrinsics):
100 | # parse intrinsics
101 | intrinsics = intrinsics.cuda()
102 | fx = intrinsics[:, 0, 0]
103 | fy = intrinsics[:, 1, 1]
104 | cx = intrinsics[:, 0, 2]
105 | cy = intrinsics[:, 1, 2]
106 | sk = intrinsics[:, 0, 1]
107 |
108 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
109 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
110 |
111 | # homogeneous
112 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
113 |
114 |
115 | def quat_to_rot(q):
116 | batch_size, _ = q.shape
117 | q = F.normalize(q, dim=1)
118 | R = torch.ones((batch_size, 3,3)).cuda()
119 | qr=q[:,0]
120 | qi = q[:, 1]
121 | qj = q[:, 2]
122 | qk = q[:, 3]
123 | R[:, 0, 0]=1-2 * (qj**2 + qk**2)
124 | R[:, 0, 1] = 2 * (qj *qi -qk*qr)
125 | R[:, 0, 2] = 2 * (qi * qk + qr * qj)
126 | R[:, 1, 0] = 2 * (qj * qi + qk * qr)
127 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
128 | R[:, 1, 2] = 2*(qj*qk - qi*qr)
129 | R[:, 2, 0] = 2 * (qk * qi-qj * qr)
130 | R[:, 2, 1] = 2 * (qj*qk + qi*qr)
131 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
132 | return R
133 |
134 |
135 | def rot_to_quat(R):
136 | batch_size, _,_ = R.shape
137 | q = torch.ones((batch_size, 4)).cuda()
138 |
139 | R00 = R[:, 0,0]
140 | R01 = R[:, 0, 1]
141 | R02 = R[:, 0, 2]
142 | R10 = R[:, 1, 0]
143 | R11 = R[:, 1, 1]
144 | R12 = R[:, 1, 2]
145 | R20 = R[:, 2, 0]
146 | R21 = R[:, 2, 1]
147 | R22 = R[:, 2, 2]
148 |
149 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
150 | q[:, 1]=(R21-R12)/(4*q[:,0])
151 | q[:, 2] = (R02 - R20) / (4 * q[:, 0])
152 | q[:, 3] = (R10 - R01) / (4 * q[:, 0])
153 | return q
154 |
155 |
156 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
157 | # Input: n_rays x 3 ; n_rays x 3
158 | # Output: n_rays x 1, n_rays x 1 (close and far)
159 |
160 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),
161 | cam_loc.view(-1, 3, 1)).squeeze(-1)
162 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)
163 |
164 | # sanity check
165 | if (under_sqrt <= 0).sum() > 0:
166 | print('BOUNDING SPHERE PROBLEM!')
167 | exit()
168 |
169 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
170 | sphere_intersections = sphere_intersections.clamp_min(0.0)
171 |
172 | return sphere_intersections
173 |
--------------------------------------------------------------------------------
/code/utils/sem_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | COLOR_MAP = {
5 | 0: [0, 0, 0],
6 | 1: [204, 0, 0],
7 | 2: [76, 153, 0],
8 | 3: [204, 204, 0],
9 | 4: [51, 51, 255],
10 | 5: [204, 0, 204],
11 | 6: [0, 255, 255],
12 | 7: [255, 204, 204],
13 | 8: [102, 51, 0],
14 | 9: [255, 0, 0],
15 | 10: [102, 204, 0],
16 | 11: [255, 255, 0],
17 | 12: [0, 0, 153],
18 | 13: [0, 0, 204],
19 | 14: [255, 51, 153],
20 | 15: [0, 204, 204],
21 | 16: [0, 51, 0],
22 | 17: [255, 153, 51],
23 | 18: [0, 204, 0]}
24 |
25 |
26 | COLOR_MAP_COMPLETE = {
27 | 0: [0, 0, 0],
28 | 1: [204, 0, 0],
29 | 2: [76, 153, 0],
30 | 3: [204, 204, 0],
31 | 4: [51, 51, 255],
32 | 5: [204, 0, 204],
33 | 6: [0, 255, 255],
34 | 7: [255, 204, 204],
35 | 8: [102, 51, 0],
36 | 9: [255, 0, 0],
37 | 10: [102, 204, 0],
38 | 11: [255, 255, 0],
39 | 12: [0, 0, 153],
40 | 13: [0, 0, 204],
41 | 14: [255, 51, 153],
42 | 15: [0, 204, 204],
43 | 16: [0, 51, 0],
44 | 17: [255, 153, 51],
45 | 18: [0, 204, 0]}
46 |
47 |
48 |
49 | def mask2color(masks, is_argmax=True):
50 | if is_argmax is True:
51 | masks = torch.argmax(masks, dim=1).float()
52 | # import pdb; pdb.set_trace()
53 | masks = masks.squeeze(1)
54 | sample_mask = torch.zeros((masks.shape[0], masks.shape[1], masks.shape[2], 3), dtype=torch.float)
55 | for key in COLOR_MAP:
56 | sample_mask[masks==key] = torch.tensor(COLOR_MAP[key], dtype=torch.float)
57 | sample_mask = sample_mask.permute(0,3,1,2)
58 | return sample_mask
--------------------------------------------------------------------------------
/media/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QianyiWu/objectsdf_plus/4c93f016d553ed3251a527ef1a32821cf53af90c/media/teaser.gif
--------------------------------------------------------------------------------
/preprocess/README.md:
--------------------------------------------------------------------------------
1 | # Data Preprocess
2 | Here we privode script to preprocess the data if you start from scratch. Take the ScanNet dataset as a example. You can download one scene from ScanNet dataset, including RGB images, campose poses/intrinsics and semantic/instance segmentations. Please refer to [ScanNet dataformat](https://github.com/ScanNet/ScanNet#data-organization) and [here](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) for more details.
3 |
4 |
5 | First, you need to prerun the Omnidata model (please install [omnidata model](https://github.com/EPFL-VILAB/omnidata) before running the command) to predict monocular cues for image. Note that the Omnidata is trained in 384*384, we follow MonoSDF to apply center-crop on original image to extract these information.
6 | ```
7 | cd preprocess
8 | python extract_monocular_cues.py --task depth --img_path PATH_to_your_image --output_path PATH_to_SAVE --omnidata_path YOUR_OMNIDATA_PATH --pretrained_models PRETRAINED_MODELS
9 | python extract_monocular_cues.py --task normal --img_path PATH_to_your_image --output_path PATH_to_SAVE --omnidata_path YOUR_OMNIDATA_PATH --pretrained_models PRETRAINED_MODELS
10 | ```
11 |
12 | Now you will get the monocular supervision. Then you need to organize the dataset to make it ready for training.
13 |
14 | ```
15 | python scannet_to_objsdfpp.py
16 | ```
17 |
18 | You can perform similar opeartion on Replica dataset to get the instance label mapping file.
19 | ```
20 | python replica_to_objsdfpp.py
21 | ```
22 |
23 | Here are some notes:
24 | 1. We merge some instance classes (such as ceiling, wall...) into the background. You can edit the `instance_mapping.txt` to define the objects you want.
25 | 2. The center and scale paramters mainly used for normalize the entire scene into a cube box. It is widely-used in many NeRF projects to obtain the camera poses for training.
26 | 3. We assume that the mask is view-consistent (i.e, the index of instance will not change). This can be done by a front-end segmentation algorithm (e.g. a video segmentation model).
27 |
28 | Please check these scripts for more details.
--------------------------------------------------------------------------------
/preprocess/extract_monocular_cues.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/EPFL-VILAB/omnidata
2 | import torch
3 | import torch.nn.functional as F
4 | from torchvision import transforms
5 |
6 | import PIL
7 | from PIL import Image
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 |
11 | import argparse
12 | import os.path
13 | from pathlib import Path
14 | import glob
15 | import sys
16 |
17 |
18 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals')
19 |
20 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model")
21 | parser.set_defaults(omnidata_path='/media/hdd/omnidata/omnidata_tools/torch/')
22 |
23 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models")
24 | parser.set_defaults(pretrained_models='/media/hdd/omnidata/omnidata_tools/torch/pretrained_models/')
25 |
26 | parser.add_argument('--task', dest='task', help="normal or depth")
27 | parser.set_defaults(task='NONE')
28 |
29 | parser.add_argument('--img_path', dest='img_path', help="path to rgb image")
30 | parser.set_defaults(im_name='NONE')
31 |
32 | parser.add_argument('--output_path', dest='output_path', help="path to where output image should be stored")
33 | parser.set_defaults(store_name='NONE')
34 |
35 | args = parser.parse_args()
36 |
37 | root_dir = args.pretrained_models
38 | omnidata_path = args.omnidata_path
39 |
40 | sys.path.append(args.omnidata_path)
41 | print(sys.path)
42 | from modules.unet import UNet
43 | from modules.midas.dpt_depth import DPTDepthModel
44 | from data.transforms import get_transform
45 |
46 | trans_topil = transforms.ToPILImage()
47 | os.system(f"mkdir -p {args.output_path}")
48 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')
49 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
50 |
51 |
52 | # get target task and model
53 | if args.task == 'normal':
54 | image_size = 384
55 |
56 | ## Version 1 model
57 | # pretrained_weights_path = root_dir + 'omnidata_unet_normal_v1.pth'
58 | # model = UNet(in_channels=3, out_channels=3)
59 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
60 |
61 | # if 'state_dict' in checkpoint:
62 | # state_dict = {}
63 | # for k, v in checkpoint['state_dict'].items():
64 | # state_dict[k.replace('model.', '')] = v
65 | # else:
66 | # state_dict = checkpoint
67 |
68 |
69 | pretrained_weights_path = root_dir + '/omnidata_dpt_normal_v2.ckpt'
70 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid
71 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
72 | if 'state_dict' in checkpoint:
73 | state_dict = {}
74 | for k, v in checkpoint['state_dict'].items():
75 | state_dict[k[6:]] = v
76 | else:
77 | state_dict = checkpoint
78 |
79 | model.load_state_dict(state_dict)
80 | model.to(device)
81 | trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
82 | transforms.CenterCrop(image_size),
83 | get_transform('rgb', image_size=None)])
84 |
85 | elif args.task == 'depth':
86 | image_size = 384
87 | pretrained_weights_path = root_dir + '/omnidata_dpt_depth_v2.ckpt' # 'omnidata_dpt_depth_v1.ckpt'
88 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large
89 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid
90 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location)
91 | if 'state_dict' in checkpoint:
92 | state_dict = {}
93 | for k, v in checkpoint['state_dict'].items():
94 | state_dict[k[6:]] = v
95 | else:
96 | state_dict = checkpoint
97 | model.load_state_dict(state_dict)
98 | model.to(device)
99 | trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
100 | transforms.CenterCrop(image_size),
101 | transforms.ToTensor(),
102 | transforms.Normalize(mean=0.5, std=0.5)])
103 |
104 | else:
105 | print("task should be one of the following: normal, depth")
106 | sys.exit()
107 |
108 | trans_rgb = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
109 | transforms.CenterCrop(image_size),
110 | ])
111 |
112 |
113 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1):
114 | if mask_valid is not None:
115 | img[~mask_valid] = torch.nan
116 | sorted_img = torch.sort(torch.flatten(img))[0]
117 | # Remove nan, nan at the end of sort
118 | num_nan = sorted_img.isnan().sum()
119 | if num_nan > 0:
120 | sorted_img = sorted_img[:-num_nan]
121 | # Remove outliers
122 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))]
123 | trunc_mean = trunc_img.mean()
124 | trunc_var = trunc_img.var()
125 | eps = 1e-6
126 | # Replace nan by mean
127 | img = torch.nan_to_num(img, nan=trunc_mean)
128 | # Standardize
129 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps)
130 | return img
131 |
132 |
133 | def save_outputs(img_path, output_file_name):
134 | with torch.no_grad():
135 | save_path = os.path.join(args.output_path, f'{output_file_name}_{args.task}.png')
136 |
137 | print(f'Reading input {img_path} ...')
138 | img = Image.open(img_path)
139 | # import pdb; pdb.set_trace();
140 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device)
141 |
142 | rgb_path = os.path.join(args.output_path, f'{output_file_name}_rgb.png')
143 | trans_rgb(img).save(rgb_path)
144 |
145 | if img_tensor.shape[1] == 1:
146 | img_tensor = img_tensor.repeat_interleave(3,1)
147 |
148 | output = model(img_tensor).clamp(min=0, max=1)
149 |
150 | if args.task == 'depth':
151 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0)
152 | output = output.clamp(0,1)
153 |
154 | np.save(save_path.replace('.png', '.npy'), output.detach().cpu().numpy()[0])
155 |
156 | #output = 1 - output
157 | # output = standardize_depth_map(output)
158 | plt.imsave(save_path, output.detach().cpu().squeeze(),cmap='viridis')
159 |
160 | else:
161 | #import pdb; pdb.set_trace()
162 | np.save(save_path.replace('.png', '.npy'), output.detach().cpu().numpy()[0])
163 | trans_topil(output[0]).save(save_path)
164 |
165 | print(f'Writing output {save_path} ...')
166 |
167 |
168 | img_path = Path(args.img_path)
169 | if img_path.is_file():
170 | save_outputs(args.img_path, os.path.splitext(os.path.basename(args.img_path))[0])
171 | elif img_path.is_dir():
172 | for f in glob.glob(args.img_path+'/*'):
173 | save_outputs(f, os.path.splitext(os.path.basename(f))[0])
174 | else:
175 | print("invalid file path!")
176 | sys.exit()
177 |
--------------------------------------------------------------------------------
/preprocess/replica_to_objsdfpp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 | import os
5 | from scipy.spatial.transform import Slerp
6 | from scipy.interpolate import interp1d
7 | from scipy.spatial.transform import Rotation as R
8 | import json
9 | import trimesh
10 | import glob
11 | import PIL
12 | from PIL import Image
13 | from torchvision import transforms
14 | import matplotlib.pyplot as plt
15 | import imageio
16 |
17 | # For Replica dataset, we adopt the camera intrinsic/extrinsic/rgb/depth/normal from MonoSDF dataset
18 | # For instance label, we use the instance label from vMAP processed dataset.
19 |
20 | # map the instance segmentation result to semantic segmentation result
21 |
22 | image_size = 384
23 | # trans_totensor = transforms.Compose([
24 | # transforms.CenterCrop(image_size),
25 | # transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
26 | # ])
27 | # depth_trans_totensor = transforms.Compose([
28 | # # transforms.Resize([680, 1200], interpolation=PIL.Image.NEAREST),
29 | # transforms.CenterCrop(image_size*2),
30 | # transforms.Resize(image_size, interpolation=PIL.Image.NEAREST),
31 | # ])
32 |
33 | seg_trans_totensor = transforms.Compose([
34 | transforms.CenterCrop(680),
35 | transforms.Resize(image_size, interpolation=PIL.Image.NEAREST),
36 | ])
37 |
38 | out_path_prefix = '../data/replica/Replica/'
39 | data_root = lambda x: '/media/hdd/Replica-Dataset/vmap/{}/imap/00'.format(x)
40 | # scenes = ['scene0050_00', 'scene0084_00', 'scene0580_00', 'scene0616_00']
41 | scenes = ['room_0', 'room_1', 'room_2', 'office_0', 'office_1', 'office_2', 'office_3', 'office_4']
42 | out_names = ['scan1', 'scan2', 'scan3', 'scan4', 'scan5', 'scan6', 'scan7', 'scan8']
43 |
44 | background_cls_list = [5,12,30,31,40,60,92,93,95,97,98,79]
45 | # merge more class into background
46 | background_cls_list.append(37) # door
47 | # background_cls_list.append(0) # undefined: 0 for this class, we mannully organize the data and the result can be found in each instance-mapping.txt
48 | background_cls_list.append(56) # panel
49 | background_cls_list.append(62) # pipe
50 |
51 | for scene, out_name in zip(scenes, out_names):
52 | # for scene, out_name in zip()
53 | out_path = os.path.join(out_path_prefix, out_name)
54 | os.makedirs(out_path, exist_ok=True)
55 | print(out_path)
56 |
57 | # folders = ["image", "mask", "depth", "segs"]
58 | folders = ['segs']
59 | for folder in folders:
60 | out_folder = os.path.join(out_path, folder)
61 | os.makedirs(out_folder, exist_ok=True)
62 |
63 | # process segmentation
64 | segs_path = os.path.join(data_root(scene), 'semantic_instance')
65 | segs_paths = sorted(glob.glob(os.path.join(segs_path, 'semantic_instance_*.png')),
66 | key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1]))
67 | print(segs_paths)
68 |
69 | labels_path = os.path.join(data_root(scene), 'semantic_class')
70 | labels_paths = sorted(glob.glob(os.path.join(labels_path, 'semantic_class_*.png')),
71 | key = lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1]))
72 | print(labels_paths)
73 |
74 | instance_semantic_mapping = {}
75 | instance_label_mapping = {}
76 | label_instance_counts = {}
77 |
78 | out_index = 0
79 | for idx, (seg_path, label_path) in enumerate(zip(segs_paths, labels_paths)):
80 | if idx % 20 !=0: continue
81 | print(idx, seg_path, label_path)
82 | # save segs
83 | target_image = os.path.join(out_path, 'segs', '{:06d}_segs.png'.format(out_index))
84 | print(target_image)
85 | seg = Image.open(seg_path)
86 | seg_tensor = seg_trans_totensor(seg)
87 | seg_tensor.save(target_image)
88 |
89 | # label_mapping
90 | label_np = cv2.imread(label_path, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0)
91 | # if 30 in np.unique(_np):
92 | # import pdb; pdb.set_trace()
93 | # label_np = label_np.
94 |
95 | segs_np = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0)
96 | insts = np.unique(segs_np)
97 | for inst_id in insts:
98 | inst_mask = segs_np == inst_id
99 | sem_cls = int(np.unique(label_np[inst_mask]))
100 | # import pdb; pdb.set_trace()
101 | instance_semantic_mapping[inst_id] = sem_cls
102 | if sem_cls in background_cls_list:
103 | instance_semantic_mapping[inst_id] = 0
104 | instance_label_mapping[inst_id] = 0
105 | # assert sem_cls.shape[0]!=0
106 | elif sem_cls in label_instance_counts:
107 | if inst_id not in instance_label_mapping:
108 | inst_count = label_instance_counts[sem_cls] + 1
109 | label_instance_counts[sem_cls] = inst_count
110 | # chaneg the instance label mapping index to 100*label + inst_count
111 | instance_label_mapping[inst_id] = sem_cls * 100 + inst_count
112 | else:
113 | continue # already saved
114 | else:
115 | inst_count = 1
116 | label_instance_counts[sem_cls] = inst_count
117 | instance_label_mapping[inst_id] = sem_cls * 100 + inst_count
118 |
119 | out_index += 1
120 |
121 | # save the instance mapping file to output path
122 | print({k: v for k, v in sorted(instance_label_mapping.items())})
123 | with open(os.path.join(out_path, 'instance_mapping_new.txt'), 'w') as f:
124 | # f.write(str(sorted(label_set)).strip('[').strip(']'))
125 | for k, v in instance_label_mapping.items():
126 | # instance_id, semantic_label, updated_instance_label (according to the number in this semantic class)
127 | f.write(str(k)+','+str(instance_semantic_mapping[k])+','+str(v)+'\n')
128 | f.close()
129 |
130 | #np.savez(os.path.join(out_path, "cameras_sphere.npz"), **cameras)
131 | # np.savez(os.path.join(out_path, "cameras.npz"), **cameras)
132 |
--------------------------------------------------------------------------------
/preprocess/scannet_to_objsdfpp.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 | import os
5 | from scipy.spatial.transform import Slerp
6 | from scipy.interpolate import interp1d
7 | from scipy.spatial.transform import Rotation as R
8 | import json
9 | import trimesh
10 | import glob
11 | import PIL
12 | from PIL import Image
13 | from torchvision import transforms
14 | import matplotlib.pyplot as plt
15 | import imageio
16 |
17 | import csv
18 | # Code from ScanNet script to convert instance images from the *_2d-instance.zip or *_2d-instance-filt.zip data for each scan.
19 | def read_label_mapping(filename, label_from='id', label_to='nyu40id'):
20 | assert os.path.isfile(filename)
21 | mapping = dict()
22 | with open(filename) as csvfile:
23 | reader = csv.DictReader(csvfile, delimiter='\t')
24 | for row in reader:
25 | mapping[row[label_from]] = int(row[label_to])
26 | # if ints convert
27 | # if represents_int(mapping.keys()[0]):
28 | mapping = {int(k):v for k,v in mapping.items()}
29 | return mapping
30 |
31 | def map_label_image(image, label_mapping):
32 | mapped = np.copy(image)
33 | for k,v in label_mapping.items():
34 | mapped[image==k] = v
35 | # merge some label like bg, wall, ceiling, floor
36 | # bg: 0, wall: 1, floor: 2, ceiling: 22, door: 8
37 | mapped[mapped==1] = 0
38 | mapped[mapped==2] = 0
39 | mapped[mapped==22] = 0
40 | # add windows
41 | mapped[mapped==9] = 0
42 | # add door
43 | mapped[mapped==8] = 0
44 | # add mirror; curtain to 0
45 | mapped[mapped==19] = 0 # mirror
46 | mapped[mapped==16] = 0 # curtain
47 | return mapped.astype(np.uint8)
48 |
49 | def make_instance_image(label_image, instance_image):
50 | output = np.zeros_like(instance_image, dtype=np.uint16)
51 | # oldinst2labelinst = {}
52 | label_instance_counts = {}
53 | old_insts = np.unique(instance_image)
54 | for inst in old_insts:
55 | label = label_image[instance_image==inst][0]
56 | if label in label_instance_counts and label !=0:
57 | inst_count = label_instance_counts[label] + 1
58 | label_instance_counts[label] = inst_count
59 | else:
60 | inst_count = 1
61 | label_instance_counts[label] = inst_count
62 | # oldinst2labelinst[inst] = (label, inst_count)
63 | output[instance_image==inst] = label * 1000 + inst_count
64 | return output
65 |
66 | image_size = 384
67 | trans_totensor = transforms.Compose([
68 | transforms.CenterCrop(image_size*2),
69 | transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR),
70 | ])
71 | depth_trans_totensor = transforms.Compose([
72 | transforms.Resize([968, 1296], interpolation=PIL.Image.NEAREST),
73 | transforms.CenterCrop(image_size*2),
74 | transforms.Resize(image_size, interpolation=PIL.Image.NEAREST),
75 | ])
76 |
77 | seg_trans_totensor = transforms.Compose([
78 | transforms.CenterCrop(image_size*2),
79 | transforms.Resize(image_size, interpolation=PIL.Image.NEAREST),
80 | ])
81 |
82 |
83 |
84 | out_path_prefix = '../data/custom'
85 | data_root = '/media/hdd/monodata_ours/scans'
86 | scenes = ['scene0050_00', 'scene0084_00', 'scene0580_00', 'scene0616_00']
87 | # scenes = ['scene0050_00']
88 | # out_names = ['scan1']
89 | out_names = ['scan1', 'scan2', 'scan3', 'scan4']
90 |
91 | label_map = read_label_mapping('/media/hdd/scannet/tasks/scannetv2-labels.combined.tsv') # path to scannet_labels.csv
92 |
93 | for scene, out_name in zip(scenes, out_names):
94 | out_path = os.path.join(out_path_prefix, out_name)
95 | os.makedirs(out_path, exist_ok=True)
96 | print(out_path)
97 |
98 | folders = ["image", "mask", "depth", "segs"]
99 | for folder in folders:
100 | out_folder = os.path.join(out_path, folder)
101 | os.makedirs(out_folder, exist_ok=True)
102 |
103 | # load color
104 | color_path = os.path.join(data_root, scene, 'color')
105 | color_paths = sorted(glob.glob(os.path.join(color_path, '*.jpg')),
106 | key=lambda x: int(os.path.basename(x)[:-4]))
107 | print(color_paths)
108 |
109 | # load depth
110 | depth_path = os.path.join(data_root, scene, 'depth')
111 | depth_paths = sorted(glob.glob(os.path.join(depth_path, '*.png')),
112 | key=lambda x: int(os.path.basename(x)[:-4]))
113 | print(depth_paths)
114 |
115 | segs_path = os.path.join(data_root, scene, 'segs', 'instance-filt')
116 | segs_paths = sorted(glob.glob(os.path.join(segs_path, '*.png')),
117 | key=lambda x: int(os.path.basename(x)[:-4]))
118 | print(segs_paths)
119 |
120 | labels_path = os.path.join(data_root, scene, 'segs', 'label-filt')
121 | labels_paths = sorted(glob.glob(os.path.join(labels_path, '*.png')),
122 | key = lambda x: int(os.path.basename(x)[:-4]))
123 | print(labels_paths)
124 |
125 | # load intrinsic
126 | intrinsic_path = os.path.join(data_root, scene, 'intrinsic', 'intrinsic_color.txt')
127 | camera_intrinsic = np.loadtxt(intrinsic_path)
128 | print(camera_intrinsic)
129 |
130 | # load pose
131 | pose_path = os.path.join(data_root, scene, 'pose')
132 | poses = []
133 | pose_paths = sorted(glob.glob(os.path.join(pose_path, '*.txt')),
134 | key=lambda x: int(os.path.basename(x)[:-4]))
135 | for pose_path in pose_paths:
136 | c2w = np.loadtxt(pose_path)
137 | poses.append(c2w)
138 | poses = np.array(poses)
139 |
140 | # deal with invalid poses
141 | valid_poses = np.isfinite(poses).all(axis=2).all(axis=1)
142 | min_vertices = poses[:, :3, 3][valid_poses].min(axis=0)
143 | max_vertices = poses[:, :3, 3][valid_poses].max(axis=0)
144 |
145 | center = (min_vertices + max_vertices) / 2.
146 | scale = 2. / (np.max(max_vertices - min_vertices) + 3.)
147 | print(center, scale)
148 |
149 | # we should normalized to unit cube
150 | scale_mat = np.eye(4).astype(np.float32)
151 | scale_mat[:3, 3] = -center
152 | scale_mat[:3 ] *= scale
153 | scale_mat = np.linalg.inv(scale_mat)
154 |
155 | # copy image
156 | out_index = 0
157 | cameras = {}
158 | pcds = []
159 | H, W = 968, 1296
160 |
161 | # center crop by 2 * image_size
162 | offset_x = (W - image_size * 2) * 0.5
163 | offset_y = (H - image_size * 2) * 0.5
164 | camera_intrinsic[0, 2] -= offset_x
165 | camera_intrinsic[1, 2] -= offset_y
166 | # resize from 384*2 to 384
167 | resize_factor = 0.5
168 | camera_intrinsic[:2, :] *= resize_factor
169 |
170 | K = camera_intrinsic
171 | print(K)
172 |
173 | instance_semantic_mapping = {}
174 | instance_label_mapping = {}
175 | label_instance_counts = {}
176 |
177 | for idx, (valid, pose, depth_path, image_path, seg_path, label_path) in enumerate(zip(valid_poses, poses, depth_paths, color_paths, segs_paths, labels_paths)):
178 | print(idx, valid)
179 | if idx % 10 != 0: continue
180 | if not valid : continue
181 |
182 | target_image = os.path.join(out_path, "image/%06d.png"%(out_index))
183 | print(target_image)
184 | img = Image.open(image_path)
185 | img_tensor = trans_totensor(img)
186 | img_tensor.save(target_image)
187 |
188 | mask = (np.ones((image_size, image_size, 3)) * 255.).astype(np.uint8)
189 |
190 | target_image = os.path.join(out_path, "mask/%06d_mask.png"%(out_index))
191 | cv2.imwrite(target_image, mask)
192 |
193 | # load depth
194 | target_image = os.path.join(out_path, "depth/%06d_depth.png"%(out_index))
195 | depth = cv2.imread(depth_path, -1).astype(np.float32) / 1000.
196 | #import pdb; pdb.set_trace()
197 | depth_PIL = Image.fromarray(depth)
198 | new_depth = depth_trans_totensor(depth_PIL)
199 | new_depth = np.asarray(new_depth)
200 | plt.imsave(target_image, new_depth, cmap='viridis')
201 | np.save(target_image.replace(".png", ".npy"), new_depth)
202 |
203 | # segs
204 | target_image = os.path.join(out_path, "segs/%06d_segs.png"%(out_index))
205 | print(target_image)
206 | seg = Image.open(seg_path)
207 | # seg_tensor = trans_totensor(seg)
208 | seg_tensor = seg_trans_totensor(seg)
209 | seg_tensor.save(target_image)
210 | # np.save(target_image)
211 |
212 | # label_mapping
213 | # label = Image.open(label_path)
214 | # label_tensor = trans_totensor(label)
215 | label_np = imageio.imread(label_path)
216 | segs_np = imageio.imread(seg_path)
217 | # import pdb; pdb.set_trace()
218 | mapped_labels = map_label_image(label_np, label_map)
219 | old_insts = np.unique(segs_np)
220 | for inst in old_insts:
221 | label = mapped_labels[segs_np==inst][0]
222 | # import pdb; pdb.set_trace()
223 | instance_semantic_mapping[inst] = label
224 | if label == 0:
225 | instance_label_mapping[inst] = 0
226 | elif label in label_instance_counts:
227 | if inst not in instance_label_mapping:
228 | inst_count = label_instance_counts[label] + 1 # add the number of counting for one label
229 | label_instance_counts[label] = inst_count
230 | # change the instance label mapping index of this instance to 1000*label + counted number
231 | instance_label_mapping[inst] = label * 1000 + inst_count
232 | else:
233 | continue
234 | else:
235 | inst_count = 1
236 | label_instance_counts[label] = inst_count # this label never exist before, so add the inst_count as 1 and put it in label_instance_conuts
237 | instance_label_mapping[inst] = label * 1000 + inst_count
238 |
239 |
240 |
241 |
242 | # save pose
243 | pcds.append(pose[:3, 3])
244 | pose = K @ np.linalg.inv(pose)
245 |
246 | #cameras["scale_mat_%d"%(out_index)] = np.eye(4).astype(np.float32)
247 | cameras["scale_mat_%d"%(out_index)] = scale_mat
248 | cameras["world_mat_%d"%(out_index)] = pose
249 |
250 | out_index += 1
251 |
252 | # save the instance mapping file to output path
253 | print({k: v for k, v in sorted(instance_label_mapping.items())})
254 | with open(os.path.join(out_path, 'instance_mapping.txt'), 'w') as f:
255 | # f.write(str(sorted(label_set)).strip('[').strip(']'))
256 | for k, v in instance_label_mapping.items():
257 | f.write(str(k)+','+str(instance_semantic_mapping[k])+','+str(v)+'\n')
258 | f.close()
259 |
260 | #np.savez(os.path.join(out_path, "cameras_sphere.npz"), **cameras)
261 | np.savez(os.path.join(out_path, "cameras.npz"), **cameras)
262 |
--------------------------------------------------------------------------------
/replica_eval/avg_metric.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 |
5 | root_dir = './evaluation_results' # path to the evaluation results folder
6 |
7 | all_metric = []
8 | for fname in os.listdir(root_dir):
9 | metric_result = np.load(os.path.join(root_dir, fname, 'metrics_3D_obj.npy'))
10 | if np.isnan(metric_result.any()):
11 | print(fname)
12 | # print(metric_result.shape)
13 | all_metric.append(metric_result)
14 | print(fname, 0.5*(metric_result.mean(1)[0] + metric_result.mean(1)[1]))
15 | result = np.concatenate(all_metric, 1)
16 | # # recompute the f-score from precision and recall
17 | precision_1, completion_1 = result[4], result[2]
18 | _sum = precision_1 + completion_1
19 | _prod = 2*precision_1 * completion_1
20 |
21 |
22 | print('var F_score 5cm: {}'.format(result.var(1)[7]))
23 | print('var Chamer: {}'.format((0.5*(result[0]+result[1])).var(0)))
24 | # print('chamfer mean each scene: {}'.format(0.5*(result.mean(1)+result.mean(2))))
25 | print('Acc mean: {}, Comp: {}, chamfer: {}, Ratio 1cm: {}, Ratio 5cm: {}, F_score 1cm: {}, F_score 5cm: {}'.format(result.mean(1)[0], result.mean(1)[1], 0.5*(result.mean(1)[0]+result.mean(1)[1]), result.mean(1)[2], result.mean(1)[3], result.mean(1)[6], result.mean(1)[7]))
26 |
--------------------------------------------------------------------------------
/replica_eval/cull_mesh.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/cvg/nice-slam
2 | import os
3 | import numpy as np
4 | import argparse
5 | import pickle
6 | import os
7 | import glob
8 | import open3d as o3d
9 | import matplotlib.pyplot as plt
10 | import torch
11 | import trimesh
12 |
13 |
14 | def load_poses(path):
15 | poses = []
16 | with open(path, "r") as f:
17 | lines = f.readlines()
18 | for line in lines:
19 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4)
20 | c2w[:3, 1] *= -1
21 | c2w[:3, 2] *= -1
22 | c2w = torch.from_numpy(c2w).float()
23 | poses.append(c2w)
24 | return poses
25 |
26 |
27 | parser = argparse.ArgumentParser(
28 | description='Arguments to cull the mesh.'
29 | )
30 |
31 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be culled')
32 | parser.add_argument('--input_scalemat', type=str, help='path to the scale mat')
33 | parser.add_argument('--traj', type=str, help='path to the trajectory')
34 | parser.add_argument('--output_mesh', type=str, help='path to the output mesh')
35 | args = parser.parse_args()
36 |
37 | H = 680
38 | W = 1200
39 | fx = 600.0
40 | fy = 600.0
41 | fx = 600.0
42 | cx = 599.5
43 | cy = 339.5
44 | # scale = 6553.5
45 |
46 | poses = load_poses(args.traj)
47 | n_imgs = len(poses)
48 | mesh = trimesh.load(args.input_mesh, process=False)
49 | # mesh.export(args.output_mesh+'_raw.ply')
50 |
51 | # transform to original coordinate system with scale mat
52 | if args.input_scalemat is not None:
53 | scalemat = np.load(args.input_scalemat)['scale_mat_0']
54 | mesh.vertices = mesh.vertices @ scalemat[:3, :3].T + scalemat[:3, 3]
55 | else:
56 | # print('not input scalemat')
57 | scalemat = np.eye(4)
58 | mesh.vertices = mesh.vertices @ scalemat[:3, :3].T + scalemat[:3, 3]
59 |
60 | pc = mesh.vertices
61 | faces = mesh.faces
62 |
63 |
64 | # delete mesh vertices that are not inside any camera's viewing frustum
65 | whole_mask = np.ones(pc.shape[0]).astype(bool)
66 | for i in range(0, n_imgs, 1):
67 | c2w = poses[i]
68 | points = pc.copy()
69 | points = torch.from_numpy(points).cuda()
70 | w2c = np.linalg.inv(c2w)
71 | w2c = torch.from_numpy(w2c).cuda().float()
72 | K = torch.from_numpy(
73 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda()
74 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
75 | homo_points = torch.cat(
76 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float()
77 | cam_cord_homo = w2c@homo_points
78 | cam_cord = cam_cord_homo[:, :3]
79 |
80 | cam_cord[:, 0] *= -1
81 | uv = K.float()@cam_cord.float()
82 | z = uv[:, -1:]+1e-5
83 | uv = uv[:, :2]/z
84 | uv = uv.float().squeeze(-1).cpu().numpy()
85 | edge = 0
86 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W -
87 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge)
88 | whole_mask &= ~mask
89 | pc = mesh.vertices
90 | faces = mesh.faces
91 | face_mask = whole_mask[mesh.faces].all(axis=1)
92 | mesh.update_faces(~face_mask)
93 | mesh.export(args.output_mesh)
94 |
--------------------------------------------------------------------------------
/replica_eval/cull_obj_gt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 |
5 | scans = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
6 |
7 | gt_data_dir = '/media/hdd/Replica-Dataset/vmap/'
8 |
9 | for idx, exp in enumerate(scans):
10 | idx = idx + 1
11 | folder_name = os.path.join(gt_data_dir, exp[:-1]+'_'+exp[-1], 'habitat')
12 | files = list(filter(os.path.isfile, glob.glob(os.path.join(folder_name, 'mesh_semantic.*.ply'))))
13 | # print(files)
14 | # exit()
15 | print(files[0].split('/')[-1])
16 | os.makedirs(os.path.join(folder_name, 'cull_object_mesh'), exist_ok=True)
17 | for name in files:
18 | cull_mesh_out = os.path.join(folder_name, 'cull_object_mesh', 'cull_'+name.split('/')[-1])
19 | # print(cull_mesh_out)
20 | cmd = f"python cull_mesh.py --input_mesh {name} --traj ../data/replica/Replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}"
21 | print(cmd)
22 | os.system(cmd)
23 | # exit()
24 |
--------------------------------------------------------------------------------
/replica_eval/eval_3D_obj.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import trimesh
4 | from metrics import accuracy, completion, completion_ratio
5 | import os
6 | import json
7 | import glob
8 |
9 | def calc_3d_metric(mesh_rec, mesh_gt, N=200000):
10 | """
11 | 3D reconstruction metric.
12 | """
13 | metrics = [[] for _ in range(8)]
14 | transform, extents = trimesh.bounds.oriented_bounds(mesh_gt)
15 | extents = extents / 0.9 if N != 200000 else extents # enlarge 0.9 for objects
16 | # extents = extents *1.2 if N != 200000 else extents # enlarge 0.9 for objects
17 | box = trimesh.creation.box(extents=extents, transform=np.linalg.inv(transform))
18 | mesh_rec = mesh_rec.slice_plane(box.facets_origin, -box.facets_normal)
19 | if mesh_rec.vertices.shape[0] == 0:
20 | print("no mesh found")
21 | return
22 | rec_pc = trimesh.sample.sample_surface(mesh_rec, N)
23 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0])
24 |
25 | gt_pc = trimesh.sample.sample_surface(mesh_gt, N)
26 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0])
27 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices)
28 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices)
29 | completion_ratio_rec = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.05)
30 | completion_ratio_rec_1 = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.01)
31 |
32 | precision_ratio_rec = completion_ratio(rec_pc_tri.vertices, gt_pc_tri.vertices, 0.05)
33 | precision_ratio_rec_1 = completion_ratio(rec_pc_tri.vertices, gt_pc_tri.vertices, 0.01)
34 |
35 | f_score = 2*precision_ratio_rec*completion_ratio_rec / (completion_ratio_rec + precision_ratio_rec)
36 | f_score_1 = 2 * precision_ratio_rec_1*completion_ratio_rec_1 / (completion_ratio_rec_1 + precision_ratio_rec_1)
37 |
38 |
39 | # accuracy_rec *= 100 # convert to cm
40 | # completion_rec *= 100 # convert to cm
41 | # completion_ratio_rec *= 100 # convert to %
42 | # print('accuracy: ', accuracy_rec)
43 | # print('completion: ', completion_rec)
44 | # print('completion ratio: ', completion_ratio_rec)
45 | # print("completion_ratio_rec_1cm ", completion_ratio_rec_1)
46 | metrics[0].append(accuracy_rec)
47 | metrics[1].append(completion_rec)
48 | metrics[2].append(completion_ratio_rec_1)
49 | metrics[3].append(completion_ratio_rec)
50 | metrics[4].append(precision_ratio_rec_1)
51 | metrics[5].append(precision_ratio_rec)
52 | metrics[6].append(np.nan_to_num(f_score_1))
53 | metrics[7].append(np.nan_to_num(f_score))
54 |
55 | return metrics
56 |
57 | def get_gt_bg_mesh(gt_dir, background_cls_list):
58 | with open(os.path.join(gt_dir, "info_semantic.json")) as f:
59 | label_obj_list = json.load(f)["objects"]
60 |
61 | bg_meshes = []
62 | for obj in label_obj_list:
63 | if int(obj["class_id"]) in background_cls_list:
64 | obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(int(obj["id"])) + ".ply")
65 | obj_mesh = trimesh.load(obj_file)
66 | bg_meshes.append(obj_mesh)
67 |
68 | bg_mesh = trimesh.util.concatenate(bg_meshes)
69 | return bg_mesh
70 |
71 | def get_obj_ids(obj_dir):
72 | files = os.listdir(obj_dir)
73 | obj_ids = []
74 | for f in files:
75 | obj_id = f.split("obj")[1][:-1]
76 | if obj_id == '':
77 | continue
78 | obj_ids.append(int(obj_id))
79 | return obj_ids
80 |
81 | def get_obj_ids_ours(obj_dir):
82 | files = list(filter(os.path.isfile, glob.glob(os.path.join(obj_dir, '*[0-9].ply'))))
83 | epoch_count = set([int(os.path.basename(f).split('_')[1]) for f in files])
84 | max_epoch = max(epoch_count)
85 | # print(max_epoch)
86 | obj_file = [int(os.path.basename(f).split('.')[0].split('_')[2]) for f in files if int(os.path.basename(f).split('_')[1])==max_epoch]
87 | return sorted(obj_file), max_epoch
88 |
89 |
90 |
91 | def get_gt_mesh_from_objid(gt_dir, mapping_list):
92 | combined_meshes = []
93 | if len(mapping_list) == 1:
94 | return trimesh.load(os.path.join(gt_dir, 'cull_mesh_semantic.ply_'+str(mapping_list[0])+'.ply'))
95 | else:
96 | for idx in mapping_list:
97 | if os.path.isfile(os.path.join(gt_dir, 'cull_mesh_semantic.ply_'+str(int(idx))+'.ply')):
98 | obj_file = os.path.join(gt_dir, 'cull_mesh_semantic.ply_'+str(int(idx))+'.ply')
99 | obj_mesh = trimesh.load(obj_file)
100 | combined_meshes.append(obj_mesh)
101 | else:
102 | continue
103 | combine_mesh = trimesh.util.concatenate(combined_meshes)
104 | return combine_mesh
105 |
106 |
107 |
108 | if __name__ == "__main__":
109 | exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
110 | data_dir = "/media/hdd/Replica-Dataset/vmap" # where to store the data
111 | log_dir = './evaluation_results' # where to store the evaluation results
112 | info_dir = '../data/replica' # path to dataset information
113 | mesh_rec_root = "../exps/objectsdfplus_replica" # path to the reconstruction results
114 | os.makedirs(log_dir, exist_ok=True)
115 | for idx, exp in enumerate(exp_name):
116 | if exp is not "room2": continue
117 | idx = idx + 1
118 | gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat")
119 | info_text = os.path.join(info_dir, 'scan'+str(idx), 'instance_mapping.txt')
120 | # mesh_dir = os.path.join()
121 | # mesh_rec_dir = os.path.join('') # path to reconstruction results
122 | # get the lastest folder for evaluation
123 | dirs = sorted(os.listdir(mesh_rec_root+f'_{idx}'))
124 | mesh_rec_dir = os.path.join(mesh_rec_root+f'_{idx}', dirs[-1], "plots")
125 | print(mesh_rec_dir)
126 |
127 | output_dir = os.path.join(log_dir, exp+'_{}'.format(mesh_rec_root.split('/')[-1]))
128 | os.makedirs(output_dir, exist_ok=True)
129 | metrics_3D = [[] for _ in range(8)]
130 |
131 | # only calculate the valid mesh in experiment
132 | instance_mapping = {}
133 | with open(info_text, 'r') as f:
134 | for l in f:
135 | (k, v_sem, v_ins) = l.split(',')
136 | instance_mapping[int(k)] = int(v_ins)
137 | label_mapping = sorted(set(instance_mapping.values()))
138 | # print(label_mapping)
139 | # get all valid obj index
140 | obj_ids, max_epoch = get_obj_ids_ours(mesh_rec_dir)
141 | # print(obj_ids)
142 | for obj_id in tqdm(obj_ids):
143 | inst_id = label_mapping[obj_id]
144 | # merge the gt mesh with the same instance_id
145 | gt_inst_list = [] # a list used to store the index in gt_object mesh that are the same object defined by instance_mapping
146 | for k, v in instance_mapping.items():
147 | if v == inst_id:
148 | gt_inst_list.append(k)
149 | mesh_gt = get_gt_mesh_from_objid(os.path.join(gt_dir, 'cull_object_mesh'), gt_inst_list)
150 | if obj_id == 0:
151 | N=200000
152 | else:
153 | N=10000
154 |
155 | # in order to evaluate the result, we need to cull the object mesh first and then evaluate the metric
156 | # mesh_rec = trimesh.load(os.path.join(mesh_dir, 'surface_{max_epoch}_{obj_id}.ply'))
157 | cull_rec_mesh_path = os.path.join(output_dir, f"{exp}_cull_surface_{max_epoch}_{obj_id}.ply")
158 |
159 | rec_mesh = os.path.join(mesh_rec_dir, f'surface_{max_epoch}_{obj_id}.ply')
160 | # print(rec_mesh)
161 | cmd = f"python cull_mesh.py --input_mesh {rec_mesh} --input_scalemat ../data/replica/scan{idx}/cameras.npz --traj ../data/replica/scan{idx}/traj.txt --output_mesh {cull_rec_mesh_path}"
162 | os.system(cmd)
163 | # evaluate the metric
164 | mesh_rec = trimesh.load(cull_rec_mesh_path)
165 | # use the biggest connected component for evaluation.
166 |
167 | metrics = calc_3d_metric(mesh_rec, mesh_gt, N=N)
168 | if metrics is None:
169 | continue
170 | np.save(output_dir + '/metric_obj{}.npy'.format(obj_id), np.array(metrics))
171 | metrics_3D[0].append(metrics[0]) # acc
172 | metrics_3D[1].append(metrics[1]) # comp
173 | metrics_3D[2].append(metrics[2]) # comp ratio 1cm
174 | metrics_3D[3].append(metrics[3]) # comp ratio 5cm
175 | metrics_3D[4].append(metrics[4]) # precision ratio 1cm
176 | metrics_3D[5].append(metrics[5]) # precision ration 5cm
177 | metrics_3D[6].append(metrics[6]) # f_score 1
178 | metrics_3D[7].append(metrics[7]) # f_score 5cm
179 | metrics_3D = np.array(metrics_3D)
180 | np.save(output_dir + '/metrics_3D_obj.npy', metrics_3D)
181 | print("metrics 3D obj \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n", metrics_3D.mean(axis=1))
182 | print("-----------------------------------------")
183 | print("finish exp ", exp)
184 | # exit()
185 |
186 |
187 |
188 |
189 | # use culled object mesh for evaluation
190 |
191 |
192 |
193 |
194 |
195 |
196 | # background_cls_list = [5, 12, 30, 31, 40, 60, 92, 93, 95, 97, 98, 79]
197 | # exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
198 | # # data_dir = "/home/xin/data/vmap/"
199 | # data_dir = "/media/hdd/Replica-Dataset/vmap/"
200 | # # log_dir = "../logs/iMAP/"
201 | # # log_dir = "../logs/vMAP/"
202 | # log_dir = "/media/hdd/Replica-Dataset/vmap/vMAP_Replica_Results/"
203 |
204 | # for exp in tqdm(exp_name):
205 | # gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat")
206 | # exp_dir = os.path.join(log_dir, exp+'_vmap')
207 | # mesh_dir = os.path.join(exp_dir, "scene_mesh")
208 | # output_path = os.path.join(exp_dir, "eval_mesh")
209 | # os.makedirs(output_path, exist_ok=True)
210 | # metrics_3D = [[] for _ in range(4)]
211 |
212 | # # get obj ids
213 | # # obj_ids = np.loadtxt() # todo use a pre-defined obj list or use vMAP results
214 | # obj_ids = get_obj_ids(mesh_dir.replace("imap", "vmap"))
215 | # for obj_id in tqdm(obj_ids):
216 | # if obj_id == 0: # for bg
217 | # N = 200000
218 | # mesh_gt = get_gt_bg_mesh(gt_dir, background_cls_list)
219 | # else: # for obj
220 | # N = 10000
221 | # obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(obj_id) + ".ply")
222 | # mesh_gt = trimesh.load(obj_file)
223 |
224 | # if "vMAP" in exp_dir:
225 | # rec_meshfile = os.path.join(mesh_dir, "imap_frame2000_obj"+str(obj_id)+".obj")
226 | # # rec_meshfile = os.path.join(mesh_dir, )
227 | # elif "iMAP" in exp_dir:
228 | # rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj0.obj")
229 | # else:
230 | # print("Not Implement")
231 | # exit(-1)
232 |
233 | # mesh_rec = trimesh.load(rec_meshfile)
234 | # # mesh_rec.invert() # niceslam mesh face needs invert
235 | # metrics = calc_3d_metric(mesh_rec, mesh_gt, N=N) # for objs use 10k, for scene use 200k points
236 | # if metrics is None:
237 | # continue
238 | # np.save(output_path + '/metric_obj{}.npy'.format(obj_id), np.array(metrics))
239 | # metrics_3D[0].append(metrics[0]) # acc
240 | # metrics_3D[1].append(metrics[1]) # comp
241 | # metrics_3D[2].append(metrics[2]) # comp ratio 1cm
242 | # metrics_3D[3].append(metrics[3]) # comp ratio 5cm
243 | # metrics_3D = np.array(metrics_3D)
244 | # np.save(output_path + '/metrics_3D_obj.npy', metrics_3D)
245 | # print("metrics 3D obj \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n", metrics_3D.mean(axis=1))
246 | # print("-----------------------------------------")
247 | # print("finish exp ", exp)
248 |
249 | # calculate the avaerage result over all 8 scenes
250 |
--------------------------------------------------------------------------------
/replica_eval/eval_recon.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/cvg/nice-slam
2 | import argparse
3 | import random
4 |
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | import trimesh
9 | from scipy.spatial import cKDTree as KDTree
10 | import cv2
11 |
12 | def normalize(x):
13 | return x / np.linalg.norm(x)
14 |
15 |
16 | def viewmatrix(z, up, pos):
17 | vec2 = normalize(z)
18 | vec1_avg = up
19 | vec0 = normalize(np.cross(vec1_avg, vec2))
20 | vec1 = normalize(np.cross(vec2, vec0))
21 | m = np.stack([vec0, vec1, vec2, pos], 1)
22 | return m
23 |
24 |
25 | def completion_ratio(gt_points, rec_points, dist_th=0.05):
26 | gen_points_kd_tree = KDTree(rec_points)
27 | distances, _ = gen_points_kd_tree.query(gt_points)
28 | comp_ratio = np.mean((distances < dist_th).astype(np.float))
29 | return comp_ratio
30 |
31 |
32 | def accuracy(gt_points, rec_points):
33 | gt_points_kd_tree = KDTree(gt_points)
34 | distances, _ = gt_points_kd_tree.query(rec_points)
35 | acc = np.mean(distances)
36 | return acc, distances
37 |
38 |
39 | def completion(gt_points, rec_points):
40 | gt_points_kd_tree = KDTree(rec_points)
41 | distances, _ = gt_points_kd_tree.query(gt_points)
42 | comp = np.mean(distances)
43 | return comp, distances
44 |
45 | def write_vis_pcd(file, points, colors):
46 | pcd = o3d.geometry.PointCloud()
47 | pcd.points = o3d.utility.Vector3dVector(points)
48 | pcd.colors = o3d.utility.Vector3dVector(colors)
49 | o3d.io.write_point_cloud(file, pcd)
50 |
51 | def get_align_transformation(rec_meshfile, gt_meshfile):
52 | """
53 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh.
54 | """
55 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
56 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
57 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices)
58 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices)
59 | trans_init = np.eye(4)
60 | threshold = 0.1
61 | reg_p2p = o3d.pipelines.registration.registration_icp(
62 | o3d_rec_pc, o3d_gt_pc, threshold, trans_init,
63 | o3d.pipelines.registration.TransformationEstimationPointToPoint())
64 | transformation = reg_p2p.transformation
65 | return transformation
66 |
67 |
68 | def check_proj(points, W, H, fx, fy, cx, cy, c2w):
69 | """
70 | Check if points can be projected into the camera view.
71 |
72 | """
73 | c2w = c2w.copy()
74 | c2w[:3, 1] *= -1.0
75 | c2w[:3, 2] *= -1.0
76 | points = torch.from_numpy(points).cuda().clone()
77 | w2c = np.linalg.inv(c2w)
78 | w2c = torch.from_numpy(w2c).cuda().float()
79 | K = torch.from_numpy(
80 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda()
81 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda()
82 | homo_points = torch.cat(
83 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float() # (N, 4)
84 | cam_cord_homo = w2c@homo_points # (N, 4, 1)=(4,4)*(N, 4, 1)
85 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1)
86 | cam_cord[:, 0] *= -1
87 | uv = K.float()@cam_cord.float()
88 | z = uv[:, -1:]+1e-5
89 | uv = uv[:, :2]/z
90 | uv = uv.float().squeeze(-1).cpu().numpy()
91 | edge = 0
92 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W -
93 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge)
94 | return mask.sum() > 0
95 |
96 | def nn_correspondance(verts1, verts2):
97 | indices = []
98 | distances = []
99 | if len(verts1) == 0 or len(verts2) == 0:
100 | return indices, distances
101 |
102 | kdtree = KDTree(verts1)
103 | distances, indices = kdtree.query(verts2)
104 | distances = distances.reshape(-1)
105 |
106 | return distances, indices
107 |
108 |
109 | def calc_3d_metric(rec_meshfile, gt_meshfile, align=False):
110 | """
111 | 3D reconstruction metric.
112 |
113 | """
114 | mesh_rec = trimesh.load(rec_meshfile, process=False)
115 | mesh_gt = trimesh.load(gt_meshfile, process=False)
116 |
117 | if align:
118 | transformation = get_align_transformation(rec_meshfile, gt_meshfile)
119 | mesh_rec = mesh_rec.apply_transform(transformation)
120 |
121 | # found the aligned bbox for the mesh
122 | to_align, _ = trimesh.bounds.oriented_bounds(mesh_gt)
123 | mesh_gt.vertices = (to_align[:3, :3] @ mesh_gt.vertices.T + to_align[:3, 3:]).T
124 | mesh_rec.vertices = (to_align[:3, :3] @ mesh_rec.vertices.T + to_align[:3, 3:]).T
125 |
126 | min_points = mesh_gt.vertices.min(axis=0) * 1.005
127 | max_points = mesh_gt.vertices.max(axis=0) * 1.005
128 |
129 | mask_min = (mesh_rec.vertices - min_points[None]) > 0
130 | mask_max = (mesh_rec.vertices - max_points[None]) < 0
131 |
132 | mask = np.concatenate((mask_min, mask_max), axis=1).all(axis=1)
133 | face_mask = mask[mesh_rec.faces].all(axis=1)
134 |
135 | mesh_rec.update_vertices(mask)
136 | mesh_rec.update_faces(face_mask)
137 |
138 | rec_pc = trimesh.sample.sample_surface(mesh_rec, 200000)
139 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0])
140 |
141 | gt_pc = trimesh.sample.sample_surface(mesh_gt, 200000)
142 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0])
143 | accuracy_rec, dist_d2s = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices)
144 | completion_rec, dist_s2d = completion(gt_pc_tri.vertices, rec_pc_tri.vertices)
145 | completion_ratio_rec = completion_ratio(
146 | gt_pc_tri.vertices, rec_pc_tri.vertices)
147 |
148 | precision_ratio_rec = completion_ratio(
149 | rec_pc_tri.vertices, gt_pc_tri.vertices)
150 |
151 | fscore = 2 * precision_ratio_rec * completion_ratio_rec / (completion_ratio_rec + precision_ratio_rec)
152 |
153 | # normal consistency
154 | N = 200000
155 | pointcloud_pred, idx = mesh_rec.sample(N, return_index=True)
156 | pointcloud_pred = pointcloud_pred.astype(np.float32)
157 | normal_pred = mesh_rec.face_normals[idx]
158 |
159 | pointcloud_trgt, idx = mesh_gt.sample(N, return_index=True)
160 | pointcloud_trgt = pointcloud_trgt.astype(np.float32)
161 | normal_trgt = mesh_gt.face_normals[idx]
162 |
163 | _, index1 = nn_correspondance(pointcloud_pred, pointcloud_trgt)
164 | _, index2 = nn_correspondance(pointcloud_trgt, pointcloud_pred)
165 |
166 | normal_acc = np.abs((normal_pred * normal_trgt[index2.reshape(-1)]).sum(axis=-1)).mean()
167 | normal_comp = np.abs((normal_trgt * normal_pred[index1.reshape(-1)]).sum(axis=-1)).mean()
168 | normal_avg = (normal_acc + normal_comp) * 0.5
169 |
170 | accuracy_rec *= 100 # convert to cm
171 | completion_rec *= 100 # convert to cm
172 | completion_ratio_rec *= 100 # convert to %
173 | precision_ratio_rec *= 100 # convert to %
174 | fscore *= 100
175 | normal_acc *= 100
176 | normal_comp *= 100
177 | normal_avg *= 100
178 |
179 | print('Acc: ', accuracy_rec, 'Comp:', completion_rec, 'Rrecision: ', precision_ratio_rec, 'Comp_ratio: ', completion_ratio_rec, 'F_score: ', fscore, 'Normal_acc: ', normal_acc, 'Normal_comp: ', normal_comp, 'Normal_avg: ',normal_avg)
180 |
181 | # add visualization and save the mesh output
182 | vis_dist = 0.05 # dist_th = 5 cm
183 | data_alpha = (dist_d2s.clip(max=vis_dist) / vis_dist).reshape(-1, 1)
184 | #data_color = R * data_alpha + W * (1-data_alpha)
185 | im_gray = (data_alpha * 255).astype(np.uint8)
186 | data_color = cv2.applyColorMap(im_gray, cv2.COLORMAP_JET)[:,0,[2, 0, 1]] / 255.
187 | write_vis_pcd(f'{rec_meshfile}_d2s.ply', rec_pc_tri.vertices, data_color)
188 |
189 | stl_alpha = (dist_s2d.clip(max=vis_dist) / vis_dist).reshape(-1, 1)
190 | #stl_color = R * stl_alpha + W * (1-stl_alpha)
191 | im_gray = (stl_alpha * 255).astype(np.uint8)
192 | stl_color = cv2.applyColorMap(im_gray, cv2.COLORMAP_JET)[:,0,[2, 0, 1]] / 255.
193 | write_vis_pcd(f'{rec_meshfile}_s2d.ply', gt_pc_tri.vertices, stl_color)
194 |
195 |
196 | def get_cam_position(gt_meshfile):
197 | mesh_gt = trimesh.load(gt_meshfile)
198 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt)
199 | extents[2] *= 0.7
200 | extents[1] *= 0.7
201 | extents[0] *= 0.3
202 | transform = np.linalg.inv(to_origin)
203 | transform[2, 3] += 0.4
204 | return extents, transform
205 |
206 |
207 | def calc_2d_metric(rec_meshfile, gt_meshfile, align=False, n_imgs=1000):
208 | """
209 | 2D reconstruction metric, depth L1 loss.
210 |
211 | """
212 | H = 500
213 | W = 500
214 | focal = 300
215 | fx = focal
216 | fy = focal
217 | cx = H/2.0-0.5
218 | cy = W/2.0-0.5
219 |
220 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile)
221 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile)
222 | unseen_gt_pointcloud_file = gt_meshfile.replace('.ply', '_pc_unseen.npy')
223 | pc_unseen = np.load(unseen_gt_pointcloud_file)
224 | if align:
225 | transformation = get_align_transformation(rec_meshfile, gt_meshfile)
226 | rec_mesh = rec_mesh.transform(transformation)
227 |
228 | # get vacant area inside the room
229 | extents, transform = get_cam_position(gt_meshfile)
230 |
231 | vis = o3d.visualization.Visualizer()
232 | vis.create_window(width=W, height=H)
233 | vis.get_render_option().mesh_show_back_face = True
234 | errors = []
235 | for i in range(n_imgs):
236 | while True:
237 | # sample view, and check if unseen region is not inside the camera view, if inside, then needs to resample
238 | up = [0, 0, -1]
239 | origin = trimesh.sample.volume_rectangular(
240 | extents, 1, transform=transform)
241 | origin = origin.reshape(-1)
242 | tx = round(random.uniform(-10000, +10000), 2)
243 | ty = round(random.uniform(-10000, +10000), 2)
244 | tz = round(random.uniform(-10000, +10000), 2)
245 | target = [tx, ty, tz]
246 | target = np.array(target)-np.array(origin)
247 | c2w = viewmatrix(target, up, origin)
248 | tmp = np.eye(4)
249 | tmp[:3, :] = c2w
250 | c2w = tmp
251 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w)
252 | if (~seen):
253 | break
254 |
255 | param = o3d.camera.PinholeCameraParameters()
256 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array
257 |
258 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic(
259 | W, H, fx, fy, cx, cy)
260 |
261 | ctr = vis.get_view_control()
262 | ctr.set_constant_z_far(20)
263 | ctr.convert_from_pinhole_camera_parameters(param)
264 |
265 | vis.add_geometry(gt_mesh, reset_bounding_box=True,)
266 | ctr.convert_from_pinhole_camera_parameters(param)
267 | vis.poll_events()
268 | vis.update_renderer()
269 | gt_depth = vis.capture_depth_float_buffer(True)
270 | gt_depth = np.asarray(gt_depth)
271 | vis.remove_geometry(gt_mesh, reset_bounding_box=True,)
272 |
273 | vis.add_geometry(rec_mesh, reset_bounding_box=True,)
274 | ctr.convert_from_pinhole_camera_parameters(param)
275 | vis.poll_events()
276 | vis.update_renderer()
277 | ours_depth = vis.capture_depth_float_buffer(True)
278 | ours_depth = np.asarray(ours_depth)
279 | vis.remove_geometry(rec_mesh, reset_bounding_box=True,)
280 |
281 | errors += [np.abs(gt_depth-ours_depth).mean()]
282 |
283 | errors = np.array(errors)
284 | # from m to cm
285 | print('Depth L1: ', errors.mean()*100)
286 |
287 |
288 | if __name__ == '__main__':
289 |
290 | parser = argparse.ArgumentParser(
291 | description='Arguments to evaluate the reconstruction.'
292 | )
293 | parser.add_argument('--rec_mesh', type=str,
294 | help='reconstructed mesh file path')
295 | parser.add_argument('--gt_mesh', type=str,
296 | help='ground truth mesh file path')
297 | args = parser.parse_args()
298 | calc_3d_metric(args.rec_mesh, args.gt_mesh)
299 | #calc_2d_metric(args.rec_mesh, args.gt_mesh, n_imgs=1000)
300 |
--------------------------------------------------------------------------------
/replica_eval/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import cv2 as cv
5 | import numpy as np
6 | import os
7 | import glob
8 |
9 | import trimesh
10 | from pathlib import Path
11 | import subprocess
12 |
13 |
14 | scans = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
15 |
16 |
17 | root_dir = "../exps" # path to the experiment results
18 | exp_name = 'objectsdfplus_replica' # experiment name
19 | out_dir = "evaluation/scene_results" # path to save the scene evaluation results
20 | Path(out_dir).mkdir(parents=True, exist_ok=True)
21 |
22 | # evaluation_txt_file = "evaluation/replica_objsdf_star.csv"
23 | evaluation_txt_file = "evaluation/rebuttal_oneobject.csv"
24 | evaluation_txt_file = open(evaluation_txt_file, 'w')
25 |
26 |
27 | for idx, scan in enumerate(scans):
28 | idx = idx + 1
29 | if scan != "room2":
30 | continue
31 | # test set
32 | #if not (idx in [4, 6, 7]):
33 | # continue
34 |
35 | cur_exp = f"{exp_name}_{idx}"
36 | cur_root = os.path.join(root_dir, cur_exp)
37 | # use first timestamps
38 | dirs = sorted(os.listdir(cur_root))
39 | cur_root = os.path.join(cur_root, dirs[-1])
40 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*_whole.ply"))))
41 |
42 | files.sort(key=lambda x:os.path.getmtime(x))
43 | ply_file = files[-1]
44 | print(ply_file)
45 |
46 | # curmesh
47 | cull_mesh_out = os.path.join(out_dir, f"{scan}.ply")
48 | cmd = f"python cull_mesh.py --input_mesh {ply_file} --input_scalemat ../data/replica/Replica/scan{idx}/cameras.npz --traj ../data/replica/Replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}"
49 | print(cmd)
50 | os.system(cmd)
51 |
52 | cmd = f"python eval_recon.py --rec_mesh {cull_mesh_out} --gt_mesh ../data/replica/Replica/cull_GTmesh/{scan}.ply"
53 | print(cmd)
54 | # accuracy_rec, completion_rec, precision_ratio_rec, completion_ratio_rec, fscore, normal_acc, normal_comp, normal_avg
55 | output = subprocess.check_output(cmd, shell=True).decode("utf-8")
56 | output = output.replace(" ", ",")
57 | print(output)
58 |
59 | evaluation_txt_file.write(f"{scan},{Path(ply_file).name},{output}")
60 | evaluation_txt_file.flush()
--------------------------------------------------------------------------------
/replica_eval/evaluate_single_scene.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import subprocess
4 | import argparse
5 |
6 |
7 | scans = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"]
8 |
9 | if __name__ == "__main__":
10 |
11 | parser = argparse.ArgumentParser(
12 | description='Arguments to evaluate the mesh.'
13 | )
14 |
15 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be evaluated')
16 | parser.add_argument('--scan_id', type=str, help='scan id of the input mesh')
17 | parser.add_argument('--output_dir', type=str, default='evaluation_results_single', help='path to the output folder')
18 | args = parser.parse_args()
19 |
20 |
21 |
22 | out_dir = args.output_dir
23 | Path(out_dir).mkdir(parents=True, exist_ok=True)
24 |
25 | idx = args.scan_id
26 | scan = scans[int(idx) - 1]
27 |
28 | ply_file = args.input_mesh
29 |
30 | result_mesh_file = os.path.join(out_dir, "culled_mesh.ply")
31 |
32 | # cumesh
33 | cull_mesh_out = os.path.join(out_dir, f"cull_{scan}.ply")
34 | # cmd = f"python cull_mesh.py --input_mesh {ply_file} --input_scalemat ../data/Replica/scan{idx}/cameras.npz --traj ../data/Replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}"
35 | cmd = f"python cull_mesh.py --input_mesh {ply_file} --input_scalemat ../data/replica/scan{idx}/cameras.npz --traj ../data/replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}"
36 | print(cmd)
37 | os.system(cmd)
38 |
39 | cmd = f"python eval_recon.py --rec_mesh {cull_mesh_out} --gt_mesh ../data/replica/cull_GTmesh/{scan}.ply"
40 | print(cmd)
41 | # accuracy_rec, completion_rec, precision_ratio_rec, completion_ratio_rec, fscore, normal_acc, normal_comp, normal_avg
42 | output = subprocess.check_output(cmd, shell=True).decode("utf-8")
43 | output = output.replace(" ", ",")
44 | print(output)
--------------------------------------------------------------------------------
/replica_eval/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial import cKDTree as KDTree
3 |
4 | def completion_ratio(gt_points, rec_points, dist_th=0.01):
5 | gen_points_kd_tree = KDTree(rec_points)
6 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)
7 | completion = np.mean((one_distances < dist_th).astype(np.float))
8 | return completion
9 |
10 |
11 | def accuracy(gt_points, rec_points):
12 | gt_points_kd_tree = KDTree(gt_points)
13 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points)
14 | gen_to_gt_chamfer = np.mean(two_distances)
15 | return gen_to_gt_chamfer
16 |
17 |
18 | def completion(gt_points, rec_points):
19 | gt_points_kd_tree = KDTree(rec_points)
20 | one_distances, two_vertex_ids = gt_points_kd_tree.query(gt_points)
21 | gt_to_gen_chamfer = np.mean(one_distances)
22 | return gt_to_gen_chamfer
23 |
24 |
25 | def chamfer(gt_points, rec_points):
26 | # one direction
27 | gen_points_kd_tree = KDTree(rec_points)
28 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points)
29 | gt_to_gen_chamfer = np.mean(one_distances)
30 |
31 | # other direction
32 | gt_points_kd_tree = KDTree(gt_points)
33 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points)
34 | gen_to_gt_chamfer = np.mean(two_distances)
35 |
36 | return (gt_to_gen_chamfer + gen_to_gt_chamfer) / 2.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0
2 | torchvision
3 | torchaudio
4 | torch
5 | trimesh
6 | pyhocon==0.3.59
7 | tqdm
8 | scikit-image
9 | matplotlib
10 | opencv-python
11 | tensorboard
12 | scipy
13 | numpy
14 | plotly
15 | ninja
16 | gdown
17 |
--------------------------------------------------------------------------------
/scannet_eval/evaluate.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/zju3dv/manhattan_sdf
2 | import numpy as np
3 | import open3d as o3d
4 | from sklearn.neighbors import KDTree
5 | import trimesh
6 | import torch
7 | import glob
8 | import os
9 | import pyrender
10 | import os
11 | from tqdm import tqdm
12 | from pathlib import Path
13 |
14 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
15 |
16 | def nn_correspondance(verts1, verts2):
17 | indices = []
18 | distances = []
19 | if len(verts1) == 0 or len(verts2) == 0:
20 | return indices, distances
21 |
22 | kdtree = KDTree(verts1)
23 | distances, indices = kdtree.query(verts2)
24 | distances = distances.reshape(-1)
25 |
26 | return distances
27 |
28 |
29 | def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):
30 | pcd_trgt = o3d.geometry.PointCloud()
31 | pcd_pred = o3d.geometry.PointCloud()
32 |
33 | pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3])
34 | pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3])
35 |
36 | if down_sample:
37 | pcd_pred = pcd_pred.voxel_down_sample(down_sample)
38 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)
39 |
40 | verts_pred = np.asarray(pcd_pred.points)
41 | verts_trgt = np.asarray(pcd_trgt.points)
42 |
43 | dist1 = nn_correspondance(verts_pred, verts_trgt)
44 | dist2 = nn_correspondance(verts_trgt, verts_pred)
45 |
46 | precision = np.mean((dist2 < threshold).astype('float'))
47 | recal = np.mean((dist1 < threshold).astype('float'))
48 | fscore = 2 * precision * recal / (precision + recal)
49 | metrics = {
50 | 'Acc': np.mean(dist2),
51 | 'Comp': np.mean(dist1),
52 | 'Prec': precision,
53 | 'Recal': recal,
54 | 'F-score': fscore,
55 | }
56 | return metrics
57 |
58 | # hard-coded image size
59 | H, W = 968, 1296
60 |
61 | # load pose
62 | def load_poses(scan_id):
63 | pose_path = os.path.join(f'../data/scannet/scan{scan_id}', 'pose')
64 | # pose_path = os.path.join(f'../data/custom/scan{scan_id}', 'pose')
65 | poses = []
66 | pose_paths = sorted(glob.glob(os.path.join(pose_path, '*.txt')),
67 | key=lambda x: int(os.path.basename(x)[:-4]))
68 | for pose_path in pose_paths[::10]:
69 | c2w = np.loadtxt(pose_path)
70 | if np.isfinite(c2w).any():
71 | poses.append(c2w)
72 | poses = np.array(poses)
73 |
74 | return poses
75 |
76 |
77 | class Renderer():
78 | def __init__(self, height=480, width=640):
79 | self.renderer = pyrender.OffscreenRenderer(width, height)
80 | self.scene = pyrender.Scene()
81 | # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES
82 |
83 | def __call__(self, height, width, intrinsics, pose, mesh):
84 | self.renderer.viewport_height = height
85 | self.renderer.viewport_width = width
86 | self.scene.clear()
87 | self.scene.add(mesh)
88 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],
89 | fx=intrinsics[0, 0], fy=intrinsics[1, 1])
90 | self.scene.add(cam, pose=self.fix_pose(pose))
91 | return self.renderer.render(self.scene) # , self.render_flags)
92 |
93 | def fix_pose(self, pose):
94 | # 3D Rotation about the x-axis.
95 | t = np.pi
96 | c = np.cos(t)
97 | s = np.sin(t)
98 | R = np.array([[1, 0, 0],
99 | [0, c, -s],
100 | [0, s, c]])
101 | axis_transform = np.eye(4)
102 | axis_transform[:3, :3] = R
103 | return pose @ axis_transform
104 |
105 | def mesh_opengl(self, mesh):
106 | return pyrender.Mesh.from_trimesh(mesh)
107 |
108 | def delete(self):
109 | self.renderer.delete()
110 |
111 |
112 | def refuse(mesh, poses, K):
113 | renderer = Renderer()
114 | mesh_opengl = renderer.mesh_opengl(mesh)
115 | volume = o3d.pipelines.integration.ScalableTSDFVolume(
116 | voxel_length=0.01,
117 | sdf_trunc=3 * 0.01,
118 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
119 | )
120 |
121 | for pose in tqdm(poses):
122 | intrinsic = np.eye(4)
123 | intrinsic[:3, :3] = K
124 |
125 | rgb = np.ones((H, W, 3))
126 | rgb = (rgb * 255).astype(np.uint8)
127 | rgb = o3d.geometry.Image(rgb)
128 | _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl)
129 | depth_pred = o3d.geometry.Image(depth_pred)
130 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
131 | rgb, depth_pred, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False
132 | )
133 | fx, fy, cx, cy = intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]
134 | intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy)
135 | extrinsic = np.linalg.inv(pose)
136 | volume.integrate(rgbd, intrinsic, extrinsic)
137 |
138 | return volume.extract_triangle_mesh()
139 |
140 |
141 | root_dir = "../exps/objectsdfplus_scannet" # path to your experiments
142 | exp_name = "scan" # experiment name
143 | out_dir ="evaluation/eval_mesh" # output directory for evaluation results
144 | Path(out_dir).mkdir(parents=True, exist_ok=True)
145 |
146 |
147 | scenes = ["scene0050_00", "scene0084_00", "scene0580_00", "scene0616_00"]
148 | all_results = []
149 | for idx, scan in enumerate(scenes):
150 | idx = idx + 1
151 | cur_exp = f"{exp_name}_{idx}"
152 | cur_root = os.path.join(root_dir, cur_exp)
153 | print(cur_root)
154 | if not os.path.exists(cur_root):
155 | print('Current experment {} folder is not exist, skiped'.format(cur_root))
156 | continue
157 | # use lastest timestamps
158 | dirs = sorted(os.listdir(cur_root))
159 | cur_root = os.path.join(cur_root, dirs[-1])
160 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*_whole.ply")))) # the whole mesh path
161 | # files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "*_whole.ply"))))
162 | # files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*.ply"))))
163 | print(files)
164 |
165 | # evalute the latest mesh
166 | files.sort(key=lambda x:os.path.getmtime(x))
167 | ply_file = files[-1]
168 | print(ply_file)
169 |
170 | mesh = trimesh.load(ply_file)
171 |
172 | # transform to world coordinate
173 | cam_file = f"../data/scannet/scan{idx}/cameras.npz"
174 | # cam_file = f"../data/custom/scan{idx}/cameras.npz"
175 | scale_mat = np.load(cam_file)['scale_mat_0']
176 | mesh.vertices = (scale_mat[:3, :3] @ mesh.vertices.T + scale_mat[:3, 3:]).T
177 |
178 | # load pose and intrinsic for render depth
179 | poses = load_poses(idx)
180 |
181 | intrinsic_path = os.path.join(f'../data/scannet/scan{idx}/intrinsic/intrinsic_color.txt')
182 | # intrinsic_path = os.path.join(f'../data/custom/scan{idx}/intrinsic/intrinsic_color.txt')
183 | K = np.loadtxt(intrinsic_path)[:3, :3]
184 |
185 | mesh = refuse(mesh, poses, K)
186 |
187 | # save mesh
188 | out_mesh_path = os.path.join(out_dir, f"{exp_name}_scan_{idx}_{scan}.ply")
189 | o3d.io.write_triangle_mesh(out_mesh_path, mesh)
190 | mesh = trimesh.load(out_mesh_path)
191 |
192 |
193 | gt_mesh = os.path.join("../data/scannet/GTmesh", f"{scan}_vh_clean_2.ply")
194 | # gt_mesh = os.path.join("../data/scannet/GTmesh_lowres", f"{scan[5:]}.obj") # a low-res version of mesh for evaluation
195 |
196 | gt_mesh = trimesh.load(gt_mesh)
197 |
198 | metrics = evaluate(mesh, gt_mesh)
199 | print(metrics)
200 | all_results.append(metrics)
201 |
202 | # print all results
203 | for scan, metric in zip(scenes, all_results):
204 | values = [scan] + [str(metric[k]) for k in metric.keys()]
205 | out = ",".join(values)
206 | print(out)
207 |
208 |
209 | # average the all_results
210 | mean_dict = {}
211 | for key in all_results[0].keys():
212 | mean_dict[key] = sum(d[key] for d in all_results) / len(all_results)
213 | print(mean_dict)
214 |
--------------------------------------------------------------------------------
/scripts/download_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 | echo "0 - Replica sample (one scene)"
4 | echo "1 - Replca"
5 | echo "2 - ScanNet"
6 | read -p "Enter the dataset ID you want to download: " ds_id
7 |
8 | mkdir -p data
9 | cd data
10 |
11 | if [ $ds_id == 0 ]
12 | then
13 | echo "Download Replica sample dataset, room00"
14 | # Download Replica scenes used in objectsdf++, includine rgb, depth, normal, semantic, instance, pose
15 | gdown --no-cookies 17U8RzDWCtUCNPTDF16pEhFqznbz5u8bc -O replica_sample.zip
16 | echo "done,start unzipping"
17 | unzip -o replica_sample.zip -d replica
18 | rm -rf replica_sample.zip
19 |
20 | elif [ $ds_id == 1 ]
21 | then
22 | echo "Download All Replica dataset, 8 scenes"
23 | # Download Replica scenes used in objectsdf++, includine rgb, depth, normal, semantic, instance, pose
24 | gdown --no-cookies 1IAFNQE3TNyE_ZNdJhDCcPPebWqbTuzYl -O replica.zip
25 | echo "done,start unzipping"
26 | unzip -o replica.zip
27 | rm -rf replica.zip
28 |
29 | elif [ $ds_id == 2 ]
30 | then
31 | # Download scannet scenes follow monosdf
32 | echo "Download ScanNet dataset, 4 scenes"
33 | gdown --no-cookies 1w-HZHhhvc71xOYhFBdZrLYu8FBsNWBhU -O scannet.zip
34 | echo "done,start unzipping"
35 | unzip -o scannet.zip
36 | rm -rf scannet.zip
37 |
38 | else
39 | echo "Invalid dataset ID"
40 | fi
41 |
42 | cd ..
--------------------------------------------------------------------------------
/scripts/download_meshes.sh:
--------------------------------------------------------------------------------
1 | gdown --no-cookies 1vDb3JN9lHTvjkBkmj8CQtQxng8fxPl39 -O results.zip
2 | unzip results.zip
3 | rm -rf results.zip
--------------------------------------------------------------------------------