├── .gitignore
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── code
├── confs
│ ├── RICO_scannet.conf
│ └── RICO_synthetic.conf
├── datasets
│ ├── scene_dataset.py
│ └── scene_dataset_rico.py
├── hashencoder
│ ├── __init__.py
│ ├── backend.py
│ ├── hashgrid.py
│ └── src
│ │ ├── bindings.cpp
│ │ ├── hashencoder.cu
│ │ └── hashencoder.h
├── model
│ ├── density.py
│ ├── embedder.py
│ ├── loss.py
│ ├── network.py
│ ├── network_rico.py
│ └── ray_sampler.py
├── slurm_run.sh
├── training
│ ├── exp_runner.py
│ └── rico_train.py
└── utils
│ ├── general.py
│ ├── plots.py
│ └── rend_util.py
├── requirements.txt
├── scripts
├── edit_render.py
└── extract_mesh_rico.py
└── synthetic_eval
├── evaluate.py
└── evaluate_bgdepth.py
/.gitignore:
--------------------------------------------------------------------------------
1 | code/run_logs
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | exps/*
134 | exps*
135 | evals*
136 | data/DTU
137 | data/BlendedMVS
138 | data/Replica
139 | data/tnt_advanced
140 | data/
141 |
142 | code/tmp_build
143 | preprocess/feature_extractor/ckpts/
144 | synthetic_eval/evaluation/
145 |
146 | code/.idea/
147 | .DS_Store
148 | ._.DS_Store
149 | .idea/
150 |
151 | *.png
152 | *.ply
153 | *.txt
154 | *.jpg
155 | *.npy
156 | *.npz
157 | *.tar
158 | uploadtnt_*/
159 |
160 | *.json
161 | *.csv
162 | dtu_eval/Offical_DTU_Dataset/
163 | media/
164 | files_save/
165 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 autonomousvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
:couch_and_lamp: RICO: Regularizing the Unobservable for Indoor Compositional Reconstruction (ICCV2023)
4 |
5 | Zizhang Li,
6 | Xiaoyang Lyu,
7 | Yuanyuan Ding,
8 | Mengmeng Wang,
9 | Yiyi Liao,
10 | Yong Liu
11 |
12 |
13 | ICCV2023
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | We use geometry motivated prior information to regularize the unobservable regions for indoor compositional reconstruction.
26 |
27 |
28 |
29 | ## TODO
30 | - [x] Training code
31 | - [x] Evaluation scripts
32 | - [x] Mesh extraction script
33 | - [x] Editted rendering script
34 | - [x] Dataset clean
35 |
36 | ## Setup
37 |
38 | ### Installation
39 | Clone the repository and create an anaconda environment called rico using
40 | ```
41 | git clone git@github.com:kyleleey/RICO.git
42 | cd RICO
43 |
44 | conda create -y -n rico python=3.8
45 | conda activate rico
46 |
47 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
48 |
49 | pip install -r requirements.txt
50 | ```
51 |
52 | ### Dataset
53 | We provide processed scannet and synthetic scenes in this [link](https://drive.google.com/drive/folders/1yY9TYj-HaM2_I9qzsNQN8leOw6WFzVDA?usp=sharing). Please download the data and unzip in the `data` folder, the resulting folder structure should be:
54 | ```
55 | └── RICO
56 | └── data
57 | ├── scannet
58 | ├── syn_data
59 | ```
60 | ## Training
61 |
62 | Run the following command to train rico on the synthetic scene 1:
63 | ```
64 | cd ./code
65 | bash slurm_run.sh PARTITION CFG_PATH SCAN_ID PORT
66 | ```
67 | where `PARTITION` is the slurm partition name you're using. You can use `confs/RICO_scannet.conf` or `confs/RICO_synthetic.conf` for `CFG_PATH` to train on ScanNet or synthetic scene. You also need to provide specific `SCAN_ID` and `PORT`.
68 |
69 | If you are not in a slurm environment you can simply run:
70 | ```
71 | python training/exp_runner.py --conf CFG_PATH --scan_id SCAN_ID --port PORT
72 | ```
73 |
74 | ## Evaluations
75 |
76 | To run quantitative evaluation on synthetic scenes for object and masked background depth:
77 | ```
78 | cd synthetic_eval
79 | python evaluate.py
80 | python evaluate_bgdepth.py
81 | ```
82 | Evaluation results will be saved in `synthetic_eval/evaluation` as .json files.
83 |
84 | We also provide other scripts for experiment files after training.
85 |
86 | To extract the per-object mesh and the combined scene mesh:
87 | ```
88 | cd scripts
89 | python extract_mesh_rico.py
90 | ```
91 |
92 | To render translation edited results:
93 | ```
94 | cd scripts
95 | python edit_render.py
96 | ```
97 |
98 | You can change the detailed settings in these scripts to run on top of different experiment results.
99 |
100 | ## Acknowledgements
101 | This project is built upon [MonoSDF](https://github.com/autonomousvision/monosdf), [ObjSDF](https://github.com/QianyiWu/objsdf) and also the original [VolSDF](https://github.com/lioryariv/volsdf). To construct the synthetic scenes, we mainly use the function of [BlenderNeRF](https://github.com/maximeraafat/BlenderNeRF). We thank all the authors for their great work and repos.
102 |
103 |
104 | ## Citation
105 | If you find our code or paper useful, please cite
106 | ```bibtex
107 | @inproceedings{li2023rico,
108 | author = {Li, Zizhang and Lyu, Xiaoyang and Ding, Yuanyuan and Wang, Mengmeng and Liao, Yiyi and Liu, Yong},
109 | title = {RICO: Regularizing the Unobservable for Indoor Compositional Reconstruction},
110 | booktitle = {ICCV},
111 | year = {2023},
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kyleleey/RICO/4254e6ff8581d21833e1b42e0352a7a63da788b1/assets/teaser.png
--------------------------------------------------------------------------------
/code/confs/RICO_scannet.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = RICO_scannet
3 | dataset_class = datasets.scene_dataset_rico.RICO_SceneDatasetDN_Mask
4 | model_class = model.network_rico.RICONetwork
5 | loss_class = model.loss.RICOLoss
6 | learning_rate = 5.0e-4
7 | num_pixels = 1024
8 | checkpoint_freq = 10000
9 | plot_freq = 50
10 | split_n_pixels = 1024
11 | max_total_iters = 50000
12 | }
13 | plot{
14 | plot_nimgs = 1
15 | resolution = 512
16 | grid_boundary = [-1.1, 1.1]
17 | }
18 | loss{
19 | rgb_loss = torch.nn.L1Loss
20 | eikonal_weight = 0.05
21 | semantic_weight = 0.04
22 | bg_render_weight = 0.1
23 | lop_weight = 0.1
24 | lrd_weight = 0.1
25 | smooth_weight = 0.005
26 | depth_weight = 0.1
27 | normal_l1_weight = 0.05
28 | normal_cos_weight = 0.05
29 | }
30 | dataset{
31 | data_dir = syn_data
32 | img_res = [384, 384]
33 | scan_id = 1
34 | center_crop_type = no_crop
35 | data_prefix = scan
36 | }
37 | model{
38 | feature_vector_size = 256
39 | scene_bounding_sphere = 1.1
40 | render_bg = True
41 | render_bg_iter = 10
42 |
43 | Grid_MLP = True
44 |
45 | implicit_network
46 | {
47 | d_in = 3
48 | d_out = 1
49 | dims = [256, 256, 256, 256, 256, 256, 256, 256]
50 | geometric_init = True
51 | bias = 0.9
52 | skip_in = [4]
53 | weight_norm = True
54 | multires = 6
55 | inside_outside = True
56 | sigmoid = 20
57 | sigmoid_optim = False
58 | }
59 |
60 | rendering_network
61 | {
62 | mode = idr
63 | d_in = 9
64 | d_out = 3
65 | dims = [ 256, 256]
66 | weight_norm = True
67 | multires_view = 4
68 | per_image_code = True
69 | }
70 | density
71 | {
72 | variance_init = 0.05
73 | speed_factor = 10.0
74 | }
75 | ray_sampler
76 | {
77 | take_sphere_intersection = True
78 | near = 0.0
79 | N_samples = 64
80 | N_samples_extra = 32
81 | N_upsample_iters = 4
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/code/confs/RICO_synthetic.conf:
--------------------------------------------------------------------------------
1 | train{
2 | expname = RICO_synthetic
3 | dataset_class = datasets.scene_dataset_rico.RICO_SceneDatasetDN_Mask
4 | model_class = model.network_rico.RICONetwork
5 | loss_class = model.loss.RICOLoss
6 | learning_rate = 5.0e-4
7 | num_pixels = 1024
8 | checkpoint_freq = 10000
9 | plot_freq = 50
10 | split_n_pixels = 1024
11 | max_total_iters = 50000
12 | }
13 | plot{
14 | plot_nimgs = 1
15 | resolution = 512
16 | grid_boundary = [-1.1, 1.1]
17 | }
18 | loss{
19 | rgb_loss = torch.nn.L1Loss
20 | eikonal_weight = 0.05
21 | semantic_weight = 0.04
22 | bg_render_weight = 0.1
23 | lop_weight = 0.1
24 | lrd_weight = 0.1
25 | smooth_weight = 0.005
26 | depth_weight = 0.1
27 | normal_l1_weight = 0.05
28 | normal_cos_weight = 0.05
29 | }
30 | dataset{
31 | data_dir = syn_data
32 | img_res = [384, 384]
33 | scan_id = 1
34 | center_crop_type = no_crop
35 | data_prefix = scene
36 | }
37 | model{
38 | feature_vector_size = 256
39 | scene_bounding_sphere = 1.1
40 | render_bg = True
41 | render_bg_iter = 10
42 |
43 | Grid_MLP = True
44 |
45 | implicit_network
46 | {
47 | d_in = 3
48 | d_out = 1
49 | dims = [256, 256, 256, 256, 256, 256, 256, 256]
50 | geometric_init = True
51 | bias = 0.9
52 | skip_in = [4]
53 | weight_norm = True
54 | multires = 6
55 | inside_outside = True
56 | sigmoid = 20
57 | sigmoid_optim = False
58 | }
59 |
60 | rendering_network
61 | {
62 | mode = idr
63 | d_in = 9
64 | d_out = 3
65 | dims = [ 256, 256]
66 | weight_norm = True
67 | multires_view = 4
68 | per_image_code = True
69 | }
70 | density
71 | {
72 | variance_init = 0.05
73 | speed_factor = 10.0
74 | }
75 | ray_sampler
76 | {
77 | take_sphere_intersection = True
78 | near = 0.0
79 | N_samples = 64
80 | N_samples_extra = 32
81 | N_upsample_iters = 4
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/code/datasets/scene_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | import utils.general as utils
7 | from utils import rend_util
8 | from glob import glob
9 | import cv2
10 | import random
11 | import json
12 | from kornia import morphology as morph
13 |
14 | class SceneDataset(torch.utils.data.Dataset):
15 |
16 | def __init__(self,
17 | data_dir,
18 | img_res,
19 | scan_id=0,
20 | num_views=-1,
21 | ):
22 |
23 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id))
24 |
25 | self.total_pixels = img_res[0] * img_res[1]
26 | self.img_res = img_res
27 |
28 | assert os.path.exists(self.instance_dir), "Data directory is empty"
29 |
30 | self.num_views = num_views
31 | assert num_views in [-1, 3, 6, 9]
32 |
33 | self.sampling_idx = None
34 |
35 | image_dir = '{0}/image'.format(self.instance_dir)
36 | image_paths = sorted(utils.glob_imgs(image_dir))
37 | self.n_images = len(image_paths)
38 |
39 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir)
40 | camera_dict = np.load(self.cam_file)
41 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
42 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
43 |
44 | self.intrinsics_all = []
45 | self.pose_all = []
46 | for scale_mat, world_mat in zip(scale_mats, world_mats):
47 | P = world_mat @ scale_mat
48 | P = P[:3, :4]
49 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
50 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
51 | self.pose_all.append(torch.from_numpy(pose).float())
52 |
53 | self.rgb_images = []
54 | for path in image_paths:
55 | rgb = rend_util.load_rgb(path)
56 | rgb = rgb.reshape(3, -1).transpose(1, 0)
57 | self.rgb_images.append(torch.from_numpy(rgb).float())
58 |
59 | # used a fake depth image and normal image
60 | self.depth_images = []
61 | self.normal_images = []
62 |
63 | for path in image_paths:
64 | depth = np.ones_like(rgb[:, :1])
65 | self.depth_images.append(torch.from_numpy(depth).float())
66 | normal = np.ones_like(rgb)
67 | self.normal_images.append(torch.from_numpy(normal).float())
68 |
69 | def __len__(self):
70 | return self.n_images
71 |
72 | def __getitem__(self, idx):
73 | if self.num_views >= 0:
74 | image_ids = [25, 22, 28, 40, 44, 48, 0, 8, 13][:self.num_views]
75 | idx = image_ids[random.randint(0, self.num_views - 1)]
76 |
77 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
78 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
79 | uv = uv.reshape(2, -1).transpose(1, 0)
80 |
81 | sample = {
82 | "uv": uv,
83 | "intrinsics": self.intrinsics_all[idx],
84 | "pose": self.pose_all[idx]
85 | }
86 |
87 | ground_truth = {
88 | "rgb": self.rgb_images[idx],
89 | "depth": self.depth_images[idx],
90 | "normal": self.normal_images[idx],
91 | }
92 |
93 | if self.sampling_idx is not None:
94 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :]
95 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :]
96 | ground_truth["mask"] = torch.ones_like(self.depth_images[idx][self.sampling_idx, :])
97 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :]
98 |
99 | sample["uv"] = uv[self.sampling_idx, :]
100 |
101 | return idx, sample, ground_truth
102 |
103 | def collate_fn(self, batch_list):
104 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
105 | batch_list = zip(*batch_list)
106 |
107 | all_parsed = []
108 | for entry in batch_list:
109 | if type(entry[0]) is dict:
110 | # make them all into a new dict
111 | ret = {}
112 | for k in entry[0].keys():
113 | ret[k] = torch.stack([obj[k] for obj in entry])
114 | all_parsed.append(ret)
115 | else:
116 | all_parsed.append(torch.LongTensor(entry))
117 |
118 | return tuple(all_parsed)
119 |
120 | def change_sampling_idx(self, sampling_size):
121 | if sampling_size == -1:
122 | self.sampling_idx = None
123 | else:
124 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
125 |
126 | def get_scale_mat(self):
127 | return np.load(self.cam_file)['scale_mat_0']
128 |
129 |
130 | # Dataset with monocular depth and normal
131 | class SceneDatasetDN(torch.utils.data.Dataset):
132 |
133 | def __init__(self,
134 | data_dir,
135 | img_res,
136 | scan_id=0,
137 | center_crop_type='xxxx',
138 | use_mask=False,
139 | num_views=-1
140 | ):
141 |
142 | if data_dir == 'syn_data':
143 | self.instance_dir = os.path.join('../data', data_dir, 'scene{0}'.format(scan_id))
144 | else:
145 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id))
146 |
147 | self.total_pixels = img_res[0] * img_res[1]
148 | self.img_res = img_res
149 | self.num_views = num_views
150 | assert num_views in [-1, 3, 6, 9]
151 |
152 | assert os.path.exists(self.instance_dir), "Data directory is empty"
153 |
154 | self.sampling_idx = None
155 |
156 | def glob_data(data_dir):
157 | data_paths = []
158 | data_paths.extend(glob(data_dir))
159 | data_paths = sorted(data_paths)
160 | return data_paths
161 |
162 | image_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_rgb.png"))
163 | depth_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_depth.npy"))
164 | normal_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_normal.npy"))
165 |
166 | # mask is only used in the replica dataset as some monocular depth predictions have very large error and we ignore it
167 | if use_mask:
168 | mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_mask.npy"))
169 | else:
170 | mask_paths = None
171 |
172 | self.n_images = len(image_paths)
173 |
174 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir)
175 | camera_dict = np.load(self.cam_file)
176 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
177 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
178 |
179 | self.intrinsics_all = []
180 | self.pose_all = []
181 | for scale_mat, world_mat in zip(scale_mats, world_mats):
182 | P = world_mat @ scale_mat
183 | P = P[:3, :4]
184 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
185 |
186 | # because we do resize and center crop 384x384 when using omnidata model, we need to adjust the camera intrinsic accordingly
187 | if center_crop_type == 'center_crop_for_replica':
188 | scale = 384 / 680
189 | offset = (1200 - 680 ) * 0.5
190 | intrinsics[0, 2] -= offset
191 | intrinsics[:2, :] *= scale
192 | elif center_crop_type == 'center_crop_for_tnt':
193 | scale = 384 / 540
194 | offset = (960 - 540) * 0.5
195 | intrinsics[0, 2] -= offset
196 | intrinsics[:2, :] *= scale
197 | elif center_crop_type == 'center_crop_for_dtu':
198 | scale = 384 / 1200
199 | offset = (1600 - 1200) * 0.5
200 | intrinsics[0, 2] -= offset
201 | intrinsics[:2, :] *= scale
202 | elif center_crop_type == 'padded_for_dtu':
203 | scale = 384 / 1200
204 | offset = 0
205 | intrinsics[0, 2] -= offset
206 | intrinsics[:2, :] *= scale
207 | elif center_crop_type == 'no_crop': # for scannet dataset, we already adjust the camera intrinsic duing preprocessing so nothing to be done here
208 | pass
209 | else:
210 | raise NotImplementedError
211 |
212 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
213 | self.pose_all.append(torch.from_numpy(pose).float())
214 |
215 | self.rgb_images = []
216 | for path in image_paths:
217 | rgb = rend_util.load_rgb(path)
218 | rgb = rgb.reshape(3, -1).transpose(1, 0)
219 | self.rgb_images.append(torch.from_numpy(rgb).float())
220 |
221 | self.depth_images = []
222 | self.normal_images = []
223 |
224 | for dpath, npath in zip(depth_paths, normal_paths):
225 | depth = np.load(dpath)
226 | self.depth_images.append(torch.from_numpy(depth.reshape(-1, 1)).float())
227 |
228 | normal = np.load(npath)
229 | normal = normal.reshape(3, -1).transpose(1, 0)
230 | # important as the output of omnidata is normalized
231 | normal = normal * 2. - 1.
232 | self.normal_images.append(torch.from_numpy(normal).float())
233 |
234 | # load mask
235 | self.mask_images = []
236 | if mask_paths is None:
237 | for depth in self.depth_images:
238 | mask = torch.ones_like(depth)
239 | self.mask_images.append(mask)
240 | else:
241 | for path in mask_paths:
242 | mask = np.load(path)
243 | self.mask_images.append(torch.from_numpy(mask.reshape(-1, 1)).float())
244 |
245 | def __len__(self):
246 | return self.n_images
247 |
248 | def __getitem__(self, idx):
249 | if self.num_views >= 0:
250 | image_ids = [25, 22, 28, 40, 44, 48, 0, 8, 13][:self.num_views]
251 | idx = image_ids[random.randint(0, self.num_views - 1)]
252 |
253 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
254 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
255 | uv = uv.reshape(2, -1).transpose(1, 0)
256 |
257 | sample = {
258 | "uv": uv,
259 | "intrinsics": self.intrinsics_all[idx],
260 | "pose": self.pose_all[idx]
261 | }
262 |
263 | ground_truth = {
264 | "rgb": self.rgb_images[idx],
265 | "depth": self.depth_images[idx],
266 | "mask": self.mask_images[idx],
267 | "normal": self.normal_images[idx],
268 | }
269 |
270 | if self.sampling_idx is not None:
271 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :]
272 | ground_truth["full_rgb"] = self.rgb_images[idx]
273 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :]
274 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :]
275 | ground_truth["full_depth"] = self.depth_images[idx]
276 | ground_truth["mask"] = self.mask_images[idx][self.sampling_idx, :]
277 | ground_truth["full_mask"] = self.mask_images[idx]
278 |
279 | sample["uv"] = uv[self.sampling_idx, :]
280 |
281 | return idx, sample, ground_truth
282 |
283 | def collate_fn(self, batch_list):
284 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
285 | batch_list = zip(*batch_list)
286 |
287 | all_parsed = []
288 | for entry in batch_list:
289 | if type(entry[0]) is dict:
290 | # make them all into a new dict
291 | ret = {}
292 | for k in entry[0].keys():
293 | ret[k] = torch.stack([obj[k] for obj in entry])
294 | all_parsed.append(ret)
295 | else:
296 | all_parsed.append(torch.LongTensor(entry))
297 |
298 | return tuple(all_parsed)
299 |
300 | def change_sampling_idx(self, sampling_size):
301 | if sampling_size == -1:
302 | self.sampling_idx = None
303 | else:
304 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
305 |
306 | def get_scale_mat(self):
307 | return np.load(self.cam_file)['scale_mat_0']
--------------------------------------------------------------------------------
/code/datasets/scene_dataset_rico.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 |
6 | import utils.general as utils
7 | from utils import rend_util
8 | from glob import glob
9 | import cv2
10 | import json
11 |
12 |
13 | # Dataset with monocular depth and normal and mask and etc.
14 | class RICO_SceneDatasetDN_Mask(torch.utils.data.Dataset):
15 | def __init__(self,
16 | data_dir,
17 | img_res,
18 | scan_id=0,
19 | center_crop_type='xxxx',
20 | use_mask=False,
21 | data_prefix='scan'
22 | ):
23 | # for scannet, data_prefix is 'scan', for synthetic dataset, data_prefix is 'scene'
24 | self.instance_dir = os.path.join('../data', data_dir, data_prefix+'{0}'.format(scan_id))
25 |
26 | self.total_pixels = img_res[0] * img_res[1]
27 | self.img_res = img_res
28 |
29 | assert os.path.exists(self.instance_dir), "Data directory is empty"
30 |
31 | self.sampling_idx = None
32 |
33 | with open(os.path.join(self.instance_dir, 'instance_id.json'), 'r') as f:
34 | id_dict = json.load(f)
35 | f.close()
36 | self.instance_dict = id_dict
37 | self.instance_ids = list(self.instance_dict.values())
38 | self.label_mapping = [0] + self.instance_ids # background ID is 0 and at the first of label_mapping
39 |
40 | def glob_data(data_dir):
41 | data_paths = []
42 | data_paths.extend(glob(data_dir))
43 | data_paths = sorted(data_paths)
44 | return data_paths
45 |
46 | image_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_rgb.png"))
47 | depth_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_depth.npy"))
48 | normal_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_normal.npy"))
49 |
50 | # mask is only used in the replica dataset as some monocular depth predictions have very large error and we ignore it
51 | if use_mask:
52 | mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_mask.npy"))
53 | else:
54 | mask_paths = None
55 |
56 | # This is the loading of Instance masks for RICO
57 | instance_mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "instance_mask", "*.png"))
58 |
59 | self.n_images = len(image_paths)
60 |
61 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir)
62 | camera_dict = np.load(self.cam_file)
63 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
64 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
65 |
66 | self.intrinsics_all = []
67 | self.pose_all = []
68 | for scale_mat, world_mat in zip(scale_mats, world_mats):
69 | P = world_mat @ scale_mat
70 | P = P[:3, :4]
71 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
72 |
73 | # because we do resize and center crop 384x384 when using omnidata model, we need to adjust the camera intrinsic accordingly
74 | # should be "no-crop" for both datasets in RICO
75 | if center_crop_type == 'center_crop_for_replica':
76 | scale = 384 / 680
77 | offset = (1200 - 680 ) * 0.5
78 | intrinsics[0, 2] -= offset
79 | intrinsics[:2, :] *= scale
80 | elif center_crop_type == 'center_crop_for_tnt':
81 | scale = 384 / 540
82 | offset = (960 - 540) * 0.5
83 | intrinsics[0, 2] -= offset
84 | intrinsics[:2, :] *= scale
85 | elif center_crop_type == 'center_crop_for_dtu':
86 | scale = 384 / 1200
87 | offset = (1600 - 1200) * 0.5
88 | intrinsics[0, 2] -= offset
89 | intrinsics[:2, :] *= scale
90 | elif center_crop_type == 'padded_for_dtu':
91 | scale = 384 / 1200
92 | offset = 0
93 | intrinsics[0, 2] -= offset
94 | intrinsics[:2, :] *= scale
95 | elif center_crop_type == 'no_crop': # for scannet dataset, we already adjust the camera intrinsic duing preprocessing so nothing to be done here
96 | pass
97 | else:
98 | raise NotImplementedError
99 |
100 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
101 | self.pose_all.append(torch.from_numpy(pose).float())
102 |
103 | self.rgb_images = []
104 | for path in image_paths:
105 | rgb = rend_util.load_rgb(path)
106 | rgb = rgb.reshape(3, -1).transpose(1, 0)
107 | self.rgb_images.append(torch.from_numpy(rgb).float())
108 |
109 | self.depth_images = []
110 | self.normal_images = []
111 |
112 | for dpath, npath in zip(depth_paths, normal_paths):
113 | depth = np.load(dpath)
114 | self.depth_images.append(torch.from_numpy(depth.reshape(-1, 1)).float())
115 |
116 | normal = np.load(npath)
117 | normal = normal.reshape(3, -1).transpose(1, 0)
118 | # important as the output of omnidata is normalized
119 | normal = normal * 2. - 1.
120 | self.normal_images.append(torch.from_numpy(normal).float())
121 |
122 | # load instance mask and map to label_mapping
123 | self.instance_masks = []
124 | self.instance_dilated_region_list = []
125 | for im_path in instance_mask_paths:
126 |
127 | instance_mask_pic = cv2.imread(im_path, -1)
128 | if len(instance_mask_pic.shape) == 3:
129 | instance_mask_pic = instance_mask_pic[:, :, 0]
130 | instance_mask = instance_mask_pic.reshape(1, -1).transpose(1, 0) # [HW, 1]
131 | instance_mask[instance_mask==255] = 0 # background is 0
132 |
133 | ins_list = np.unique(instance_mask)
134 | cur_sems = np.copy(instance_mask)
135 | for i in ins_list:
136 | if i not in self.label_mapping:
137 | cur_sems[instance_mask == i] = self.label_mapping.index(0)
138 | else:
139 | cur_sems[instance_mask == i] = self.label_mapping.index(i)
140 |
141 | self.instance_masks.append(torch.from_numpy(cur_sems).float())
142 |
143 | # load mask
144 | self.mask_images = []
145 | if mask_paths is None:
146 | for depth in self.depth_images:
147 | mask = torch.ones_like(depth)
148 | self.mask_images.append(mask)
149 | else:
150 | for path in mask_paths:
151 | mask = np.load(path)
152 | self.mask_images.append(torch.from_numpy(mask.reshape(-1, 1)).float())
153 |
154 | self.n_images = len(self.rgb_images)
155 |
156 | def __len__(self):
157 | return self.n_images
158 |
159 | def __getitem__(self, idx):
160 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
161 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
162 | uv = uv.reshape(2, -1).transpose(1, 0)
163 |
164 | sample = {
165 | "uv": uv,
166 | "intrinsics": self.intrinsics_all[idx],
167 | "pose": self.pose_all[idx]
168 | }
169 |
170 | ground_truth = {
171 | "rgb": self.rgb_images[idx],
172 | "depth": self.depth_images[idx],
173 | "mask": self.mask_images[idx],
174 | "normal": self.normal_images[idx],
175 | "instance_mask": self.instance_masks[idx],
176 | "use_syn_data": torch.Tensor([0.]).reshape(-1)
177 | }
178 |
179 | if self.sampling_idx is not None:
180 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :]
181 | ground_truth["full_rgb"] = self.rgb_images[idx]
182 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :]
183 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :]
184 | ground_truth["full_depth"] = self.depth_images[idx]
185 | ground_truth["mask"] = self.mask_images[idx][self.sampling_idx, :]
186 | ground_truth["full_mask"] = self.mask_images[idx]
187 |
188 | ground_truth["instance_mask"] = self.instance_masks[idx][self.sampling_idx, :]
189 | ground_truth["full_instance_mask"] = self.instance_masks[idx]
190 |
191 | sample["uv"] = uv[self.sampling_idx, :]
192 |
193 | return idx, sample, ground_truth
194 |
195 | def collate_fn(self, batch_list):
196 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
197 | batch_list = zip(*batch_list)
198 |
199 | all_parsed = []
200 | for entry in batch_list:
201 | if type(entry[0]) is dict:
202 | # make them all into a new dict
203 | ret = {}
204 | for k in entry[0].keys():
205 | ret[k] = torch.stack([obj[k] for obj in entry])
206 | all_parsed.append(ret)
207 | else:
208 | all_parsed.append(torch.LongTensor(entry))
209 |
210 | return tuple(all_parsed)
211 |
212 | def change_sampling_idx(self, sampling_size):
213 | if sampling_size == -1:
214 | self.sampling_idx = None
215 | else:
216 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size]
217 |
218 | def get_scale_mat(self):
219 | return np.load(self.cam_file)['scale_mat_0']
--------------------------------------------------------------------------------
/code/hashencoder/__init__.py:
--------------------------------------------------------------------------------
1 | from .hashgrid import HashEncoder
--------------------------------------------------------------------------------
/code/hashencoder/backend.py:
--------------------------------------------------------------------------------
1 | from distutils.command.build import build
2 | import os
3 | from torch.utils.cpp_extension import load
4 | from pathlib import Path
5 |
6 | Path('./tmp_build/').mkdir(parents=True, exist_ok=True)
7 |
8 | _src_path = os.path.dirname(os.path.abspath(__file__))
9 |
10 | _backend = load(name='_hash_encoder',
11 | extra_cflags=['-O3', '-std=c++14'],
12 | extra_cuda_cflags=[
13 | '-O3', '-std=c++14', '-allow-unsupported-compiler',
14 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
15 | ],
16 | sources=[os.path.join(_src_path, 'src', f) for f in [
17 | 'hashencoder.cu',
18 | 'bindings.cpp',
19 | ]],
20 | build_directory='./tmp_build/',
21 | verbose=True,
22 | )
23 |
24 | __all__ = ['_backend']
--------------------------------------------------------------------------------
/code/hashencoder/hashgrid.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from math import ceil
3 | from cachetools import cached
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Function
9 | from torch.autograd.function import once_differentiable
10 | from torch.cuda.amp import custom_bwd, custom_fwd
11 |
12 | from .backend import _backend
13 |
14 | class _hash_encode(Function):
15 | @staticmethod
16 | @custom_fwd(cast_inputs=torch.half)
17 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False):
18 | # inputs: [B, D], float in [0, 1]
19 | # embeddings: [sO, C], float
20 | # offsets: [L + 1], int
21 | # RETURN: [B, F], float
22 |
23 | inputs = inputs.contiguous()
24 | embeddings = embeddings.contiguous()
25 | offsets = offsets.contiguous()
26 |
27 | B, D = inputs.shape # batch size, coord dim
28 | L = offsets.shape[0] - 1 # level
29 | C = embeddings.shape[1] # embedding dim for each level
30 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f
31 | H = base_resolution # base resolution
32 |
33 | # L first, optimize cache for cuda kernel, but needs an extra permute later
34 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=inputs.dtype)
35 |
36 | if calc_grad_inputs:
37 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=inputs.dtype)
38 | else:
39 | dy_dx = torch.empty(1, device=inputs.device, dtype=inputs.dtype)
40 |
41 | _backend.hash_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx)
42 |
43 | # permute back to [B, L * C]
44 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C)
45 |
46 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx)
47 | ctx.dims = [B, D, C, L, S, H]
48 | ctx.calc_grad_inputs = calc_grad_inputs
49 |
50 | return outputs
51 |
52 | @staticmethod
53 | @custom_bwd
54 | def backward(ctx, grad):
55 |
56 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors
57 | B, D, C, L, S, H = ctx.dims
58 | calc_grad_inputs = ctx.calc_grad_inputs
59 |
60 | # grad: [B, L * C] --> [L, B, C]
61 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous()
62 |
63 | grad_inputs, grad_embeddings = _hash_encode_second_backward.apply(grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx)
64 |
65 | if calc_grad_inputs:
66 | return grad_inputs, grad_embeddings, None, None, None, None
67 | else:
68 | return None, grad_embeddings, None, None, None, None
69 |
70 |
71 | class _hash_encode_second_backward(Function):
72 | @staticmethod
73 | def forward(ctx, grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx):
74 |
75 | grad_inputs = torch.zeros_like(inputs)
76 | grad_embeddings = torch.zeros_like(embeddings)
77 |
78 | ctx.save_for_backward(grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings)
79 | ctx.dims = [B, D, C, L, S, H]
80 | ctx.calc_grad_inputs = calc_grad_inputs
81 |
82 | _backend.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs)
83 |
84 | return grad_inputs, grad_embeddings
85 |
86 | @staticmethod
87 | def backward(ctx, grad_grad_inputs, grad_grad_embeddings):
88 |
89 | grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings = ctx.saved_tensors
90 | B, D, C, L, S, H = ctx.dims
91 | calc_grad_inputs = ctx.calc_grad_inputs
92 |
93 | grad_grad = torch.zeros_like(grad)
94 | grad2_embeddings = torch.zeros_like(embeddings)
95 |
96 | _backend.hash_encode_second_backward(grad, inputs, embeddings, offsets,
97 | B, D, C, L, S, H, calc_grad_inputs, dy_dx,
98 | grad_grad_inputs,
99 | grad_grad, grad2_embeddings)
100 |
101 | return grad_grad, None, grad2_embeddings, None, None, None, None, None, None, None, None, None
102 |
103 |
104 | hash_encode = _hash_encode.apply
105 |
106 |
107 | class HashEncoder(nn.Module):
108 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None):
109 | super().__init__()
110 |
111 | # the finest resolution desired at the last level, if provided, overridee per_level_scale
112 | if desired_resolution is not None:
113 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1))
114 |
115 | self.input_dim = input_dim # coord dims, 2 or 3
116 | self.num_levels = num_levels # num levels, each level multiply resolution by 2
117 | self.level_dim = level_dim # encode channels per level
118 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level.
119 | self.log2_hashmap_size = log2_hashmap_size
120 | self.base_resolution = base_resolution
121 | self.output_dim = num_levels * level_dim
122 |
123 | if level_dim % 2 != 0:
124 | print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)')
125 |
126 | # allocate parameters
127 | offsets = []
128 | offset = 0
129 | self.max_params = 2 ** log2_hashmap_size
130 | for i in range(num_levels):
131 | resolution = int(np.ceil(base_resolution * per_level_scale ** i))
132 | params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number
133 | #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible
134 | offsets.append(offset)
135 | offset += params_in_level
136 | offsets.append(offset)
137 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
138 | self.register_buffer('offsets', offsets)
139 |
140 | self.n_params = offsets[-1] * level_dim
141 |
142 | # parameters
143 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim))
144 |
145 | self.reset_parameters()
146 |
147 | def reset_parameters(self):
148 | std = 1e-4
149 | self.embeddings.data.uniform_(-std, std)
150 |
151 | def __repr__(self):
152 | return f"HashEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)}"
153 |
154 | def forward(self, inputs, size=1):
155 | # inputs: [..., input_dim], normalized real world positions in [-size, size]
156 | # return: [..., num_levels * level_dim]
157 |
158 | inputs = (inputs + size) / (2 * size) # map to [0, 1]
159 |
160 | prefix_shape = list(inputs.shape[:-1])
161 | inputs = inputs.view(-1, self.input_dim)
162 |
163 | outputs = hash_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad)
164 | outputs = outputs.view(prefix_shape + [self.output_dim])
165 |
166 | return outputs
--------------------------------------------------------------------------------
/code/hashencoder/src/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "hashencoder.h"
4 |
5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6 | m.def("hash_encode_forward", &hash_encode_forward, "hash encode forward (CUDA)");
7 | m.def("hash_encode_backward", &hash_encode_backward, "hash encode backward (CUDA)");
8 | m.def("hash_encode_second_backward", &hash_encode_second_backward, "hash encode second backward (CUDA)");
9 | }
--------------------------------------------------------------------------------
/code/hashencoder/src/hashencoder.h:
--------------------------------------------------------------------------------
1 | #ifndef _HASH_ENCODE_H
2 | #define _HASH_ENCODE_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | // inputs: [B, D], float, in [0, 1]
9 | // embeddings: [sO, C], float
10 | // offsets: [L + 1], uint32_t
11 | // outputs: [B, L * C], float
12 | // H: base resolution
13 | void hash_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx);
14 | void hash_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs);
15 | void hash_encode_second_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, const at::Tensor grad_grad_inputs, at::Tensor grad_grad, at::Tensor grad2_embeddings);
16 |
17 | #endif
--------------------------------------------------------------------------------
/code/model/density.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class Density(nn.Module):
6 | def __init__(self, params_init={}):
7 | super().__init__()
8 | for p in params_init:
9 | param = nn.Parameter(torch.tensor(params_init[p]))
10 | setattr(self, p, param)
11 |
12 | def forward(self, sdf, beta=None):
13 | return self.density_func(sdf, beta=beta)
14 |
15 |
16 | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
17 | def __init__(self, params_init={}, beta_min=0.0001):
18 | super().__init__(params_init=params_init)
19 | self.beta_min = torch.tensor(beta_min).cuda()
20 |
21 | def density_func(self, sdf, beta=None):
22 | if beta is None:
23 | beta = self.get_beta()
24 |
25 | alpha = 1 / beta
26 | return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))
27 |
28 | def get_beta(self):
29 | beta = self.beta.abs() + self.beta_min
30 | return beta
31 |
32 |
33 | class AbsDensity(Density): # like NeRF++
34 | def density_func(self, sdf, beta=None):
35 | return torch.abs(sdf)
36 |
37 |
38 | class SimpleDensity(Density): # like NeRF
39 | def __init__(self, params_init={}, noise_std=1.0):
40 | super().__init__(params_init=params_init)
41 | self.noise_std = noise_std
42 |
43 | def density_func(self, sdf, beta=None):
44 | if self.training and self.noise_std > 0.0:
45 | noise = torch.randn(sdf.shape).cuda() * self.noise_std
46 | sdf = sdf + noise
47 | return torch.relu(sdf)
48 |
--------------------------------------------------------------------------------
/code/model/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
4 |
5 | class Embedder:
6 | def __init__(self, **kwargs):
7 | self.kwargs = kwargs
8 | self.create_embedding_fn()
9 |
10 | def create_embedding_fn(self):
11 | embed_fns = []
12 | d = self.kwargs['input_dims']
13 | out_dim = 0
14 | if self.kwargs['include_input']:
15 | embed_fns.append(lambda x: x)
16 | out_dim += d
17 |
18 | max_freq = self.kwargs['max_freq_log2']
19 | N_freqs = self.kwargs['num_freqs']
20 |
21 | if self.kwargs['log_sampling']:
22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
23 | else:
24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
25 |
26 | for freq in freq_bands:
27 | for p_fn in self.kwargs['periodic_fns']:
28 | embed_fns.append(lambda x, p_fn=p_fn,
29 | freq=freq: p_fn(x * freq))
30 | out_dim += d
31 |
32 | self.embed_fns = embed_fns
33 | self.out_dim = out_dim
34 |
35 | def embed(self, inputs):
36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
37 |
38 | def get_embedder(multires, input_dims=3):
39 | embed_kwargs = {
40 | 'include_input': True,
41 | 'input_dims': input_dims,
42 | 'max_freq_log2': multires-1,
43 | 'num_freqs': multires,
44 | 'log_sampling': True,
45 | 'periodic_fns': [torch.sin, torch.cos],
46 | }
47 |
48 | embedder_obj = Embedder(**embed_kwargs)
49 | def embed(x, eo=embedder_obj): return eo.embed(x)
50 | return embed, embedder_obj.out_dim
51 |
--------------------------------------------------------------------------------
/code/model/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import utils.general as utils
4 | import math
5 |
6 | # copy from MiDaS
7 | def compute_scale_and_shift(prediction, target, mask):
8 | # system matrix: A = [[a_00, a_01], [a_10, a_11]]
9 | a_00 = torch.sum(mask * prediction * prediction, (1, 2))
10 | a_01 = torch.sum(mask * prediction, (1, 2))
11 | a_11 = torch.sum(mask, (1, 2))
12 |
13 | # right hand side: b = [b_0, b_1]
14 | b_0 = torch.sum(mask * prediction * target, (1, 2))
15 | b_1 = torch.sum(mask * target, (1, 2))
16 |
17 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
18 | x_0 = torch.zeros_like(b_0)
19 | x_1 = torch.zeros_like(b_1)
20 |
21 | det = a_00 * a_11 - a_01 * a_01
22 | valid = det.nonzero()
23 |
24 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
25 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
26 |
27 | return x_0, x_1
28 |
29 |
30 | def reduction_batch_based(image_loss, M):
31 | # average of all valid pixels of the batch
32 |
33 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
34 | divisor = torch.sum(M)
35 |
36 | if divisor == 0:
37 | return 0
38 | else:
39 | return torch.sum(image_loss) / divisor
40 |
41 |
42 | def reduction_image_based(image_loss, M):
43 | # mean of average of valid pixels of an image
44 |
45 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
46 | valid = M.nonzero()
47 |
48 | image_loss[valid] = image_loss[valid] / M[valid]
49 |
50 | return torch.mean(image_loss)
51 |
52 |
53 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based):
54 |
55 | M = torch.sum(mask, (1, 2))
56 | res = prediction - target
57 |
58 | _loss = mask * res * res
59 |
60 | image_loss = torch.sum(_loss, (1, 2))
61 |
62 | return reduction(image_loss, 2 * M)
63 |
64 |
65 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):
66 |
67 | M = torch.sum(mask, (1, 2))
68 |
69 | diff = prediction - target
70 | diff = torch.mul(mask, diff)
71 |
72 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
73 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
74 | grad_x = torch.mul(mask_x, grad_x)
75 |
76 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
77 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
78 | grad_y = torch.mul(mask_y, grad_y)
79 |
80 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
81 |
82 | return reduction(image_loss, M)
83 |
84 |
85 | class MSELoss(nn.Module):
86 | def __init__(self, reduction='batch-based'):
87 | super().__init__()
88 |
89 | if reduction == 'batch-based':
90 | self.__reduction = reduction_batch_based
91 | else:
92 | self.__reduction = reduction_image_based
93 |
94 | def forward(self, prediction, target, mask):
95 | return mse_loss(prediction, target, mask, reduction=self.__reduction)
96 |
97 |
98 | class GradientLoss(nn.Module):
99 | def __init__(self, scales=4, reduction='batch-based'):
100 | super().__init__()
101 |
102 | if reduction == 'batch-based':
103 | self.__reduction = reduction_batch_based
104 | else:
105 | self.__reduction = reduction_image_based
106 |
107 | self.__scales = scales
108 |
109 | def forward(self, prediction, target, mask):
110 | total = 0
111 |
112 | for scale in range(self.__scales):
113 | step = pow(2, scale)
114 |
115 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
116 | mask[:, ::step, ::step], reduction=self.__reduction)
117 |
118 | return total
119 |
120 |
121 | class ScaleAndShiftInvariantLoss(nn.Module):
122 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'):
123 | super().__init__()
124 |
125 | self.__data_loss = MSELoss(reduction=reduction)
126 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
127 | self.__alpha = alpha
128 |
129 | self.__prediction_ssi = None
130 |
131 | def forward(self, prediction, target, mask):
132 | scale, shift = compute_scale_and_shift(prediction, target, mask)
133 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)
134 |
135 | total = self.__data_loss(self.__prediction_ssi, target, mask)
136 | if self.__alpha > 0:
137 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)
138 |
139 | return total
140 |
141 | def __get_prediction_ssi(self):
142 | return self.__prediction_ssi
143 |
144 | prediction_ssi = property(__get_prediction_ssi)
145 | # end copy
146 |
147 |
148 | class MonoSDFLoss(nn.Module):
149 | def __init__(self, rgb_loss,
150 | eikonal_weight,
151 | smooth_weight = 0.005,
152 | depth_weight = 0.1,
153 | normal_l1_weight = 0.05,
154 | normal_cos_weight = 0.05,
155 | end_step = -1):
156 | super().__init__()
157 | self.eikonal_weight = eikonal_weight
158 | self.smooth_weight = smooth_weight
159 | self.depth_weight = depth_weight
160 | self.normal_l1_weight = normal_l1_weight
161 | self.normal_cos_weight = normal_cos_weight
162 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean')
163 |
164 | self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1)
165 |
166 | print(f"using weight for loss RGB_1.0 EK_{self.eikonal_weight} SM_{self.smooth_weight} Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}")
167 |
168 | self.step = 0
169 | self.end_step = end_step
170 |
171 | def get_rgb_loss(self,rgb_values, rgb_gt):
172 | rgb_gt = rgb_gt.reshape(-1, 3)
173 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt)
174 | return rgb_loss
175 |
176 | def get_eikonal_loss(self, grad_theta):
177 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
178 | return eikonal_loss
179 |
180 | def get_smooth_loss(self,model_outputs):
181 | # smoothness loss as unisurf
182 | g1 = model_outputs['grad_theta']
183 | g2 = model_outputs['grad_theta_nei']
184 |
185 | normals_1 = g1 / (g1.norm(2, dim=1).unsqueeze(-1) + 1e-5)
186 | normals_2 = g2 / (g2.norm(2, dim=1).unsqueeze(-1) + 1e-5)
187 | smooth_loss = torch.norm(normals_1 - normals_2, dim=-1).mean()
188 | return smooth_loss
189 |
190 | def get_depth_loss(self, depth_pred, depth_gt, mask):
191 | # TODO remove hard-coded scaling for depth
192 | return self.depth_loss(depth_pred.reshape(1, 32, 32), (depth_gt * 50 + 0.5).reshape(1, 32, 32), mask.reshape(1, 32, 32))
193 |
194 | def get_normal_loss(self, normal_pred, normal_gt):
195 | normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1)
196 | normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1)
197 | l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean()
198 | cos = (1. - torch.sum(normal_pred * normal_gt, dim = -1)).mean()
199 | return l1, cos
200 |
201 | def forward(self, model_outputs, ground_truth):
202 | rgb_gt = ground_truth['rgb'].cuda()
203 | # monocular depth and normal
204 | depth_gt = ground_truth['depth'].cuda()
205 | normal_gt = ground_truth['normal'].cuda()
206 |
207 | depth_pred = model_outputs['depth_values']
208 | normal_pred = model_outputs['normal_map'][None]
209 |
210 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)
211 |
212 | if 'grad_theta' in model_outputs:
213 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
214 | else:
215 | eikonal_loss = torch.tensor(0.0).cuda().float()
216 |
217 | # only supervised the foreground normal
218 | mask = ((model_outputs['sdf'] > 0.).any(dim=-1) & (model_outputs['sdf'] < 0.).any(dim=-1))[None, :, None]
219 | # combine with GT
220 | mask = (ground_truth['mask'] > 0.5).cuda() & mask
221 |
222 | depth_loss = self.get_depth_loss(depth_pred, depth_gt, mask)
223 | if isinstance(depth_loss, float):
224 | depth_loss = torch.tensor(0.0).cuda().float()
225 |
226 | normal_l1, normal_cos = self.get_normal_loss(normal_pred * mask, normal_gt)
227 |
228 | smooth_loss = self.get_smooth_loss(model_outputs)
229 |
230 | # compute decay weights
231 | if self.end_step > 0:
232 | decay = math.exp(-self.step / self.end_step * 10.)
233 | else:
234 | decay = 1.0
235 |
236 | self.step += 1
237 |
238 | loss = rgb_loss + \
239 | self.eikonal_weight * eikonal_loss +\
240 | self.smooth_weight * smooth_loss +\
241 | decay * self.depth_weight * depth_loss +\
242 | decay * self.normal_l1_weight * normal_l1 +\
243 | decay * self.normal_cos_weight * normal_cos
244 |
245 | output = {
246 | 'loss': loss,
247 | 'rgb_loss': rgb_loss,
248 | 'eikonal_loss': eikonal_loss,
249 | 'smooth_loss': smooth_loss,
250 | 'depth_loss': depth_loss,
251 | 'normal_l1': normal_l1,
252 | 'normal_cos': normal_cos
253 | }
254 |
255 | return output
256 |
257 |
258 | class RICOLoss(nn.Module):
259 | def __init__(self, rgb_loss,
260 | eikonal_weight,
261 | semantic_weight = 0.04,
262 | bg_render_weight = 0.0,
263 | lop_weight = 0.1,
264 | lrd_weight = 0.1,
265 | smooth_weight = 0.005,
266 | depth_weight = 0.1,
267 | normal_l1_weight = 0.05,
268 | normal_cos_weight = 0.05,
269 | end_step = -1,
270 | epsilon_param = 0.05):
271 | super().__init__()
272 | self.eikonal_weight = eikonal_weight
273 | self.smooth_weight = smooth_weight
274 | self.depth_weight = depth_weight
275 | self.normal_l1_weight = normal_l1_weight
276 | self.normal_cos_weight = normal_cos_weight
277 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean')
278 |
279 | self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1)
280 |
281 | self.semantic_weight = semantic_weight
282 | # self.semantic_loss = torch.nn.NLLLoss()
283 | self.semantic_loss = torch.nn.CrossEntropyLoss(ignore_index = -1)
284 |
285 | self.bg_render_weight = bg_render_weight # when use this loss, make sure the sampled idx is in patch
286 |
287 | self.lop_weight = lop_weight
288 | self.lrd_weight = lrd_weight
289 |
290 | print(f"using weight for loss RGB_1.0 SEMANTIC_{self.semantic_weight} \
291 | Lop_{self.lop_weight} Lrd_{self.lrd_weight} BG_RENDER_{self.bg_render_weight} \
292 | EK_{self.eikonal_weight} SM_{self.smooth_weight} \
293 | Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}")
294 |
295 | self.step = 0
296 | self.end_step = end_step
297 |
298 | self.epsilon_param = epsilon_param
299 |
300 | def get_rgb_loss(self,rgb_values, rgb_gt):
301 | rgb_gt = rgb_gt.reshape(-1, 3)
302 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt)
303 | return rgb_loss
304 |
305 | def get_eikonal_loss(self, grad_theta):
306 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
307 | return eikonal_loss
308 |
309 | def get_smooth_loss(self,model_outputs):
310 | # smoothness loss as unisurf
311 | g1 = model_outputs['grad_theta']
312 | g2 = model_outputs['grad_theta_nei']
313 |
314 | normals_1 = g1 / (g1.norm(2, dim=1).unsqueeze(-1) + 1e-5)
315 | normals_2 = g2 / (g2.norm(2, dim=1).unsqueeze(-1) + 1e-5)
316 | smooth_loss = torch.norm(normals_1 - normals_2, dim=-1).mean()
317 | return smooth_loss
318 |
319 | def get_depth_loss(self, depth_pred, depth_gt, mask):
320 | # TODO remove hard-coded scaling for depth
321 | return self.depth_loss(depth_pred.reshape(1, 32, 32), (depth_gt * 50 + 0.5).reshape(1, 32, 32), mask.reshape(1, 32, 32))
322 |
323 | def get_normal_loss(self, normal_pred, normal_gt):
324 | normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1)
325 | normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1)
326 | l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1)
327 | cos = (1. - torch.sum(normal_pred * normal_gt, dim = -1))
328 |
329 | l1 = l1.mean()
330 | cos = cos.mean()
331 |
332 | return l1, cos
333 |
334 | def get_semantic_loss(self, semantic_value, semantic_gt):
335 | semantic_gt = semantic_gt.squeeze()
336 | # semantic_loss = torch.nn.functional.nll_loss(semantic_value, semantic_gt)
337 | semantic_loss = self.semantic_loss(semantic_value, semantic_gt)
338 | # semantic_loss = self.semantic_loss(semantic_value, semantic_gt)
339 | return semantic_loss
340 |
341 | def get_bg_render_loss(self, bg_render_results, mask):
342 | bg_depth = bg_render_results['depth_values']
343 | bg_normal = bg_render_results['normal_map']
344 |
345 | bg_depth = bg_depth.reshape(1, 32, 32)
346 | bg_normal = bg_normal.reshape(32, 32, 3).permute(2, 0, 1)
347 |
348 | mask = mask.reshape(1, 32, 32)
349 |
350 | depth_grad = self.compute_grad_error(bg_depth, mask)
351 | normal_grad = self.compute_grad_error(bg_normal, mask.repeat(3, 1, 1))
352 |
353 | bg_render_loss = depth_grad + normal_grad
354 | return bg_render_loss
355 |
356 | def compute_grad_error(self, x, mask):
357 | scales = 4
358 | grad_loss = torch.tensor(0.0).cuda().float()
359 | for i in range(scales):
360 | step = pow(2, i)
361 |
362 | mask_step = mask[:, ::step, ::step]
363 | x_step = x[:, ::step, ::step]
364 |
365 | M = torch.sum(mask_step[:1], (1, 2))
366 |
367 | diff = torch.mul(mask_step, x_step)
368 |
369 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
370 | mask_x = torch.mul(mask_step[:, :, 1:], mask_step[:, :, :-1])
371 | grad_x = torch.mul(mask_x, grad_x)
372 |
373 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
374 | mask_y = torch.mul(mask_step[:, 1:, :], mask_step[:, :-1, :])
375 | grad_y = torch.mul(mask_y, grad_y)
376 |
377 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))
378 |
379 | divisor = torch.sum(M)
380 |
381 | if divisor == 0:
382 | scale_loss = torch.tensor(0.0).cuda().float()
383 | else:
384 | scale_loss = torch.sum(image_loss) / divisor
385 |
386 | grad_loss += scale_loss
387 |
388 | return grad_loss
389 |
390 | def get_lop_loss(self, obj_sdfs):
391 | margin_target = torch.ones(obj_sdfs.shape).cuda()
392 | # threshold = 0.05 * torch.ones(obj_sdfs.shape).cuda()
393 | threshold = self.epsilon_param * torch.ones(obj_sdfs.shape).cuda()
394 | loss = torch.nn.functional.margin_ranking_loss(obj_sdfs, threshold, margin_target)
395 |
396 | return loss
397 |
398 | def get_lrd_loss(self, obj_r_d, bg_r_d):
399 | if len(obj_r_d) == 0:
400 | loss = torch.tensor(0.0).cuda().float()
401 | return loss
402 |
403 | bg_r_d = bg_r_d.detach()
404 |
405 | obj_d = torch.where(obj_r_d > bg_r_d, bg_r_d, obj_r_d)
406 | loss = bg_r_d - obj_d
407 | loss = loss.mean()
408 |
409 | return loss
410 |
411 | def forward(self, model_outputs, ground_truth, iter_ratio=-1):
412 | rgb_gt = ground_truth['rgb'].cuda()
413 | # monocular depth and normal
414 | depth_gt = ground_truth['depth'].cuda()
415 | normal_gt = ground_truth['normal'].cuda()
416 |
417 | depth_pred = model_outputs['depth_values']
418 | normal_pred = model_outputs['normal_map'][None]
419 |
420 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)
421 |
422 | if 'grad_theta' in model_outputs:
423 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
424 | else:
425 | eikonal_loss = torch.tensor(0.0).cuda().float()
426 |
427 | # only supervised the foreground normal
428 | mask = ((model_outputs['sdf'] > 0.).any(dim=-1) & (model_outputs['sdf'] < 0.).any(dim=-1))[None, :, None]
429 | # combine with GT
430 | mask = (ground_truth['mask'] > 0.5).cuda() & mask
431 |
432 | depth_loss = self.get_depth_loss(depth_pred, depth_gt, mask)
433 | if isinstance(depth_loss, float):
434 | depth_loss = torch.tensor(0.0).cuda().float()
435 |
436 | normal_l1, normal_cos = self.get_normal_loss(normal_pred * mask, normal_gt)
437 |
438 | if 'grad_theta_nei' in model_outputs:
439 | smooth_loss = self.get_smooth_loss(model_outputs)
440 | else:
441 | smooth_loss = torch.tensor(0.0).cuda().float()
442 |
443 | if 'semantic_values' in model_outputs:
444 | semantic_gt = ground_truth["instance_mask"].cuda().long()
445 | semantic_loss = self.get_semantic_loss(model_outputs['semantic_values'], semantic_gt)
446 | else:
447 | semantic_loss = torch.tensor(0.0).cuda().float()
448 |
449 | # background render smooth loss
450 | if self.bg_render_weight > 0 and model_outputs['background_render'] is not None:
451 | bg_mask = torch.argmax(model_outputs['background_render']['semantic_values'], dim=-1, keepdim=True)
452 | bg_mask = bg_mask != 0
453 | bg_mask = bg_mask.int()
454 | bg_render_loss = self.get_bg_render_loss(model_outputs['background_render'], bg_mask)
455 | else:
456 | bg_render_loss = torch.tensor(0.0).cuda().float()
457 |
458 | # Object Point SDF Loss
459 | lop_loss = self.get_lop_loss(model_outputs['obj_sdfs_behind_bg'])
460 | if torch.isnan(lop_loss):
461 | lop_loss = torch.tensor(0.0).cuda().float()
462 | # Reversed Depth Loss
463 | lrd_loss = self.get_lrd_loss(model_outputs['obj_d_vals'], model_outputs['bg_d_vals'])
464 |
465 | # compute decay weights
466 | if self.end_step > 0:
467 | decay = math.exp(-self.step / self.end_step * 10.)
468 | else:
469 | decay = 1.0
470 |
471 | self.step += 1
472 |
473 | loss = rgb_loss + \
474 | self.bg_render_weight * bg_render_loss+\
475 | self.eikonal_weight * eikonal_loss +\
476 | self.semantic_weight * semantic_loss +\
477 | self.smooth_weight * smooth_loss +\
478 | self.lop_weight * lop_loss +\
479 | self.lrd_weight * lrd_loss +\
480 | decay * self.depth_weight * depth_loss +\
481 | decay * self.normal_l1_weight * normal_l1 +\
482 | decay * self.normal_cos_weight * normal_cos
483 |
484 | output = {
485 | 'loss': loss,
486 | 'rgb_loss': rgb_loss,
487 | 'eikonal_loss': eikonal_loss,
488 | 'bg_render_loss': bg_render_loss,
489 | 'lop_loss': lop_loss,
490 | 'lrd_loss': lrd_loss,
491 | 'semantic_loss': semantic_loss,
492 | 'smooth_loss': smooth_loss,
493 | 'depth_loss': depth_loss,
494 | 'normal_l1': normal_l1,
495 | 'normal_cos': normal_cos
496 | }
497 |
498 | return output
--------------------------------------------------------------------------------
/code/model/network.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import numpy as np
4 |
5 | from utils import rend_util
6 | from model.embedder import *
7 | from model.density import LaplaceDensity
8 | from model.ray_sampler import ErrorBoundSampler
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 |
12 | class ImplicitNetwork(nn.Module):
13 | def __init__(
14 | self,
15 | feature_vector_size,
16 | sdf_bounding_sphere,
17 | d_in,
18 | d_out,
19 | dims,
20 | geometric_init=True,
21 | bias=1.0,
22 | skip_in=(),
23 | weight_norm=True,
24 | multires=0,
25 | sphere_scale=1.0,
26 | inside_outside=False,
27 | ):
28 | super().__init__()
29 |
30 | self.sdf_bounding_sphere = sdf_bounding_sphere
31 | self.sphere_scale = sphere_scale
32 | dims = [d_in] + dims + [d_out + feature_vector_size]
33 |
34 | self.embed_fn = None
35 | if multires > 0:
36 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
37 | self.embed_fn = embed_fn
38 | dims[0] = input_ch
39 | print(multires, dims)
40 | self.num_layers = len(dims)
41 | self.skip_in = skip_in
42 |
43 | for l in range(0, self.num_layers - 1):
44 | if l + 1 in self.skip_in:
45 | out_dim = dims[l + 1] - dims[0]
46 | else:
47 | out_dim = dims[l + 1]
48 |
49 | lin = nn.Linear(dims[l], out_dim)
50 |
51 | if geometric_init:
52 | if l == self.num_layers - 2:
53 | if not inside_outside:
54 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
55 | torch.nn.init.constant_(lin.bias, -bias)
56 | else:
57 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
58 | torch.nn.init.constant_(lin.bias, bias)
59 |
60 | elif multires > 0 and l == 0:
61 | torch.nn.init.constant_(lin.bias, 0.0)
62 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
63 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
64 | elif multires > 0 and l in self.skip_in:
65 | torch.nn.init.constant_(lin.bias, 0.0)
66 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
67 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
68 | else:
69 | torch.nn.init.constant_(lin.bias, 0.0)
70 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
71 |
72 | if weight_norm:
73 | lin = nn.utils.weight_norm(lin)
74 |
75 | setattr(self, "lin" + str(l), lin)
76 |
77 | self.softplus = nn.Softplus(beta=100)
78 |
79 | def forward(self, input):
80 | if self.embed_fn is not None:
81 | input = self.embed_fn(input)
82 |
83 | x = input
84 |
85 | for l in range(0, self.num_layers - 1):
86 | lin = getattr(self, "lin" + str(l))
87 |
88 | if l in self.skip_in:
89 | x = torch.cat([x, input], 1) / np.sqrt(2)
90 |
91 | x = lin(x)
92 |
93 | if l < self.num_layers - 2:
94 | x = self.softplus(x)
95 |
96 | return x
97 |
98 | def gradient(self, x):
99 | x.requires_grad_(True)
100 | y = self.forward(x)[:,:1]
101 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
102 | gradients = torch.autograd.grad(
103 | outputs=y,
104 | inputs=x,
105 | grad_outputs=d_output,
106 | create_graph=True,
107 | retain_graph=True,
108 | only_inputs=True)[0]
109 | return gradients
110 |
111 | def get_outputs(self, x):
112 | x.requires_grad_(True)
113 | output = self.forward(x)
114 | sdf = output[:,:1]
115 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
116 | if self.sdf_bounding_sphere > 0.0:
117 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
118 | sdf = torch.minimum(sdf, sphere_sdf)
119 | feature_vectors = output[:, 1:]
120 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
121 | gradients = torch.autograd.grad(
122 | outputs=sdf,
123 | inputs=x,
124 | grad_outputs=d_output,
125 | create_graph=True,
126 | retain_graph=True,
127 | only_inputs=True)[0]
128 |
129 | return sdf, feature_vectors, gradients
130 |
131 | def get_sdf_vals(self, x):
132 | sdf = self.forward(x)[:,:1]
133 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
134 | if self.sdf_bounding_sphere > 0.0:
135 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
136 | sdf = torch.minimum(sdf, sphere_sdf)
137 | return sdf
138 |
139 |
140 | # from hashencoder.hashgrid import _hash_encode, HashEncoder
141 | class ImplicitNetworkGrid(nn.Module):
142 | def __init__(
143 | self,
144 | feature_vector_size,
145 | sdf_bounding_sphere,
146 | d_in,
147 | d_out,
148 | dims,
149 | geometric_init=True,
150 | bias=1.0,
151 | skip_in=(),
152 | weight_norm=True,
153 | multires=0,
154 | sphere_scale=1.0,
155 | inside_outside=False,
156 | base_size = 16,
157 | end_size = 2048,
158 | logmap = 19,
159 | num_levels=16,
160 | level_dim=2,
161 | divide_factor = 1.5, # used to normalize the points range for multi-res grid
162 | use_grid_feature = True
163 | ):
164 | super().__init__()
165 |
166 | self.sdf_bounding_sphere = sdf_bounding_sphere
167 | self.sphere_scale = sphere_scale
168 | dims = [d_in] + dims + [d_out + feature_vector_size]
169 | self.embed_fn = None
170 | self.divide_factor = divide_factor
171 | self.grid_feature_dim = num_levels * level_dim
172 | self.use_grid_feature = use_grid_feature
173 | dims[0] += self.grid_feature_dim
174 |
175 | print(f"using hash encoder with {num_levels} levels, each level with feature dim {level_dim}")
176 | print(f"resolution:{base_size} -> {end_size} with hash map size {logmap}")
177 | # self.encoding = HashEncoder(input_dim=3, num_levels=num_levels, level_dim=level_dim,
178 | # per_level_scale=2, base_resolution=base_size,
179 | # log2_hashmap_size=logmap, desired_resolution=end_size)
180 |
181 | '''
182 | # can also use tcnn for multi-res grid as it now supports eikonal loss
183 | base_size = 16
184 | hash = True
185 | smoothstep = True
186 | self.encoding = tcnn.Encoding(3, {
187 | "otype": "HashGrid" if hash else "DenseGrid",
188 | "n_levels": 16,
189 | "n_features_per_level": 2,
190 | "log2_hashmap_size": 19,
191 | "base_resolution": base_size,
192 | "per_level_scale": 1.34,
193 | "interpolation": "Smoothstep" if smoothstep else "Linear"
194 | })
195 | '''
196 |
197 | if multires > 0:
198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
199 | self.embed_fn = embed_fn
200 | dims[0] += input_ch - 3
201 | print("network architecture")
202 | print(dims)
203 |
204 | self.num_layers = len(dims)
205 | self.skip_in = skip_in
206 |
207 | for l in range(0, self.num_layers - 1):
208 | if l + 1 in self.skip_in:
209 | out_dim = dims[l + 1] - dims[0]
210 | else:
211 | out_dim = dims[l + 1]
212 |
213 | lin = nn.Linear(dims[l], out_dim)
214 |
215 | if geometric_init:
216 | if l == self.num_layers - 2:
217 | if not inside_outside:
218 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
219 | torch.nn.init.constant_(lin.bias, -bias)
220 | else:
221 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
222 | torch.nn.init.constant_(lin.bias, bias)
223 |
224 | elif multires > 0 and l == 0:
225 | torch.nn.init.constant_(lin.bias, 0.0)
226 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
227 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
228 | elif multires > 0 and l in self.skip_in:
229 | torch.nn.init.constant_(lin.bias, 0.0)
230 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
231 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
232 | else:
233 | torch.nn.init.constant_(lin.bias, 0.0)
234 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
235 |
236 | if weight_norm:
237 | lin = nn.utils.weight_norm(lin)
238 |
239 | setattr(self, "lin" + str(l), lin)
240 |
241 | self.softplus = nn.Softplus(beta=100)
242 | self.cache_sdf = None
243 |
244 | def forward(self, input):
245 | if self.use_grid_feature:
246 | # normalize point range as encoding assume points are in [-1, 1]
247 | feature = self.encoding(input / self.divide_factor)
248 | else:
249 | feature = torch.zeros_like(input[:, :1].repeat(1, self.grid_feature_dim))
250 |
251 | if self.embed_fn is not None:
252 | embed = self.embed_fn(input)
253 | input = torch.cat((embed, feature), dim=-1)
254 | else:
255 | input = torch.cat((input, feature), dim=-1)
256 |
257 | x = input
258 |
259 | for l in range(0, self.num_layers - 1):
260 | lin = getattr(self, "lin" + str(l))
261 |
262 | if l in self.skip_in:
263 | x = torch.cat([x, input], 1) / np.sqrt(2)
264 |
265 | x = lin(x)
266 |
267 | if l < self.num_layers - 2:
268 | x = self.softplus(x)
269 |
270 | return x
271 |
272 | def gradient(self, x):
273 | x.requires_grad_(True)
274 | y = self.forward(x)[:,:1]
275 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
276 | gradients = torch.autograd.grad(
277 | outputs=y,
278 | inputs=x,
279 | grad_outputs=d_output,
280 | create_graph=True,
281 | retain_graph=True,
282 | only_inputs=True)[0]
283 | return gradients
284 |
285 | def get_outputs(self, x):
286 | x.requires_grad_(True)
287 | output = self.forward(x)
288 | sdf = output[:,:1]
289 |
290 | feature_vectors = output[:, 1:]
291 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
292 | gradients = torch.autograd.grad(
293 | outputs=sdf,
294 | inputs=x,
295 | grad_outputs=d_output,
296 | create_graph=True,
297 | retain_graph=True,
298 | only_inputs=True)[0]
299 |
300 | return sdf, feature_vectors, gradients
301 |
302 | def get_sdf_vals(self, x):
303 | sdf = self.forward(x)[:,:1]
304 | return sdf
305 |
306 | def mlp_parameters(self):
307 | parameters = []
308 | for l in range(0, self.num_layers - 1):
309 | lin = getattr(self, "lin" + str(l))
310 | parameters += list(lin.parameters())
311 | return parameters
312 |
313 | def grid_parameters(self):
314 | print("grid parameters", len(list(self.encoding.parameters())))
315 | for p in self.encoding.parameters():
316 | print(p.shape)
317 | return self.encoding.parameters()
318 |
319 |
320 | class RenderingNetwork(nn.Module):
321 | def __init__(
322 | self,
323 | feature_vector_size,
324 | mode,
325 | d_in,
326 | d_out,
327 | dims,
328 | weight_norm=True,
329 | multires_view=0,
330 | per_image_code = False
331 | ):
332 | super().__init__()
333 |
334 | self.mode = mode
335 | dims = [d_in + feature_vector_size] + dims + [d_out]
336 |
337 | self.embedview_fn = None
338 | if multires_view > 0:
339 | embedview_fn, input_ch = get_embedder(multires_view)
340 | self.embedview_fn = embedview_fn
341 | dims[0] += (input_ch - 3)
342 |
343 | self.per_image_code = per_image_code
344 | if self.per_image_code:
345 | # nerf in the wild parameter
346 | # parameters
347 | # maximum 1024 images
348 | self.embeddings = nn.Parameter(torch.empty(1024, 32))
349 | std = 1e-4
350 | self.embeddings.data.uniform_(-std, std)
351 | dims[0] += 32
352 |
353 | print("rendering network architecture:")
354 | print(dims)
355 |
356 | self.num_layers = len(dims)
357 |
358 | for l in range(0, self.num_layers - 1):
359 | out_dim = dims[l + 1]
360 | lin = nn.Linear(dims[l], out_dim)
361 |
362 | if weight_norm:
363 | lin = nn.utils.weight_norm(lin)
364 |
365 | setattr(self, "lin" + str(l), lin)
366 |
367 | self.relu = nn.ReLU()
368 | self.sigmoid = torch.nn.Sigmoid()
369 |
370 | def forward(self, points, normals, view_dirs, feature_vectors, indices):
371 | if self.embedview_fn is not None:
372 | view_dirs = self.embedview_fn(view_dirs)
373 |
374 | if self.mode == 'idr':
375 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
376 | elif self.mode == 'nerf':
377 | rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1)
378 | else:
379 | raise NotImplementedError
380 |
381 | if self.per_image_code:
382 | image_code = self.embeddings[indices].expand(rendering_input.shape[0], -1)
383 | rendering_input = torch.cat([rendering_input, image_code], dim=-1)
384 |
385 | x = rendering_input
386 |
387 | for l in range(0, self.num_layers - 1):
388 | lin = getattr(self, "lin" + str(l))
389 |
390 | x = lin(x)
391 |
392 | if l < self.num_layers - 2:
393 | x = self.relu(x)
394 |
395 | x = self.sigmoid(x)
396 | return x
397 |
398 |
399 | class MonoSDFNetwork(nn.Module):
400 | def __init__(self, conf):
401 | super().__init__()
402 | self.feature_vector_size = conf.get_int('feature_vector_size')
403 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0)
404 | self.white_bkgd = conf.get_bool('white_bkgd', default=False)
405 | self.bg_color = torch.tensor(conf.get_list("bg_color", default=[1.0, 1.0, 1.0])).float().cuda()
406 |
407 | Grid_MLP = conf.get_bool('Grid_MLP', default=False)
408 | self.Grid_MLP = Grid_MLP
409 | if Grid_MLP:
410 | self.implicit_network = ImplicitNetworkGrid(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network'))
411 | else:
412 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network'))
413 |
414 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network'))
415 |
416 | self.density = LaplaceDensity(**conf.get_config('density'))
417 | sampling_method = conf.get_string('sampling_method', default="errorbounded")
418 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, **conf.get_config('ray_sampler'))
419 |
420 |
421 | def forward(self, input, indices):
422 | # Parse model input
423 | intrinsics = input["intrinsics"]
424 | uv = input["uv"]
425 | pose = input["pose"]
426 |
427 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics)
428 |
429 | # we should use unnormalized ray direction for depth
430 | ray_dirs_tmp, _ = rend_util.get_camera_params(uv, torch.eye(4).to(pose.device)[None], intrinsics)
431 | depth_scale = ray_dirs_tmp[0, :, 2:]
432 |
433 | batch_size, num_pixels, _ = ray_dirs.shape
434 |
435 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
436 | ray_dirs = ray_dirs.reshape(-1, 3)
437 |
438 |
439 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)
440 | N_samples = z_vals.shape[1]
441 |
442 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)
443 | points_flat = points.reshape(-1, 3)
444 |
445 |
446 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
447 | dirs_flat = dirs.reshape(-1, 3)
448 |
449 | sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat)
450 |
451 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors, indices)
452 | rgb = rgb_flat.reshape(-1, N_samples, 3)
453 |
454 | weights = self.volume_rendering(z_vals, sdf)
455 |
456 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)
457 |
458 | depth_values = torch.sum(weights * z_vals, 1, keepdims=True) / (weights.sum(dim=1, keepdims=True) +1e-8)
459 | # we should scale rendered distance to depth along z direction
460 | depth_values = depth_scale * depth_values
461 |
462 | # white background assumption
463 | if self.white_bkgd:
464 | acc_map = torch.sum(weights, -1)
465 | rgb_values = rgb_values + (1. - acc_map[..., None]) * self.bg_color.unsqueeze(0)
466 |
467 | output = {
468 | 'rgb':rgb,
469 | 'rgb_values': rgb_values,
470 | 'depth_values': depth_values,
471 | 'z_vals': z_vals,
472 | 'depth_vals': z_vals * depth_scale,
473 | 'sdf': sdf.reshape(z_vals.shape),
474 | 'weights': weights,
475 | }
476 |
477 | if self.training:
478 | # Sample points for the eikonal loss
479 | n_eik_points = batch_size * num_pixels
480 |
481 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda()
482 |
483 | # add some of the near surface points
484 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)
485 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)
486 | # add some neighbour points as unisurf
487 | neighbour_points = eikonal_points + (torch.rand_like(eikonal_points) - 0.5) * 0.01
488 | eikonal_points = torch.cat([eikonal_points, neighbour_points], 0)
489 |
490 | grad_theta = self.implicit_network.gradient(eikonal_points)
491 |
492 | # split gradient to eikonal points and heighbour ponits
493 | output['grad_theta'] = grad_theta[:grad_theta.shape[0]//2]
494 | output['grad_theta_nei'] = grad_theta[grad_theta.shape[0]//2:]
495 |
496 | # compute normal map
497 | normals = gradients / (gradients.norm(2, -1, keepdim=True) + 1e-6)
498 | normals = normals.reshape(-1, N_samples, 3)
499 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
500 |
501 | # transform to local coordinate system
502 | rot = pose[0, :3, :3].permute(1, 0).contiguous()
503 | normal_map = rot @ normal_map.permute(1, 0)
504 | normal_map = normal_map.permute(1, 0).contiguous()
505 |
506 | output['normal_map'] = normal_map
507 |
508 | return output
509 |
510 | def volume_rendering(self, z_vals, sdf):
511 | density_flat = self.density(sdf)
512 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
513 |
514 | dists = z_vals[:, 1:] - z_vals[:, :-1]
515 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
516 |
517 | # LOG SPACE
518 | free_energy = dists * density
519 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) # shift one step
520 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
521 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
522 | weights = alpha * transmittance # probability of the ray hits something here
523 |
524 | return weights
525 |
--------------------------------------------------------------------------------
/code/model/ray_sampler.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from tkinter.messagebox import NO
3 | import torch
4 |
5 | from utils import rend_util
6 |
7 | class RaySampler(metaclass=abc.ABCMeta):
8 | def __init__(self,near, far):
9 | self.near = near
10 | self.far = far
11 |
12 | @abc.abstractmethod
13 | def get_z_vals(self, ray_dirs, cam_loc, model):
14 | pass
15 |
16 | class UniformSampler(RaySampler):
17 | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
18 | #super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
19 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75 if far == -1 else far) # default far is 2*R
20 | self.N_samples = N_samples
21 | self.scene_bounding_sphere = scene_bounding_sphere
22 | self.take_sphere_intersection = take_sphere_intersection
23 |
24 | # dtu and bmvs
25 | def get_z_vals_dtu_bmvs(self, ray_dirs, cam_loc, model):
26 | if not self.take_sphere_intersection:
27 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
28 | else:
29 | sphere_intersections = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
30 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
31 | far = sphere_intersections[:,1:]
32 |
33 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
34 | z_vals = near * (1. - t_vals) + far * (t_vals)
35 |
36 | if model.training:
37 | # get intervals between samples
38 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
39 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
40 | lower = torch.cat([z_vals[..., :1], mids], -1)
41 | # stratified samples in those intervals
42 | t_rand = torch.rand(z_vals.shape).cuda()
43 |
44 | z_vals = lower + (upper - lower) * t_rand
45 |
46 | return z_vals, near, far
47 |
48 | def near_far_from_cube(self, rays_o, rays_d, bound):
49 | tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3]
50 | tmax = (bound - rays_o) / (rays_d + 1e-15)
51 | near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0]
52 | far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0]
53 | # if far < near, means no intersection, set both near and far to inf (1e9 here)
54 | mask = far < near
55 | near[mask] = 1e9
56 | far[mask] = 1e9
57 | # restrict near to a minimal value
58 | near = torch.clamp(near, min=self.near)
59 | far = torch.clamp(far, max=self.far)
60 | return near, far
61 |
62 | # currently this is used for replica scannet and T&T
63 | def get_z_vals(self, ray_dirs, cam_loc, model):
64 | if not self.take_sphere_intersection:
65 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
66 | else:
67 | _, far = self.near_far_from_cube(cam_loc, ray_dirs, bound=self.scene_bounding_sphere)
68 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
69 |
70 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
71 | z_vals = near * (1. - t_vals) + far * (t_vals)
72 |
73 | if model.training:
74 | # get intervals between samples
75 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
76 | upper = torch.cat([mids, z_vals[..., -1:]], -1)
77 | lower = torch.cat([z_vals[..., :1], mids], -1)
78 | # stratified samples in those intervals
79 | t_rand = torch.rand(z_vals.shape).cuda()
80 |
81 | z_vals = lower + (upper - lower) * t_rand
82 |
83 | return z_vals, near, far
84 |
85 |
86 | class ErrorBoundSampler(RaySampler):
87 | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
88 | eps, beta_iters, max_total_iters,
89 | inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=1.0e-6):
90 | #super().__init__(near, 2.0 * scene_bounding_sphere)
91 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75)
92 |
93 | self.N_samples = N_samples
94 | self.N_samples_eval = N_samples_eval
95 | self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=True) # replica scannet and T&T courtroom
96 | #self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) # dtu and bmvs
97 |
98 | self.N_samples_extra = N_samples_extra
99 |
100 | self.eps = eps
101 | self.beta_iters = beta_iters
102 | self.max_total_iters = max_total_iters
103 | self.scene_bounding_sphere = scene_bounding_sphere
104 | self.add_tiny = add_tiny
105 |
106 | self.inverse_sphere_bg = inverse_sphere_bg
107 | if inverse_sphere_bg:
108 | self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
109 |
110 | def get_z_vals(self, ray_dirs, cam_loc, model):
111 | with torch.no_grad():
112 | beta0 = model.density.get_beta().detach()
113 |
114 | # Start with uniform sampling
115 | z_vals, near, far = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
116 | samples, samples_idx = z_vals, None
117 |
118 | # Get maximum beta from the upper bound (Lemma 2)
119 | dists = z_vals[:, 1:] - z_vals[:, :-1]
120 | bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
121 | beta = torch.sqrt(bound)
122 |
123 | total_iters, not_converge = 0, True
124 |
125 | # Algorithm 1
126 | while not_converge and total_iters < self.max_total_iters:
127 | points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
128 | points_flat = points.reshape(-1, 3)
129 |
130 | # Calculating the SDF only for the new sampled points
131 | # with torch.no_grad():
132 | samples_sdf = model.implicit_network.get_sdf_vals(points_flat)
133 | if samples_idx is not None:
134 | sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
135 | samples_sdf.reshape(-1, samples.shape[1])], -1)
136 | sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
137 | else:
138 | sdf = samples_sdf
139 |
140 |
141 | # Calculating the bound d* (Theorem 1)
142 | d = sdf.reshape(z_vals.shape)
143 | dists = z_vals[:, 1:] - z_vals[:, :-1]
144 | a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
145 | first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
146 | second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
147 | d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
148 | d_star[first_cond] = b[first_cond]
149 | d_star[second_cond] = c[second_cond]
150 | s = (a + b + c) / 2.0
151 | area_before_sqrt = s * (s - a) * (s - b) * (s - c)
152 | mask = ~first_cond & ~second_cond & (b + c - a > 0)
153 | d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
154 | d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
155 |
156 |
157 | # Updating beta using line search
158 | curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
159 | beta[curr_error <= self.eps] = beta0
160 | beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
161 | for j in range(self.beta_iters):
162 | beta_mid = (beta_min + beta_max) / 2.
163 | curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
164 | beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
165 | beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
166 | beta = beta_max
167 |
168 | # Upsample more points
169 | density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
170 |
171 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
172 | free_energy = dists * density
173 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
174 | alpha = 1 - torch.exp(-free_energy)
175 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
176 | weights = alpha * transmittance # probability of the ray hits something here
177 |
178 | # Check if we are done and this is the last sampling
179 | total_iters += 1
180 | not_converge = beta.max() > beta0
181 |
182 | if not_converge and total_iters < self.max_total_iters:
183 | ''' Sample more points proportional to the current error bound'''
184 |
185 | N = self.N_samples_eval
186 |
187 | bins = z_vals
188 | error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
189 | error_integral = torch.cumsum(error_per_section, dim=-1)
190 | bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
191 |
192 | pdf = bound_opacity + self.add_tiny
193 | pdf = pdf / torch.sum(pdf, -1, keepdim=True)
194 | cdf = torch.cumsum(pdf, -1)
195 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
196 |
197 | else:
198 | ''' Sample the final sample set to be used in the volume rendering integral '''
199 |
200 | N = self.N_samples
201 |
202 | bins = z_vals
203 | pdf = weights[..., :-1]
204 | pdf = pdf + 1e-5 # prevent nans
205 | pdf = pdf / torch.sum(pdf, -1, keepdim=True)
206 | cdf = torch.cumsum(pdf, -1)
207 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
208 |
209 |
210 | # Invert CDF
211 | if (not_converge and total_iters < self.max_total_iters) or (not model.training):
212 | u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
213 | else:
214 | u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
215 | u = u.contiguous()
216 |
217 | inds = torch.searchsorted(cdf, u, right=True)
218 | below = torch.max(torch.zeros_like(inds - 1), inds - 1)
219 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
220 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
221 |
222 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
223 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
224 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
225 |
226 | denom = (cdf_g[..., 1] - cdf_g[..., 0])
227 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
228 | t = (u - cdf_g[..., 0]) / denom
229 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
230 |
231 |
232 | # Adding samples if we not converged
233 | if not_converge and total_iters < self.max_total_iters:
234 | z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
235 |
236 |
237 | z_samples = samples
238 | #TODO Use near and far from intersection
239 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
240 | if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
241 | far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
242 |
243 | if self.N_samples_extra > 0:
244 | if model.training:
245 | sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
246 | else:
247 | sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
248 | z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
249 | else:
250 | z_vals_extra = torch.cat([near, far], -1)
251 |
252 | z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
253 |
254 | # add some of the near surface points
255 | idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
256 | z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
257 |
258 | if self.inverse_sphere_bg:
259 | z_vals_inverse_sphere, _, _ = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
260 | z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
261 | z_vals = (z_vals, z_vals_inverse_sphere)
262 |
263 | return z_vals, z_samples_eik
264 |
265 | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
266 | density = model.density(sdf.reshape(z_vals.shape), beta=beta)
267 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
268 | integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
269 | error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
270 | error_integral = torch.cumsum(error_per_section, dim=-1)
271 | bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
272 |
273 | return bound_opacity.max(-1)[0]
--------------------------------------------------------------------------------
/code/slurm_run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 |
4 | PARTITION=$1
5 | NUM_NODES=1
6 | NUM_GPUS_PER_NODE=1
7 | CFG_PATH=$2
8 | SCAN_ID=$3
9 | PORT=$4
10 |
11 | srun -p ${PARTITION} \
12 | -N ${NUM_NODES} \
13 | --gres=gpu:${NUM_GPUS_PER_NODE} \
14 | --cpus-per-task=4 \
15 | -t 5-00:00:00 \
16 | python training/exp_runner.py --conf $CFG_PATH --scan_id $SCAN_ID --port $PORT
--------------------------------------------------------------------------------
/code/training/exp_runner.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../code')
4 | import argparse
5 | import torch
6 | import random
7 | import numpy as np
8 |
9 | import os
10 | import subprocess
11 | import datetime
12 | from training.rico_train import RICOTrainRunner
13 |
14 |
15 | if __name__ == '__main__':
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
19 | parser.add_argument('--nepoch', type=int, default=2000, help='number of epochs to train for')
20 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf')
21 | parser.add_argument('--expname', type=str, default='')
22 | parser.add_argument("--exps_folder", type=str, default="exps")
23 | #parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]')
24 | parser.add_argument('--is_continue', default=False, action="store_true",
25 | help='If set, indicates continuing from a previous run.')
26 | parser.add_argument('--timestamp', default='latest', type=str,
27 | help='The timestamp of the run to be used in case of continuing from a previous run.')
28 | parser.add_argument('--checkpoint', default='latest', type=str,
29 | help='The checkpoint epoch of the run to be used in case of continuing from a previous run.')
30 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')
31 | parser.add_argument('--cancel_vis', default=False, action="store_true",
32 | help='If set, cancel visualization in intermediate epochs.')
33 | parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel')
34 | parser.add_argument('--port', type=int, default=29500)
35 |
36 | opt = parser.parse_args()
37 |
38 | '''
39 | if opt.gpu == "auto":
40 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False,
41 | excludeID=[], excludeUUID=[])
42 | gpu = deviceIDs[0]
43 | else:
44 | gpu = opt.gpu
45 | '''
46 | gpu = opt.local_rank
47 |
48 | random.seed(0)
49 | np.random.seed(0)
50 | torch.manual_seed(0)
51 | torch.cuda.manual_seed_all(0)
52 | # torch.backends.cudnn.deterministic = True
53 | # torch.backends.cudnn.benchmark = False
54 | torch.backends.cudnn.benchmark = True
55 |
56 | # set distributed training
57 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
58 | rank = int(os.environ["RANK"])
59 | world_size = int(os.environ['WORLD_SIZE'])
60 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
61 | elif 'SLURM_PROCID' in os.environ:
62 | proc_id = int(os.environ['SLURM_PROCID'])
63 | ntasks = int(os.environ['SLURM_NTASKS'])
64 | node_list = os.environ['SLURM_NODELIST']
65 | num_gpus = torch.cuda.device_count()
66 | addr = subprocess.getoutput(
67 | 'scontrol show hostname {} | head -n1'.format(node_list)
68 | )
69 | port_str = str(opt.port)
70 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', port_str)
71 | os.environ['MASTER_ADDR'] = addr
72 | os.environ['WORLD_SIZE'] = str(ntasks)
73 | os.environ['RANK'] = str(proc_id)
74 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
75 | os.environ['LOCAL_SIZE'] = str(num_gpus)
76 | rank = proc_id
77 | world_size = ntasks
78 | print(f"RANK and WORLD_SIZE in SLURM environ: {rank}/{world_size}")
79 | else:
80 | rank = -1
81 | world_size = -1
82 |
83 | print(opt.local_rank)
84 | torch.cuda.set_device(opt.local_rank)
85 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank, timeout=datetime.timedelta(1, 1800))
86 | torch.distributed.barrier()
87 |
88 | torch.autograd.set_detect_anomaly(True)
89 |
90 | trainrunner = RICOTrainRunner(
91 | conf=opt.conf,
92 | batch_size=opt.batch_size,
93 | nepochs=opt.nepoch,
94 | expname=opt.expname,
95 | gpu_index=gpu,
96 | exps_folder_name=opt.exps_folder,
97 | is_continue=opt.is_continue,
98 | timestamp=opt.timestamp,
99 | checkpoint=opt.checkpoint,
100 | scan_id=opt.scan_id,
101 | do_vis=not opt.cancel_vis
102 | )
103 |
104 | trainrunner.run()
105 |
--------------------------------------------------------------------------------
/code/training/rico_train.py:
--------------------------------------------------------------------------------
1 | import imp
2 | import os
3 | from datetime import datetime
4 | from pyhocon import ConfigFactory
5 | import sys
6 | import torch
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 | import utils.general as utils
11 | import utils.plots as plt
12 | from utils import rend_util
13 | from utils.general import get_time
14 | from torch.utils.tensorboard import SummaryWriter
15 | from model.loss import compute_scale_and_shift
16 | from utils.general import BackprojectDepth
17 |
18 | import torch.distributed as dist
19 |
20 | class RICOTrainRunner():
21 | def __init__(self,**kwargs):
22 | torch.set_default_dtype(torch.float32)
23 | torch.set_num_threads(1)
24 |
25 | self.conf = ConfigFactory.parse_file(kwargs['conf'])
26 | self.batch_size = kwargs['batch_size']
27 | self.nepochs = kwargs['nepochs']
28 | self.exps_folder_name = kwargs['exps_folder_name']
29 | self.GPU_INDEX = kwargs['gpu_index']
30 |
31 | self.expname = self.conf.get_string('train.expname') + kwargs['expname']
32 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1)
33 | if scan_id != -1:
34 | self.expname = self.expname + '_{0}'.format(scan_id)
35 |
36 | if kwargs['is_continue'] and kwargs['timestamp'] == 'latest':
37 | if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)):
38 | timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname))
39 | if (len(timestamps)) == 0:
40 | is_continue = False
41 | timestamp = None
42 | else:
43 | timestamp = sorted(timestamps)[-1]
44 | is_continue = True
45 | else:
46 | is_continue = False
47 | timestamp = None
48 | else:
49 | timestamp = kwargs['timestamp']
50 | is_continue = kwargs['is_continue']
51 |
52 | if self.GPU_INDEX == 0:
53 | utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name))
54 | self.expdir = os.path.join('../', self.exps_folder_name, self.expname)
55 | utils.mkdir_ifnotexists(self.expdir)
56 | self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
57 | utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp))
58 |
59 | self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots')
60 | utils.mkdir_ifnotexists(self.plots_dir)
61 |
62 | # create checkpoints dirs
63 | self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints')
64 | utils.mkdir_ifnotexists(self.checkpoints_path)
65 | self.model_params_subdir = "ModelParameters"
66 | self.optimizer_params_subdir = "OptimizerParameters"
67 | self.scheduler_params_subdir = "SchedulerParameters"
68 |
69 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir))
70 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir))
71 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir))
72 |
73 | os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf')))
74 |
75 | print('shell command : {0}'.format(' '.join(sys.argv)))
76 |
77 | print('Loading data ...')
78 |
79 | dataset_conf = self.conf.get_config('dataset')
80 | if kwargs['scan_id'] != -1:
81 | dataset_conf['scan_id'] = kwargs['scan_id']
82 |
83 | self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(**dataset_conf)
84 |
85 | self.max_total_iters = self.conf.get_int('train.max_total_iters', default=50000)
86 | self.ds_len = len(self.train_dataset)
87 | print('Finish loading data. Data-set size: {0}'.format(self.ds_len))
88 | # use total iterations to compute how many epochs
89 | self.nepochs = int(self.max_total_iters / self.ds_len)
90 | print('RUNNING FOR {0}'.format(self.nepochs))
91 |
92 | if len(self.train_dataset.label_mapping) > 0:
93 | # a hack way to let network know how many categories, so don't need to manually set in config file
94 | self.conf['model']['implicit_network']['d_out'] = len(self.train_dataset.label_mapping)
95 | print('RUNNING FOR {0} CLASSES'.format(len(self.train_dataset.label_mapping)))
96 |
97 | self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset,
98 | batch_size=self.batch_size,
99 | shuffle=True,
100 | collate_fn=self.train_dataset.collate_fn,
101 | num_workers=4)
102 | self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset,
103 | batch_size=self.conf.get_int('plot.plot_nimgs'),
104 | shuffle=True,
105 | collate_fn=self.train_dataset.collate_fn
106 | )
107 |
108 | conf_model = self.conf.get_config('model')
109 | instance_ids = self.train_dataset.instance_ids
110 | print('Instance IDs: ', instance_ids)
111 | print('Label mappings: ', self.train_dataset.label_mapping)
112 |
113 | self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=conf_model)
114 |
115 | if torch.cuda.is_available():
116 | self.model.cuda()
117 |
118 | self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss'))
119 |
120 | self.lr = self.conf.get_float('train.learning_rate')
121 |
122 | # current model uses MLP and a unified lr
123 | print('using optimizer w unified lr')
124 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.99), eps=1e-15)
125 |
126 | # Exponential learning rate scheduler
127 | decay_rate = self.conf.get_float('train.sched_decay_rate', default=0.1)
128 | decay_steps = self.nepochs * len(self.train_dataset)
129 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, decay_rate ** (1./decay_steps))
130 |
131 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.GPU_INDEX], broadcast_buffers=False, find_unused_parameters=True)
132 |
133 | self.do_vis = kwargs['do_vis']
134 |
135 | self.start_epoch = 0
136 | if is_continue:
137 | old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints')
138 |
139 | saved_model_state = torch.load(
140 | os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth"))
141 | self.model.load_state_dict(saved_model_state["model_state_dict"])
142 | self.start_epoch = saved_model_state['epoch']
143 |
144 | data = torch.load(
145 | os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth"))
146 | self.optimizer.load_state_dict(data["optimizer_state_dict"])
147 |
148 | data = torch.load(
149 | os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth"))
150 | self.scheduler.load_state_dict(data["scheduler_state_dict"])
151 |
152 | self.num_pixels = self.conf.get_int('train.num_pixels')
153 | self.total_pixels = self.train_dataset.total_pixels
154 | self.img_res = self.train_dataset.img_res
155 | self.n_batches = len(self.train_dataloader)
156 | self.plot_freq = self.conf.get_int('train.plot_freq')
157 | self.checkpoint_freq = self.conf.get_int('train.checkpoint_freq', default=100)
158 | self.split_n_pixels = self.conf.get_int('train.split_n_pixels', default=10000)
159 | self.plot_conf = self.conf.get_config('plot')
160 | self.backproject = BackprojectDepth(1, self.img_res[0], self.img_res[1]).cuda()
161 | self.n_sem = self.conf.get_int('model.implicit_network.d_out')
162 | assert self.n_sem == len(self.train_dataset.label_mapping)
163 |
164 | def save_checkpoints(self, epoch):
165 | torch.save(
166 | {"epoch": epoch, "model_state_dict": self.model.state_dict()},
167 | os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth"))
168 | torch.save(
169 | {"epoch": epoch, "model_state_dict": self.model.state_dict()},
170 | os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth"))
171 |
172 | torch.save(
173 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
174 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth"))
175 | torch.save(
176 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()},
177 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth"))
178 |
179 | torch.save(
180 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
181 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth"))
182 | torch.save(
183 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()},
184 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth"))
185 |
186 | def run(self):
187 | print("training...")
188 | if self.GPU_INDEX == 0 :
189 | self.writer = SummaryWriter(log_dir=os.path.join(self.plots_dir, 'logs'))
190 |
191 | self.iter_step = 0
192 | for epoch in range(self.start_epoch, self.nepochs + 1):
193 |
194 | if (self.GPU_INDEX == 0 and epoch % self.checkpoint_freq == 0) or (self.GPU_INDEX == 0 and epoch == self.nepochs):
195 | self.save_checkpoints(epoch)
196 |
197 | if (self.GPU_INDEX == 0 and self.do_vis and epoch % self.plot_freq == 0) or (self.GPU_INDEX == 0 and self.do_vis and epoch == self.nepochs):
198 | self.model.eval()
199 |
200 | self.train_dataset.change_sampling_idx(-1)
201 |
202 | indices, model_input, ground_truth = next(iter(self.plot_dataloader))
203 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
204 | model_input["uv"] = model_input["uv"].cuda()
205 | model_input['pose'] = model_input['pose'].cuda()
206 |
207 | split = utils.split_input(model_input, self.total_pixels, n_pixels=self.split_n_pixels)
208 | res = []
209 | for s in tqdm(split):
210 | out = self.model(s, indices)
211 | d = {'rgb_values': out['rgb_values'].detach(),
212 | 'normal_map': out['normal_map'].detach(),
213 | 'depth_values': out['depth_values'].detach(),
214 | 'semantic_values': out['semantic_values'].detach()}
215 | res.append(d)
216 |
217 | batch_size = ground_truth['rgb'].shape[0]
218 | model_outputs = utils.merge_output(res, self.total_pixels, batch_size)
219 | plot_data = self.get_plot_data(model_input, model_outputs, model_input['pose'], ground_truth['rgb'], ground_truth['normal'], ground_truth['depth'], ground_truth['instance_mask'])
220 |
221 | plot_mesh = True
222 |
223 | plt.plot_rico(
224 | self.model.module.implicit_network,
225 | indices,
226 | plot_data,
227 | self.plots_dir,
228 | epoch,
229 | self.img_res,
230 | plot_mesh,
231 | **self.plot_conf
232 | )
233 |
234 | self.model.train()
235 | self.train_dataset.change_sampling_idx(self.num_pixels)
236 |
237 | for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader):
238 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
239 | model_input["uv"] = model_input["uv"].cuda()
240 | model_input['pose'] = model_input['pose'].cuda()
241 |
242 | model_input['instance_mask'] = ground_truth["instance_mask"].cuda().reshape(-1).long()
243 |
244 | self.optimizer.zero_grad()
245 |
246 | model_outputs = self.model(model_input, indices, iter_step=self.iter_step)
247 |
248 | loss_output = self.loss(model_outputs, ground_truth, iter_ratio=self.iter_step / self.max_total_iters)
249 | loss = loss_output['loss']
250 | loss.backward()
251 | self.optimizer.step()
252 |
253 | psnr = rend_util.get_psnr(model_outputs['rgb_values'],
254 | ground_truth['rgb'].cuda().reshape(-1,3))
255 |
256 | self.iter_step += 1
257 |
258 | if self.GPU_INDEX == 0:
259 | if data_index % 25 == 0:
260 | head_str = '{0}_{1} [{2}] ({3}/{4}): '.format(self.expname, self.timestamp, epoch, data_index, self.n_batches)
261 | loss_print_str = ''
262 | for k, v in loss_output.items():
263 | loss_print_str = loss_print_str + '{0} = {1}, '.format(k, v.item())
264 | print_str = head_str + loss_print_str + 'psnr = {0}'.format(psnr.item())
265 | print(print_str)
266 |
267 | for k, v in loss_output.items():
268 | self.writer.add_scalar(f'Loss/{k}', v.item(), self.iter_step)
269 |
270 | self.writer.add_scalar('Statistics/s_value', self.model.module.get_s_value().item(), self.iter_step)
271 | self.writer.add_scalar('Statistics/psnr', psnr.item(), self.iter_step)
272 |
273 | self.train_dataset.change_sampling_idx(self.num_pixels)
274 | self.scheduler.step()
275 |
276 | self.save_checkpoints(epoch)
277 |
278 |
279 | def get_plot_data(self, model_input, model_outputs, pose, rgb_gt, normal_gt, depth_gt, semantic_gt):
280 | batch_size, num_samples, _ = rgb_gt.shape
281 |
282 | rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3)
283 | normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3)
284 | normal_map = (normal_map + 1.) / 2.
285 |
286 | depth_map = model_outputs['depth_values'].reshape(batch_size, num_samples)
287 | depth_gt = depth_gt.to(depth_map.device)
288 | scale, shift = compute_scale_and_shift(depth_map[..., None], depth_gt, depth_gt > 0.)
289 | depth_map = depth_map * scale + shift
290 |
291 | # save point cloud
292 | depth = depth_map.reshape(1, 1, self.img_res[0], self.img_res[1])
293 | pred_points = self.get_point_cloud(depth, model_input, model_outputs)
294 |
295 | gt_depth = depth_gt.reshape(1, 1, self.img_res[0], self.img_res[1])
296 | gt_points = self.get_point_cloud(gt_depth, model_input, model_outputs)
297 |
298 | # semantic map
299 | semantic_map = model_outputs['semantic_values'].argmax(dim=-1).reshape(batch_size, num_samples, 1)
300 | # in label mapping, 0 is bg idx and 0
301 | # for instance, first fg is 3 and 1
302 | # so when using argmax, the output will be label_mapping idx if correct
303 |
304 | plot_data = {
305 | 'rgb_gt': rgb_gt,
306 | 'normal_gt': (normal_gt + 1.)/ 2.,
307 | 'depth_gt': depth_gt,
308 | 'pose': pose,
309 | 'rgb_eval': rgb_eval,
310 | 'normal_map': normal_map,
311 | 'depth_map': depth_map,
312 | "pred_points": pred_points,
313 | "gt_points": gt_points,
314 | "semantic_map": semantic_map,
315 | "semantic_gt": semantic_gt,
316 | }
317 |
318 | return plot_data
319 |
320 | def get_point_cloud(self, depth, model_input, model_outputs):
321 | color = model_outputs["rgb_values"].reshape(-1, 3)
322 |
323 | K_inv = torch.inverse(model_input["intrinsics"][0])[None]
324 | points = self.backproject(depth, K_inv)[0, :3, :].permute(1, 0)
325 | points = torch.cat([points, color], dim=-1)
326 | return points.detach().cpu().numpy()
327 |
--------------------------------------------------------------------------------
/code/utils/general.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import time
7 | from torchvision import transforms
8 | import numpy as np
9 |
10 | def mkdir_ifnotexists(directory):
11 | if not os.path.exists(directory):
12 | os.mkdir(directory)
13 |
14 | def get_class(kls):
15 | parts = kls.split('.')
16 | module = ".".join(parts[:-1])
17 | m = __import__(module)
18 | for comp in parts[1:]:
19 | m = getattr(m, comp)
20 | return m
21 |
22 | def glob_imgs(path):
23 | imgs = []
24 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']:
25 | imgs.extend(glob(os.path.join(path, ext)))
26 | return imgs
27 |
28 | def split_input(model_input, total_pixels, n_pixels=10000):
29 | '''
30 | Split the input to fit Cuda memory for large resolution.
31 | Can decrease the value of n_pixels in case of cuda out of memory error.
32 | '''
33 | split = []
34 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
35 | data = model_input.copy()
36 | data['uv'] = torch.index_select(model_input['uv'], 1, indx)
37 | if 'object_mask' in data:
38 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
39 | if 'depth' in data:
40 | data['depth'] = torch.index_select(model_input['depth'], 1, indx)
41 | if 'instance_mask' in data:
42 | data['instance_mask'] = torch.index_select(model_input['instance_mask'], 1, indx)
43 | split.append(data)
44 | return split
45 |
46 | def merge_output(res, total_pixels, batch_size):
47 | ''' Merge the split output. '''
48 |
49 | model_outputs = {}
50 | for entry in res[0]:
51 | if res[0][entry] is None:
52 | continue
53 | if len(res[0][entry].shape) == 1:
54 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
55 | 1).reshape(batch_size * total_pixels)
56 | else:
57 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
58 | 1).reshape(batch_size * total_pixels, -1)
59 |
60 | return model_outputs
61 |
62 | def concat_home_dir(path):
63 | return os.path.join(os.environ['HOME'],'data',path)
64 |
65 | def get_time():
66 | torch.cuda.synchronize()
67 | return time.time()
68 |
69 | trans_topil = transforms.ToPILImage()
70 |
71 |
72 | class BackprojectDepth(nn.Module):
73 | """Layer to transform a depth image into a point cloud
74 | """
75 | def __init__(self, batch_size, height, width):
76 | super(BackprojectDepth, self).__init__()
77 |
78 | self.batch_size = batch_size
79 | self.height = height
80 | self.width = width
81 |
82 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
83 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
84 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
85 | requires_grad=False)
86 |
87 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
88 | requires_grad=False)
89 |
90 | self.pix_coords = torch.unsqueeze(torch.stack(
91 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
92 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
93 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
94 | requires_grad=False)
95 |
96 | def forward(self, depth, inv_K):
97 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
98 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points
99 | cam_points = torch.cat([cam_points, self.ones], 1)
100 | return cam_points
101 |
--------------------------------------------------------------------------------
/code/utils/plots.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from skimage import measure
4 | import torchvision
5 | import trimesh
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 | import cv2
9 |
10 | from utils import rend_util
11 | from utils.general import trans_topil
12 |
13 |
14 | def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_mesh, plot_nimgs, resolution, grid_boundary, level=0):
15 |
16 | if plot_data is not None:
17 | cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose'])
18 |
19 | # plot images
20 | plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, indices)
21 |
22 | # plot normal maps
23 | plot_normal_maps(plot_data['normal_map'], plot_data['normal_gt'], path, epoch, plot_nimgs, img_res, indices)
24 |
25 | # plot depth maps
26 | plot_depth_maps(plot_data['depth_map'], plot_data['depth_gt'], path, epoch, plot_nimgs, img_res, indices)
27 |
28 | # concat output images to single large image
29 | images = []
30 | for name in ["rendering", "depth", "normal"]:
31 | images.append(cv2.imread('{0}/{1}_{2}_{3}.png'.format(path, name, epoch, indices[0])))
32 |
33 | images = np.concatenate(images, axis=1)
34 | cv2.imwrite('{0}/merge_{1}_{2}.png'.format(path, epoch, indices[0]), images)
35 |
36 | if plot_mesh:
37 | surface_traces = get_surface_sliding(send_path=path,
38 | epoch=epoch,
39 | sdf=lambda x: implicit_network(x)[:, 0],
40 | resolution=resolution,
41 | grid_boundary=grid_boundary,
42 | level=level
43 | )
44 |
45 |
46 | def plot_rico(implicit_network, indices, plot_data, path, epoch, img_res, plot_mesh, plot_nimgs, resolution, grid_boundary, level=0):
47 |
48 | if plot_data is not None:
49 | cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose'])
50 |
51 | # plot images
52 | plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, indices)
53 |
54 | # plot normal maps
55 | plot_normal_maps(plot_data['normal_map'], plot_data['normal_gt'], path, epoch, plot_nimgs, img_res, indices)
56 |
57 | # plot depth maps
58 | plot_depth_maps(plot_data['depth_map'], plot_data['depth_gt'], path, epoch, plot_nimgs, img_res, indices)
59 |
60 | # plot semantic maps
61 | plot_seg_images(plot_data['semantic_map'], plot_data['semantic_gt'], path, epoch, plot_nimgs, img_res, indices)
62 |
63 | # concat output images to single large image
64 | images = []
65 | for name in ["rendering", "semantic", "depth", "normal"]:
66 | images.append(cv2.imread('{0}/{1}_{2}_{3}.png'.format(path, name, epoch, indices[0])))
67 |
68 | images = np.concatenate(images, axis=1)
69 | cv2.imwrite('{0}/merge_{1}_{2}.png'.format(path, epoch, indices[0]), images)
70 |
71 | if plot_mesh:
72 | sem_num = implicit_network.d_out
73 | f = torch.nn.MaxPool1d(sem_num)
74 | for indx in range(sem_num):
75 | # plot each object and background and save in different files
76 | _ = get_surface_sliding(
77 | send_path=[path, str(indx)],
78 | epoch=epoch,
79 | sdf = lambda x: implicit_network(x)[:, indx],
80 | resolution=resolution,
81 | grid_boundary=grid_boundary,
82 | level=level
83 | )
84 |
85 | # plot the overall scene
86 | surface_traces = get_surface_sliding(
87 | send_path=[path, 'all'],
88 | epoch=epoch,
89 | sdf=lambda x: -f(-implicit_network(x)[:, :sem_num].unsqueeze(1)).squeeze(-1).squeeze(-1),
90 | resolution=resolution,
91 | grid_boundary=grid_boundary,
92 | level=level
93 | )
94 |
95 |
96 | avg_pool_3d = torch.nn.AvgPool3d(2, stride=2)
97 | upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
98 |
99 | @torch.no_grad()
100 | def get_surface_sliding(send_path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0):
101 | if isinstance(send_path, list):
102 | path = send_path[0]
103 | mesh_name = send_path[1]
104 | else:
105 | path = send_path
106 | mesh_name = ''
107 |
108 | resN = resolution
109 | cropN = resolution
110 | level = 0
111 | N = resN // cropN
112 |
113 | grid_min = [grid_boundary[0], grid_boundary[0], grid_boundary[0]]
114 | grid_max = [grid_boundary[1], grid_boundary[1], grid_boundary[1]]
115 |
116 | xs = np.linspace(grid_min[0], grid_max[0], N+1)
117 | ys = np.linspace(grid_min[1], grid_max[1], N+1)
118 | zs = np.linspace(grid_min[2], grid_max[2], N+1)
119 |
120 | print(xs)
121 | print(ys)
122 | print(zs)
123 | meshes = []
124 | for i in range(N):
125 | for j in range(N):
126 | for k in range(N):
127 | print(i, j, k)
128 | x_min, x_max = xs[i], xs[i+1]
129 | y_min, y_max = ys[j], ys[j+1]
130 | z_min, z_max = zs[k], zs[k+1]
131 |
132 | x = np.linspace(x_min, x_max, cropN)
133 | y = np.linspace(y_min, y_max, cropN)
134 | z = np.linspace(z_min, z_max, cropN)
135 |
136 | xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
137 | points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
138 |
139 | def evaluate(points):
140 | z = []
141 | for _, pnts in enumerate(torch.split(points, 100000, dim=0)):
142 | z.append(sdf(pnts))
143 | z = torch.cat(z, axis=0)
144 | return z
145 |
146 | # construct point pyramids
147 | points = points.reshape(cropN, cropN, cropN, 3).permute(3, 0, 1, 2)
148 | points_pyramid = [points]
149 | for _ in range(3):
150 | points = avg_pool_3d(points[None])[0]
151 | points_pyramid.append(points)
152 | points_pyramid = points_pyramid[::-1]
153 |
154 | # evalute pyramid with mask
155 | mask = None
156 | threshold = 2 * (x_max - x_min)/cropN * 8
157 | for pid, pts in enumerate(points_pyramid):
158 | coarse_N = pts.shape[-1]
159 | pts = pts.reshape(3, -1).permute(1, 0).contiguous()
160 |
161 | if mask is None:
162 | pts_sdf = evaluate(pts)
163 | else:
164 | mask = mask.reshape(-1)
165 | pts_to_eval = pts[mask]
166 | #import pdb; pdb.set_trace()
167 | if pts_to_eval.shape[0] > 0:
168 | pts_sdf_eval = evaluate(pts_to_eval.contiguous())
169 | pts_sdf[mask] = pts_sdf_eval
170 | print("ratio", pts_to_eval.shape[0] / pts.shape[0])
171 |
172 | if pid < 3:
173 | # update mask
174 | mask = torch.abs(pts_sdf) < threshold
175 | mask = mask.reshape(coarse_N, coarse_N, coarse_N)[None, None]
176 | mask = upsample(mask.float()).bool()
177 |
178 | pts_sdf = pts_sdf.reshape(coarse_N, coarse_N, coarse_N)[None, None]
179 | pts_sdf = upsample(pts_sdf)
180 | pts_sdf = pts_sdf.reshape(-1)
181 |
182 | threshold /= 2.
183 |
184 | z = pts_sdf.detach().cpu().numpy()
185 |
186 | if (not (np.min(z) > level or np.max(z) < level)):
187 | z = z.astype(np.float32)
188 | verts, faces, normals, values = measure.marching_cubes(
189 | volume=z.reshape(cropN, cropN, cropN), #.transpose([1, 0, 2]),
190 | level=level,
191 | spacing=(
192 | (x_max - x_min)/(cropN-1),
193 | (y_max - y_min)/(cropN-1),
194 | (z_max - z_min)/(cropN-1) ))
195 | print(np.array([x_min, y_min, z_min]))
196 | print(verts.min(), verts.max())
197 | verts = verts + np.array([x_min, y_min, z_min])
198 | print(verts.min(), verts.max())
199 |
200 | meshcrop = trimesh.Trimesh(verts, faces, normals)
201 | #meshcrop.export(f"{i}_{j}_{k}.ply")
202 | meshes.append(meshcrop)
203 | try:
204 | combined = trimesh.util.concatenate(meshes)
205 |
206 | combined.export('{0}/surface_{1}_{2}.ply'.format(path, epoch, mesh_name), 'ply')
207 | except:
208 | print('no mesh')
209 |
210 | def get_3D_scatter_trace(points, name='', size=3, caption=None):
211 | assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
212 | assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "
213 |
214 | trace = go.Scatter3d(
215 | x=points[:, 0].cpu(),
216 | y=points[:, 1].cpu(),
217 | z=points[:, 2].cpu(),
218 | mode='markers',
219 | name=name,
220 | marker=dict(
221 | size=size,
222 | line=dict(
223 | width=2,
224 | ),
225 | opacity=1.0,
226 | ), text=caption)
227 |
228 | return trace
229 |
230 |
231 | def get_3D_quiver_trace(points, directions, color='#bd1540', name=''):
232 | assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped "
233 | assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped "
234 | assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped "
235 | assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped "
236 |
237 | trace = go.Cone(
238 | name=name,
239 | x=points[:, 0].cpu(),
240 | y=points[:, 1].cpu(),
241 | z=points[:, 2].cpu(),
242 | u=directions[:, 0].cpu(),
243 | v=directions[:, 1].cpu(),
244 | w=directions[:, 2].cpu(),
245 | sizemode='absolute',
246 | sizeref=0.125,
247 | showscale=False,
248 | colorscale=[[0, color], [1, color]],
249 | anchor="tail"
250 | )
251 |
252 | return trace
253 |
254 |
255 | def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0):
256 | grid = get_grid_uniform(resolution, grid_boundary)
257 | points = grid['grid_points']
258 |
259 | z = []
260 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
261 | z.append(sdf(pnts.cuda()).detach().cpu().numpy())
262 | z = np.concatenate(z, axis=0)
263 |
264 | if (not (np.min(z) > level or np.max(z) < level)):
265 |
266 | z = z.astype(np.float32)
267 |
268 | verts, faces, normals, values = measure.marching_cubes(
269 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
270 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
271 | level=level,
272 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
273 | grid['xyz'][0][2] - grid['xyz'][0][1],
274 | grid['xyz'][0][2] - grid['xyz'][0][1]))
275 |
276 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
277 | '''
278 | I, J, K = faces.transpose()
279 |
280 | traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],
281 | i=I, j=J, k=K, name='implicit_surface',
282 | color='#ffffff', opacity=1.0, flatshading=False,
283 | lighting=dict(diffuse=1, ambient=0, specular=0),
284 | lightposition=dict(x=0, y=0, z=-1), showlegend=True)]
285 | '''
286 | meshexport = trimesh.Trimesh(verts, faces, normals)
287 | meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply')
288 |
289 | if return_mesh:
290 | return meshexport
291 | #return traces
292 | return None
293 |
294 | def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, 2.0], level=0, take_components=True):
295 | # get low res mesh to sample point cloud
296 | grid = get_grid_uniform(100, grid_boundary)
297 | z = []
298 | points = grid['grid_points']
299 |
300 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
301 | z.append(sdf(pnts).detach().cpu().numpy())
302 | z = np.concatenate(z, axis=0)
303 |
304 | z = z.astype(np.float32)
305 |
306 | verts, faces, normals, values = measure.marching_cubes(
307 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
308 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
309 | level=level,
310 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
311 | grid['xyz'][0][2] - grid['xyz'][0][1],
312 | grid['xyz'][0][2] - grid['xyz'][0][1]))
313 |
314 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
315 |
316 | mesh_low_res = trimesh.Trimesh(verts, faces, normals)
317 | if take_components:
318 | components = mesh_low_res.split(only_watertight=False)
319 | areas = np.array([c.area for c in components], dtype=np.float)
320 | mesh_low_res = components[areas.argmax()]
321 |
322 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
323 | recon_pc = torch.from_numpy(recon_pc).float().cuda()
324 |
325 | # Center and align the recon pc
326 | s_mean = recon_pc.mean(dim=0)
327 | s_cov = recon_pc - s_mean
328 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
329 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]
330 | if torch.det(vecs) < 0:
331 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)
332 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
333 | (recon_pc - s_mean).unsqueeze(-1)).squeeze()
334 |
335 | grid_aligned = get_grid(helper.cpu(), resolution)
336 |
337 | grid_points = grid_aligned['grid_points']
338 |
339 | g = []
340 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):
341 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
342 | pnts.unsqueeze(-1)).squeeze() + s_mean)
343 | grid_points = torch.cat(g, dim=0)
344 |
345 | # MC to new grid
346 | points = grid_points
347 | z = []
348 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
349 | z.append(sdf(pnts).detach().cpu().numpy())
350 | z = np.concatenate(z, axis=0)
351 |
352 | meshexport = None
353 | if (not (np.min(z) > level or np.max(z) < level)):
354 |
355 | z = z.astype(np.float32)
356 |
357 | verts, faces, normals, values = measure.marching_cubes(
358 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],
359 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),
360 | level=level,
361 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
362 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
363 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))
364 |
365 | verts = torch.from_numpy(verts).cuda().float()
366 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
367 | verts.unsqueeze(-1)).squeeze()
368 | verts = (verts + grid_points[0]).cpu().numpy()
369 |
370 | meshexport = trimesh.Trimesh(verts, faces, normals)
371 |
372 | return meshexport
373 |
374 |
375 | def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, higher_res=False):
376 | grid_params = grid_params * [[1.5], [1.0]]
377 |
378 | # params = PLOT_DICT[scan_id]
379 | input_min = torch.tensor(grid_params[0]).float()
380 | input_max = torch.tensor(grid_params[1]).float()
381 |
382 | if higher_res:
383 | # get low res mesh to sample point cloud
384 | grid = get_grid(None, 100, input_min=input_min, input_max=input_max, eps=0.0)
385 | z = []
386 | points = grid['grid_points']
387 |
388 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
389 | z.append(sdf(pnts).detach().cpu().numpy())
390 | z = np.concatenate(z, axis=0)
391 |
392 | z = z.astype(np.float32)
393 |
394 | verts, faces, normals, values = measure.marching_cubes(
395 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
396 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
397 | level=level,
398 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
399 | grid['xyz'][0][2] - grid['xyz'][0][1],
400 | grid['xyz'][0][2] - grid['xyz'][0][1]))
401 |
402 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
403 |
404 | mesh_low_res = trimesh.Trimesh(verts, faces, normals)
405 | components = mesh_low_res.split(only_watertight=False)
406 | areas = np.array([c.area for c in components], dtype=np.float)
407 | mesh_low_res = components[areas.argmax()]
408 |
409 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
410 | recon_pc = torch.from_numpy(recon_pc).float().cuda()
411 |
412 | # Center and align the recon pc
413 | s_mean = recon_pc.mean(dim=0)
414 | s_cov = recon_pc - s_mean
415 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
416 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]
417 | if torch.det(vecs) < 0:
418 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)
419 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
420 | (recon_pc - s_mean).unsqueeze(-1)).squeeze()
421 |
422 | grid_aligned = get_grid(helper.cpu(), resolution, eps=0.01)
423 | else:
424 | grid_aligned = get_grid(None, resolution, input_min=input_min, input_max=input_max, eps=0.0)
425 |
426 | grid_points = grid_aligned['grid_points']
427 |
428 | if higher_res:
429 | g = []
430 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):
431 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
432 | pnts.unsqueeze(-1)).squeeze() + s_mean)
433 | grid_points = torch.cat(g, dim=0)
434 |
435 | # MC to new grid
436 | points = grid_points
437 | z = []
438 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)):
439 | z.append(sdf(pnts).detach().cpu().numpy())
440 | z = np.concatenate(z, axis=0)
441 |
442 | meshexport = None
443 | if (not (np.min(z) > level or np.max(z) < level)):
444 |
445 | z = z.astype(np.float32)
446 |
447 | verts, faces, normals, values = measure.marching_cubes(
448 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],
449 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),
450 | level=level,
451 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
452 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],
453 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))
454 |
455 | if higher_res:
456 | verts = torch.from_numpy(verts).cuda().float()
457 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
458 | verts.unsqueeze(-1)).squeeze()
459 | verts = (verts + grid_points[0]).cpu().numpy()
460 | else:
461 | verts = verts + np.array([grid_aligned['xyz'][0][0], grid_aligned['xyz'][1][0], grid_aligned['xyz'][2][0]])
462 |
463 | meshexport = trimesh.Trimesh(verts, faces, normals)
464 |
465 | # CUTTING MESH ACCORDING TO THE BOUNDING BOX
466 | if higher_res:
467 | bb = grid_params
468 | transformation = np.eye(4)
469 | transformation[:3, 3] = (bb[1,:] + bb[0,:])/2.
470 | bounding_box = trimesh.creation.box(extents=bb[1,:] - bb[0,:], transform=transformation)
471 |
472 | meshexport = meshexport.slice_plane(bounding_box.facets_origin, -bounding_box.facets_normal)
473 |
474 | return meshexport
475 |
476 | def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]):
477 | x = np.linspace(grid_boundary[0], grid_boundary[1], resolution)
478 | y = x
479 | z = x
480 |
481 | xx, yy, zz = np.meshgrid(x, y, z)
482 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)
483 |
484 | return {"grid_points": grid_points,
485 | "shortest_axis_length": 2.0,
486 | "xyz": [x, y, z],
487 | "shortest_axis_index": 0}
488 |
489 | def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1):
490 | if input_min is None or input_max is None:
491 | input_min = torch.min(points, dim=0)[0].squeeze().numpy()
492 | input_max = torch.max(points, dim=0)[0].squeeze().numpy()
493 |
494 | bounding_box = input_max - input_min
495 | shortest_axis = np.argmin(bounding_box)
496 | if (shortest_axis == 0):
497 | x = np.linspace(input_min[shortest_axis] - eps,
498 | input_max[shortest_axis] + eps, resolution)
499 | length = np.max(x) - np.min(x)
500 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
501 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))
502 | elif (shortest_axis == 1):
503 | y = np.linspace(input_min[shortest_axis] - eps,
504 | input_max[shortest_axis] + eps, resolution)
505 | length = np.max(y) - np.min(y)
506 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
507 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))
508 | elif (shortest_axis == 2):
509 | z = np.linspace(input_min[shortest_axis] - eps,
510 | input_max[shortest_axis] + eps, resolution)
511 | length = np.max(z) - np.min(z)
512 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
513 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))
514 |
515 | xx, yy, zz = np.meshgrid(x, y, z)
516 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
517 | return {"grid_points": grid_points,
518 | "shortest_axis_length": length,
519 | "xyz": [x, y, z],
520 | "shortest_axis_index": shortest_axis}
521 |
522 |
523 | def plot_normal_maps(normal_maps, ground_true, path, epoch, plot_nrow, img_res, indices):
524 | ground_true = ground_true.cuda()
525 | normal_maps = torch.cat((normal_maps, ground_true), dim=0)
526 | normal_maps_plot = lin2img(normal_maps, img_res)
527 |
528 | tensor = torchvision.utils.make_grid(normal_maps_plot,
529 | scale_each=False,
530 | normalize=False,
531 | nrow=plot_nrow).cpu().detach().numpy()
532 | tensor = tensor.transpose(1, 2, 0)
533 | scale_factor = 255
534 | tensor = (tensor * scale_factor).astype(np.uint8)
535 |
536 | img = Image.fromarray(tensor)
537 | img.save('{0}/normal_{1}_{2}.png'.format(path, epoch, indices[0]))
538 |
539 | #import pdb; pdb.set_trace()
540 | #trans_topil(normal_maps_plot[0, :, :, 260:260+680]).save('{0}/2normal_{1}.png'.format(path, epoch))
541 |
542 |
543 | def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res, indices, exposure=False):
544 | ground_true = ground_true.cuda()
545 |
546 | output_vs_gt = torch.cat((rgb_points, ground_true), dim=0)
547 | output_vs_gt_plot = lin2img(output_vs_gt, img_res)
548 |
549 | tensor = torchvision.utils.make_grid(output_vs_gt_plot,
550 | scale_each=False,
551 | normalize=False,
552 | nrow=plot_nrow).cpu().detach().numpy()
553 |
554 | tensor = tensor.transpose(1, 2, 0)
555 | scale_factor = 255
556 | tensor = (tensor * scale_factor).astype(np.uint8)
557 |
558 | img = Image.fromarray(tensor)
559 | if exposure:
560 | img.save('{0}/exposure_{1}_{2}.png'.format(path, epoch, indices[0]))
561 | else:
562 | img.save('{0}/rendering_{1}_{2}.png'.format(path, epoch, indices[0]))
563 |
564 |
565 | def plot_depth_maps(depth_maps, ground_true, path, epoch, plot_nrow, img_res, indices):
566 | ground_true = ground_true.cuda()
567 | depth_maps = torch.cat((depth_maps[..., None], ground_true), dim=0)
568 | depth_maps_plot = lin2img(depth_maps, img_res)
569 | depth_maps_plot = depth_maps_plot.expand(-1, 3, -1, -1)
570 |
571 | tensor = torchvision.utils.make_grid(depth_maps_plot,
572 | scale_each=False,
573 | normalize=False,
574 | nrow=plot_nrow).cpu().detach().numpy()
575 | tensor = tensor.transpose(1, 2, 0)
576 |
577 | save_path = '{0}/depth_{1}_{2}.png'.format(path, epoch, indices[0])
578 |
579 | plt.imsave(save_path, tensor[:, :, 0], cmap='viridis')
580 |
581 |
582 | def colored_data(x, cmap='jet', d_min=None, d_max=None):
583 | if d_min is None:
584 | d_min = np.min(x)
585 | if d_max is None:
586 | d_max = np.max(x)
587 | x_relative = (x - d_min) / (d_max - d_min)
588 | cmap_ = plt.cm.get_cmap(cmap)
589 | return (255 * cmap_(x_relative)[:,:,:3]).astype(np.uint8) # H, W, C
590 |
591 |
592 | def plot_seg_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res, indices):
593 | ground_true = ground_true.cuda()
594 |
595 | output_vs_gt = torch.cat((rgb_points, ground_true), dim=0)
596 | output_vs_gt_plot = lin2img(output_vs_gt, img_res)
597 |
598 | tensor = torchvision.utils.make_grid(output_vs_gt_plot,
599 | scale_each=False,
600 | normalize=False,
601 | nrow=plot_nrow).cpu().detach().numpy()
602 | tensor = tensor.transpose(1, 2, 0)[:, :, 0]
603 | tensor = colored_data(tensor)
604 |
605 | img = Image.fromarray(tensor)
606 | img.save('{0}/semantic_{1}_{2}.png'.format(path, epoch, indices[0]))
607 |
608 |
609 | def lin2img(tensor, img_res):
610 | batch_size, num_samples, channels = tensor.shape
611 | return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1])
612 |
--------------------------------------------------------------------------------
/code/utils/rend_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio.v2 as imageio
3 | import skimage
4 | import cv2
5 | import torch
6 | from torch.nn import functional as F
7 |
8 |
9 | def get_psnr(img1, img2, normalize_rgb=False):
10 | if normalize_rgb: # [-1,1] --> [0,1]
11 | img1 = (img1 + 1.) / 2.
12 | img2 = (img2 + 1. ) / 2.
13 |
14 | mse = torch.mean((img1 - img2) ** 2)
15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())
16 |
17 | return psnr
18 |
19 |
20 | def load_rgb(path, normalize_rgb = False):
21 | img = imageio.imread(path)
22 | img = skimage.img_as_float32(img)
23 |
24 | if normalize_rgb: # [-1,1] --> [0,1]
25 | img -= 0.5
26 | img *= 2.
27 | img = img.transpose(2, 0, 1)
28 | return img
29 |
30 |
31 | def load_K_Rt_from_P(filename, P=None):
32 | if P is None:
33 | lines = open(filename).read().splitlines()
34 | if len(lines) == 4:
35 | lines = lines[1:]
36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
37 | P = np.asarray(lines).astype(np.float32).squeeze()
38 |
39 | out = cv2.decomposeProjectionMatrix(P)
40 | K = out[0]
41 | R = out[1]
42 | t = out[2]
43 |
44 | K = K/K[2,2]
45 | intrinsics = np.eye(4)
46 | intrinsics[:3, :3] = K
47 |
48 | pose = np.eye(4, dtype=np.float32)
49 | pose[:3, :3] = R.transpose()
50 | pose[:3,3] = (t[:3] / t[3])[:,0]
51 |
52 | return intrinsics, pose
53 |
54 |
55 | def get_camera_params(uv, pose, intrinsics):
56 | if pose.shape[1] == 7: #In case of quaternion vector representation
57 | cam_loc = pose[:, 4:]
58 | R = quat_to_rot(pose[:,:4])
59 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
60 | p[:, :3, :3] = R
61 | p[:, :3, 3] = cam_loc
62 | else: # In case of pose matrix representation
63 | cam_loc = pose[:, :3, 3]
64 | p = pose
65 |
66 | batch_size, num_samples, _ = uv.shape
67 |
68 | depth = torch.ones((batch_size, num_samples)).cuda()
69 | x_cam = uv[:, :, 0].view(batch_size, -1)
70 | y_cam = uv[:, :, 1].view(batch_size, -1)
71 | z_cam = depth.view(batch_size, -1)
72 |
73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
74 |
75 | # permute for batch matrix product
76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
77 |
78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
79 | ray_dirs = world_coords - cam_loc[:, None, :]
80 | ray_dirs = F.normalize(ray_dirs, dim=2)
81 |
82 | return ray_dirs, cam_loc
83 |
84 |
85 | def get_camera_for_plot(pose):
86 | if pose.shape[1] == 7: #In case of quaternion vector representation
87 | cam_loc = pose[:, 4:].detach()
88 | R = quat_to_rot(pose[:,:4].detach())
89 | else: # In case of pose matrix representation
90 | cam_loc = pose[:, :3, 3]
91 | R = pose[:, :3, :3]
92 | cam_dir = R[:, :3, 2]
93 | return cam_loc, cam_dir
94 |
95 |
96 | def lift(x, y, z, intrinsics):
97 | # parse intrinsics
98 | intrinsics = intrinsics.cuda()
99 | fx = intrinsics[:, 0, 0]
100 | fy = intrinsics[:, 1, 1]
101 | cx = intrinsics[:, 0, 2]
102 | cy = intrinsics[:, 1, 2]
103 | sk = intrinsics[:, 0, 1]
104 |
105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
107 |
108 | # homogeneous
109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
110 |
111 |
112 | def quat_to_rot(q):
113 | batch_size, _ = q.shape
114 | q = F.normalize(q, dim=1)
115 | R = torch.ones((batch_size, 3,3)).cuda()
116 | qr=q[:,0]
117 | qi = q[:, 1]
118 | qj = q[:, 2]
119 | qk = q[:, 3]
120 | R[:, 0, 0]=1-2 * (qj**2 + qk**2)
121 | R[:, 0, 1] = 2 * (qj *qi -qk*qr)
122 | R[:, 0, 2] = 2 * (qi * qk + qr * qj)
123 | R[:, 1, 0] = 2 * (qj * qi + qk * qr)
124 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
125 | R[:, 1, 2] = 2*(qj*qk - qi*qr)
126 | R[:, 2, 0] = 2 * (qk * qi-qj * qr)
127 | R[:, 2, 1] = 2 * (qj*qk + qi*qr)
128 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
129 | return R
130 |
131 |
132 | def rot_to_quat(R):
133 | batch_size, _,_ = R.shape
134 | q = torch.ones((batch_size, 4)).cuda()
135 |
136 | R00 = R[:, 0,0]
137 | R01 = R[:, 0, 1]
138 | R02 = R[:, 0, 2]
139 | R10 = R[:, 1, 0]
140 | R11 = R[:, 1, 1]
141 | R12 = R[:, 1, 2]
142 | R20 = R[:, 2, 0]
143 | R21 = R[:, 2, 1]
144 | R22 = R[:, 2, 2]
145 |
146 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
147 | q[:, 1]=(R21-R12)/(4*q[:,0])
148 | q[:, 2] = (R02 - R20) / (4 * q[:, 0])
149 | q[:, 3] = (R10 - R01) / (4 * q[:, 0])
150 | return q
151 |
152 |
153 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
154 | # Input: n_rays x 3 ; n_rays x 3
155 | # Output: n_rays x 1, n_rays x 1 (close and far)
156 |
157 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),
158 | cam_loc.view(-1, 3, 1)).squeeze(-1)
159 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)
160 |
161 | # sanity check
162 | if (under_sqrt <= 0).sum() > 0:
163 | print('BOUNDING SPHERE PROBLEM!')
164 | exit()
165 |
166 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
167 | sphere_intersections = sphere_intersections.clamp_min(0.0)
168 |
169 | return sphere_intersections
170 |
171 |
172 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5):
173 | # device = weights.get_device()
174 | device = weights.device
175 | # Get pdf
176 | weights = weights + 1e-5 # prevent nans
177 | pdf = weights / torch.sum(weights, -1, keepdim=True)
178 | cdf = torch.cumsum(pdf, -1)
179 | cdf = torch.cat(
180 | [torch.zeros_like(cdf[..., :1], device=device), cdf], -1
181 | ) # (batch, len(bins))
182 |
183 | # Take uniform samples
184 | if det:
185 | u = torch.linspace(0.0, 1.0, steps=N_importance, device=device)
186 | u = u.expand(list(cdf.shape[:-1]) + [N_importance])
187 | else:
188 | u = torch.rand(list(cdf.shape[:-1]) + [N_importance], device=device)
189 | u = u.contiguous()
190 |
191 | # Invert CDF
192 | inds = torch.searchsorted(cdf.detach(), u, right=False)
193 |
194 | below = torch.clamp_min(inds-1, 0)
195 | above = torch.clamp_max(inds, cdf.shape[-1]-1)
196 | # (batch, N_importance, 2) ==> (B, batch, N_importance, 2)
197 | inds_g = torch.stack([below, above], -1)
198 |
199 | matched_shape = [*inds_g.shape[:-1], cdf.shape[-1]] # fix prefix shape
200 |
201 | cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), -1, inds_g)
202 | bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), -1, inds_g) # fix prefix shape
203 |
204 | denom = cdf_g[..., 1] - cdf_g[..., 0]
205 | denom[denom 0.)
33 | depth_map = depth_map * scale + shift
34 |
35 | # save point cloud
36 | depth = depth_map.reshape(1, 1, 384, 384)
37 | # pred_points = get_point_cloud(depth, model_input, model_outputs)
38 |
39 | gt_depth = depth_gt.reshape(1, 1, 384, 384)
40 | # gt_points = get_point_cloud(gt_depth, model_input, model_outputs)
41 |
42 | # semantic map
43 | semantic_map = model_outputs['semantic_values'].argmax(dim=-1).reshape(batch_size, num_samples, 1)
44 | # in label mapping, 0 is bg idx and 0
45 | # for instance, first fg is 3 and 1
46 | # so when using argmax, the output will be label_mapping idx if correct
47 |
48 | plot_data = {
49 | 'rgb_gt': rgb_gt,
50 | 'normal_gt': (normal_gt + 1.)/ 2.,
51 | 'depth_gt': depth_gt,
52 | 'pose': pose,
53 | 'rgb_eval': rgb_eval,
54 | 'normal_map': normal_map,
55 | 'depth_map': depth_map,
56 | # "pred_points": pred_points,
57 | # "gt_points": gt_points,
58 | "semantic_map": semantic_map,
59 | "semantic_gt": semantic_gt,
60 | }
61 |
62 | return plot_data
63 |
64 | def get_sdf_vals_edit(pts, model, idx, edit_param, edit_type):
65 | with torch.no_grad():
66 | sdf_original = model.implicit_network.forward(pts)[:,:model.implicit_network.d_out] # [N_pts, K]
67 |
68 | if edit_type == 'translate':
69 | edit_pts = pts - edit_param
70 |
71 | sdf_edit = model.implicit_network.forward(edit_pts)[:,:model.implicit_network.d_out] # [N_pts, K]
72 |
73 | sdf_original[:, idx] = sdf_original[:, idx] * 0. + sdf_edit[:, idx]
74 |
75 | sdf = sdf_original
76 |
77 | sdf = -model.implicit_network.pool(-sdf.unsqueeze(1)).squeeze(-1) # get the minium value of sdf if bound apply before min
78 | return sdf
79 |
80 | def neus_sample_edit(cam_loc, ray_dirs, model, idx, edit_param, edit_type):
81 | device = cam_loc.device
82 | perturb = False
83 | _, far = model.near_far_from_cube(cam_loc, ray_dirs, bound=model.scene_bounding_sphere)
84 | near = model.near * torch.ones(ray_dirs.shape[0], 1).cuda()
85 |
86 | _t = torch.linspace(0, 1, model.N_samples).float().to(device)
87 | z_vals = near * (1 - _t) + far * _t
88 |
89 | with torch.no_grad():
90 | _z = z_vals # [N, 64]
91 |
92 | # follow the objsdf setting and use min sdf for sample
93 | _pts = cam_loc.unsqueeze(-2) + _z.unsqueeze(-1) * ray_dirs.unsqueeze(-2)
94 | N_rays, N_steps = _pts.shape[0], _pts.shape[1]
95 |
96 | _sdf = get_sdf_vals_edit(_pts.reshape(-1, 3), model, idx, edit_param, edit_type)
97 |
98 | _sdf = _sdf.reshape(N_rays, N_steps)
99 |
100 | for i in range(model.N_upsample_iters):
101 | prev_sdf, next_sdf = _sdf[..., :-1], _sdf[..., 1:]
102 | prev_z_vals, next_z_vals = _z[..., :-1], _z[..., 1:]
103 | mid_sdf = (prev_sdf + next_sdf) * 0.5
104 | dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5)
105 | prev_dot_val = torch.cat([torch.zeros_like(dot_val[..., :1], device=device), dot_val[..., :-1]], dim=-1)
106 | dot_val = torch.stack([prev_dot_val, dot_val], dim=-1)
107 | dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False)
108 | dot_val = dot_val.clamp(-10.0, 0.0)
109 |
110 | dist = (next_z_vals - prev_z_vals)
111 | prev_esti_sdf = mid_sdf - dot_val * dist * 0.5
112 | next_esti_sdf = mid_sdf + dot_val * dist * 0.5
113 |
114 | prev_cdf = cdf_Phi_s(prev_esti_sdf, 64 * (2**i))
115 | next_cdf = cdf_Phi_s(next_esti_sdf, 64 * (2**i))
116 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5)
117 | _w = alpha_to_w(alpha)
118 | z_fine = rend_util.sample_pdf(_z, _w, model.N_samples_extra // model.N_upsample_iters, det=not perturb)
119 | _z = torch.cat([_z, z_fine], dim=-1)
120 |
121 | _pts_fine = cam_loc.unsqueeze(-2) + z_fine.unsqueeze(-1) * ray_dirs.unsqueeze(-2)
122 | N_rays, N_steps_fine = _pts_fine.shape[0], _pts_fine.shape[1]
123 |
124 | sdf_fine = get_sdf_vals_edit(_pts_fine.reshape(-1, 3), model, idx, edit_param, edit_type)
125 |
126 | sdf_fine = sdf_fine.reshape(N_rays, N_steps_fine)
127 | _sdf = torch.cat([_sdf, sdf_fine], dim=-1)
128 | _z, z_sort_indices = torch.sort(_z, dim=-1)
129 |
130 | _sdf = torch.gather(_sdf, 1, z_sort_indices)
131 |
132 | z_all = _z
133 |
134 | return z_all
135 |
136 | def get_sdf_vals_and_sdfs_edit(pts, model, idx, edit_param, edit_type):
137 | with torch.no_grad():
138 | sdf_original = model.implicit_network.forward(pts)[:,:model.implicit_network.d_out] # [N_pts, K]
139 |
140 | if edit_type == 'translate':
141 | edit_pts = pts - edit_param
142 |
143 | sdf_edit = model.implicit_network.forward(edit_pts)[:,:model.implicit_network.d_out] # [N_pts, K]
144 |
145 | sdf_original[:, idx] = sdf_original[:, idx] * 0. + sdf_edit[:, idx]
146 |
147 | sdf = sdf_original
148 |
149 | sdf_all = sdf
150 | sdf = -model.implicit_network.pool(-sdf.unsqueeze(1)).squeeze(-1)
151 | return sdf, sdf_all
152 |
153 | def get_outputs_edit(points, model, idx, edit_param, edit_type):
154 | points.requires_grad_(True)
155 |
156 | # directly use the original geometry feature vector
157 | # fuse sdf together
158 | # then compute semantic, gradient, sdf
159 |
160 | original_output = model.implicit_network.forward(points)
161 | sdf_original = original_output[:,:model.implicit_network.d_out]
162 | feature_vectors = original_output[:,model.implicit_network.d_out:]
163 |
164 | if edit_type == 'translate':
165 | edit_pts = points - edit_param
166 | edit_output = model.implicit_network.forward(edit_pts)
167 | sdf_edit = edit_output[:, :model.implicit_network.d_out]
168 |
169 | sdf_raw = sdf_original
170 | sdf_raw[:, idx] = sdf_original[:, idx] * 0. + sdf_edit[:, idx]
171 |
172 | sigmoid_value = model.implicit_network.sigmoid
173 | semantic = sigmoid_value * torch.sigmoid(-sigmoid_value * sdf_raw)
174 |
175 | sdf = -model.implicit_network.pool(-sdf_raw.unsqueeze(1)).squeeze(-1)
176 |
177 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
178 | gradients = torch.autograd.grad(
179 | outputs=sdf,
180 | inputs=points,
181 | grad_outputs=d_output,
182 | create_graph=True,
183 | retain_graph=True,
184 | only_inputs=True)[0]
185 |
186 | return sdf, feature_vectors, gradients, semantic, sdf_raw
187 |
188 |
189 | def render_edit(model, input, indices, idx=0, edit_param=[0., 0., 0.], edit_type='translate'):
190 | '''
191 | Currently only support one object
192 | if edit_type == 'translate', then edit_param is [dx, dy, dz]
193 | if edit_type == 'rotate', then edit_param is []: TODO
194 | just use neus
195 | '''
196 | assert idx > 0
197 | edit_param = torch.tensor(edit_param).cuda()
198 |
199 | intrinsics = input["intrinsics"].cuda()
200 | uv = input["uv"].cuda()
201 | pose = input["pose"].cuda()
202 |
203 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics)
204 | # we should use unnormalized ray direction for depth
205 | ray_dirs_tmp, _ = rend_util.get_camera_params(uv, torch.eye(4).to(pose.device)[None], intrinsics)
206 | depth_scale = ray_dirs_tmp[0, :, 2:] # [N, 1]
207 |
208 | batch_size, num_pixels, _ = ray_dirs.shape
209 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
210 | ray_dirs = ray_dirs.reshape(-1, 3)
211 |
212 | '''
213 | Sample points with edited forward
214 | '''
215 | z_vals = neus_sample_edit(cam_loc, ray_dirs, model, idx, edit_param, edit_type)
216 |
217 | N_samples_tmp = z_vals.shape[1]
218 |
219 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # [N_rays, N_samples_tmp, 3]
220 | points_flat_tmp = points.reshape(-1, 3)
221 |
222 | sdf_tmp, sdf_all_tmp = get_sdf_vals_and_sdfs_edit(points_flat_tmp, model, idx, edit_param, edit_type)
223 | sdf_tmp = sdf_tmp.reshape(-1, N_samples_tmp)
224 | s_value = model.get_s_value()
225 |
226 | cdf, opacity_alpha = sdf_to_alpha(sdf_tmp, s_value) # [N_rays, N_samples_tmp-1]
227 |
228 | sdf_all_tmp = sdf_all_tmp.reshape(-1, N_samples_tmp, model.num_semantic)
229 |
230 | z_mid_vals = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1])
231 | N_samples = z_mid_vals.shape[1]
232 |
233 | points_mid = cam_loc.unsqueeze(1) + z_mid_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # [N_rays, N_samples, 3]
234 | points_flat = points_mid.reshape(-1, 3)
235 |
236 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
237 | dirs_flat = dirs.reshape(-1, 3)
238 |
239 | sdf, feature_vectors, gradients, semantic, sdf_raw = get_outputs_edit(points_flat, model, idx, edit_param, edit_type)
240 |
241 | # here the rgb output might be wrong
242 | rgb_flat = model.rendering_network(points_flat, gradients, dirs_flat, feature_vectors, indices)
243 | rgb = rgb_flat.reshape(-1, N_samples, 3)
244 |
245 | semantic = semantic.reshape(-1, N_samples, model.num_semantic)
246 |
247 | weights = alpha_to_w(opacity_alpha)
248 |
249 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)
250 | semantic_values = torch.sum(weights.unsqueeze(-1)*semantic, 1)
251 | raw_depth_values = torch.sum(weights * z_mid_vals, 1, keepdims=True) / (weights.sum(dim=1, keepdims=True) +1e-8)
252 | depth_values = depth_scale * raw_depth_values
253 |
254 | output = {
255 | 'rgb_values': rgb_values,
256 | 'semantic_values': semantic_values,
257 | 'depth_values': depth_values,
258 | }
259 |
260 | # compute normal map
261 | normals = gradients / (gradients.norm(2, -1, keepdim=True) + 1e-6)
262 | normals = normals.reshape(-1, N_samples, 3)
263 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
264 |
265 | # transform to local coordinate system
266 | rot = pose[0, :3, :3].permute(1, 0).contiguous()
267 | normal_map = rot @ normal_map.permute(1, 0)
268 | normal_map = normal_map.permute(1, 0).contiguous()
269 |
270 | output['normal_map'] = normal_map
271 |
272 | return output
273 |
274 |
275 | edit_idx = 1
276 | edit_param = [0., 0., 0.]
277 | edit_type = 'translate'
278 |
279 | exp_name = 'RICO_synthetic_1'
280 | scan_id = int(exp_name[-1])
281 |
282 | exp_path = os.path.join('../exps/', exp_name)
283 | timestamp = os.listdir(exp_path)[-1] # use the latest if not other need
284 | exp_path = os.path.join(exp_path, timestamp)
285 |
286 | conf = ConfigFactory.parse_file(os.path.join(exp_path, 'runconf.conf'))
287 | dataset_conf = conf.get_config('dataset')
288 | dataset_conf['scan_id'] = scan_id
289 | conf_model = conf.get_config('model')
290 |
291 | train_dataset = utils.get_class(conf.get_string('train.dataset_class'))(**dataset_conf)
292 | plot_dataloader = torch.utils.data.DataLoader(
293 | train_dataset,
294 | batch_size=conf.get_int('plot.plot_nimgs'),
295 | shuffle=False,
296 | collate_fn=train_dataset.collate_fn)
297 |
298 | model = utils.get_class(conf.get_string('train.model_class'))(conf=conf_model)
299 |
300 | if torch.cuda.is_available():
301 | model.cuda()
302 |
303 | ckpt_path = os.path.join(exp_path, 'checkpoints/ModelParameters', 'latest.pth')
304 | ckpt = torch.load(ckpt_path)
305 | print(ckpt['epoch'])
306 |
307 | # model.load_state_dict(ckpt['model_state_dict'])
308 | # load in a non-DDP fashion
309 | model.load_state_dict({k.replace('module.',''): v for k,v in ckpt['model_state_dict'].items()})
310 | os.makedirs('./tmp_edit', exist_ok=True)
311 |
312 | model.eval()
313 |
314 | data_idx = 75
315 | vis_data = plot_dataloader.dataset[data_idx]
316 |
317 | indices, model_input, ground_truth = vis_data
318 | indices = torch.tensor([indices])
319 | print(indices)
320 | for k, v in model_input.items():
321 | model_input[k] = v.unsqueeze(0)
322 | for k, v in ground_truth.items():
323 | ground_truth[k] = v.unsqueeze(0)
324 |
325 | model_input["intrinsics"] = model_input["intrinsics"].cuda()
326 | model_input["uv"] = model_input["uv"].cuda()
327 | model_input['pose'] = model_input['pose'].cuda()
328 |
329 | split = utils.split_input(model_input, 384*384, n_pixels=128)
330 | res = []
331 |
332 | for s in tqdm(split):
333 | # out = model(s, indices)
334 | out = render_edit(model, s, indices, edit_idx, edit_param, edit_type)
335 | d = {'rgb_values': out['rgb_values'].detach(),
336 | 'normal_map': out['normal_map'].detach(),
337 | 'depth_values': out['depth_values'].detach(),
338 | 'semantic_values': out['semantic_values'].detach()}
339 | if 'rgb_un_values' in out:
340 | d['rgb_un_values'] = out['rgb_un_values'].detach()
341 | res.append(d)
342 |
343 | batch_size = ground_truth['rgb'].shape[0]
344 | model_outputs = utils.merge_output(res, 384*384, batch_size)
345 | plot_data = get_plot_data(model_input, model_outputs, model_input['pose'], ground_truth['rgb'], ground_truth['normal'], ground_truth['depth'], ground_truth['instance_mask'])
346 |
347 | plot_conf = conf.get_config('plot')
348 | plot_conf['obj_boxes'] = None
349 | plts.plot_rico(
350 | None,
351 | indices,
352 | plot_data,
353 | './tmp_edit/',
354 | ckpt['epoch'],
355 | [384, 384],
356 | plot_mesh = False,
357 | **plot_conf
358 | )
--------------------------------------------------------------------------------
/scripts/extract_mesh_rico.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from skimage import measure
4 | import torchvision
5 | import trimesh
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 | import cv2
9 | import os
10 | import json
11 | from pyhocon import ConfigFactory
12 |
13 | import sys
14 | sys.path.append("../code")
15 | import utils.general as utils
16 |
17 |
18 | exp_name = 'RICO_synthetic_1'
19 | scan_id = int(exp_name[-1])
20 |
21 | avg_pool_3d = torch.nn.AvgPool3d(2, stride=2)
22 | upsample = torch.nn.Upsample(scale_factor=2, mode='nearest')
23 |
24 | @torch.no_grad()
25 | def get_surface_sliding(send_path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0):
26 | if isinstance(send_path, list):
27 | path = send_path[0]
28 | mesh_name = send_path[1]
29 | else:
30 | path = send_path
31 | mesh_name = ''
32 |
33 | # assert resolution % 512 == 0
34 | resN = resolution
35 | cropN = resolution
36 | level = 0
37 | N = resN // cropN
38 |
39 | if len(grid_boundary) == 2:
40 | grid_min = [grid_boundary[0], grid_boundary[0], grid_boundary[0]]
41 | grid_max = [grid_boundary[1], grid_boundary[1], grid_boundary[1]]
42 | elif len(grid_boundary) == 6: # xmin, ymin, zmin, xmax, ymax, zmax
43 | grid_min = [grid_boundary[0], grid_boundary[1], grid_boundary[2]]
44 | grid_max = [grid_boundary[3], grid_boundary[4], grid_boundary[5]]
45 | xs = np.linspace(grid_min[0], grid_max[0], N+1)
46 | ys = np.linspace(grid_min[1], grid_max[1], N+1)
47 | zs = np.linspace(grid_min[2], grid_max[2], N+1)
48 |
49 | print(xs)
50 | print(ys)
51 | print(zs)
52 | meshes = []
53 | for i in range(N):
54 | for j in range(N):
55 | for k in range(N):
56 | print(i, j, k)
57 | x_min, x_max = xs[i], xs[i+1]
58 | y_min, y_max = ys[j], ys[j+1]
59 | z_min, z_max = zs[k], zs[k+1]
60 |
61 | x = np.linspace(x_min, x_max, cropN)
62 | y = np.linspace(y_min, y_max, cropN)
63 | z = np.linspace(z_min, z_max, cropN)
64 |
65 | xx, yy, zz = np.meshgrid(x, y, z, indexing='ij')
66 | points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()
67 |
68 | def evaluate(points):
69 | z = []
70 | for _, pnts in enumerate(torch.split(points, 100000, dim=0)):
71 | z.append(sdf(pnts))
72 | z = torch.cat(z, axis=0)
73 | return z
74 |
75 | # construct point pyramids
76 | points = points.reshape(cropN, cropN, cropN, 3).permute(3, 0, 1, 2)
77 |
78 | points_pyramid = [points]
79 | for _ in range(3):
80 | points = avg_pool_3d(points[None])[0]
81 | points_pyramid.append(points)
82 | points_pyramid = points_pyramid[::-1]
83 |
84 | # evalute pyramid with mask
85 | mask = None
86 | threshold = 2 * (x_max - x_min)/cropN * 8
87 | for pid, pts in enumerate(points_pyramid):
88 | coarse_N = pts.shape[-1]
89 | pts = pts.reshape(3, -1).permute(1, 0).contiguous()
90 |
91 | if mask is None:
92 | pts_sdf = evaluate(pts)
93 | else:
94 | mask = mask.reshape(-1)
95 | pts_to_eval = pts[mask]
96 | #import pdb; pdb.set_trace()
97 | if pts_to_eval.shape[0] > 0:
98 | pts_sdf_eval = evaluate(pts_to_eval.contiguous())
99 | pts_sdf[mask] = pts_sdf_eval
100 | print("ratio", pts_to_eval.shape[0] / pts.shape[0])
101 |
102 | if pid < 3:
103 | # update mask
104 | mask = torch.abs(pts_sdf) < threshold
105 | mask = mask.reshape(coarse_N, coarse_N, coarse_N)[None, None]
106 | mask = upsample(mask.float()).bool()
107 |
108 | pts_sdf = pts_sdf.reshape(coarse_N, coarse_N, coarse_N)[None, None]
109 | pts_sdf = upsample(pts_sdf)
110 | pts_sdf = pts_sdf.reshape(-1)
111 |
112 | threshold /= 2.
113 |
114 | z = pts_sdf.detach().cpu().numpy()
115 |
116 | if (not (np.min(z) > level or np.max(z) < level)):
117 | z = z.astype(np.float32)
118 | verts, faces, normals, values = measure.marching_cubes(
119 | volume=z.reshape(cropN, cropN, cropN), #.transpose([1, 0, 2]),
120 | level=level,
121 | spacing=(
122 | (x_max - x_min)/(cropN-1),
123 | (y_max - y_min)/(cropN-1),
124 | (z_max - z_min)/(cropN-1) ))
125 | print(np.array([x_min, y_min, z_min]))
126 | print(verts.min(), verts.max())
127 | verts = verts + np.array([x_min, y_min, z_min])
128 | print(verts.min(), verts.max())
129 |
130 | meshcrop = trimesh.Trimesh(verts, faces, normals)
131 | #meshcrop.export(f"{i}_{j}_{k}.ply")
132 | meshes.append(meshcrop)
133 |
134 | combined = trimesh.util.concatenate(meshes)
135 |
136 | combined.export('{0}/surface_{1}_{2}.ply'.format(path, epoch, mesh_name), 'ply')
137 |
138 |
139 | exp_path = os.path.join('../exps/', exp_name)
140 | timestamp = os.listdir(exp_path)[-1] # use the latest if not other need
141 | exp_path = os.path.join(exp_path, timestamp)
142 |
143 | conf = ConfigFactory.parse_file(os.path.join(exp_path, 'runconf.conf'))
144 | dataset_conf = conf.get_config('dataset')
145 | conf_model = conf.get_config('model')
146 |
147 | model = utils.get_class(conf.get_string('train.model_class'))(conf=conf_model)
148 |
149 | if torch.cuda.is_available():
150 | model.cuda()
151 |
152 | ckpt_path = os.path.join(exp_path, 'checkpoints/ModelParameters', 'latest.pth')
153 | ckpt = torch.load(ckpt_path)
154 | print(ckpt['epoch'])
155 |
156 | # model.load_state_dict(ckpt['model_state_dict'])
157 | # load in a non-DDP fashion
158 | model.load_state_dict({k.replace('module.',''): v for k,v in ckpt['model_state_dict'].items()})
159 | os.makedirs('./tmp', exist_ok=True)
160 |
161 | sem_num = model.implicit_network.d_out
162 | f = torch.nn.MaxPool1d(sem_num)
163 |
164 | for indx in range(sem_num):
165 | obj_grid_boundary = [-1.1, 1.1]
166 | _ = get_surface_sliding(
167 | send_path=['./tmp/', str(indx)],
168 | epoch=ckpt['epoch'],
169 | sdf = lambda x: model.implicit_network(x)[:, indx],
170 | resolution=512,
171 | grid_boundary=obj_grid_boundary,
172 | level=0.
173 | )
174 |
175 | _ = get_surface_sliding(
176 | send_path=['./tmp/', 'all'],
177 | epoch=ckpt['epoch'],
178 | sdf=lambda x: -f(-model.implicit_network(x)[:, :sem_num].unsqueeze(1)).squeeze(-1).squeeze(-1),
179 | resolution=512,
180 | grid_boundary=[-1.1, 1.1],
181 | level=0.
182 | )
183 |
184 | print('finish')
185 |
186 |
--------------------------------------------------------------------------------
/synthetic_eval/evaluate.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import numpy as np
3 | import open3d as o3d
4 | from sklearn.neighbors import KDTree
5 | import trimesh
6 | import torch
7 | import glob
8 | import os
9 | import pyrender
10 | import os
11 | import cv2
12 | import json
13 | from tqdm import tqdm
14 | from pathlib import Path
15 |
16 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
17 |
18 | def nn_correspondance(verts1, verts2):
19 | indices = []
20 | distances = []
21 | if len(verts1) == 0 or len(verts2) == 0:
22 | return indices, distances
23 |
24 | kdtree = KDTree(verts1)
25 | distances, indices = kdtree.query(verts2)
26 | distances = distances.reshape(-1)
27 |
28 | return distances
29 |
30 |
31 | def evaluate(mesh_pred, mesh_trgt, obj_type='bg', threshold=.05, down_sample=.02):
32 | pcd_trgt = o3d.geometry.PointCloud()
33 | pcd_pred = o3d.geometry.PointCloud()
34 |
35 | trgt_pts = mesh_trgt.vertices[:, :3]
36 | pred_pts = mesh_pred.vertices[:, :3]
37 |
38 | if obj_type == 'obj':
39 | pts_mask = pred_pts[:, 2] < -0.9
40 | pred_pts = pred_pts[pts_mask]
41 |
42 | pcd_trgt.points = o3d.utility.Vector3dVector(trgt_pts)
43 | pcd_pred.points = o3d.utility.Vector3dVector(pred_pts)
44 |
45 | if down_sample:
46 | pcd_pred = pcd_pred.voxel_down_sample(down_sample)
47 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)
48 |
49 | verts_pred = np.asarray(pcd_pred.points)
50 | verts_trgt = np.asarray(pcd_trgt.points)
51 |
52 | dist1 = nn_correspondance(verts_pred, verts_trgt)
53 | dist2 = nn_correspondance(verts_trgt, verts_pred)
54 |
55 | precision = np.mean((dist2 < threshold).astype('float'))
56 | recal = np.mean((dist1 < threshold).astype('float'))
57 | fscore = 2 * precision * recal / (precision + recal)
58 | chamfer = (np.mean(dist2) + np.mean(dist1)) / 2
59 | metrics = {
60 | 'Acc': np.mean(dist2),
61 | 'Comp': np.mean(dist1),
62 | 'Chamfer': chamfer,
63 | 'Prec': precision,
64 | 'Recal': recal,
65 | 'F-score': fscore,
66 | }
67 | return metrics
68 |
69 | # hard-coded image size
70 | H, W = 384, 384
71 |
72 | def average_dicts(dicts):
73 | # input is a list of dict
74 | # all the dict have same keys
75 | dict_num = len(dicts)
76 | keys = dicts[0].keys()
77 | ret = {}
78 |
79 | for k in keys:
80 | values = [x[k] for x in dicts]
81 | value = np.array(values).mean()
82 | ret[k] = value
83 |
84 | return ret
85 |
86 |
87 | root_dir = "../exps/"
88 | exp_name = "RICO_synthetic"
89 | out_dir = "evaluation/" + exp_name
90 | Path(out_dir).mkdir(parents=True, exist_ok=True)
91 |
92 |
93 | scenes = {
94 | 1: 'scene1',
95 | 2: 'scene2',
96 | 3: 'scene3',
97 | 4: 'scene4',
98 | 5: 'scene5',
99 | }
100 |
101 | all_obj_results = []
102 | all_obj_results_dict = OrderedDict()
103 |
104 | for k, v in scenes.items():
105 |
106 | cur_exp = f"{exp_name}_{k}"
107 | cur_root = os.path.join(root_dir, cur_exp)
108 | if not os.path.isdir(cur_root):
109 | continue
110 | # use last timestamps
111 | dirs = sorted(os.listdir(cur_root))
112 | cur_root = os.path.join(cur_root, dirs[-1])
113 |
114 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*.ply"))))
115 |
116 | # evalute the meshes for obj and bg, the first is bg and last is all
117 | files.sort(key=lambda x:os.path.getmtime(x))
118 |
119 | cam_file = f"../data/syn_data/scene{k}/cameras.npz"
120 | scale_mat = np.load(cam_file)['scale_mat_0']
121 |
122 | ply_files = files[1: -1]
123 | # print(ply_files)
124 |
125 | cnt = 1
126 | obj_results = []
127 | obj_results_dict = OrderedDict()
128 | for ply_file in ply_files:
129 |
130 | mesh = trimesh.load(ply_file)
131 | mesh.vertices = (scale_mat[:3, :3] @ mesh.vertices.T + scale_mat[:3, 3:]).T
132 |
133 | gt_mesh = os.path.join(f"../data/syn_data/scene{k}/GT_mesh", f"object{cnt}.ply")
134 |
135 | gt_mesh = trimesh.load(gt_mesh)
136 |
137 | metrics = evaluate(mesh, gt_mesh, 'obj')
138 | obj_results.append(metrics)
139 | obj_results_dict[cnt] = metrics
140 |
141 | cnt += 1
142 |
143 | obj_results = average_dicts(obj_results)
144 | all_obj_results.append(obj_results)
145 | all_obj_results_dict[k] = obj_results_dict
146 |
147 | # the average result print
148 | all_obj_results = average_dicts(all_obj_results)
149 | print('objects:')
150 | print(all_obj_results)
151 |
152 | # all the result save
153 | obj_json_str = json.dumps(all_obj_results_dict, indent=4)
154 | obj_json_file = os.path.join('evaluation', exp_name + '_obj.json')
155 |
156 | with open(obj_json_file, 'w') as json_file:
157 | json_file.write(obj_json_str)
158 | json_file.close()
--------------------------------------------------------------------------------
/synthetic_eval/evaluate_bgdepth.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import numpy as np
3 | import open3d as o3d
4 | from sklearn.neighbors import KDTree
5 | import trimesh
6 | import torch
7 | import glob
8 | import os
9 | import pyrender
10 | import os
11 | import cv2
12 | import json
13 | from tqdm import tqdm
14 | from pathlib import Path
15 |
16 | os.environ['PYOPENGL_PLATFORM'] = 'egl'
17 |
18 | def load_K_Rt_from_P(filename, P=None):
19 | if P is None:
20 | lines = open(filename).read().splitlines()
21 | if len(lines) == 4:
22 | lines = lines[1:]
23 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
24 | P = np.asarray(lines).astype(np.float32).squeeze()
25 |
26 | out = cv2.decomposeProjectionMatrix(P)
27 | K = out[0]
28 | R = out[1]
29 | t = out[2]
30 |
31 | K = K/K[2,2]
32 | intrinsics = np.eye(4)
33 | intrinsics[:3, :3] = K
34 |
35 | pose = np.eye(4, dtype=np.float32)
36 | pose[:3, :3] = R.transpose()
37 | pose[:3,3] = (t[:3] / t[3])[:,0]
38 |
39 | return intrinsics, pose
40 |
41 | # hard-coded image size
42 | H, W = 384, 384
43 |
44 | # load pose
45 | def load_poses(scan_id, object_id):
46 | pose_path = os.path.join(f'../data/syn_data/scene{scan_id}', 'cameras.npz')
47 |
48 | camera_dict = np.load(pose_path)
49 | len_pose = len(camera_dict.files) // 2
50 |
51 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(len_pose)]
52 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(len_pose)]
53 | P = world_mats[0] @ scale_mats[0]
54 | P = P[:3, :4]
55 | intrinsics, pose = load_K_Rt_from_P(None, P)
56 |
57 | poses = []
58 | cnt = 0
59 |
60 | masks_path = os.path.join(f'../data/syn_data/scene{scan_id}', 'instance_mask')
61 | mask_files = sorted(os.listdir(masks_path))
62 |
63 | id_json = os.path.join(f'../data/syn_data/scene{scan_id}', 'instance_id.json')
64 | with open(id_json, 'r') as f:
65 | id_data = json.load(f)
66 | f.close()
67 |
68 | if object_id > 0: # the valid object id
69 | obj_idx = id_data[f'obj_{object_id-1}']
70 | else:
71 | obj_idx = -1 # invalid id, maybe for bg, however we load all the poses in this situation
72 |
73 | for scale_mat, world_mat in zip(scale_mats, world_mats):
74 | # first check if object is in this pose's corresponding image
75 | mask = cv2.imread(os.path.join(masks_path, mask_files[cnt]))
76 | mask = np.array(mask)
77 | mask = np.unique(mask)
78 |
79 | if obj_idx == -1:
80 | orig_pose = world_mat
81 | pose = np.linalg.inv(orig_pose) @ intrinsics
82 | poses.append(np.array(pose))
83 |
84 | elif obj_idx in mask:
85 | orig_pose = world_mat
86 | pose = np.linalg.inv(orig_pose) @ intrinsics
87 | poses.append(np.array(pose))
88 |
89 | cnt += 1
90 |
91 | poses = np.array(poses)
92 | print(poses.shape)
93 | return poses, intrinsics
94 |
95 |
96 | class Renderer():
97 | def __init__(self, height=480, width=640):
98 | self.renderer = pyrender.OffscreenRenderer(width, height)
99 | self.scene = pyrender.Scene()
100 | self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES
101 |
102 | def __call__(self, height, width, intrinsics, pose, mesh, need_flag=True):
103 | self.renderer.viewport_height = height
104 | self.renderer.viewport_width = width
105 | self.scene.clear()
106 | self.scene.add(mesh)
107 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],
108 | fx=intrinsics[0, 0], fy=intrinsics[1, 1])
109 | self.scene.add(cam, pose=self.fix_pose(pose))
110 | if need_flag:
111 | return self.renderer.render(self.scene, self.render_flags)
112 | else:
113 | return self.renderer.render(self.scene) # , self.render_flags)
114 |
115 | def fix_pose(self, pose):
116 | # 3D Rotation about the x-axis.
117 | t = np.pi
118 | c = np.cos(t)
119 | s = np.sin(t)
120 | R = np.array([[1, 0, 0],
121 | [0, c, -s],
122 | [0, s, c]])
123 | axis_transform = np.eye(4)
124 | axis_transform[:3, :3] = R
125 | return pose @ axis_transform
126 |
127 | def mesh_opengl(self, mesh):
128 | return pyrender.Mesh.from_trimesh(mesh)
129 |
130 | def delete(self):
131 | self.renderer.delete()
132 |
133 |
134 | def refuse_depth(mesh, poses, K, need_flag=False, scan_id=-1):
135 | renderer = Renderer()
136 | mesh_opengl = renderer.mesh_opengl(mesh)
137 |
138 | depths = []
139 |
140 | for pose in tqdm(poses):
141 | intrinsic = np.eye(4)
142 | intrinsic[:3, :3] = K
143 |
144 | rgb = np.ones((H, W, 3))
145 | rgb = (rgb * 255).astype(np.uint8)
146 | rgb = o3d.geometry.Image(rgb)
147 | _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl, need_flag=need_flag)
148 | depths.append(depth_pred)
149 |
150 | return depths
151 |
152 |
153 | def average_dicts(dicts):
154 | # input is a list of dict
155 | # all the dict have same keys
156 | dict_num = len(dicts)
157 | keys = dicts[0].keys()
158 | ret = {}
159 |
160 | for k in keys:
161 | values = [x[k] for x in dicts]
162 | value = np.array(values).mean()
163 | ret[k] = value
164 |
165 | return ret
166 |
167 |
168 | root_dir = "../exps/"
169 | exp_name = "RICO_synthetic"
170 | out_dir = "evaluation/" + exp_name
171 | Path(out_dir).mkdir(parents=True, exist_ok=True)
172 |
173 |
174 | scenes = {
175 | 1: 'scene1',
176 | 2: 'scene2',
177 | 3: 'scene3',
178 | 4: 'scene4',
179 | 5: 'scene5',
180 | }
181 |
182 | all_bg_results = []
183 | all_bg_results_dict = OrderedDict()
184 |
185 | for k, v in scenes.items():
186 |
187 | cur_exp = f"{exp_name}_{k}"
188 | cur_root = os.path.join(root_dir, cur_exp)
189 | if not os.path.isdir(cur_root):
190 | continue
191 | # use last timestamps
192 | dirs = sorted(os.listdir(cur_root))
193 | cur_root = os.path.join(cur_root, dirs[-1])
194 |
195 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*.ply"))))
196 |
197 | # evalute the meshes for obj and bg, the first is bg and last is all
198 | files.sort(key=lambda x:os.path.getmtime(x))
199 |
200 | bg_file = files[0]
201 | print(bg_file)
202 | bg_mesh = trimesh.load(bg_file)
203 |
204 | cam_file = f"../data/syn_data/scene{k}/cameras.npz"
205 | scale_mat = np.load(cam_file)['scale_mat_0']
206 | bg_mesh.vertices = (scale_mat[:3, :3] @ bg_mesh.vertices.T + scale_mat[:3, 3:]).T
207 |
208 | poses, K = load_poses(k, -1)
209 | K = K[:3, :3]
210 | bg_mesh_depth = refuse_depth(bg_mesh, poses, K, scan_id=k)
211 |
212 | gt_mesh = os.path.join(f"../data/syn_data/scene{k}/GT_mesh", f"background.ply")
213 | gt_mesh = trimesh.load(gt_mesh)
214 |
215 | gt_mesh.vertex_normals = -gt_mesh.vertex_normals
216 |
217 | gt_mesh_depth = refuse_depth(gt_mesh, poses, K, need_flag=True, scan_id=k)
218 |
219 | masks_path = os.path.join(f'../data/syn_data/scene{k}', 'instance_mask')
220 | mask_files = sorted(os.listdir(masks_path))
221 | masks = [cv2.imread(os.path.join(masks_path, x)) for x in mask_files]
222 |
223 | depth_errors = []
224 | for gt_depth, pred_depth, seg_mask in zip(gt_mesh_depth, bg_mesh_depth, masks):
225 | seg = seg_mask
226 | seg = np.array(seg)
227 | seg = seg[:, :, 0] > 0 # obj regions
228 |
229 | gtd = gt_depth[seg]
230 | prd = pred_depth[seg]
231 |
232 | mse = (np.square(gtd - prd)).mean(axis=0)
233 | depth_errors.append(mse)
234 | depth_errors = np.array(depth_errors)
235 | metrics = {'bg_depth_error': depth_errors.mean().astype(float)}
236 |
237 | all_bg_results.append(metrics)
238 | all_bg_results_dict[k] = metrics
239 |
240 | # the average result print
241 | all_bg_results = average_dicts(all_bg_results)
242 | print('background:')
243 | print(all_bg_results)
244 | # all the result save
245 | bg_json_str = json.dumps(all_bg_results_dict, indent=4)
246 | bg_json_file = os.path.join('evaluation', exp_name + '_bg_depth.json')
247 |
248 | with open(bg_json_file, 'w') as json_file:
249 | json_file.write(bg_json_str)
250 | json_file.close()
--------------------------------------------------------------------------------