├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.png ├── code ├── confs │ ├── RICO_scannet.conf │ └── RICO_synthetic.conf ├── datasets │ ├── scene_dataset.py │ └── scene_dataset_rico.py ├── hashencoder │ ├── __init__.py │ ├── backend.py │ ├── hashgrid.py │ └── src │ │ ├── bindings.cpp │ │ ├── hashencoder.cu │ │ └── hashencoder.h ├── model │ ├── density.py │ ├── embedder.py │ ├── loss.py │ ├── network.py │ ├── network_rico.py │ └── ray_sampler.py ├── slurm_run.sh ├── training │ ├── exp_runner.py │ └── rico_train.py └── utils │ ├── general.py │ ├── plots.py │ └── rend_util.py ├── requirements.txt ├── scripts ├── edit_render.py └── extract_mesh_rico.py └── synthetic_eval ├── evaluate.py └── evaluate_bgdepth.py /.gitignore: -------------------------------------------------------------------------------- 1 | code/run_logs 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | exps/* 134 | exps* 135 | evals* 136 | data/DTU 137 | data/BlendedMVS 138 | data/Replica 139 | data/tnt_advanced 140 | data/ 141 | 142 | code/tmp_build 143 | preprocess/feature_extractor/ckpts/ 144 | synthetic_eval/evaluation/ 145 | 146 | code/.idea/ 147 | .DS_Store 148 | ._.DS_Store 149 | .idea/ 150 | 151 | *.png 152 | *.ply 153 | *.txt 154 | *.jpg 155 | *.npy 156 | *.npz 157 | *.tar 158 | uploadtnt_*/ 159 | 160 | *.json 161 | *.csv 162 | dtu_eval/Offical_DTU_Dataset/ 163 | media/ 164 | files_save/ 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 autonomousvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

:couch_and_lamp: RICO: Regularizing the Unobservable for Indoor Compositional Reconstruction (ICCV2023)

4 |

5 | Zizhang Li, 6 | Xiaoyang Lyu, 7 | Yuanyuan Ding, 8 | Mengmeng Wang, 9 | Yiyi Liao, 10 | Yong Liu 11 | 12 |

13 |

ICCV2023

14 |

Paper

15 |
16 |

17 | 18 |

19 | 20 | Logo 21 | 22 |

23 | 24 |

25 | We use geometry motivated prior information to regularize the unobservable regions for indoor compositional reconstruction. 26 |

27 |
28 | 29 | ## TODO 30 | - [x] Training code 31 | - [x] Evaluation scripts 32 | - [x] Mesh extraction script 33 | - [x] Editted rendering script 34 | - [x] Dataset clean 35 | 36 | ## Setup 37 | 38 | ### Installation 39 | Clone the repository and create an anaconda environment called rico using 40 | ``` 41 | git clone git@github.com:kyleleey/RICO.git 42 | cd RICO 43 | 44 | conda create -y -n rico python=3.8 45 | conda activate rico 46 | 47 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch 48 | 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | ### Dataset 53 | We provide processed scannet and synthetic scenes in this [link](https://drive.google.com/drive/folders/1yY9TYj-HaM2_I9qzsNQN8leOw6WFzVDA?usp=sharing). Please download the data and unzip in the `data` folder, the resulting folder structure should be: 54 | ``` 55 | └── RICO 56 | └── data 57 | ├── scannet 58 | ├── syn_data 59 | ``` 60 | ## Training 61 | 62 | Run the following command to train rico on the synthetic scene 1: 63 | ``` 64 | cd ./code 65 | bash slurm_run.sh PARTITION CFG_PATH SCAN_ID PORT 66 | ``` 67 | where `PARTITION` is the slurm partition name you're using. You can use `confs/RICO_scannet.conf` or `confs/RICO_synthetic.conf` for `CFG_PATH` to train on ScanNet or synthetic scene. You also need to provide specific `SCAN_ID` and `PORT`. 68 | 69 | If you are not in a slurm environment you can simply run: 70 | ``` 71 | python training/exp_runner.py --conf CFG_PATH --scan_id SCAN_ID --port PORT 72 | ``` 73 | 74 | ## Evaluations 75 | 76 | To run quantitative evaluation on synthetic scenes for object and masked background depth: 77 | ``` 78 | cd synthetic_eval 79 | python evaluate.py 80 | python evaluate_bgdepth.py 81 | ``` 82 | Evaluation results will be saved in `synthetic_eval/evaluation` as .json files. 83 | 84 | We also provide other scripts for experiment files after training. 85 | 86 | To extract the per-object mesh and the combined scene mesh: 87 | ``` 88 | cd scripts 89 | python extract_mesh_rico.py 90 | ``` 91 | 92 | To render translation edited results: 93 | ``` 94 | cd scripts 95 | python edit_render.py 96 | ``` 97 | 98 | You can change the detailed settings in these scripts to run on top of different experiment results. 99 | 100 | ## Acknowledgements 101 | This project is built upon [MonoSDF](https://github.com/autonomousvision/monosdf), [ObjSDF](https://github.com/QianyiWu/objsdf) and also the original [VolSDF](https://github.com/lioryariv/volsdf). To construct the synthetic scenes, we mainly use the function of [BlenderNeRF](https://github.com/maximeraafat/BlenderNeRF). We thank all the authors for their great work and repos. 102 | 103 | 104 | ## Citation 105 | If you find our code or paper useful, please cite 106 | ```bibtex 107 | @inproceedings{li2023rico, 108 | author = {Li, Zizhang and Lyu, Xiaoyang and Ding, Yuanyuan and Wang, Mengmeng and Liao, Yiyi and Liu, Yong}, 109 | title = {RICO: Regularizing the Unobservable for Indoor Compositional Reconstruction}, 110 | booktitle = {ICCV}, 111 | year = {2023}, 112 | } 113 | ``` 114 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyleleey/RICO/4254e6ff8581d21833e1b42e0352a7a63da788b1/assets/teaser.png -------------------------------------------------------------------------------- /code/confs/RICO_scannet.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = RICO_scannet 3 | dataset_class = datasets.scene_dataset_rico.RICO_SceneDatasetDN_Mask 4 | model_class = model.network_rico.RICONetwork 5 | loss_class = model.loss.RICOLoss 6 | learning_rate = 5.0e-4 7 | num_pixels = 1024 8 | checkpoint_freq = 10000 9 | plot_freq = 50 10 | split_n_pixels = 1024 11 | max_total_iters = 50000 12 | } 13 | plot{ 14 | plot_nimgs = 1 15 | resolution = 512 16 | grid_boundary = [-1.1, 1.1] 17 | } 18 | loss{ 19 | rgb_loss = torch.nn.L1Loss 20 | eikonal_weight = 0.05 21 | semantic_weight = 0.04 22 | bg_render_weight = 0.1 23 | lop_weight = 0.1 24 | lrd_weight = 0.1 25 | smooth_weight = 0.005 26 | depth_weight = 0.1 27 | normal_l1_weight = 0.05 28 | normal_cos_weight = 0.05 29 | } 30 | dataset{ 31 | data_dir = syn_data 32 | img_res = [384, 384] 33 | scan_id = 1 34 | center_crop_type = no_crop 35 | data_prefix = scan 36 | } 37 | model{ 38 | feature_vector_size = 256 39 | scene_bounding_sphere = 1.1 40 | render_bg = True 41 | render_bg_iter = 10 42 | 43 | Grid_MLP = True 44 | 45 | implicit_network 46 | { 47 | d_in = 3 48 | d_out = 1 49 | dims = [256, 256, 256, 256, 256, 256, 256, 256] 50 | geometric_init = True 51 | bias = 0.9 52 | skip_in = [4] 53 | weight_norm = True 54 | multires = 6 55 | inside_outside = True 56 | sigmoid = 20 57 | sigmoid_optim = False 58 | } 59 | 60 | rendering_network 61 | { 62 | mode = idr 63 | d_in = 9 64 | d_out = 3 65 | dims = [ 256, 256] 66 | weight_norm = True 67 | multires_view = 4 68 | per_image_code = True 69 | } 70 | density 71 | { 72 | variance_init = 0.05 73 | speed_factor = 10.0 74 | } 75 | ray_sampler 76 | { 77 | take_sphere_intersection = True 78 | near = 0.0 79 | N_samples = 64 80 | N_samples_extra = 32 81 | N_upsample_iters = 4 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /code/confs/RICO_synthetic.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = RICO_synthetic 3 | dataset_class = datasets.scene_dataset_rico.RICO_SceneDatasetDN_Mask 4 | model_class = model.network_rico.RICONetwork 5 | loss_class = model.loss.RICOLoss 6 | learning_rate = 5.0e-4 7 | num_pixels = 1024 8 | checkpoint_freq = 10000 9 | plot_freq = 50 10 | split_n_pixels = 1024 11 | max_total_iters = 50000 12 | } 13 | plot{ 14 | plot_nimgs = 1 15 | resolution = 512 16 | grid_boundary = [-1.1, 1.1] 17 | } 18 | loss{ 19 | rgb_loss = torch.nn.L1Loss 20 | eikonal_weight = 0.05 21 | semantic_weight = 0.04 22 | bg_render_weight = 0.1 23 | lop_weight = 0.1 24 | lrd_weight = 0.1 25 | smooth_weight = 0.005 26 | depth_weight = 0.1 27 | normal_l1_weight = 0.05 28 | normal_cos_weight = 0.05 29 | } 30 | dataset{ 31 | data_dir = syn_data 32 | img_res = [384, 384] 33 | scan_id = 1 34 | center_crop_type = no_crop 35 | data_prefix = scene 36 | } 37 | model{ 38 | feature_vector_size = 256 39 | scene_bounding_sphere = 1.1 40 | render_bg = True 41 | render_bg_iter = 10 42 | 43 | Grid_MLP = True 44 | 45 | implicit_network 46 | { 47 | d_in = 3 48 | d_out = 1 49 | dims = [256, 256, 256, 256, 256, 256, 256, 256] 50 | geometric_init = True 51 | bias = 0.9 52 | skip_in = [4] 53 | weight_norm = True 54 | multires = 6 55 | inside_outside = True 56 | sigmoid = 20 57 | sigmoid_optim = False 58 | } 59 | 60 | rendering_network 61 | { 62 | mode = idr 63 | d_in = 9 64 | d_out = 3 65 | dims = [ 256, 256] 66 | weight_norm = True 67 | multires_view = 4 68 | per_image_code = True 69 | } 70 | density 71 | { 72 | variance_init = 0.05 73 | speed_factor = 10.0 74 | } 75 | ray_sampler 76 | { 77 | take_sphere_intersection = True 78 | near = 0.0 79 | N_samples = 64 80 | N_samples_extra = 32 81 | N_upsample_iters = 4 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /code/datasets/scene_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import utils.general as utils 7 | from utils import rend_util 8 | from glob import glob 9 | import cv2 10 | import random 11 | import json 12 | from kornia import morphology as morph 13 | 14 | class SceneDataset(torch.utils.data.Dataset): 15 | 16 | def __init__(self, 17 | data_dir, 18 | img_res, 19 | scan_id=0, 20 | num_views=-1, 21 | ): 22 | 23 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id)) 24 | 25 | self.total_pixels = img_res[0] * img_res[1] 26 | self.img_res = img_res 27 | 28 | assert os.path.exists(self.instance_dir), "Data directory is empty" 29 | 30 | self.num_views = num_views 31 | assert num_views in [-1, 3, 6, 9] 32 | 33 | self.sampling_idx = None 34 | 35 | image_dir = '{0}/image'.format(self.instance_dir) 36 | image_paths = sorted(utils.glob_imgs(image_dir)) 37 | self.n_images = len(image_paths) 38 | 39 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) 40 | camera_dict = np.load(self.cam_file) 41 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 42 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 43 | 44 | self.intrinsics_all = [] 45 | self.pose_all = [] 46 | for scale_mat, world_mat in zip(scale_mats, world_mats): 47 | P = world_mat @ scale_mat 48 | P = P[:3, :4] 49 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) 50 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 51 | self.pose_all.append(torch.from_numpy(pose).float()) 52 | 53 | self.rgb_images = [] 54 | for path in image_paths: 55 | rgb = rend_util.load_rgb(path) 56 | rgb = rgb.reshape(3, -1).transpose(1, 0) 57 | self.rgb_images.append(torch.from_numpy(rgb).float()) 58 | 59 | # used a fake depth image and normal image 60 | self.depth_images = [] 61 | self.normal_images = [] 62 | 63 | for path in image_paths: 64 | depth = np.ones_like(rgb[:, :1]) 65 | self.depth_images.append(torch.from_numpy(depth).float()) 66 | normal = np.ones_like(rgb) 67 | self.normal_images.append(torch.from_numpy(normal).float()) 68 | 69 | def __len__(self): 70 | return self.n_images 71 | 72 | def __getitem__(self, idx): 73 | if self.num_views >= 0: 74 | image_ids = [25, 22, 28, 40, 44, 48, 0, 8, 13][:self.num_views] 75 | idx = image_ids[random.randint(0, self.num_views - 1)] 76 | 77 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 78 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 79 | uv = uv.reshape(2, -1).transpose(1, 0) 80 | 81 | sample = { 82 | "uv": uv, 83 | "intrinsics": self.intrinsics_all[idx], 84 | "pose": self.pose_all[idx] 85 | } 86 | 87 | ground_truth = { 88 | "rgb": self.rgb_images[idx], 89 | "depth": self.depth_images[idx], 90 | "normal": self.normal_images[idx], 91 | } 92 | 93 | if self.sampling_idx is not None: 94 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :] 95 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :] 96 | ground_truth["mask"] = torch.ones_like(self.depth_images[idx][self.sampling_idx, :]) 97 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :] 98 | 99 | sample["uv"] = uv[self.sampling_idx, :] 100 | 101 | return idx, sample, ground_truth 102 | 103 | def collate_fn(self, batch_list): 104 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 105 | batch_list = zip(*batch_list) 106 | 107 | all_parsed = [] 108 | for entry in batch_list: 109 | if type(entry[0]) is dict: 110 | # make them all into a new dict 111 | ret = {} 112 | for k in entry[0].keys(): 113 | ret[k] = torch.stack([obj[k] for obj in entry]) 114 | all_parsed.append(ret) 115 | else: 116 | all_parsed.append(torch.LongTensor(entry)) 117 | 118 | return tuple(all_parsed) 119 | 120 | def change_sampling_idx(self, sampling_size): 121 | if sampling_size == -1: 122 | self.sampling_idx = None 123 | else: 124 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] 125 | 126 | def get_scale_mat(self): 127 | return np.load(self.cam_file)['scale_mat_0'] 128 | 129 | 130 | # Dataset with monocular depth and normal 131 | class SceneDatasetDN(torch.utils.data.Dataset): 132 | 133 | def __init__(self, 134 | data_dir, 135 | img_res, 136 | scan_id=0, 137 | center_crop_type='xxxx', 138 | use_mask=False, 139 | num_views=-1 140 | ): 141 | 142 | if data_dir == 'syn_data': 143 | self.instance_dir = os.path.join('../data', data_dir, 'scene{0}'.format(scan_id)) 144 | else: 145 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id)) 146 | 147 | self.total_pixels = img_res[0] * img_res[1] 148 | self.img_res = img_res 149 | self.num_views = num_views 150 | assert num_views in [-1, 3, 6, 9] 151 | 152 | assert os.path.exists(self.instance_dir), "Data directory is empty" 153 | 154 | self.sampling_idx = None 155 | 156 | def glob_data(data_dir): 157 | data_paths = [] 158 | data_paths.extend(glob(data_dir)) 159 | data_paths = sorted(data_paths) 160 | return data_paths 161 | 162 | image_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_rgb.png")) 163 | depth_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_depth.npy")) 164 | normal_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_normal.npy")) 165 | 166 | # mask is only used in the replica dataset as some monocular depth predictions have very large error and we ignore it 167 | if use_mask: 168 | mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_mask.npy")) 169 | else: 170 | mask_paths = None 171 | 172 | self.n_images = len(image_paths) 173 | 174 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) 175 | camera_dict = np.load(self.cam_file) 176 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 177 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 178 | 179 | self.intrinsics_all = [] 180 | self.pose_all = [] 181 | for scale_mat, world_mat in zip(scale_mats, world_mats): 182 | P = world_mat @ scale_mat 183 | P = P[:3, :4] 184 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) 185 | 186 | # because we do resize and center crop 384x384 when using omnidata model, we need to adjust the camera intrinsic accordingly 187 | if center_crop_type == 'center_crop_for_replica': 188 | scale = 384 / 680 189 | offset = (1200 - 680 ) * 0.5 190 | intrinsics[0, 2] -= offset 191 | intrinsics[:2, :] *= scale 192 | elif center_crop_type == 'center_crop_for_tnt': 193 | scale = 384 / 540 194 | offset = (960 - 540) * 0.5 195 | intrinsics[0, 2] -= offset 196 | intrinsics[:2, :] *= scale 197 | elif center_crop_type == 'center_crop_for_dtu': 198 | scale = 384 / 1200 199 | offset = (1600 - 1200) * 0.5 200 | intrinsics[0, 2] -= offset 201 | intrinsics[:2, :] *= scale 202 | elif center_crop_type == 'padded_for_dtu': 203 | scale = 384 / 1200 204 | offset = 0 205 | intrinsics[0, 2] -= offset 206 | intrinsics[:2, :] *= scale 207 | elif center_crop_type == 'no_crop': # for scannet dataset, we already adjust the camera intrinsic duing preprocessing so nothing to be done here 208 | pass 209 | else: 210 | raise NotImplementedError 211 | 212 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 213 | self.pose_all.append(torch.from_numpy(pose).float()) 214 | 215 | self.rgb_images = [] 216 | for path in image_paths: 217 | rgb = rend_util.load_rgb(path) 218 | rgb = rgb.reshape(3, -1).transpose(1, 0) 219 | self.rgb_images.append(torch.from_numpy(rgb).float()) 220 | 221 | self.depth_images = [] 222 | self.normal_images = [] 223 | 224 | for dpath, npath in zip(depth_paths, normal_paths): 225 | depth = np.load(dpath) 226 | self.depth_images.append(torch.from_numpy(depth.reshape(-1, 1)).float()) 227 | 228 | normal = np.load(npath) 229 | normal = normal.reshape(3, -1).transpose(1, 0) 230 | # important as the output of omnidata is normalized 231 | normal = normal * 2. - 1. 232 | self.normal_images.append(torch.from_numpy(normal).float()) 233 | 234 | # load mask 235 | self.mask_images = [] 236 | if mask_paths is None: 237 | for depth in self.depth_images: 238 | mask = torch.ones_like(depth) 239 | self.mask_images.append(mask) 240 | else: 241 | for path in mask_paths: 242 | mask = np.load(path) 243 | self.mask_images.append(torch.from_numpy(mask.reshape(-1, 1)).float()) 244 | 245 | def __len__(self): 246 | return self.n_images 247 | 248 | def __getitem__(self, idx): 249 | if self.num_views >= 0: 250 | image_ids = [25, 22, 28, 40, 44, 48, 0, 8, 13][:self.num_views] 251 | idx = image_ids[random.randint(0, self.num_views - 1)] 252 | 253 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 254 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 255 | uv = uv.reshape(2, -1).transpose(1, 0) 256 | 257 | sample = { 258 | "uv": uv, 259 | "intrinsics": self.intrinsics_all[idx], 260 | "pose": self.pose_all[idx] 261 | } 262 | 263 | ground_truth = { 264 | "rgb": self.rgb_images[idx], 265 | "depth": self.depth_images[idx], 266 | "mask": self.mask_images[idx], 267 | "normal": self.normal_images[idx], 268 | } 269 | 270 | if self.sampling_idx is not None: 271 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :] 272 | ground_truth["full_rgb"] = self.rgb_images[idx] 273 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :] 274 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :] 275 | ground_truth["full_depth"] = self.depth_images[idx] 276 | ground_truth["mask"] = self.mask_images[idx][self.sampling_idx, :] 277 | ground_truth["full_mask"] = self.mask_images[idx] 278 | 279 | sample["uv"] = uv[self.sampling_idx, :] 280 | 281 | return idx, sample, ground_truth 282 | 283 | def collate_fn(self, batch_list): 284 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 285 | batch_list = zip(*batch_list) 286 | 287 | all_parsed = [] 288 | for entry in batch_list: 289 | if type(entry[0]) is dict: 290 | # make them all into a new dict 291 | ret = {} 292 | for k in entry[0].keys(): 293 | ret[k] = torch.stack([obj[k] for obj in entry]) 294 | all_parsed.append(ret) 295 | else: 296 | all_parsed.append(torch.LongTensor(entry)) 297 | 298 | return tuple(all_parsed) 299 | 300 | def change_sampling_idx(self, sampling_size): 301 | if sampling_size == -1: 302 | self.sampling_idx = None 303 | else: 304 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] 305 | 306 | def get_scale_mat(self): 307 | return np.load(self.cam_file)['scale_mat_0'] -------------------------------------------------------------------------------- /code/datasets/scene_dataset_rico.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import utils.general as utils 7 | from utils import rend_util 8 | from glob import glob 9 | import cv2 10 | import json 11 | 12 | 13 | # Dataset with monocular depth and normal and mask and etc. 14 | class RICO_SceneDatasetDN_Mask(torch.utils.data.Dataset): 15 | def __init__(self, 16 | data_dir, 17 | img_res, 18 | scan_id=0, 19 | center_crop_type='xxxx', 20 | use_mask=False, 21 | data_prefix='scan' 22 | ): 23 | # for scannet, data_prefix is 'scan', for synthetic dataset, data_prefix is 'scene' 24 | self.instance_dir = os.path.join('../data', data_dir, data_prefix+'{0}'.format(scan_id)) 25 | 26 | self.total_pixels = img_res[0] * img_res[1] 27 | self.img_res = img_res 28 | 29 | assert os.path.exists(self.instance_dir), "Data directory is empty" 30 | 31 | self.sampling_idx = None 32 | 33 | with open(os.path.join(self.instance_dir, 'instance_id.json'), 'r') as f: 34 | id_dict = json.load(f) 35 | f.close() 36 | self.instance_dict = id_dict 37 | self.instance_ids = list(self.instance_dict.values()) 38 | self.label_mapping = [0] + self.instance_ids # background ID is 0 and at the first of label_mapping 39 | 40 | def glob_data(data_dir): 41 | data_paths = [] 42 | data_paths.extend(glob(data_dir)) 43 | data_paths = sorted(data_paths) 44 | return data_paths 45 | 46 | image_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_rgb.png")) 47 | depth_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_depth.npy")) 48 | normal_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_normal.npy")) 49 | 50 | # mask is only used in the replica dataset as some monocular depth predictions have very large error and we ignore it 51 | if use_mask: 52 | mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_mask.npy")) 53 | else: 54 | mask_paths = None 55 | 56 | # This is the loading of Instance masks for RICO 57 | instance_mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "instance_mask", "*.png")) 58 | 59 | self.n_images = len(image_paths) 60 | 61 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) 62 | camera_dict = np.load(self.cam_file) 63 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 64 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 65 | 66 | self.intrinsics_all = [] 67 | self.pose_all = [] 68 | for scale_mat, world_mat in zip(scale_mats, world_mats): 69 | P = world_mat @ scale_mat 70 | P = P[:3, :4] 71 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) 72 | 73 | # because we do resize and center crop 384x384 when using omnidata model, we need to adjust the camera intrinsic accordingly 74 | # should be "no-crop" for both datasets in RICO 75 | if center_crop_type == 'center_crop_for_replica': 76 | scale = 384 / 680 77 | offset = (1200 - 680 ) * 0.5 78 | intrinsics[0, 2] -= offset 79 | intrinsics[:2, :] *= scale 80 | elif center_crop_type == 'center_crop_for_tnt': 81 | scale = 384 / 540 82 | offset = (960 - 540) * 0.5 83 | intrinsics[0, 2] -= offset 84 | intrinsics[:2, :] *= scale 85 | elif center_crop_type == 'center_crop_for_dtu': 86 | scale = 384 / 1200 87 | offset = (1600 - 1200) * 0.5 88 | intrinsics[0, 2] -= offset 89 | intrinsics[:2, :] *= scale 90 | elif center_crop_type == 'padded_for_dtu': 91 | scale = 384 / 1200 92 | offset = 0 93 | intrinsics[0, 2] -= offset 94 | intrinsics[:2, :] *= scale 95 | elif center_crop_type == 'no_crop': # for scannet dataset, we already adjust the camera intrinsic duing preprocessing so nothing to be done here 96 | pass 97 | else: 98 | raise NotImplementedError 99 | 100 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 101 | self.pose_all.append(torch.from_numpy(pose).float()) 102 | 103 | self.rgb_images = [] 104 | for path in image_paths: 105 | rgb = rend_util.load_rgb(path) 106 | rgb = rgb.reshape(3, -1).transpose(1, 0) 107 | self.rgb_images.append(torch.from_numpy(rgb).float()) 108 | 109 | self.depth_images = [] 110 | self.normal_images = [] 111 | 112 | for dpath, npath in zip(depth_paths, normal_paths): 113 | depth = np.load(dpath) 114 | self.depth_images.append(torch.from_numpy(depth.reshape(-1, 1)).float()) 115 | 116 | normal = np.load(npath) 117 | normal = normal.reshape(3, -1).transpose(1, 0) 118 | # important as the output of omnidata is normalized 119 | normal = normal * 2. - 1. 120 | self.normal_images.append(torch.from_numpy(normal).float()) 121 | 122 | # load instance mask and map to label_mapping 123 | self.instance_masks = [] 124 | self.instance_dilated_region_list = [] 125 | for im_path in instance_mask_paths: 126 | 127 | instance_mask_pic = cv2.imread(im_path, -1) 128 | if len(instance_mask_pic.shape) == 3: 129 | instance_mask_pic = instance_mask_pic[:, :, 0] 130 | instance_mask = instance_mask_pic.reshape(1, -1).transpose(1, 0) # [HW, 1] 131 | instance_mask[instance_mask==255] = 0 # background is 0 132 | 133 | ins_list = np.unique(instance_mask) 134 | cur_sems = np.copy(instance_mask) 135 | for i in ins_list: 136 | if i not in self.label_mapping: 137 | cur_sems[instance_mask == i] = self.label_mapping.index(0) 138 | else: 139 | cur_sems[instance_mask == i] = self.label_mapping.index(i) 140 | 141 | self.instance_masks.append(torch.from_numpy(cur_sems).float()) 142 | 143 | # load mask 144 | self.mask_images = [] 145 | if mask_paths is None: 146 | for depth in self.depth_images: 147 | mask = torch.ones_like(depth) 148 | self.mask_images.append(mask) 149 | else: 150 | for path in mask_paths: 151 | mask = np.load(path) 152 | self.mask_images.append(torch.from_numpy(mask.reshape(-1, 1)).float()) 153 | 154 | self.n_images = len(self.rgb_images) 155 | 156 | def __len__(self): 157 | return self.n_images 158 | 159 | def __getitem__(self, idx): 160 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 161 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 162 | uv = uv.reshape(2, -1).transpose(1, 0) 163 | 164 | sample = { 165 | "uv": uv, 166 | "intrinsics": self.intrinsics_all[idx], 167 | "pose": self.pose_all[idx] 168 | } 169 | 170 | ground_truth = { 171 | "rgb": self.rgb_images[idx], 172 | "depth": self.depth_images[idx], 173 | "mask": self.mask_images[idx], 174 | "normal": self.normal_images[idx], 175 | "instance_mask": self.instance_masks[idx], 176 | "use_syn_data": torch.Tensor([0.]).reshape(-1) 177 | } 178 | 179 | if self.sampling_idx is not None: 180 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :] 181 | ground_truth["full_rgb"] = self.rgb_images[idx] 182 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :] 183 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :] 184 | ground_truth["full_depth"] = self.depth_images[idx] 185 | ground_truth["mask"] = self.mask_images[idx][self.sampling_idx, :] 186 | ground_truth["full_mask"] = self.mask_images[idx] 187 | 188 | ground_truth["instance_mask"] = self.instance_masks[idx][self.sampling_idx, :] 189 | ground_truth["full_instance_mask"] = self.instance_masks[idx] 190 | 191 | sample["uv"] = uv[self.sampling_idx, :] 192 | 193 | return idx, sample, ground_truth 194 | 195 | def collate_fn(self, batch_list): 196 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 197 | batch_list = zip(*batch_list) 198 | 199 | all_parsed = [] 200 | for entry in batch_list: 201 | if type(entry[0]) is dict: 202 | # make them all into a new dict 203 | ret = {} 204 | for k in entry[0].keys(): 205 | ret[k] = torch.stack([obj[k] for obj in entry]) 206 | all_parsed.append(ret) 207 | else: 208 | all_parsed.append(torch.LongTensor(entry)) 209 | 210 | return tuple(all_parsed) 211 | 212 | def change_sampling_idx(self, sampling_size): 213 | if sampling_size == -1: 214 | self.sampling_idx = None 215 | else: 216 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] 217 | 218 | def get_scale_mat(self): 219 | return np.load(self.cam_file)['scale_mat_0'] -------------------------------------------------------------------------------- /code/hashencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .hashgrid import HashEncoder -------------------------------------------------------------------------------- /code/hashencoder/backend.py: -------------------------------------------------------------------------------- 1 | from distutils.command.build import build 2 | import os 3 | from torch.utils.cpp_extension import load 4 | from pathlib import Path 5 | 6 | Path('./tmp_build/').mkdir(parents=True, exist_ok=True) 7 | 8 | _src_path = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | _backend = load(name='_hash_encoder', 11 | extra_cflags=['-O3', '-std=c++14'], 12 | extra_cuda_cflags=[ 13 | '-O3', '-std=c++14', '-allow-unsupported-compiler', 14 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 15 | ], 16 | sources=[os.path.join(_src_path, 'src', f) for f in [ 17 | 'hashencoder.cu', 18 | 'bindings.cpp', 19 | ]], 20 | build_directory='./tmp_build/', 21 | verbose=True, 22 | ) 23 | 24 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /code/hashencoder/hashgrid.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from math import ceil 3 | from cachetools import cached 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Function 9 | from torch.autograd.function import once_differentiable 10 | from torch.cuda.amp import custom_bwd, custom_fwd 11 | 12 | from .backend import _backend 13 | 14 | class _hash_encode(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.half) 17 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False): 18 | # inputs: [B, D], float in [0, 1] 19 | # embeddings: [sO, C], float 20 | # offsets: [L + 1], int 21 | # RETURN: [B, F], float 22 | 23 | inputs = inputs.contiguous() 24 | embeddings = embeddings.contiguous() 25 | offsets = offsets.contiguous() 26 | 27 | B, D = inputs.shape # batch size, coord dim 28 | L = offsets.shape[0] - 1 # level 29 | C = embeddings.shape[1] # embedding dim for each level 30 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 31 | H = base_resolution # base resolution 32 | 33 | # L first, optimize cache for cuda kernel, but needs an extra permute later 34 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=inputs.dtype) 35 | 36 | if calc_grad_inputs: 37 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=inputs.dtype) 38 | else: 39 | dy_dx = torch.empty(1, device=inputs.device, dtype=inputs.dtype) 40 | 41 | _backend.hash_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx) 42 | 43 | # permute back to [B, L * C] 44 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 45 | 46 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 47 | ctx.dims = [B, D, C, L, S, H] 48 | ctx.calc_grad_inputs = calc_grad_inputs 49 | 50 | return outputs 51 | 52 | @staticmethod 53 | @custom_bwd 54 | def backward(ctx, grad): 55 | 56 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 57 | B, D, C, L, S, H = ctx.dims 58 | calc_grad_inputs = ctx.calc_grad_inputs 59 | 60 | # grad: [B, L * C] --> [L, B, C] 61 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 62 | 63 | grad_inputs, grad_embeddings = _hash_encode_second_backward.apply(grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx) 64 | 65 | if calc_grad_inputs: 66 | return grad_inputs, grad_embeddings, None, None, None, None 67 | else: 68 | return None, grad_embeddings, None, None, None, None 69 | 70 | 71 | class _hash_encode_second_backward(Function): 72 | @staticmethod 73 | def forward(ctx, grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx): 74 | 75 | grad_inputs = torch.zeros_like(inputs) 76 | grad_embeddings = torch.zeros_like(embeddings) 77 | 78 | ctx.save_for_backward(grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings) 79 | ctx.dims = [B, D, C, L, S, H] 80 | ctx.calc_grad_inputs = calc_grad_inputs 81 | 82 | _backend.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs) 83 | 84 | return grad_inputs, grad_embeddings 85 | 86 | @staticmethod 87 | def backward(ctx, grad_grad_inputs, grad_grad_embeddings): 88 | 89 | grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings = ctx.saved_tensors 90 | B, D, C, L, S, H = ctx.dims 91 | calc_grad_inputs = ctx.calc_grad_inputs 92 | 93 | grad_grad = torch.zeros_like(grad) 94 | grad2_embeddings = torch.zeros_like(embeddings) 95 | 96 | _backend.hash_encode_second_backward(grad, inputs, embeddings, offsets, 97 | B, D, C, L, S, H, calc_grad_inputs, dy_dx, 98 | grad_grad_inputs, 99 | grad_grad, grad2_embeddings) 100 | 101 | return grad_grad, None, grad2_embeddings, None, None, None, None, None, None, None, None, None 102 | 103 | 104 | hash_encode = _hash_encode.apply 105 | 106 | 107 | class HashEncoder(nn.Module): 108 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None): 109 | super().__init__() 110 | 111 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 112 | if desired_resolution is not None: 113 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 114 | 115 | self.input_dim = input_dim # coord dims, 2 or 3 116 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 117 | self.level_dim = level_dim # encode channels per level 118 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 119 | self.log2_hashmap_size = log2_hashmap_size 120 | self.base_resolution = base_resolution 121 | self.output_dim = num_levels * level_dim 122 | 123 | if level_dim % 2 != 0: 124 | print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)') 125 | 126 | # allocate parameters 127 | offsets = [] 128 | offset = 0 129 | self.max_params = 2 ** log2_hashmap_size 130 | for i in range(num_levels): 131 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 132 | params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number 133 | #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible 134 | offsets.append(offset) 135 | offset += params_in_level 136 | offsets.append(offset) 137 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 138 | self.register_buffer('offsets', offsets) 139 | 140 | self.n_params = offsets[-1] * level_dim 141 | 142 | # parameters 143 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 144 | 145 | self.reset_parameters() 146 | 147 | def reset_parameters(self): 148 | std = 1e-4 149 | self.embeddings.data.uniform_(-std, std) 150 | 151 | def __repr__(self): 152 | return f"HashEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)}" 153 | 154 | def forward(self, inputs, size=1): 155 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 156 | # return: [..., num_levels * level_dim] 157 | 158 | inputs = (inputs + size) / (2 * size) # map to [0, 1] 159 | 160 | prefix_shape = list(inputs.shape[:-1]) 161 | inputs = inputs.view(-1, self.input_dim) 162 | 163 | outputs = hash_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad) 164 | outputs = outputs.view(prefix_shape + [self.output_dim]) 165 | 166 | return outputs -------------------------------------------------------------------------------- /code/hashencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "hashencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("hash_encode_forward", &hash_encode_forward, "hash encode forward (CUDA)"); 7 | m.def("hash_encode_backward", &hash_encode_backward, "hash encode backward (CUDA)"); 8 | m.def("hash_encode_second_backward", &hash_encode_second_backward, "hash encode second backward (CUDA)"); 9 | } -------------------------------------------------------------------------------- /code/hashencoder/src/hashencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | // inputs: [B, D], float, in [0, 1] 9 | // embeddings: [sO, C], float 10 | // offsets: [L + 1], uint32_t 11 | // outputs: [B, L * C], float 12 | // H: base resolution 13 | void hash_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx); 14 | void hash_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs); 15 | void hash_encode_second_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, const at::Tensor grad_grad_inputs, at::Tensor grad_grad, at::Tensor grad2_embeddings); 16 | 17 | #endif -------------------------------------------------------------------------------- /code/model/density.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Density(nn.Module): 6 | def __init__(self, params_init={}): 7 | super().__init__() 8 | for p in params_init: 9 | param = nn.Parameter(torch.tensor(params_init[p])) 10 | setattr(self, p, param) 11 | 12 | def forward(self, sdf, beta=None): 13 | return self.density_func(sdf, beta=beta) 14 | 15 | 16 | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) 17 | def __init__(self, params_init={}, beta_min=0.0001): 18 | super().__init__(params_init=params_init) 19 | self.beta_min = torch.tensor(beta_min).cuda() 20 | 21 | def density_func(self, sdf, beta=None): 22 | if beta is None: 23 | beta = self.get_beta() 24 | 25 | alpha = 1 / beta 26 | return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) 27 | 28 | def get_beta(self): 29 | beta = self.beta.abs() + self.beta_min 30 | return beta 31 | 32 | 33 | class AbsDensity(Density): # like NeRF++ 34 | def density_func(self, sdf, beta=None): 35 | return torch.abs(sdf) 36 | 37 | 38 | class SimpleDensity(Density): # like NeRF 39 | def __init__(self, params_init={}, noise_std=1.0): 40 | super().__init__(params_init=params_init) 41 | self.noise_std = noise_std 42 | 43 | def density_func(self, sdf, beta=None): 44 | if self.training and self.noise_std > 0.0: 45 | noise = torch.randn(sdf.shape).cuda() * self.noise_std 46 | sdf = sdf + noise 47 | return torch.relu(sdf) 48 | -------------------------------------------------------------------------------- /code/model/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ 4 | 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | embed_fns = [] 12 | d = self.kwargs['input_dims'] 13 | out_dim = 0 14 | if self.kwargs['include_input']: 15 | embed_fns.append(lambda x: x) 16 | out_dim += d 17 | 18 | max_freq = self.kwargs['max_freq_log2'] 19 | N_freqs = self.kwargs['num_freqs'] 20 | 21 | if self.kwargs['log_sampling']: 22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 25 | 26 | for freq in freq_bands: 27 | for p_fn in self.kwargs['periodic_fns']: 28 | embed_fns.append(lambda x, p_fn=p_fn, 29 | freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | def get_embedder(multires, input_dims=3): 39 | embed_kwargs = { 40 | 'include_input': True, 41 | 'input_dims': input_dims, 42 | 'max_freq_log2': multires-1, 43 | 'num_freqs': multires, 44 | 'log_sampling': True, 45 | 'periodic_fns': [torch.sin, torch.cos], 46 | } 47 | 48 | embedder_obj = Embedder(**embed_kwargs) 49 | def embed(x, eo=embedder_obj): return eo.embed(x) 50 | return embed, embedder_obj.out_dim 51 | -------------------------------------------------------------------------------- /code/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import utils.general as utils 4 | import math 5 | 6 | # copy from MiDaS 7 | def compute_scale_and_shift(prediction, target, mask): 8 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 9 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 10 | a_01 = torch.sum(mask * prediction, (1, 2)) 11 | a_11 = torch.sum(mask, (1, 2)) 12 | 13 | # right hand side: b = [b_0, b_1] 14 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 15 | b_1 = torch.sum(mask * target, (1, 2)) 16 | 17 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 18 | x_0 = torch.zeros_like(b_0) 19 | x_1 = torch.zeros_like(b_1) 20 | 21 | det = a_00 * a_11 - a_01 * a_01 22 | valid = det.nonzero() 23 | 24 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] 25 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] 26 | 27 | return x_0, x_1 28 | 29 | 30 | def reduction_batch_based(image_loss, M): 31 | # average of all valid pixels of the batch 32 | 33 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 34 | divisor = torch.sum(M) 35 | 36 | if divisor == 0: 37 | return 0 38 | else: 39 | return torch.sum(image_loss) / divisor 40 | 41 | 42 | def reduction_image_based(image_loss, M): 43 | # mean of average of valid pixels of an image 44 | 45 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 46 | valid = M.nonzero() 47 | 48 | image_loss[valid] = image_loss[valid] / M[valid] 49 | 50 | return torch.mean(image_loss) 51 | 52 | 53 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based): 54 | 55 | M = torch.sum(mask, (1, 2)) 56 | res = prediction - target 57 | 58 | _loss = mask * res * res 59 | 60 | image_loss = torch.sum(_loss, (1, 2)) 61 | 62 | return reduction(image_loss, 2 * M) 63 | 64 | 65 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 66 | 67 | M = torch.sum(mask, (1, 2)) 68 | 69 | diff = prediction - target 70 | diff = torch.mul(mask, diff) 71 | 72 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 73 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 74 | grad_x = torch.mul(mask_x, grad_x) 75 | 76 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 77 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 78 | grad_y = torch.mul(mask_y, grad_y) 79 | 80 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 81 | 82 | return reduction(image_loss, M) 83 | 84 | 85 | class MSELoss(nn.Module): 86 | def __init__(self, reduction='batch-based'): 87 | super().__init__() 88 | 89 | if reduction == 'batch-based': 90 | self.__reduction = reduction_batch_based 91 | else: 92 | self.__reduction = reduction_image_based 93 | 94 | def forward(self, prediction, target, mask): 95 | return mse_loss(prediction, target, mask, reduction=self.__reduction) 96 | 97 | 98 | class GradientLoss(nn.Module): 99 | def __init__(self, scales=4, reduction='batch-based'): 100 | super().__init__() 101 | 102 | if reduction == 'batch-based': 103 | self.__reduction = reduction_batch_based 104 | else: 105 | self.__reduction = reduction_image_based 106 | 107 | self.__scales = scales 108 | 109 | def forward(self, prediction, target, mask): 110 | total = 0 111 | 112 | for scale in range(self.__scales): 113 | step = pow(2, scale) 114 | 115 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 116 | mask[:, ::step, ::step], reduction=self.__reduction) 117 | 118 | return total 119 | 120 | 121 | class ScaleAndShiftInvariantLoss(nn.Module): 122 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): 123 | super().__init__() 124 | 125 | self.__data_loss = MSELoss(reduction=reduction) 126 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) 127 | self.__alpha = alpha 128 | 129 | self.__prediction_ssi = None 130 | 131 | def forward(self, prediction, target, mask): 132 | scale, shift = compute_scale_and_shift(prediction, target, mask) 133 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) 134 | 135 | total = self.__data_loss(self.__prediction_ssi, target, mask) 136 | if self.__alpha > 0: 137 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) 138 | 139 | return total 140 | 141 | def __get_prediction_ssi(self): 142 | return self.__prediction_ssi 143 | 144 | prediction_ssi = property(__get_prediction_ssi) 145 | # end copy 146 | 147 | 148 | class MonoSDFLoss(nn.Module): 149 | def __init__(self, rgb_loss, 150 | eikonal_weight, 151 | smooth_weight = 0.005, 152 | depth_weight = 0.1, 153 | normal_l1_weight = 0.05, 154 | normal_cos_weight = 0.05, 155 | end_step = -1): 156 | super().__init__() 157 | self.eikonal_weight = eikonal_weight 158 | self.smooth_weight = smooth_weight 159 | self.depth_weight = depth_weight 160 | self.normal_l1_weight = normal_l1_weight 161 | self.normal_cos_weight = normal_cos_weight 162 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean') 163 | 164 | self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1) 165 | 166 | print(f"using weight for loss RGB_1.0 EK_{self.eikonal_weight} SM_{self.smooth_weight} Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}") 167 | 168 | self.step = 0 169 | self.end_step = end_step 170 | 171 | def get_rgb_loss(self,rgb_values, rgb_gt): 172 | rgb_gt = rgb_gt.reshape(-1, 3) 173 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt) 174 | return rgb_loss 175 | 176 | def get_eikonal_loss(self, grad_theta): 177 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() 178 | return eikonal_loss 179 | 180 | def get_smooth_loss(self,model_outputs): 181 | # smoothness loss as unisurf 182 | g1 = model_outputs['grad_theta'] 183 | g2 = model_outputs['grad_theta_nei'] 184 | 185 | normals_1 = g1 / (g1.norm(2, dim=1).unsqueeze(-1) + 1e-5) 186 | normals_2 = g2 / (g2.norm(2, dim=1).unsqueeze(-1) + 1e-5) 187 | smooth_loss = torch.norm(normals_1 - normals_2, dim=-1).mean() 188 | return smooth_loss 189 | 190 | def get_depth_loss(self, depth_pred, depth_gt, mask): 191 | # TODO remove hard-coded scaling for depth 192 | return self.depth_loss(depth_pred.reshape(1, 32, 32), (depth_gt * 50 + 0.5).reshape(1, 32, 32), mask.reshape(1, 32, 32)) 193 | 194 | def get_normal_loss(self, normal_pred, normal_gt): 195 | normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) 196 | normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) 197 | l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean() 198 | cos = (1. - torch.sum(normal_pred * normal_gt, dim = -1)).mean() 199 | return l1, cos 200 | 201 | def forward(self, model_outputs, ground_truth): 202 | rgb_gt = ground_truth['rgb'].cuda() 203 | # monocular depth and normal 204 | depth_gt = ground_truth['depth'].cuda() 205 | normal_gt = ground_truth['normal'].cuda() 206 | 207 | depth_pred = model_outputs['depth_values'] 208 | normal_pred = model_outputs['normal_map'][None] 209 | 210 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt) 211 | 212 | if 'grad_theta' in model_outputs: 213 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) 214 | else: 215 | eikonal_loss = torch.tensor(0.0).cuda().float() 216 | 217 | # only supervised the foreground normal 218 | mask = ((model_outputs['sdf'] > 0.).any(dim=-1) & (model_outputs['sdf'] < 0.).any(dim=-1))[None, :, None] 219 | # combine with GT 220 | mask = (ground_truth['mask'] > 0.5).cuda() & mask 221 | 222 | depth_loss = self.get_depth_loss(depth_pred, depth_gt, mask) 223 | if isinstance(depth_loss, float): 224 | depth_loss = torch.tensor(0.0).cuda().float() 225 | 226 | normal_l1, normal_cos = self.get_normal_loss(normal_pred * mask, normal_gt) 227 | 228 | smooth_loss = self.get_smooth_loss(model_outputs) 229 | 230 | # compute decay weights 231 | if self.end_step > 0: 232 | decay = math.exp(-self.step / self.end_step * 10.) 233 | else: 234 | decay = 1.0 235 | 236 | self.step += 1 237 | 238 | loss = rgb_loss + \ 239 | self.eikonal_weight * eikonal_loss +\ 240 | self.smooth_weight * smooth_loss +\ 241 | decay * self.depth_weight * depth_loss +\ 242 | decay * self.normal_l1_weight * normal_l1 +\ 243 | decay * self.normal_cos_weight * normal_cos 244 | 245 | output = { 246 | 'loss': loss, 247 | 'rgb_loss': rgb_loss, 248 | 'eikonal_loss': eikonal_loss, 249 | 'smooth_loss': smooth_loss, 250 | 'depth_loss': depth_loss, 251 | 'normal_l1': normal_l1, 252 | 'normal_cos': normal_cos 253 | } 254 | 255 | return output 256 | 257 | 258 | class RICOLoss(nn.Module): 259 | def __init__(self, rgb_loss, 260 | eikonal_weight, 261 | semantic_weight = 0.04, 262 | bg_render_weight = 0.0, 263 | lop_weight = 0.1, 264 | lrd_weight = 0.1, 265 | smooth_weight = 0.005, 266 | depth_weight = 0.1, 267 | normal_l1_weight = 0.05, 268 | normal_cos_weight = 0.05, 269 | end_step = -1, 270 | epsilon_param = 0.05): 271 | super().__init__() 272 | self.eikonal_weight = eikonal_weight 273 | self.smooth_weight = smooth_weight 274 | self.depth_weight = depth_weight 275 | self.normal_l1_weight = normal_l1_weight 276 | self.normal_cos_weight = normal_cos_weight 277 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean') 278 | 279 | self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1) 280 | 281 | self.semantic_weight = semantic_weight 282 | # self.semantic_loss = torch.nn.NLLLoss() 283 | self.semantic_loss = torch.nn.CrossEntropyLoss(ignore_index = -1) 284 | 285 | self.bg_render_weight = bg_render_weight # when use this loss, make sure the sampled idx is in patch 286 | 287 | self.lop_weight = lop_weight 288 | self.lrd_weight = lrd_weight 289 | 290 | print(f"using weight for loss RGB_1.0 SEMANTIC_{self.semantic_weight} \ 291 | Lop_{self.lop_weight} Lrd_{self.lrd_weight} BG_RENDER_{self.bg_render_weight} \ 292 | EK_{self.eikonal_weight} SM_{self.smooth_weight} \ 293 | Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}") 294 | 295 | self.step = 0 296 | self.end_step = end_step 297 | 298 | self.epsilon_param = epsilon_param 299 | 300 | def get_rgb_loss(self,rgb_values, rgb_gt): 301 | rgb_gt = rgb_gt.reshape(-1, 3) 302 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt) 303 | return rgb_loss 304 | 305 | def get_eikonal_loss(self, grad_theta): 306 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() 307 | return eikonal_loss 308 | 309 | def get_smooth_loss(self,model_outputs): 310 | # smoothness loss as unisurf 311 | g1 = model_outputs['grad_theta'] 312 | g2 = model_outputs['grad_theta_nei'] 313 | 314 | normals_1 = g1 / (g1.norm(2, dim=1).unsqueeze(-1) + 1e-5) 315 | normals_2 = g2 / (g2.norm(2, dim=1).unsqueeze(-1) + 1e-5) 316 | smooth_loss = torch.norm(normals_1 - normals_2, dim=-1).mean() 317 | return smooth_loss 318 | 319 | def get_depth_loss(self, depth_pred, depth_gt, mask): 320 | # TODO remove hard-coded scaling for depth 321 | return self.depth_loss(depth_pred.reshape(1, 32, 32), (depth_gt * 50 + 0.5).reshape(1, 32, 32), mask.reshape(1, 32, 32)) 322 | 323 | def get_normal_loss(self, normal_pred, normal_gt): 324 | normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) 325 | normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) 326 | l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1) 327 | cos = (1. - torch.sum(normal_pred * normal_gt, dim = -1)) 328 | 329 | l1 = l1.mean() 330 | cos = cos.mean() 331 | 332 | return l1, cos 333 | 334 | def get_semantic_loss(self, semantic_value, semantic_gt): 335 | semantic_gt = semantic_gt.squeeze() 336 | # semantic_loss = torch.nn.functional.nll_loss(semantic_value, semantic_gt) 337 | semantic_loss = self.semantic_loss(semantic_value, semantic_gt) 338 | # semantic_loss = self.semantic_loss(semantic_value, semantic_gt) 339 | return semantic_loss 340 | 341 | def get_bg_render_loss(self, bg_render_results, mask): 342 | bg_depth = bg_render_results['depth_values'] 343 | bg_normal = bg_render_results['normal_map'] 344 | 345 | bg_depth = bg_depth.reshape(1, 32, 32) 346 | bg_normal = bg_normal.reshape(32, 32, 3).permute(2, 0, 1) 347 | 348 | mask = mask.reshape(1, 32, 32) 349 | 350 | depth_grad = self.compute_grad_error(bg_depth, mask) 351 | normal_grad = self.compute_grad_error(bg_normal, mask.repeat(3, 1, 1)) 352 | 353 | bg_render_loss = depth_grad + normal_grad 354 | return bg_render_loss 355 | 356 | def compute_grad_error(self, x, mask): 357 | scales = 4 358 | grad_loss = torch.tensor(0.0).cuda().float() 359 | for i in range(scales): 360 | step = pow(2, i) 361 | 362 | mask_step = mask[:, ::step, ::step] 363 | x_step = x[:, ::step, ::step] 364 | 365 | M = torch.sum(mask_step[:1], (1, 2)) 366 | 367 | diff = torch.mul(mask_step, x_step) 368 | 369 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 370 | mask_x = torch.mul(mask_step[:, :, 1:], mask_step[:, :, :-1]) 371 | grad_x = torch.mul(mask_x, grad_x) 372 | 373 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 374 | mask_y = torch.mul(mask_step[:, 1:, :], mask_step[:, :-1, :]) 375 | grad_y = torch.mul(mask_y, grad_y) 376 | 377 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 378 | 379 | divisor = torch.sum(M) 380 | 381 | if divisor == 0: 382 | scale_loss = torch.tensor(0.0).cuda().float() 383 | else: 384 | scale_loss = torch.sum(image_loss) / divisor 385 | 386 | grad_loss += scale_loss 387 | 388 | return grad_loss 389 | 390 | def get_lop_loss(self, obj_sdfs): 391 | margin_target = torch.ones(obj_sdfs.shape).cuda() 392 | # threshold = 0.05 * torch.ones(obj_sdfs.shape).cuda() 393 | threshold = self.epsilon_param * torch.ones(obj_sdfs.shape).cuda() 394 | loss = torch.nn.functional.margin_ranking_loss(obj_sdfs, threshold, margin_target) 395 | 396 | return loss 397 | 398 | def get_lrd_loss(self, obj_r_d, bg_r_d): 399 | if len(obj_r_d) == 0: 400 | loss = torch.tensor(0.0).cuda().float() 401 | return loss 402 | 403 | bg_r_d = bg_r_d.detach() 404 | 405 | obj_d = torch.where(obj_r_d > bg_r_d, bg_r_d, obj_r_d) 406 | loss = bg_r_d - obj_d 407 | loss = loss.mean() 408 | 409 | return loss 410 | 411 | def forward(self, model_outputs, ground_truth, iter_ratio=-1): 412 | rgb_gt = ground_truth['rgb'].cuda() 413 | # monocular depth and normal 414 | depth_gt = ground_truth['depth'].cuda() 415 | normal_gt = ground_truth['normal'].cuda() 416 | 417 | depth_pred = model_outputs['depth_values'] 418 | normal_pred = model_outputs['normal_map'][None] 419 | 420 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt) 421 | 422 | if 'grad_theta' in model_outputs: 423 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) 424 | else: 425 | eikonal_loss = torch.tensor(0.0).cuda().float() 426 | 427 | # only supervised the foreground normal 428 | mask = ((model_outputs['sdf'] > 0.).any(dim=-1) & (model_outputs['sdf'] < 0.).any(dim=-1))[None, :, None] 429 | # combine with GT 430 | mask = (ground_truth['mask'] > 0.5).cuda() & mask 431 | 432 | depth_loss = self.get_depth_loss(depth_pred, depth_gt, mask) 433 | if isinstance(depth_loss, float): 434 | depth_loss = torch.tensor(0.0).cuda().float() 435 | 436 | normal_l1, normal_cos = self.get_normal_loss(normal_pred * mask, normal_gt) 437 | 438 | if 'grad_theta_nei' in model_outputs: 439 | smooth_loss = self.get_smooth_loss(model_outputs) 440 | else: 441 | smooth_loss = torch.tensor(0.0).cuda().float() 442 | 443 | if 'semantic_values' in model_outputs: 444 | semantic_gt = ground_truth["instance_mask"].cuda().long() 445 | semantic_loss = self.get_semantic_loss(model_outputs['semantic_values'], semantic_gt) 446 | else: 447 | semantic_loss = torch.tensor(0.0).cuda().float() 448 | 449 | # background render smooth loss 450 | if self.bg_render_weight > 0 and model_outputs['background_render'] is not None: 451 | bg_mask = torch.argmax(model_outputs['background_render']['semantic_values'], dim=-1, keepdim=True) 452 | bg_mask = bg_mask != 0 453 | bg_mask = bg_mask.int() 454 | bg_render_loss = self.get_bg_render_loss(model_outputs['background_render'], bg_mask) 455 | else: 456 | bg_render_loss = torch.tensor(0.0).cuda().float() 457 | 458 | # Object Point SDF Loss 459 | lop_loss = self.get_lop_loss(model_outputs['obj_sdfs_behind_bg']) 460 | if torch.isnan(lop_loss): 461 | lop_loss = torch.tensor(0.0).cuda().float() 462 | # Reversed Depth Loss 463 | lrd_loss = self.get_lrd_loss(model_outputs['obj_d_vals'], model_outputs['bg_d_vals']) 464 | 465 | # compute decay weights 466 | if self.end_step > 0: 467 | decay = math.exp(-self.step / self.end_step * 10.) 468 | else: 469 | decay = 1.0 470 | 471 | self.step += 1 472 | 473 | loss = rgb_loss + \ 474 | self.bg_render_weight * bg_render_loss+\ 475 | self.eikonal_weight * eikonal_loss +\ 476 | self.semantic_weight * semantic_loss +\ 477 | self.smooth_weight * smooth_loss +\ 478 | self.lop_weight * lop_loss +\ 479 | self.lrd_weight * lrd_loss +\ 480 | decay * self.depth_weight * depth_loss +\ 481 | decay * self.normal_l1_weight * normal_l1 +\ 482 | decay * self.normal_cos_weight * normal_cos 483 | 484 | output = { 485 | 'loss': loss, 486 | 'rgb_loss': rgb_loss, 487 | 'eikonal_loss': eikonal_loss, 488 | 'bg_render_loss': bg_render_loss, 489 | 'lop_loss': lop_loss, 490 | 'lrd_loss': lrd_loss, 491 | 'semantic_loss': semantic_loss, 492 | 'smooth_loss': smooth_loss, 493 | 'depth_loss': depth_loss, 494 | 'normal_l1': normal_l1, 495 | 'normal_cos': normal_cos 496 | } 497 | 498 | return output -------------------------------------------------------------------------------- /code/model/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from utils import rend_util 6 | from model.embedder import * 7 | from model.density import LaplaceDensity 8 | from model.ray_sampler import ErrorBoundSampler 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | class ImplicitNetwork(nn.Module): 13 | def __init__( 14 | self, 15 | feature_vector_size, 16 | sdf_bounding_sphere, 17 | d_in, 18 | d_out, 19 | dims, 20 | geometric_init=True, 21 | bias=1.0, 22 | skip_in=(), 23 | weight_norm=True, 24 | multires=0, 25 | sphere_scale=1.0, 26 | inside_outside=False, 27 | ): 28 | super().__init__() 29 | 30 | self.sdf_bounding_sphere = sdf_bounding_sphere 31 | self.sphere_scale = sphere_scale 32 | dims = [d_in] + dims + [d_out + feature_vector_size] 33 | 34 | self.embed_fn = None 35 | if multires > 0: 36 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 37 | self.embed_fn = embed_fn 38 | dims[0] = input_ch 39 | print(multires, dims) 40 | self.num_layers = len(dims) 41 | self.skip_in = skip_in 42 | 43 | for l in range(0, self.num_layers - 1): 44 | if l + 1 in self.skip_in: 45 | out_dim = dims[l + 1] - dims[0] 46 | else: 47 | out_dim = dims[l + 1] 48 | 49 | lin = nn.Linear(dims[l], out_dim) 50 | 51 | if geometric_init: 52 | if l == self.num_layers - 2: 53 | if not inside_outside: 54 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 55 | torch.nn.init.constant_(lin.bias, -bias) 56 | else: 57 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 58 | torch.nn.init.constant_(lin.bias, bias) 59 | 60 | elif multires > 0 and l == 0: 61 | torch.nn.init.constant_(lin.bias, 0.0) 62 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 63 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 64 | elif multires > 0 and l in self.skip_in: 65 | torch.nn.init.constant_(lin.bias, 0.0) 66 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 67 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 68 | else: 69 | torch.nn.init.constant_(lin.bias, 0.0) 70 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 71 | 72 | if weight_norm: 73 | lin = nn.utils.weight_norm(lin) 74 | 75 | setattr(self, "lin" + str(l), lin) 76 | 77 | self.softplus = nn.Softplus(beta=100) 78 | 79 | def forward(self, input): 80 | if self.embed_fn is not None: 81 | input = self.embed_fn(input) 82 | 83 | x = input 84 | 85 | for l in range(0, self.num_layers - 1): 86 | lin = getattr(self, "lin" + str(l)) 87 | 88 | if l in self.skip_in: 89 | x = torch.cat([x, input], 1) / np.sqrt(2) 90 | 91 | x = lin(x) 92 | 93 | if l < self.num_layers - 2: 94 | x = self.softplus(x) 95 | 96 | return x 97 | 98 | def gradient(self, x): 99 | x.requires_grad_(True) 100 | y = self.forward(x)[:,:1] 101 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 102 | gradients = torch.autograd.grad( 103 | outputs=y, 104 | inputs=x, 105 | grad_outputs=d_output, 106 | create_graph=True, 107 | retain_graph=True, 108 | only_inputs=True)[0] 109 | return gradients 110 | 111 | def get_outputs(self, x): 112 | x.requires_grad_(True) 113 | output = self.forward(x) 114 | sdf = output[:,:1] 115 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' 116 | if self.sdf_bounding_sphere > 0.0: 117 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 118 | sdf = torch.minimum(sdf, sphere_sdf) 119 | feature_vectors = output[:, 1:] 120 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 121 | gradients = torch.autograd.grad( 122 | outputs=sdf, 123 | inputs=x, 124 | grad_outputs=d_output, 125 | create_graph=True, 126 | retain_graph=True, 127 | only_inputs=True)[0] 128 | 129 | return sdf, feature_vectors, gradients 130 | 131 | def get_sdf_vals(self, x): 132 | sdf = self.forward(x)[:,:1] 133 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' 134 | if self.sdf_bounding_sphere > 0.0: 135 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 136 | sdf = torch.minimum(sdf, sphere_sdf) 137 | return sdf 138 | 139 | 140 | # from hashencoder.hashgrid import _hash_encode, HashEncoder 141 | class ImplicitNetworkGrid(nn.Module): 142 | def __init__( 143 | self, 144 | feature_vector_size, 145 | sdf_bounding_sphere, 146 | d_in, 147 | d_out, 148 | dims, 149 | geometric_init=True, 150 | bias=1.0, 151 | skip_in=(), 152 | weight_norm=True, 153 | multires=0, 154 | sphere_scale=1.0, 155 | inside_outside=False, 156 | base_size = 16, 157 | end_size = 2048, 158 | logmap = 19, 159 | num_levels=16, 160 | level_dim=2, 161 | divide_factor = 1.5, # used to normalize the points range for multi-res grid 162 | use_grid_feature = True 163 | ): 164 | super().__init__() 165 | 166 | self.sdf_bounding_sphere = sdf_bounding_sphere 167 | self.sphere_scale = sphere_scale 168 | dims = [d_in] + dims + [d_out + feature_vector_size] 169 | self.embed_fn = None 170 | self.divide_factor = divide_factor 171 | self.grid_feature_dim = num_levels * level_dim 172 | self.use_grid_feature = use_grid_feature 173 | dims[0] += self.grid_feature_dim 174 | 175 | print(f"using hash encoder with {num_levels} levels, each level with feature dim {level_dim}") 176 | print(f"resolution:{base_size} -> {end_size} with hash map size {logmap}") 177 | # self.encoding = HashEncoder(input_dim=3, num_levels=num_levels, level_dim=level_dim, 178 | # per_level_scale=2, base_resolution=base_size, 179 | # log2_hashmap_size=logmap, desired_resolution=end_size) 180 | 181 | ''' 182 | # can also use tcnn for multi-res grid as it now supports eikonal loss 183 | base_size = 16 184 | hash = True 185 | smoothstep = True 186 | self.encoding = tcnn.Encoding(3, { 187 | "otype": "HashGrid" if hash else "DenseGrid", 188 | "n_levels": 16, 189 | "n_features_per_level": 2, 190 | "log2_hashmap_size": 19, 191 | "base_resolution": base_size, 192 | "per_level_scale": 1.34, 193 | "interpolation": "Smoothstep" if smoothstep else "Linear" 194 | }) 195 | ''' 196 | 197 | if multires > 0: 198 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 199 | self.embed_fn = embed_fn 200 | dims[0] += input_ch - 3 201 | print("network architecture") 202 | print(dims) 203 | 204 | self.num_layers = len(dims) 205 | self.skip_in = skip_in 206 | 207 | for l in range(0, self.num_layers - 1): 208 | if l + 1 in self.skip_in: 209 | out_dim = dims[l + 1] - dims[0] 210 | else: 211 | out_dim = dims[l + 1] 212 | 213 | lin = nn.Linear(dims[l], out_dim) 214 | 215 | if geometric_init: 216 | if l == self.num_layers - 2: 217 | if not inside_outside: 218 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 219 | torch.nn.init.constant_(lin.bias, -bias) 220 | else: 221 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 222 | torch.nn.init.constant_(lin.bias, bias) 223 | 224 | elif multires > 0 and l == 0: 225 | torch.nn.init.constant_(lin.bias, 0.0) 226 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 227 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 228 | elif multires > 0 and l in self.skip_in: 229 | torch.nn.init.constant_(lin.bias, 0.0) 230 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 231 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 232 | else: 233 | torch.nn.init.constant_(lin.bias, 0.0) 234 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 235 | 236 | if weight_norm: 237 | lin = nn.utils.weight_norm(lin) 238 | 239 | setattr(self, "lin" + str(l), lin) 240 | 241 | self.softplus = nn.Softplus(beta=100) 242 | self.cache_sdf = None 243 | 244 | def forward(self, input): 245 | if self.use_grid_feature: 246 | # normalize point range as encoding assume points are in [-1, 1] 247 | feature = self.encoding(input / self.divide_factor) 248 | else: 249 | feature = torch.zeros_like(input[:, :1].repeat(1, self.grid_feature_dim)) 250 | 251 | if self.embed_fn is not None: 252 | embed = self.embed_fn(input) 253 | input = torch.cat((embed, feature), dim=-1) 254 | else: 255 | input = torch.cat((input, feature), dim=-1) 256 | 257 | x = input 258 | 259 | for l in range(0, self.num_layers - 1): 260 | lin = getattr(self, "lin" + str(l)) 261 | 262 | if l in self.skip_in: 263 | x = torch.cat([x, input], 1) / np.sqrt(2) 264 | 265 | x = lin(x) 266 | 267 | if l < self.num_layers - 2: 268 | x = self.softplus(x) 269 | 270 | return x 271 | 272 | def gradient(self, x): 273 | x.requires_grad_(True) 274 | y = self.forward(x)[:,:1] 275 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 276 | gradients = torch.autograd.grad( 277 | outputs=y, 278 | inputs=x, 279 | grad_outputs=d_output, 280 | create_graph=True, 281 | retain_graph=True, 282 | only_inputs=True)[0] 283 | return gradients 284 | 285 | def get_outputs(self, x): 286 | x.requires_grad_(True) 287 | output = self.forward(x) 288 | sdf = output[:,:1] 289 | 290 | feature_vectors = output[:, 1:] 291 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 292 | gradients = torch.autograd.grad( 293 | outputs=sdf, 294 | inputs=x, 295 | grad_outputs=d_output, 296 | create_graph=True, 297 | retain_graph=True, 298 | only_inputs=True)[0] 299 | 300 | return sdf, feature_vectors, gradients 301 | 302 | def get_sdf_vals(self, x): 303 | sdf = self.forward(x)[:,:1] 304 | return sdf 305 | 306 | def mlp_parameters(self): 307 | parameters = [] 308 | for l in range(0, self.num_layers - 1): 309 | lin = getattr(self, "lin" + str(l)) 310 | parameters += list(lin.parameters()) 311 | return parameters 312 | 313 | def grid_parameters(self): 314 | print("grid parameters", len(list(self.encoding.parameters()))) 315 | for p in self.encoding.parameters(): 316 | print(p.shape) 317 | return self.encoding.parameters() 318 | 319 | 320 | class RenderingNetwork(nn.Module): 321 | def __init__( 322 | self, 323 | feature_vector_size, 324 | mode, 325 | d_in, 326 | d_out, 327 | dims, 328 | weight_norm=True, 329 | multires_view=0, 330 | per_image_code = False 331 | ): 332 | super().__init__() 333 | 334 | self.mode = mode 335 | dims = [d_in + feature_vector_size] + dims + [d_out] 336 | 337 | self.embedview_fn = None 338 | if multires_view > 0: 339 | embedview_fn, input_ch = get_embedder(multires_view) 340 | self.embedview_fn = embedview_fn 341 | dims[0] += (input_ch - 3) 342 | 343 | self.per_image_code = per_image_code 344 | if self.per_image_code: 345 | # nerf in the wild parameter 346 | # parameters 347 | # maximum 1024 images 348 | self.embeddings = nn.Parameter(torch.empty(1024, 32)) 349 | std = 1e-4 350 | self.embeddings.data.uniform_(-std, std) 351 | dims[0] += 32 352 | 353 | print("rendering network architecture:") 354 | print(dims) 355 | 356 | self.num_layers = len(dims) 357 | 358 | for l in range(0, self.num_layers - 1): 359 | out_dim = dims[l + 1] 360 | lin = nn.Linear(dims[l], out_dim) 361 | 362 | if weight_norm: 363 | lin = nn.utils.weight_norm(lin) 364 | 365 | setattr(self, "lin" + str(l), lin) 366 | 367 | self.relu = nn.ReLU() 368 | self.sigmoid = torch.nn.Sigmoid() 369 | 370 | def forward(self, points, normals, view_dirs, feature_vectors, indices): 371 | if self.embedview_fn is not None: 372 | view_dirs = self.embedview_fn(view_dirs) 373 | 374 | if self.mode == 'idr': 375 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 376 | elif self.mode == 'nerf': 377 | rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1) 378 | else: 379 | raise NotImplementedError 380 | 381 | if self.per_image_code: 382 | image_code = self.embeddings[indices].expand(rendering_input.shape[0], -1) 383 | rendering_input = torch.cat([rendering_input, image_code], dim=-1) 384 | 385 | x = rendering_input 386 | 387 | for l in range(0, self.num_layers - 1): 388 | lin = getattr(self, "lin" + str(l)) 389 | 390 | x = lin(x) 391 | 392 | if l < self.num_layers - 2: 393 | x = self.relu(x) 394 | 395 | x = self.sigmoid(x) 396 | return x 397 | 398 | 399 | class MonoSDFNetwork(nn.Module): 400 | def __init__(self, conf): 401 | super().__init__() 402 | self.feature_vector_size = conf.get_int('feature_vector_size') 403 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0) 404 | self.white_bkgd = conf.get_bool('white_bkgd', default=False) 405 | self.bg_color = torch.tensor(conf.get_list("bg_color", default=[1.0, 1.0, 1.0])).float().cuda() 406 | 407 | Grid_MLP = conf.get_bool('Grid_MLP', default=False) 408 | self.Grid_MLP = Grid_MLP 409 | if Grid_MLP: 410 | self.implicit_network = ImplicitNetworkGrid(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network')) 411 | else: 412 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network')) 413 | 414 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network')) 415 | 416 | self.density = LaplaceDensity(**conf.get_config('density')) 417 | sampling_method = conf.get_string('sampling_method', default="errorbounded") 418 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, **conf.get_config('ray_sampler')) 419 | 420 | 421 | def forward(self, input, indices): 422 | # Parse model input 423 | intrinsics = input["intrinsics"] 424 | uv = input["uv"] 425 | pose = input["pose"] 426 | 427 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) 428 | 429 | # we should use unnormalized ray direction for depth 430 | ray_dirs_tmp, _ = rend_util.get_camera_params(uv, torch.eye(4).to(pose.device)[None], intrinsics) 431 | depth_scale = ray_dirs_tmp[0, :, 2:] 432 | 433 | batch_size, num_pixels, _ = ray_dirs.shape 434 | 435 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) 436 | ray_dirs = ray_dirs.reshape(-1, 3) 437 | 438 | 439 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self) 440 | N_samples = z_vals.shape[1] 441 | 442 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) 443 | points_flat = points.reshape(-1, 3) 444 | 445 | 446 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) 447 | dirs_flat = dirs.reshape(-1, 3) 448 | 449 | sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat) 450 | 451 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors, indices) 452 | rgb = rgb_flat.reshape(-1, N_samples, 3) 453 | 454 | weights = self.volume_rendering(z_vals, sdf) 455 | 456 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1) 457 | 458 | depth_values = torch.sum(weights * z_vals, 1, keepdims=True) / (weights.sum(dim=1, keepdims=True) +1e-8) 459 | # we should scale rendered distance to depth along z direction 460 | depth_values = depth_scale * depth_values 461 | 462 | # white background assumption 463 | if self.white_bkgd: 464 | acc_map = torch.sum(weights, -1) 465 | rgb_values = rgb_values + (1. - acc_map[..., None]) * self.bg_color.unsqueeze(0) 466 | 467 | output = { 468 | 'rgb':rgb, 469 | 'rgb_values': rgb_values, 470 | 'depth_values': depth_values, 471 | 'z_vals': z_vals, 472 | 'depth_vals': z_vals * depth_scale, 473 | 'sdf': sdf.reshape(z_vals.shape), 474 | 'weights': weights, 475 | } 476 | 477 | if self.training: 478 | # Sample points for the eikonal loss 479 | n_eik_points = batch_size * num_pixels 480 | 481 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda() 482 | 483 | # add some of the near surface points 484 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3) 485 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0) 486 | # add some neighbour points as unisurf 487 | neighbour_points = eikonal_points + (torch.rand_like(eikonal_points) - 0.5) * 0.01 488 | eikonal_points = torch.cat([eikonal_points, neighbour_points], 0) 489 | 490 | grad_theta = self.implicit_network.gradient(eikonal_points) 491 | 492 | # split gradient to eikonal points and heighbour ponits 493 | output['grad_theta'] = grad_theta[:grad_theta.shape[0]//2] 494 | output['grad_theta_nei'] = grad_theta[grad_theta.shape[0]//2:] 495 | 496 | # compute normal map 497 | normals = gradients / (gradients.norm(2, -1, keepdim=True) + 1e-6) 498 | normals = normals.reshape(-1, N_samples, 3) 499 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1) 500 | 501 | # transform to local coordinate system 502 | rot = pose[0, :3, :3].permute(1, 0).contiguous() 503 | normal_map = rot @ normal_map.permute(1, 0) 504 | normal_map = normal_map.permute(1, 0).contiguous() 505 | 506 | output['normal_map'] = normal_map 507 | 508 | return output 509 | 510 | def volume_rendering(self, z_vals, sdf): 511 | density_flat = self.density(sdf) 512 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples 513 | 514 | dists = z_vals[:, 1:] - z_vals[:, :-1] 515 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) 516 | 517 | # LOG SPACE 518 | free_energy = dists * density 519 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) # shift one step 520 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here 521 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now 522 | weights = alpha * transmittance # probability of the ray hits something here 523 | 524 | return weights 525 | -------------------------------------------------------------------------------- /code/model/ray_sampler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from tkinter.messagebox import NO 3 | import torch 4 | 5 | from utils import rend_util 6 | 7 | class RaySampler(metaclass=abc.ABCMeta): 8 | def __init__(self,near, far): 9 | self.near = near 10 | self.far = far 11 | 12 | @abc.abstractmethod 13 | def get_z_vals(self, ray_dirs, cam_loc, model): 14 | pass 15 | 16 | class UniformSampler(RaySampler): 17 | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1): 18 | #super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R 19 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75 if far == -1 else far) # default far is 2*R 20 | self.N_samples = N_samples 21 | self.scene_bounding_sphere = scene_bounding_sphere 22 | self.take_sphere_intersection = take_sphere_intersection 23 | 24 | # dtu and bmvs 25 | def get_z_vals_dtu_bmvs(self, ray_dirs, cam_loc, model): 26 | if not self.take_sphere_intersection: 27 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() 28 | else: 29 | sphere_intersections = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere) 30 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() 31 | far = sphere_intersections[:,1:] 32 | 33 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() 34 | z_vals = near * (1. - t_vals) + far * (t_vals) 35 | 36 | if model.training: 37 | # get intervals between samples 38 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 39 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 40 | lower = torch.cat([z_vals[..., :1], mids], -1) 41 | # stratified samples in those intervals 42 | t_rand = torch.rand(z_vals.shape).cuda() 43 | 44 | z_vals = lower + (upper - lower) * t_rand 45 | 46 | return z_vals, near, far 47 | 48 | def near_far_from_cube(self, rays_o, rays_d, bound): 49 | tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] 50 | tmax = (bound - rays_o) / (rays_d + 1e-15) 51 | near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] 52 | far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] 53 | # if far < near, means no intersection, set both near and far to inf (1e9 here) 54 | mask = far < near 55 | near[mask] = 1e9 56 | far[mask] = 1e9 57 | # restrict near to a minimal value 58 | near = torch.clamp(near, min=self.near) 59 | far = torch.clamp(far, max=self.far) 60 | return near, far 61 | 62 | # currently this is used for replica scannet and T&T 63 | def get_z_vals(self, ray_dirs, cam_loc, model): 64 | if not self.take_sphere_intersection: 65 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() 66 | else: 67 | _, far = self.near_far_from_cube(cam_loc, ray_dirs, bound=self.scene_bounding_sphere) 68 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() 69 | 70 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() 71 | z_vals = near * (1. - t_vals) + far * (t_vals) 72 | 73 | if model.training: 74 | # get intervals between samples 75 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 76 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 77 | lower = torch.cat([z_vals[..., :1], mids], -1) 78 | # stratified samples in those intervals 79 | t_rand = torch.rand(z_vals.shape).cuda() 80 | 81 | z_vals = lower + (upper - lower) * t_rand 82 | 83 | return z_vals, near, far 84 | 85 | 86 | class ErrorBoundSampler(RaySampler): 87 | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra, 88 | eps, beta_iters, max_total_iters, 89 | inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=1.0e-6): 90 | #super().__init__(near, 2.0 * scene_bounding_sphere) 91 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75) 92 | 93 | self.N_samples = N_samples 94 | self.N_samples_eval = N_samples_eval 95 | self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=True) # replica scannet and T&T courtroom 96 | #self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) # dtu and bmvs 97 | 98 | self.N_samples_extra = N_samples_extra 99 | 100 | self.eps = eps 101 | self.beta_iters = beta_iters 102 | self.max_total_iters = max_total_iters 103 | self.scene_bounding_sphere = scene_bounding_sphere 104 | self.add_tiny = add_tiny 105 | 106 | self.inverse_sphere_bg = inverse_sphere_bg 107 | if inverse_sphere_bg: 108 | self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0) 109 | 110 | def get_z_vals(self, ray_dirs, cam_loc, model): 111 | with torch.no_grad(): 112 | beta0 = model.density.get_beta().detach() 113 | 114 | # Start with uniform sampling 115 | z_vals, near, far = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model) 116 | samples, samples_idx = z_vals, None 117 | 118 | # Get maximum beta from the upper bound (Lemma 2) 119 | dists = z_vals[:, 1:] - z_vals[:, :-1] 120 | bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1) 121 | beta = torch.sqrt(bound) 122 | 123 | total_iters, not_converge = 0, True 124 | 125 | # Algorithm 1 126 | while not_converge and total_iters < self.max_total_iters: 127 | points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1) 128 | points_flat = points.reshape(-1, 3) 129 | 130 | # Calculating the SDF only for the new sampled points 131 | # with torch.no_grad(): 132 | samples_sdf = model.implicit_network.get_sdf_vals(points_flat) 133 | if samples_idx is not None: 134 | sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]), 135 | samples_sdf.reshape(-1, samples.shape[1])], -1) 136 | sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1) 137 | else: 138 | sdf = samples_sdf 139 | 140 | 141 | # Calculating the bound d* (Theorem 1) 142 | d = sdf.reshape(z_vals.shape) 143 | dists = z_vals[:, 1:] - z_vals[:, :-1] 144 | a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs() 145 | first_cond = a.pow(2) + b.pow(2) <= c.pow(2) 146 | second_cond = a.pow(2) + c.pow(2) <= b.pow(2) 147 | d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda() 148 | d_star[first_cond] = b[first_cond] 149 | d_star[second_cond] = c[second_cond] 150 | s = (a + b + c) / 2.0 151 | area_before_sqrt = s * (s - a) * (s - b) * (s - c) 152 | mask = ~first_cond & ~second_cond & (b + c - a > 0) 153 | d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask]) 154 | d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign 155 | 156 | 157 | # Updating beta using line search 158 | curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star) 159 | beta[curr_error <= self.eps] = beta0 160 | beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta 161 | for j in range(self.beta_iters): 162 | beta_mid = (beta_min + beta_max) / 2. 163 | curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star) 164 | beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] 165 | beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] 166 | beta = beta_max 167 | 168 | # Upsample more points 169 | density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1)) 170 | 171 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) 172 | free_energy = dists * density 173 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) 174 | alpha = 1 - torch.exp(-free_energy) 175 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) 176 | weights = alpha * transmittance # probability of the ray hits something here 177 | 178 | # Check if we are done and this is the last sampling 179 | total_iters += 1 180 | not_converge = beta.max() > beta0 181 | 182 | if not_converge and total_iters < self.max_total_iters: 183 | ''' Sample more points proportional to the current error bound''' 184 | 185 | N = self.N_samples_eval 186 | 187 | bins = z_vals 188 | error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2) 189 | error_integral = torch.cumsum(error_per_section, dim=-1) 190 | bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1] 191 | 192 | pdf = bound_opacity + self.add_tiny 193 | pdf = pdf / torch.sum(pdf, -1, keepdim=True) 194 | cdf = torch.cumsum(pdf, -1) 195 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 196 | 197 | else: 198 | ''' Sample the final sample set to be used in the volume rendering integral ''' 199 | 200 | N = self.N_samples 201 | 202 | bins = z_vals 203 | pdf = weights[..., :-1] 204 | pdf = pdf + 1e-5 # prevent nans 205 | pdf = pdf / torch.sum(pdf, -1, keepdim=True) 206 | cdf = torch.cumsum(pdf, -1) 207 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 208 | 209 | 210 | # Invert CDF 211 | if (not_converge and total_iters < self.max_total_iters) or (not model.training): 212 | u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1) 213 | else: 214 | u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda() 215 | u = u.contiguous() 216 | 217 | inds = torch.searchsorted(cdf, u, right=True) 218 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 219 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 220 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 221 | 222 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 223 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 224 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 225 | 226 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 227 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 228 | t = (u - cdf_g[..., 0]) / denom 229 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 230 | 231 | 232 | # Adding samples if we not converged 233 | if not_converge and total_iters < self.max_total_iters: 234 | z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1) 235 | 236 | 237 | z_samples = samples 238 | #TODO Use near and far from intersection 239 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda() 240 | if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection 241 | far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:] 242 | 243 | if self.N_samples_extra > 0: 244 | if model.training: 245 | sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra] 246 | else: 247 | sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long() 248 | z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1) 249 | else: 250 | z_vals_extra = torch.cat([near, far], -1) 251 | 252 | z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1) 253 | 254 | # add some of the near surface points 255 | idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda() 256 | z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1)) 257 | 258 | if self.inverse_sphere_bg: 259 | z_vals_inverse_sphere, _, _ = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model) 260 | z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere) 261 | z_vals = (z_vals, z_vals_inverse_sphere) 262 | 263 | return z_vals, z_samples_eik 264 | 265 | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star): 266 | density = model.density(sdf.reshape(z_vals.shape), beta=beta) 267 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1) 268 | integral_estimation = torch.cumsum(shifted_free_energy, dim=-1) 269 | error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2) 270 | error_integral = torch.cumsum(error_per_section, dim=-1) 271 | bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1]) 272 | 273 | return bound_opacity.max(-1)[0] -------------------------------------------------------------------------------- /code/slurm_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | 4 | PARTITION=$1 5 | NUM_NODES=1 6 | NUM_GPUS_PER_NODE=1 7 | CFG_PATH=$2 8 | SCAN_ID=$3 9 | PORT=$4 10 | 11 | srun -p ${PARTITION} \ 12 | -N ${NUM_NODES} \ 13 | --gres=gpu:${NUM_GPUS_PER_NODE} \ 14 | --cpus-per-task=4 \ 15 | -t 5-00:00:00 \ 16 | python training/exp_runner.py --conf $CFG_PATH --scan_id $SCAN_ID --port $PORT -------------------------------------------------------------------------------- /code/training/exp_runner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../code') 4 | import argparse 5 | import torch 6 | import random 7 | import numpy as np 8 | 9 | import os 10 | import subprocess 11 | import datetime 12 | from training.rico_train import RICOTrainRunner 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 19 | parser.add_argument('--nepoch', type=int, default=2000, help='number of epochs to train for') 20 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf') 21 | parser.add_argument('--expname', type=str, default='') 22 | parser.add_argument("--exps_folder", type=str, default="exps") 23 | #parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') 24 | parser.add_argument('--is_continue', default=False, action="store_true", 25 | help='If set, indicates continuing from a previous run.') 26 | parser.add_argument('--timestamp', default='latest', type=str, 27 | help='The timestamp of the run to be used in case of continuing from a previous run.') 28 | parser.add_argument('--checkpoint', default='latest', type=str, 29 | help='The checkpoint epoch of the run to be used in case of continuing from a previous run.') 30 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') 31 | parser.add_argument('--cancel_vis', default=False, action="store_true", 32 | help='If set, cancel visualization in intermediate epochs.') 33 | parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel') 34 | parser.add_argument('--port', type=int, default=29500) 35 | 36 | opt = parser.parse_args() 37 | 38 | ''' 39 | if opt.gpu == "auto": 40 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, 41 | excludeID=[], excludeUUID=[]) 42 | gpu = deviceIDs[0] 43 | else: 44 | gpu = opt.gpu 45 | ''' 46 | gpu = opt.local_rank 47 | 48 | random.seed(0) 49 | np.random.seed(0) 50 | torch.manual_seed(0) 51 | torch.cuda.manual_seed_all(0) 52 | # torch.backends.cudnn.deterministic = True 53 | # torch.backends.cudnn.benchmark = False 54 | torch.backends.cudnn.benchmark = True 55 | 56 | # set distributed training 57 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 58 | rank = int(os.environ["RANK"]) 59 | world_size = int(os.environ['WORLD_SIZE']) 60 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 61 | elif 'SLURM_PROCID' in os.environ: 62 | proc_id = int(os.environ['SLURM_PROCID']) 63 | ntasks = int(os.environ['SLURM_NTASKS']) 64 | node_list = os.environ['SLURM_NODELIST'] 65 | num_gpus = torch.cuda.device_count() 66 | addr = subprocess.getoutput( 67 | 'scontrol show hostname {} | head -n1'.format(node_list) 68 | ) 69 | port_str = str(opt.port) 70 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', port_str) 71 | os.environ['MASTER_ADDR'] = addr 72 | os.environ['WORLD_SIZE'] = str(ntasks) 73 | os.environ['RANK'] = str(proc_id) 74 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 75 | os.environ['LOCAL_SIZE'] = str(num_gpus) 76 | rank = proc_id 77 | world_size = ntasks 78 | print(f"RANK and WORLD_SIZE in SLURM environ: {rank}/{world_size}") 79 | else: 80 | rank = -1 81 | world_size = -1 82 | 83 | print(opt.local_rank) 84 | torch.cuda.set_device(opt.local_rank) 85 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank, timeout=datetime.timedelta(1, 1800)) 86 | torch.distributed.barrier() 87 | 88 | torch.autograd.set_detect_anomaly(True) 89 | 90 | trainrunner = RICOTrainRunner( 91 | conf=opt.conf, 92 | batch_size=opt.batch_size, 93 | nepochs=opt.nepoch, 94 | expname=opt.expname, 95 | gpu_index=gpu, 96 | exps_folder_name=opt.exps_folder, 97 | is_continue=opt.is_continue, 98 | timestamp=opt.timestamp, 99 | checkpoint=opt.checkpoint, 100 | scan_id=opt.scan_id, 101 | do_vis=not opt.cancel_vis 102 | ) 103 | 104 | trainrunner.run() 105 | -------------------------------------------------------------------------------- /code/training/rico_train.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | from datetime import datetime 4 | from pyhocon import ConfigFactory 5 | import sys 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import utils.general as utils 11 | import utils.plots as plt 12 | from utils import rend_util 13 | from utils.general import get_time 14 | from torch.utils.tensorboard import SummaryWriter 15 | from model.loss import compute_scale_and_shift 16 | from utils.general import BackprojectDepth 17 | 18 | import torch.distributed as dist 19 | 20 | class RICOTrainRunner(): 21 | def __init__(self,**kwargs): 22 | torch.set_default_dtype(torch.float32) 23 | torch.set_num_threads(1) 24 | 25 | self.conf = ConfigFactory.parse_file(kwargs['conf']) 26 | self.batch_size = kwargs['batch_size'] 27 | self.nepochs = kwargs['nepochs'] 28 | self.exps_folder_name = kwargs['exps_folder_name'] 29 | self.GPU_INDEX = kwargs['gpu_index'] 30 | 31 | self.expname = self.conf.get_string('train.expname') + kwargs['expname'] 32 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1) 33 | if scan_id != -1: 34 | self.expname = self.expname + '_{0}'.format(scan_id) 35 | 36 | if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': 37 | if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)): 38 | timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname)) 39 | if (len(timestamps)) == 0: 40 | is_continue = False 41 | timestamp = None 42 | else: 43 | timestamp = sorted(timestamps)[-1] 44 | is_continue = True 45 | else: 46 | is_continue = False 47 | timestamp = None 48 | else: 49 | timestamp = kwargs['timestamp'] 50 | is_continue = kwargs['is_continue'] 51 | 52 | if self.GPU_INDEX == 0: 53 | utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name)) 54 | self.expdir = os.path.join('../', self.exps_folder_name, self.expname) 55 | utils.mkdir_ifnotexists(self.expdir) 56 | self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) 57 | utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) 58 | 59 | self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') 60 | utils.mkdir_ifnotexists(self.plots_dir) 61 | 62 | # create checkpoints dirs 63 | self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') 64 | utils.mkdir_ifnotexists(self.checkpoints_path) 65 | self.model_params_subdir = "ModelParameters" 66 | self.optimizer_params_subdir = "OptimizerParameters" 67 | self.scheduler_params_subdir = "SchedulerParameters" 68 | 69 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) 70 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) 71 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir)) 72 | 73 | os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf'))) 74 | 75 | print('shell command : {0}'.format(' '.join(sys.argv))) 76 | 77 | print('Loading data ...') 78 | 79 | dataset_conf = self.conf.get_config('dataset') 80 | if kwargs['scan_id'] != -1: 81 | dataset_conf['scan_id'] = kwargs['scan_id'] 82 | 83 | self.train_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(**dataset_conf) 84 | 85 | self.max_total_iters = self.conf.get_int('train.max_total_iters', default=50000) 86 | self.ds_len = len(self.train_dataset) 87 | print('Finish loading data. Data-set size: {0}'.format(self.ds_len)) 88 | # use total iterations to compute how many epochs 89 | self.nepochs = int(self.max_total_iters / self.ds_len) 90 | print('RUNNING FOR {0}'.format(self.nepochs)) 91 | 92 | if len(self.train_dataset.label_mapping) > 0: 93 | # a hack way to let network know how many categories, so don't need to manually set in config file 94 | self.conf['model']['implicit_network']['d_out'] = len(self.train_dataset.label_mapping) 95 | print('RUNNING FOR {0} CLASSES'.format(len(self.train_dataset.label_mapping))) 96 | 97 | self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, 98 | batch_size=self.batch_size, 99 | shuffle=True, 100 | collate_fn=self.train_dataset.collate_fn, 101 | num_workers=4) 102 | self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset, 103 | batch_size=self.conf.get_int('plot.plot_nimgs'), 104 | shuffle=True, 105 | collate_fn=self.train_dataset.collate_fn 106 | ) 107 | 108 | conf_model = self.conf.get_config('model') 109 | instance_ids = self.train_dataset.instance_ids 110 | print('Instance IDs: ', instance_ids) 111 | print('Label mappings: ', self.train_dataset.label_mapping) 112 | 113 | self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=conf_model) 114 | 115 | if torch.cuda.is_available(): 116 | self.model.cuda() 117 | 118 | self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss')) 119 | 120 | self.lr = self.conf.get_float('train.learning_rate') 121 | 122 | # current model uses MLP and a unified lr 123 | print('using optimizer w unified lr') 124 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.99), eps=1e-15) 125 | 126 | # Exponential learning rate scheduler 127 | decay_rate = self.conf.get_float('train.sched_decay_rate', default=0.1) 128 | decay_steps = self.nepochs * len(self.train_dataset) 129 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, decay_rate ** (1./decay_steps)) 130 | 131 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.GPU_INDEX], broadcast_buffers=False, find_unused_parameters=True) 132 | 133 | self.do_vis = kwargs['do_vis'] 134 | 135 | self.start_epoch = 0 136 | if is_continue: 137 | old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') 138 | 139 | saved_model_state = torch.load( 140 | os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) 141 | self.model.load_state_dict(saved_model_state["model_state_dict"]) 142 | self.start_epoch = saved_model_state['epoch'] 143 | 144 | data = torch.load( 145 | os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) 146 | self.optimizer.load_state_dict(data["optimizer_state_dict"]) 147 | 148 | data = torch.load( 149 | os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth")) 150 | self.scheduler.load_state_dict(data["scheduler_state_dict"]) 151 | 152 | self.num_pixels = self.conf.get_int('train.num_pixels') 153 | self.total_pixels = self.train_dataset.total_pixels 154 | self.img_res = self.train_dataset.img_res 155 | self.n_batches = len(self.train_dataloader) 156 | self.plot_freq = self.conf.get_int('train.plot_freq') 157 | self.checkpoint_freq = self.conf.get_int('train.checkpoint_freq', default=100) 158 | self.split_n_pixels = self.conf.get_int('train.split_n_pixels', default=10000) 159 | self.plot_conf = self.conf.get_config('plot') 160 | self.backproject = BackprojectDepth(1, self.img_res[0], self.img_res[1]).cuda() 161 | self.n_sem = self.conf.get_int('model.implicit_network.d_out') 162 | assert self.n_sem == len(self.train_dataset.label_mapping) 163 | 164 | def save_checkpoints(self, epoch): 165 | torch.save( 166 | {"epoch": epoch, "model_state_dict": self.model.state_dict()}, 167 | os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) 168 | torch.save( 169 | {"epoch": epoch, "model_state_dict": self.model.state_dict()}, 170 | os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) 171 | 172 | torch.save( 173 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, 174 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) 175 | torch.save( 176 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, 177 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) 178 | 179 | torch.save( 180 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, 181 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth")) 182 | torch.save( 183 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, 184 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth")) 185 | 186 | def run(self): 187 | print("training...") 188 | if self.GPU_INDEX == 0 : 189 | self.writer = SummaryWriter(log_dir=os.path.join(self.plots_dir, 'logs')) 190 | 191 | self.iter_step = 0 192 | for epoch in range(self.start_epoch, self.nepochs + 1): 193 | 194 | if (self.GPU_INDEX == 0 and epoch % self.checkpoint_freq == 0) or (self.GPU_INDEX == 0 and epoch == self.nepochs): 195 | self.save_checkpoints(epoch) 196 | 197 | if (self.GPU_INDEX == 0 and self.do_vis and epoch % self.plot_freq == 0) or (self.GPU_INDEX == 0 and self.do_vis and epoch == self.nepochs): 198 | self.model.eval() 199 | 200 | self.train_dataset.change_sampling_idx(-1) 201 | 202 | indices, model_input, ground_truth = next(iter(self.plot_dataloader)) 203 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 204 | model_input["uv"] = model_input["uv"].cuda() 205 | model_input['pose'] = model_input['pose'].cuda() 206 | 207 | split = utils.split_input(model_input, self.total_pixels, n_pixels=self.split_n_pixels) 208 | res = [] 209 | for s in tqdm(split): 210 | out = self.model(s, indices) 211 | d = {'rgb_values': out['rgb_values'].detach(), 212 | 'normal_map': out['normal_map'].detach(), 213 | 'depth_values': out['depth_values'].detach(), 214 | 'semantic_values': out['semantic_values'].detach()} 215 | res.append(d) 216 | 217 | batch_size = ground_truth['rgb'].shape[0] 218 | model_outputs = utils.merge_output(res, self.total_pixels, batch_size) 219 | plot_data = self.get_plot_data(model_input, model_outputs, model_input['pose'], ground_truth['rgb'], ground_truth['normal'], ground_truth['depth'], ground_truth['instance_mask']) 220 | 221 | plot_mesh = True 222 | 223 | plt.plot_rico( 224 | self.model.module.implicit_network, 225 | indices, 226 | plot_data, 227 | self.plots_dir, 228 | epoch, 229 | self.img_res, 230 | plot_mesh, 231 | **self.plot_conf 232 | ) 233 | 234 | self.model.train() 235 | self.train_dataset.change_sampling_idx(self.num_pixels) 236 | 237 | for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader): 238 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 239 | model_input["uv"] = model_input["uv"].cuda() 240 | model_input['pose'] = model_input['pose'].cuda() 241 | 242 | model_input['instance_mask'] = ground_truth["instance_mask"].cuda().reshape(-1).long() 243 | 244 | self.optimizer.zero_grad() 245 | 246 | model_outputs = self.model(model_input, indices, iter_step=self.iter_step) 247 | 248 | loss_output = self.loss(model_outputs, ground_truth, iter_ratio=self.iter_step / self.max_total_iters) 249 | loss = loss_output['loss'] 250 | loss.backward() 251 | self.optimizer.step() 252 | 253 | psnr = rend_util.get_psnr(model_outputs['rgb_values'], 254 | ground_truth['rgb'].cuda().reshape(-1,3)) 255 | 256 | self.iter_step += 1 257 | 258 | if self.GPU_INDEX == 0: 259 | if data_index % 25 == 0: 260 | head_str = '{0}_{1} [{2}] ({3}/{4}): '.format(self.expname, self.timestamp, epoch, data_index, self.n_batches) 261 | loss_print_str = '' 262 | for k, v in loss_output.items(): 263 | loss_print_str = loss_print_str + '{0} = {1}, '.format(k, v.item()) 264 | print_str = head_str + loss_print_str + 'psnr = {0}'.format(psnr.item()) 265 | print(print_str) 266 | 267 | for k, v in loss_output.items(): 268 | self.writer.add_scalar(f'Loss/{k}', v.item(), self.iter_step) 269 | 270 | self.writer.add_scalar('Statistics/s_value', self.model.module.get_s_value().item(), self.iter_step) 271 | self.writer.add_scalar('Statistics/psnr', psnr.item(), self.iter_step) 272 | 273 | self.train_dataset.change_sampling_idx(self.num_pixels) 274 | self.scheduler.step() 275 | 276 | self.save_checkpoints(epoch) 277 | 278 | 279 | def get_plot_data(self, model_input, model_outputs, pose, rgb_gt, normal_gt, depth_gt, semantic_gt): 280 | batch_size, num_samples, _ = rgb_gt.shape 281 | 282 | rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3) 283 | normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3) 284 | normal_map = (normal_map + 1.) / 2. 285 | 286 | depth_map = model_outputs['depth_values'].reshape(batch_size, num_samples) 287 | depth_gt = depth_gt.to(depth_map.device) 288 | scale, shift = compute_scale_and_shift(depth_map[..., None], depth_gt, depth_gt > 0.) 289 | depth_map = depth_map * scale + shift 290 | 291 | # save point cloud 292 | depth = depth_map.reshape(1, 1, self.img_res[0], self.img_res[1]) 293 | pred_points = self.get_point_cloud(depth, model_input, model_outputs) 294 | 295 | gt_depth = depth_gt.reshape(1, 1, self.img_res[0], self.img_res[1]) 296 | gt_points = self.get_point_cloud(gt_depth, model_input, model_outputs) 297 | 298 | # semantic map 299 | semantic_map = model_outputs['semantic_values'].argmax(dim=-1).reshape(batch_size, num_samples, 1) 300 | # in label mapping, 0 is bg idx and 0 301 | # for instance, first fg is 3 and 1 302 | # so when using argmax, the output will be label_mapping idx if correct 303 | 304 | plot_data = { 305 | 'rgb_gt': rgb_gt, 306 | 'normal_gt': (normal_gt + 1.)/ 2., 307 | 'depth_gt': depth_gt, 308 | 'pose': pose, 309 | 'rgb_eval': rgb_eval, 310 | 'normal_map': normal_map, 311 | 'depth_map': depth_map, 312 | "pred_points": pred_points, 313 | "gt_points": gt_points, 314 | "semantic_map": semantic_map, 315 | "semantic_gt": semantic_gt, 316 | } 317 | 318 | return plot_data 319 | 320 | def get_point_cloud(self, depth, model_input, model_outputs): 321 | color = model_outputs["rgb_values"].reshape(-1, 3) 322 | 323 | K_inv = torch.inverse(model_input["intrinsics"][0])[None] 324 | points = self.backproject(depth, K_inv)[0, :3, :].permute(1, 0) 325 | points = torch.cat([points, color], dim=-1) 326 | return points.detach().cpu().numpy() 327 | -------------------------------------------------------------------------------- /code/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import time 7 | from torchvision import transforms 8 | import numpy as np 9 | 10 | def mkdir_ifnotexists(directory): 11 | if not os.path.exists(directory): 12 | os.mkdir(directory) 13 | 14 | def get_class(kls): 15 | parts = kls.split('.') 16 | module = ".".join(parts[:-1]) 17 | m = __import__(module) 18 | for comp in parts[1:]: 19 | m = getattr(m, comp) 20 | return m 21 | 22 | def glob_imgs(path): 23 | imgs = [] 24 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']: 25 | imgs.extend(glob(os.path.join(path, ext))) 26 | return imgs 27 | 28 | def split_input(model_input, total_pixels, n_pixels=10000): 29 | ''' 30 | Split the input to fit Cuda memory for large resolution. 31 | Can decrease the value of n_pixels in case of cuda out of memory error. 32 | ''' 33 | split = [] 34 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)): 35 | data = model_input.copy() 36 | data['uv'] = torch.index_select(model_input['uv'], 1, indx) 37 | if 'object_mask' in data: 38 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) 39 | if 'depth' in data: 40 | data['depth'] = torch.index_select(model_input['depth'], 1, indx) 41 | if 'instance_mask' in data: 42 | data['instance_mask'] = torch.index_select(model_input['instance_mask'], 1, indx) 43 | split.append(data) 44 | return split 45 | 46 | def merge_output(res, total_pixels, batch_size): 47 | ''' Merge the split output. ''' 48 | 49 | model_outputs = {} 50 | for entry in res[0]: 51 | if res[0][entry] is None: 52 | continue 53 | if len(res[0][entry].shape) == 1: 54 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res], 55 | 1).reshape(batch_size * total_pixels) 56 | else: 57 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res], 58 | 1).reshape(batch_size * total_pixels, -1) 59 | 60 | return model_outputs 61 | 62 | def concat_home_dir(path): 63 | return os.path.join(os.environ['HOME'],'data',path) 64 | 65 | def get_time(): 66 | torch.cuda.synchronize() 67 | return time.time() 68 | 69 | trans_topil = transforms.ToPILImage() 70 | 71 | 72 | class BackprojectDepth(nn.Module): 73 | """Layer to transform a depth image into a point cloud 74 | """ 75 | def __init__(self, batch_size, height, width): 76 | super(BackprojectDepth, self).__init__() 77 | 78 | self.batch_size = batch_size 79 | self.height = height 80 | self.width = width 81 | 82 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 83 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 84 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 85 | requires_grad=False) 86 | 87 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 88 | requires_grad=False) 89 | 90 | self.pix_coords = torch.unsqueeze(torch.stack( 91 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 92 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 93 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 94 | requires_grad=False) 95 | 96 | def forward(self, depth, inv_K): 97 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 98 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 99 | cam_points = torch.cat([cam_points, self.ones], 1) 100 | return cam_points 101 | -------------------------------------------------------------------------------- /code/utils/plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import measure 4 | import torchvision 5 | import trimesh 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | import cv2 9 | 10 | from utils import rend_util 11 | from utils.general import trans_topil 12 | 13 | 14 | def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_mesh, plot_nimgs, resolution, grid_boundary, level=0): 15 | 16 | if plot_data is not None: 17 | cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose']) 18 | 19 | # plot images 20 | plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, indices) 21 | 22 | # plot normal maps 23 | plot_normal_maps(plot_data['normal_map'], plot_data['normal_gt'], path, epoch, plot_nimgs, img_res, indices) 24 | 25 | # plot depth maps 26 | plot_depth_maps(plot_data['depth_map'], plot_data['depth_gt'], path, epoch, plot_nimgs, img_res, indices) 27 | 28 | # concat output images to single large image 29 | images = [] 30 | for name in ["rendering", "depth", "normal"]: 31 | images.append(cv2.imread('{0}/{1}_{2}_{3}.png'.format(path, name, epoch, indices[0]))) 32 | 33 | images = np.concatenate(images, axis=1) 34 | cv2.imwrite('{0}/merge_{1}_{2}.png'.format(path, epoch, indices[0]), images) 35 | 36 | if plot_mesh: 37 | surface_traces = get_surface_sliding(send_path=path, 38 | epoch=epoch, 39 | sdf=lambda x: implicit_network(x)[:, 0], 40 | resolution=resolution, 41 | grid_boundary=grid_boundary, 42 | level=level 43 | ) 44 | 45 | 46 | def plot_rico(implicit_network, indices, plot_data, path, epoch, img_res, plot_mesh, plot_nimgs, resolution, grid_boundary, level=0): 47 | 48 | if plot_data is not None: 49 | cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose']) 50 | 51 | # plot images 52 | plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, indices) 53 | 54 | # plot normal maps 55 | plot_normal_maps(plot_data['normal_map'], plot_data['normal_gt'], path, epoch, plot_nimgs, img_res, indices) 56 | 57 | # plot depth maps 58 | plot_depth_maps(plot_data['depth_map'], plot_data['depth_gt'], path, epoch, plot_nimgs, img_res, indices) 59 | 60 | # plot semantic maps 61 | plot_seg_images(plot_data['semantic_map'], plot_data['semantic_gt'], path, epoch, plot_nimgs, img_res, indices) 62 | 63 | # concat output images to single large image 64 | images = [] 65 | for name in ["rendering", "semantic", "depth", "normal"]: 66 | images.append(cv2.imread('{0}/{1}_{2}_{3}.png'.format(path, name, epoch, indices[0]))) 67 | 68 | images = np.concatenate(images, axis=1) 69 | cv2.imwrite('{0}/merge_{1}_{2}.png'.format(path, epoch, indices[0]), images) 70 | 71 | if plot_mesh: 72 | sem_num = implicit_network.d_out 73 | f = torch.nn.MaxPool1d(sem_num) 74 | for indx in range(sem_num): 75 | # plot each object and background and save in different files 76 | _ = get_surface_sliding( 77 | send_path=[path, str(indx)], 78 | epoch=epoch, 79 | sdf = lambda x: implicit_network(x)[:, indx], 80 | resolution=resolution, 81 | grid_boundary=grid_boundary, 82 | level=level 83 | ) 84 | 85 | # plot the overall scene 86 | surface_traces = get_surface_sliding( 87 | send_path=[path, 'all'], 88 | epoch=epoch, 89 | sdf=lambda x: -f(-implicit_network(x)[:, :sem_num].unsqueeze(1)).squeeze(-1).squeeze(-1), 90 | resolution=resolution, 91 | grid_boundary=grid_boundary, 92 | level=level 93 | ) 94 | 95 | 96 | avg_pool_3d = torch.nn.AvgPool3d(2, stride=2) 97 | upsample = torch.nn.Upsample(scale_factor=2, mode='nearest') 98 | 99 | @torch.no_grad() 100 | def get_surface_sliding(send_path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0): 101 | if isinstance(send_path, list): 102 | path = send_path[0] 103 | mesh_name = send_path[1] 104 | else: 105 | path = send_path 106 | mesh_name = '' 107 | 108 | resN = resolution 109 | cropN = resolution 110 | level = 0 111 | N = resN // cropN 112 | 113 | grid_min = [grid_boundary[0], grid_boundary[0], grid_boundary[0]] 114 | grid_max = [grid_boundary[1], grid_boundary[1], grid_boundary[1]] 115 | 116 | xs = np.linspace(grid_min[0], grid_max[0], N+1) 117 | ys = np.linspace(grid_min[1], grid_max[1], N+1) 118 | zs = np.linspace(grid_min[2], grid_max[2], N+1) 119 | 120 | print(xs) 121 | print(ys) 122 | print(zs) 123 | meshes = [] 124 | for i in range(N): 125 | for j in range(N): 126 | for k in range(N): 127 | print(i, j, k) 128 | x_min, x_max = xs[i], xs[i+1] 129 | y_min, y_max = ys[j], ys[j+1] 130 | z_min, z_max = zs[k], zs[k+1] 131 | 132 | x = np.linspace(x_min, x_max, cropN) 133 | y = np.linspace(y_min, y_max, cropN) 134 | z = np.linspace(z_min, z_max, cropN) 135 | 136 | xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') 137 | points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 138 | 139 | def evaluate(points): 140 | z = [] 141 | for _, pnts in enumerate(torch.split(points, 100000, dim=0)): 142 | z.append(sdf(pnts)) 143 | z = torch.cat(z, axis=0) 144 | return z 145 | 146 | # construct point pyramids 147 | points = points.reshape(cropN, cropN, cropN, 3).permute(3, 0, 1, 2) 148 | points_pyramid = [points] 149 | for _ in range(3): 150 | points = avg_pool_3d(points[None])[0] 151 | points_pyramid.append(points) 152 | points_pyramid = points_pyramid[::-1] 153 | 154 | # evalute pyramid with mask 155 | mask = None 156 | threshold = 2 * (x_max - x_min)/cropN * 8 157 | for pid, pts in enumerate(points_pyramid): 158 | coarse_N = pts.shape[-1] 159 | pts = pts.reshape(3, -1).permute(1, 0).contiguous() 160 | 161 | if mask is None: 162 | pts_sdf = evaluate(pts) 163 | else: 164 | mask = mask.reshape(-1) 165 | pts_to_eval = pts[mask] 166 | #import pdb; pdb.set_trace() 167 | if pts_to_eval.shape[0] > 0: 168 | pts_sdf_eval = evaluate(pts_to_eval.contiguous()) 169 | pts_sdf[mask] = pts_sdf_eval 170 | print("ratio", pts_to_eval.shape[0] / pts.shape[0]) 171 | 172 | if pid < 3: 173 | # update mask 174 | mask = torch.abs(pts_sdf) < threshold 175 | mask = mask.reshape(coarse_N, coarse_N, coarse_N)[None, None] 176 | mask = upsample(mask.float()).bool() 177 | 178 | pts_sdf = pts_sdf.reshape(coarse_N, coarse_N, coarse_N)[None, None] 179 | pts_sdf = upsample(pts_sdf) 180 | pts_sdf = pts_sdf.reshape(-1) 181 | 182 | threshold /= 2. 183 | 184 | z = pts_sdf.detach().cpu().numpy() 185 | 186 | if (not (np.min(z) > level or np.max(z) < level)): 187 | z = z.astype(np.float32) 188 | verts, faces, normals, values = measure.marching_cubes( 189 | volume=z.reshape(cropN, cropN, cropN), #.transpose([1, 0, 2]), 190 | level=level, 191 | spacing=( 192 | (x_max - x_min)/(cropN-1), 193 | (y_max - y_min)/(cropN-1), 194 | (z_max - z_min)/(cropN-1) )) 195 | print(np.array([x_min, y_min, z_min])) 196 | print(verts.min(), verts.max()) 197 | verts = verts + np.array([x_min, y_min, z_min]) 198 | print(verts.min(), verts.max()) 199 | 200 | meshcrop = trimesh.Trimesh(verts, faces, normals) 201 | #meshcrop.export(f"{i}_{j}_{k}.ply") 202 | meshes.append(meshcrop) 203 | try: 204 | combined = trimesh.util.concatenate(meshes) 205 | 206 | combined.export('{0}/surface_{1}_{2}.ply'.format(path, epoch, mesh_name), 'ply') 207 | except: 208 | print('no mesh') 209 | 210 | def get_3D_scatter_trace(points, name='', size=3, caption=None): 211 | assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped " 212 | assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped " 213 | 214 | trace = go.Scatter3d( 215 | x=points[:, 0].cpu(), 216 | y=points[:, 1].cpu(), 217 | z=points[:, 2].cpu(), 218 | mode='markers', 219 | name=name, 220 | marker=dict( 221 | size=size, 222 | line=dict( 223 | width=2, 224 | ), 225 | opacity=1.0, 226 | ), text=caption) 227 | 228 | return trace 229 | 230 | 231 | def get_3D_quiver_trace(points, directions, color='#bd1540', name=''): 232 | assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped " 233 | assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped " 234 | assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped " 235 | assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped " 236 | 237 | trace = go.Cone( 238 | name=name, 239 | x=points[:, 0].cpu(), 240 | y=points[:, 1].cpu(), 241 | z=points[:, 2].cpu(), 242 | u=directions[:, 0].cpu(), 243 | v=directions[:, 1].cpu(), 244 | w=directions[:, 2].cpu(), 245 | sizemode='absolute', 246 | sizeref=0.125, 247 | showscale=False, 248 | colorscale=[[0, color], [1, color]], 249 | anchor="tail" 250 | ) 251 | 252 | return trace 253 | 254 | 255 | def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0): 256 | grid = get_grid_uniform(resolution, grid_boundary) 257 | points = grid['grid_points'] 258 | 259 | z = [] 260 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 261 | z.append(sdf(pnts.cuda()).detach().cpu().numpy()) 262 | z = np.concatenate(z, axis=0) 263 | 264 | if (not (np.min(z) > level or np.max(z) < level)): 265 | 266 | z = z.astype(np.float32) 267 | 268 | verts, faces, normals, values = measure.marching_cubes( 269 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 270 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 271 | level=level, 272 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 273 | grid['xyz'][0][2] - grid['xyz'][0][1], 274 | grid['xyz'][0][2] - grid['xyz'][0][1])) 275 | 276 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 277 | ''' 278 | I, J, K = faces.transpose() 279 | 280 | traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2], 281 | i=I, j=J, k=K, name='implicit_surface', 282 | color='#ffffff', opacity=1.0, flatshading=False, 283 | lighting=dict(diffuse=1, ambient=0, specular=0), 284 | lightposition=dict(x=0, y=0, z=-1), showlegend=True)] 285 | ''' 286 | meshexport = trimesh.Trimesh(verts, faces, normals) 287 | meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply') 288 | 289 | if return_mesh: 290 | return meshexport 291 | #return traces 292 | return None 293 | 294 | def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, 2.0], level=0, take_components=True): 295 | # get low res mesh to sample point cloud 296 | grid = get_grid_uniform(100, grid_boundary) 297 | z = [] 298 | points = grid['grid_points'] 299 | 300 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 301 | z.append(sdf(pnts).detach().cpu().numpy()) 302 | z = np.concatenate(z, axis=0) 303 | 304 | z = z.astype(np.float32) 305 | 306 | verts, faces, normals, values = measure.marching_cubes( 307 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 308 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 309 | level=level, 310 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 311 | grid['xyz'][0][2] - grid['xyz'][0][1], 312 | grid['xyz'][0][2] - grid['xyz'][0][1])) 313 | 314 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 315 | 316 | mesh_low_res = trimesh.Trimesh(verts, faces, normals) 317 | if take_components: 318 | components = mesh_low_res.split(only_watertight=False) 319 | areas = np.array([c.area for c in components], dtype=np.float) 320 | mesh_low_res = components[areas.argmax()] 321 | 322 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] 323 | recon_pc = torch.from_numpy(recon_pc).float().cuda() 324 | 325 | # Center and align the recon pc 326 | s_mean = recon_pc.mean(dim=0) 327 | s_cov = recon_pc - s_mean 328 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) 329 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] 330 | if torch.det(vecs) < 0: 331 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) 332 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), 333 | (recon_pc - s_mean).unsqueeze(-1)).squeeze() 334 | 335 | grid_aligned = get_grid(helper.cpu(), resolution) 336 | 337 | grid_points = grid_aligned['grid_points'] 338 | 339 | g = [] 340 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): 341 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), 342 | pnts.unsqueeze(-1)).squeeze() + s_mean) 343 | grid_points = torch.cat(g, dim=0) 344 | 345 | # MC to new grid 346 | points = grid_points 347 | z = [] 348 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 349 | z.append(sdf(pnts).detach().cpu().numpy()) 350 | z = np.concatenate(z, axis=0) 351 | 352 | meshexport = None 353 | if (not (np.min(z) > level or np.max(z) < level)): 354 | 355 | z = z.astype(np.float32) 356 | 357 | verts, faces, normals, values = measure.marching_cubes( 358 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], 359 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), 360 | level=level, 361 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 362 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 363 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) 364 | 365 | verts = torch.from_numpy(verts).cuda().float() 366 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), 367 | verts.unsqueeze(-1)).squeeze() 368 | verts = (verts + grid_points[0]).cpu().numpy() 369 | 370 | meshexport = trimesh.Trimesh(verts, faces, normals) 371 | 372 | return meshexport 373 | 374 | 375 | def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, higher_res=False): 376 | grid_params = grid_params * [[1.5], [1.0]] 377 | 378 | # params = PLOT_DICT[scan_id] 379 | input_min = torch.tensor(grid_params[0]).float() 380 | input_max = torch.tensor(grid_params[1]).float() 381 | 382 | if higher_res: 383 | # get low res mesh to sample point cloud 384 | grid = get_grid(None, 100, input_min=input_min, input_max=input_max, eps=0.0) 385 | z = [] 386 | points = grid['grid_points'] 387 | 388 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 389 | z.append(sdf(pnts).detach().cpu().numpy()) 390 | z = np.concatenate(z, axis=0) 391 | 392 | z = z.astype(np.float32) 393 | 394 | verts, faces, normals, values = measure.marching_cubes( 395 | volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0], 396 | grid['xyz'][2].shape[0]).transpose([1, 0, 2]), 397 | level=level, 398 | spacing=(grid['xyz'][0][2] - grid['xyz'][0][1], 399 | grid['xyz'][0][2] - grid['xyz'][0][1], 400 | grid['xyz'][0][2] - grid['xyz'][0][1])) 401 | 402 | verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]]) 403 | 404 | mesh_low_res = trimesh.Trimesh(verts, faces, normals) 405 | components = mesh_low_res.split(only_watertight=False) 406 | areas = np.array([c.area for c in components], dtype=np.float) 407 | mesh_low_res = components[areas.argmax()] 408 | 409 | recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0] 410 | recon_pc = torch.from_numpy(recon_pc).float().cuda() 411 | 412 | # Center and align the recon pc 413 | s_mean = recon_pc.mean(dim=0) 414 | s_cov = recon_pc - s_mean 415 | s_cov = torch.mm(s_cov.transpose(0, 1), s_cov) 416 | vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0] 417 | if torch.det(vecs) < 0: 418 | vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs) 419 | helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1), 420 | (recon_pc - s_mean).unsqueeze(-1)).squeeze() 421 | 422 | grid_aligned = get_grid(helper.cpu(), resolution, eps=0.01) 423 | else: 424 | grid_aligned = get_grid(None, resolution, input_min=input_min, input_max=input_max, eps=0.0) 425 | 426 | grid_points = grid_aligned['grid_points'] 427 | 428 | if higher_res: 429 | g = [] 430 | for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)): 431 | g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2), 432 | pnts.unsqueeze(-1)).squeeze() + s_mean) 433 | grid_points = torch.cat(g, dim=0) 434 | 435 | # MC to new grid 436 | points = grid_points 437 | z = [] 438 | for i, pnts in enumerate(torch.split(points, 100000, dim=0)): 439 | z.append(sdf(pnts).detach().cpu().numpy()) 440 | z = np.concatenate(z, axis=0) 441 | 442 | meshexport = None 443 | if (not (np.min(z) > level or np.max(z) < level)): 444 | 445 | z = z.astype(np.float32) 446 | 447 | verts, faces, normals, values = measure.marching_cubes( 448 | volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0], 449 | grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]), 450 | level=level, 451 | spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 452 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1], 453 | grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1])) 454 | 455 | if higher_res: 456 | verts = torch.from_numpy(verts).cuda().float() 457 | verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2), 458 | verts.unsqueeze(-1)).squeeze() 459 | verts = (verts + grid_points[0]).cpu().numpy() 460 | else: 461 | verts = verts + np.array([grid_aligned['xyz'][0][0], grid_aligned['xyz'][1][0], grid_aligned['xyz'][2][0]]) 462 | 463 | meshexport = trimesh.Trimesh(verts, faces, normals) 464 | 465 | # CUTTING MESH ACCORDING TO THE BOUNDING BOX 466 | if higher_res: 467 | bb = grid_params 468 | transformation = np.eye(4) 469 | transformation[:3, 3] = (bb[1,:] + bb[0,:])/2. 470 | bounding_box = trimesh.creation.box(extents=bb[1,:] - bb[0,:], transform=transformation) 471 | 472 | meshexport = meshexport.slice_plane(bounding_box.facets_origin, -bounding_box.facets_normal) 473 | 474 | return meshexport 475 | 476 | def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]): 477 | x = np.linspace(grid_boundary[0], grid_boundary[1], resolution) 478 | y = x 479 | z = x 480 | 481 | xx, yy, zz = np.meshgrid(x, y, z) 482 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float) 483 | 484 | return {"grid_points": grid_points, 485 | "shortest_axis_length": 2.0, 486 | "xyz": [x, y, z], 487 | "shortest_axis_index": 0} 488 | 489 | def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1): 490 | if input_min is None or input_max is None: 491 | input_min = torch.min(points, dim=0)[0].squeeze().numpy() 492 | input_max = torch.max(points, dim=0)[0].squeeze().numpy() 493 | 494 | bounding_box = input_max - input_min 495 | shortest_axis = np.argmin(bounding_box) 496 | if (shortest_axis == 0): 497 | x = np.linspace(input_min[shortest_axis] - eps, 498 | input_max[shortest_axis] + eps, resolution) 499 | length = np.max(x) - np.min(x) 500 | y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 501 | z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1)) 502 | elif (shortest_axis == 1): 503 | y = np.linspace(input_min[shortest_axis] - eps, 504 | input_max[shortest_axis] + eps, resolution) 505 | length = np.max(y) - np.min(y) 506 | x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 507 | z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1)) 508 | elif (shortest_axis == 2): 509 | z = np.linspace(input_min[shortest_axis] - eps, 510 | input_max[shortest_axis] + eps, resolution) 511 | length = np.max(z) - np.min(z) 512 | x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 513 | y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1)) 514 | 515 | xx, yy, zz = np.meshgrid(x, y, z) 516 | grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 517 | return {"grid_points": grid_points, 518 | "shortest_axis_length": length, 519 | "xyz": [x, y, z], 520 | "shortest_axis_index": shortest_axis} 521 | 522 | 523 | def plot_normal_maps(normal_maps, ground_true, path, epoch, plot_nrow, img_res, indices): 524 | ground_true = ground_true.cuda() 525 | normal_maps = torch.cat((normal_maps, ground_true), dim=0) 526 | normal_maps_plot = lin2img(normal_maps, img_res) 527 | 528 | tensor = torchvision.utils.make_grid(normal_maps_plot, 529 | scale_each=False, 530 | normalize=False, 531 | nrow=plot_nrow).cpu().detach().numpy() 532 | tensor = tensor.transpose(1, 2, 0) 533 | scale_factor = 255 534 | tensor = (tensor * scale_factor).astype(np.uint8) 535 | 536 | img = Image.fromarray(tensor) 537 | img.save('{0}/normal_{1}_{2}.png'.format(path, epoch, indices[0])) 538 | 539 | #import pdb; pdb.set_trace() 540 | #trans_topil(normal_maps_plot[0, :, :, 260:260+680]).save('{0}/2normal_{1}.png'.format(path, epoch)) 541 | 542 | 543 | def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res, indices, exposure=False): 544 | ground_true = ground_true.cuda() 545 | 546 | output_vs_gt = torch.cat((rgb_points, ground_true), dim=0) 547 | output_vs_gt_plot = lin2img(output_vs_gt, img_res) 548 | 549 | tensor = torchvision.utils.make_grid(output_vs_gt_plot, 550 | scale_each=False, 551 | normalize=False, 552 | nrow=plot_nrow).cpu().detach().numpy() 553 | 554 | tensor = tensor.transpose(1, 2, 0) 555 | scale_factor = 255 556 | tensor = (tensor * scale_factor).astype(np.uint8) 557 | 558 | img = Image.fromarray(tensor) 559 | if exposure: 560 | img.save('{0}/exposure_{1}_{2}.png'.format(path, epoch, indices[0])) 561 | else: 562 | img.save('{0}/rendering_{1}_{2}.png'.format(path, epoch, indices[0])) 563 | 564 | 565 | def plot_depth_maps(depth_maps, ground_true, path, epoch, plot_nrow, img_res, indices): 566 | ground_true = ground_true.cuda() 567 | depth_maps = torch.cat((depth_maps[..., None], ground_true), dim=0) 568 | depth_maps_plot = lin2img(depth_maps, img_res) 569 | depth_maps_plot = depth_maps_plot.expand(-1, 3, -1, -1) 570 | 571 | tensor = torchvision.utils.make_grid(depth_maps_plot, 572 | scale_each=False, 573 | normalize=False, 574 | nrow=plot_nrow).cpu().detach().numpy() 575 | tensor = tensor.transpose(1, 2, 0) 576 | 577 | save_path = '{0}/depth_{1}_{2}.png'.format(path, epoch, indices[0]) 578 | 579 | plt.imsave(save_path, tensor[:, :, 0], cmap='viridis') 580 | 581 | 582 | def colored_data(x, cmap='jet', d_min=None, d_max=None): 583 | if d_min is None: 584 | d_min = np.min(x) 585 | if d_max is None: 586 | d_max = np.max(x) 587 | x_relative = (x - d_min) / (d_max - d_min) 588 | cmap_ = plt.cm.get_cmap(cmap) 589 | return (255 * cmap_(x_relative)[:,:,:3]).astype(np.uint8) # H, W, C 590 | 591 | 592 | def plot_seg_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res, indices): 593 | ground_true = ground_true.cuda() 594 | 595 | output_vs_gt = torch.cat((rgb_points, ground_true), dim=0) 596 | output_vs_gt_plot = lin2img(output_vs_gt, img_res) 597 | 598 | tensor = torchvision.utils.make_grid(output_vs_gt_plot, 599 | scale_each=False, 600 | normalize=False, 601 | nrow=plot_nrow).cpu().detach().numpy() 602 | tensor = tensor.transpose(1, 2, 0)[:, :, 0] 603 | tensor = colored_data(tensor) 604 | 605 | img = Image.fromarray(tensor) 606 | img.save('{0}/semantic_{1}_{2}.png'.format(path, epoch, indices[0])) 607 | 608 | 609 | def lin2img(tensor, img_res): 610 | batch_size, num_samples, channels = tensor.shape 611 | return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1]) 612 | -------------------------------------------------------------------------------- /code/utils/rend_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio.v2 as imageio 3 | import skimage 4 | import cv2 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def get_psnr(img1, img2, normalize_rgb=False): 10 | if normalize_rgb: # [-1,1] --> [0,1] 11 | img1 = (img1 + 1.) / 2. 12 | img2 = (img2 + 1. ) / 2. 13 | 14 | mse = torch.mean((img1 - img2) ** 2) 15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda()) 16 | 17 | return psnr 18 | 19 | 20 | def load_rgb(path, normalize_rgb = False): 21 | img = imageio.imread(path) 22 | img = skimage.img_as_float32(img) 23 | 24 | if normalize_rgb: # [-1,1] --> [0,1] 25 | img -= 0.5 26 | img *= 2. 27 | img = img.transpose(2, 0, 1) 28 | return img 29 | 30 | 31 | def load_K_Rt_from_P(filename, P=None): 32 | if P is None: 33 | lines = open(filename).read().splitlines() 34 | if len(lines) == 4: 35 | lines = lines[1:] 36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 37 | P = np.asarray(lines).astype(np.float32).squeeze() 38 | 39 | out = cv2.decomposeProjectionMatrix(P) 40 | K = out[0] 41 | R = out[1] 42 | t = out[2] 43 | 44 | K = K/K[2,2] 45 | intrinsics = np.eye(4) 46 | intrinsics[:3, :3] = K 47 | 48 | pose = np.eye(4, dtype=np.float32) 49 | pose[:3, :3] = R.transpose() 50 | pose[:3,3] = (t[:3] / t[3])[:,0] 51 | 52 | return intrinsics, pose 53 | 54 | 55 | def get_camera_params(uv, pose, intrinsics): 56 | if pose.shape[1] == 7: #In case of quaternion vector representation 57 | cam_loc = pose[:, 4:] 58 | R = quat_to_rot(pose[:,:4]) 59 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() 60 | p[:, :3, :3] = R 61 | p[:, :3, 3] = cam_loc 62 | else: # In case of pose matrix representation 63 | cam_loc = pose[:, :3, 3] 64 | p = pose 65 | 66 | batch_size, num_samples, _ = uv.shape 67 | 68 | depth = torch.ones((batch_size, num_samples)).cuda() 69 | x_cam = uv[:, :, 0].view(batch_size, -1) 70 | y_cam = uv[:, :, 1].view(batch_size, -1) 71 | z_cam = depth.view(batch_size, -1) 72 | 73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) 74 | 75 | # permute for batch matrix product 76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1) 77 | 78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] 79 | ray_dirs = world_coords - cam_loc[:, None, :] 80 | ray_dirs = F.normalize(ray_dirs, dim=2) 81 | 82 | return ray_dirs, cam_loc 83 | 84 | 85 | def get_camera_for_plot(pose): 86 | if pose.shape[1] == 7: #In case of quaternion vector representation 87 | cam_loc = pose[:, 4:].detach() 88 | R = quat_to_rot(pose[:,:4].detach()) 89 | else: # In case of pose matrix representation 90 | cam_loc = pose[:, :3, 3] 91 | R = pose[:, :3, :3] 92 | cam_dir = R[:, :3, 2] 93 | return cam_loc, cam_dir 94 | 95 | 96 | def lift(x, y, z, intrinsics): 97 | # parse intrinsics 98 | intrinsics = intrinsics.cuda() 99 | fx = intrinsics[:, 0, 0] 100 | fy = intrinsics[:, 1, 1] 101 | cx = intrinsics[:, 0, 2] 102 | cy = intrinsics[:, 1, 2] 103 | sk = intrinsics[:, 0, 1] 104 | 105 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z 106 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 107 | 108 | # homogeneous 109 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) 110 | 111 | 112 | def quat_to_rot(q): 113 | batch_size, _ = q.shape 114 | q = F.normalize(q, dim=1) 115 | R = torch.ones((batch_size, 3,3)).cuda() 116 | qr=q[:,0] 117 | qi = q[:, 1] 118 | qj = q[:, 2] 119 | qk = q[:, 3] 120 | R[:, 0, 0]=1-2 * (qj**2 + qk**2) 121 | R[:, 0, 1] = 2 * (qj *qi -qk*qr) 122 | R[:, 0, 2] = 2 * (qi * qk + qr * qj) 123 | R[:, 1, 0] = 2 * (qj * qi + qk * qr) 124 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2) 125 | R[:, 1, 2] = 2*(qj*qk - qi*qr) 126 | R[:, 2, 0] = 2 * (qk * qi-qj * qr) 127 | R[:, 2, 1] = 2 * (qj*qk + qi*qr) 128 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2) 129 | return R 130 | 131 | 132 | def rot_to_quat(R): 133 | batch_size, _,_ = R.shape 134 | q = torch.ones((batch_size, 4)).cuda() 135 | 136 | R00 = R[:, 0,0] 137 | R01 = R[:, 0, 1] 138 | R02 = R[:, 0, 2] 139 | R10 = R[:, 1, 0] 140 | R11 = R[:, 1, 1] 141 | R12 = R[:, 1, 2] 142 | R20 = R[:, 2, 0] 143 | R21 = R[:, 2, 1] 144 | R22 = R[:, 2, 2] 145 | 146 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 147 | q[:, 1]=(R21-R12)/(4*q[:,0]) 148 | q[:, 2] = (R02 - R20) / (4 * q[:, 0]) 149 | q[:, 3] = (R10 - R01) / (4 * q[:, 0]) 150 | return q 151 | 152 | 153 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0): 154 | # Input: n_rays x 3 ; n_rays x 3 155 | # Output: n_rays x 1, n_rays x 1 (close and far) 156 | 157 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3), 158 | cam_loc.view(-1, 3, 1)).squeeze(-1) 159 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2) 160 | 161 | # sanity check 162 | if (under_sqrt <= 0).sum() > 0: 163 | print('BOUNDING SPHERE PROBLEM!') 164 | exit() 165 | 166 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot 167 | sphere_intersections = sphere_intersections.clamp_min(0.0) 168 | 169 | return sphere_intersections 170 | 171 | 172 | def sample_pdf(bins, weights, N_importance, det=False, eps=1e-5): 173 | # device = weights.get_device() 174 | device = weights.device 175 | # Get pdf 176 | weights = weights + 1e-5 # prevent nans 177 | pdf = weights / torch.sum(weights, -1, keepdim=True) 178 | cdf = torch.cumsum(pdf, -1) 179 | cdf = torch.cat( 180 | [torch.zeros_like(cdf[..., :1], device=device), cdf], -1 181 | ) # (batch, len(bins)) 182 | 183 | # Take uniform samples 184 | if det: 185 | u = torch.linspace(0.0, 1.0, steps=N_importance, device=device) 186 | u = u.expand(list(cdf.shape[:-1]) + [N_importance]) 187 | else: 188 | u = torch.rand(list(cdf.shape[:-1]) + [N_importance], device=device) 189 | u = u.contiguous() 190 | 191 | # Invert CDF 192 | inds = torch.searchsorted(cdf.detach(), u, right=False) 193 | 194 | below = torch.clamp_min(inds-1, 0) 195 | above = torch.clamp_max(inds, cdf.shape[-1]-1) 196 | # (batch, N_importance, 2) ==> (B, batch, N_importance, 2) 197 | inds_g = torch.stack([below, above], -1) 198 | 199 | matched_shape = [*inds_g.shape[:-1], cdf.shape[-1]] # fix prefix shape 200 | 201 | cdf_g = torch.gather(cdf.unsqueeze(-2).expand(matched_shape), -1, inds_g) 202 | bins_g = torch.gather(bins.unsqueeze(-2).expand(matched_shape), -1, inds_g) # fix prefix shape 203 | 204 | denom = cdf_g[..., 1] - cdf_g[..., 0] 205 | denom[denom 0.) 33 | depth_map = depth_map * scale + shift 34 | 35 | # save point cloud 36 | depth = depth_map.reshape(1, 1, 384, 384) 37 | # pred_points = get_point_cloud(depth, model_input, model_outputs) 38 | 39 | gt_depth = depth_gt.reshape(1, 1, 384, 384) 40 | # gt_points = get_point_cloud(gt_depth, model_input, model_outputs) 41 | 42 | # semantic map 43 | semantic_map = model_outputs['semantic_values'].argmax(dim=-1).reshape(batch_size, num_samples, 1) 44 | # in label mapping, 0 is bg idx and 0 45 | # for instance, first fg is 3 and 1 46 | # so when using argmax, the output will be label_mapping idx if correct 47 | 48 | plot_data = { 49 | 'rgb_gt': rgb_gt, 50 | 'normal_gt': (normal_gt + 1.)/ 2., 51 | 'depth_gt': depth_gt, 52 | 'pose': pose, 53 | 'rgb_eval': rgb_eval, 54 | 'normal_map': normal_map, 55 | 'depth_map': depth_map, 56 | # "pred_points": pred_points, 57 | # "gt_points": gt_points, 58 | "semantic_map": semantic_map, 59 | "semantic_gt": semantic_gt, 60 | } 61 | 62 | return plot_data 63 | 64 | def get_sdf_vals_edit(pts, model, idx, edit_param, edit_type): 65 | with torch.no_grad(): 66 | sdf_original = model.implicit_network.forward(pts)[:,:model.implicit_network.d_out] # [N_pts, K] 67 | 68 | if edit_type == 'translate': 69 | edit_pts = pts - edit_param 70 | 71 | sdf_edit = model.implicit_network.forward(edit_pts)[:,:model.implicit_network.d_out] # [N_pts, K] 72 | 73 | sdf_original[:, idx] = sdf_original[:, idx] * 0. + sdf_edit[:, idx] 74 | 75 | sdf = sdf_original 76 | 77 | sdf = -model.implicit_network.pool(-sdf.unsqueeze(1)).squeeze(-1) # get the minium value of sdf if bound apply before min 78 | return sdf 79 | 80 | def neus_sample_edit(cam_loc, ray_dirs, model, idx, edit_param, edit_type): 81 | device = cam_loc.device 82 | perturb = False 83 | _, far = model.near_far_from_cube(cam_loc, ray_dirs, bound=model.scene_bounding_sphere) 84 | near = model.near * torch.ones(ray_dirs.shape[0], 1).cuda() 85 | 86 | _t = torch.linspace(0, 1, model.N_samples).float().to(device) 87 | z_vals = near * (1 - _t) + far * _t 88 | 89 | with torch.no_grad(): 90 | _z = z_vals # [N, 64] 91 | 92 | # follow the objsdf setting and use min sdf for sample 93 | _pts = cam_loc.unsqueeze(-2) + _z.unsqueeze(-1) * ray_dirs.unsqueeze(-2) 94 | N_rays, N_steps = _pts.shape[0], _pts.shape[1] 95 | 96 | _sdf = get_sdf_vals_edit(_pts.reshape(-1, 3), model, idx, edit_param, edit_type) 97 | 98 | _sdf = _sdf.reshape(N_rays, N_steps) 99 | 100 | for i in range(model.N_upsample_iters): 101 | prev_sdf, next_sdf = _sdf[..., :-1], _sdf[..., 1:] 102 | prev_z_vals, next_z_vals = _z[..., :-1], _z[..., 1:] 103 | mid_sdf = (prev_sdf + next_sdf) * 0.5 104 | dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) 105 | prev_dot_val = torch.cat([torch.zeros_like(dot_val[..., :1], device=device), dot_val[..., :-1]], dim=-1) 106 | dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) 107 | dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) 108 | dot_val = dot_val.clamp(-10.0, 0.0) 109 | 110 | dist = (next_z_vals - prev_z_vals) 111 | prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 112 | next_esti_sdf = mid_sdf + dot_val * dist * 0.5 113 | 114 | prev_cdf = cdf_Phi_s(prev_esti_sdf, 64 * (2**i)) 115 | next_cdf = cdf_Phi_s(next_esti_sdf, 64 * (2**i)) 116 | alpha = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) 117 | _w = alpha_to_w(alpha) 118 | z_fine = rend_util.sample_pdf(_z, _w, model.N_samples_extra // model.N_upsample_iters, det=not perturb) 119 | _z = torch.cat([_z, z_fine], dim=-1) 120 | 121 | _pts_fine = cam_loc.unsqueeze(-2) + z_fine.unsqueeze(-1) * ray_dirs.unsqueeze(-2) 122 | N_rays, N_steps_fine = _pts_fine.shape[0], _pts_fine.shape[1] 123 | 124 | sdf_fine = get_sdf_vals_edit(_pts_fine.reshape(-1, 3), model, idx, edit_param, edit_type) 125 | 126 | sdf_fine = sdf_fine.reshape(N_rays, N_steps_fine) 127 | _sdf = torch.cat([_sdf, sdf_fine], dim=-1) 128 | _z, z_sort_indices = torch.sort(_z, dim=-1) 129 | 130 | _sdf = torch.gather(_sdf, 1, z_sort_indices) 131 | 132 | z_all = _z 133 | 134 | return z_all 135 | 136 | def get_sdf_vals_and_sdfs_edit(pts, model, idx, edit_param, edit_type): 137 | with torch.no_grad(): 138 | sdf_original = model.implicit_network.forward(pts)[:,:model.implicit_network.d_out] # [N_pts, K] 139 | 140 | if edit_type == 'translate': 141 | edit_pts = pts - edit_param 142 | 143 | sdf_edit = model.implicit_network.forward(edit_pts)[:,:model.implicit_network.d_out] # [N_pts, K] 144 | 145 | sdf_original[:, idx] = sdf_original[:, idx] * 0. + sdf_edit[:, idx] 146 | 147 | sdf = sdf_original 148 | 149 | sdf_all = sdf 150 | sdf = -model.implicit_network.pool(-sdf.unsqueeze(1)).squeeze(-1) 151 | return sdf, sdf_all 152 | 153 | def get_outputs_edit(points, model, idx, edit_param, edit_type): 154 | points.requires_grad_(True) 155 | 156 | # directly use the original geometry feature vector 157 | # fuse sdf together 158 | # then compute semantic, gradient, sdf 159 | 160 | original_output = model.implicit_network.forward(points) 161 | sdf_original = original_output[:,:model.implicit_network.d_out] 162 | feature_vectors = original_output[:,model.implicit_network.d_out:] 163 | 164 | if edit_type == 'translate': 165 | edit_pts = points - edit_param 166 | edit_output = model.implicit_network.forward(edit_pts) 167 | sdf_edit = edit_output[:, :model.implicit_network.d_out] 168 | 169 | sdf_raw = sdf_original 170 | sdf_raw[:, idx] = sdf_original[:, idx] * 0. + sdf_edit[:, idx] 171 | 172 | sigmoid_value = model.implicit_network.sigmoid 173 | semantic = sigmoid_value * torch.sigmoid(-sigmoid_value * sdf_raw) 174 | 175 | sdf = -model.implicit_network.pool(-sdf_raw.unsqueeze(1)).squeeze(-1) 176 | 177 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 178 | gradients = torch.autograd.grad( 179 | outputs=sdf, 180 | inputs=points, 181 | grad_outputs=d_output, 182 | create_graph=True, 183 | retain_graph=True, 184 | only_inputs=True)[0] 185 | 186 | return sdf, feature_vectors, gradients, semantic, sdf_raw 187 | 188 | 189 | def render_edit(model, input, indices, idx=0, edit_param=[0., 0., 0.], edit_type='translate'): 190 | ''' 191 | Currently only support one object 192 | if edit_type == 'translate', then edit_param is [dx, dy, dz] 193 | if edit_type == 'rotate', then edit_param is []: TODO 194 | just use neus 195 | ''' 196 | assert idx > 0 197 | edit_param = torch.tensor(edit_param).cuda() 198 | 199 | intrinsics = input["intrinsics"].cuda() 200 | uv = input["uv"].cuda() 201 | pose = input["pose"].cuda() 202 | 203 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) 204 | # we should use unnormalized ray direction for depth 205 | ray_dirs_tmp, _ = rend_util.get_camera_params(uv, torch.eye(4).to(pose.device)[None], intrinsics) 206 | depth_scale = ray_dirs_tmp[0, :, 2:] # [N, 1] 207 | 208 | batch_size, num_pixels, _ = ray_dirs.shape 209 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) 210 | ray_dirs = ray_dirs.reshape(-1, 3) 211 | 212 | ''' 213 | Sample points with edited forward 214 | ''' 215 | z_vals = neus_sample_edit(cam_loc, ray_dirs, model, idx, edit_param, edit_type) 216 | 217 | N_samples_tmp = z_vals.shape[1] 218 | 219 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # [N_rays, N_samples_tmp, 3] 220 | points_flat_tmp = points.reshape(-1, 3) 221 | 222 | sdf_tmp, sdf_all_tmp = get_sdf_vals_and_sdfs_edit(points_flat_tmp, model, idx, edit_param, edit_type) 223 | sdf_tmp = sdf_tmp.reshape(-1, N_samples_tmp) 224 | s_value = model.get_s_value() 225 | 226 | cdf, opacity_alpha = sdf_to_alpha(sdf_tmp, s_value) # [N_rays, N_samples_tmp-1] 227 | 228 | sdf_all_tmp = sdf_all_tmp.reshape(-1, N_samples_tmp, model.num_semantic) 229 | 230 | z_mid_vals = 0.5 * (z_vals[..., 1:] + z_vals[..., :-1]) 231 | N_samples = z_mid_vals.shape[1] 232 | 233 | points_mid = cam_loc.unsqueeze(1) + z_mid_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # [N_rays, N_samples, 3] 234 | points_flat = points_mid.reshape(-1, 3) 235 | 236 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) 237 | dirs_flat = dirs.reshape(-1, 3) 238 | 239 | sdf, feature_vectors, gradients, semantic, sdf_raw = get_outputs_edit(points_flat, model, idx, edit_param, edit_type) 240 | 241 | # here the rgb output might be wrong 242 | rgb_flat = model.rendering_network(points_flat, gradients, dirs_flat, feature_vectors, indices) 243 | rgb = rgb_flat.reshape(-1, N_samples, 3) 244 | 245 | semantic = semantic.reshape(-1, N_samples, model.num_semantic) 246 | 247 | weights = alpha_to_w(opacity_alpha) 248 | 249 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1) 250 | semantic_values = torch.sum(weights.unsqueeze(-1)*semantic, 1) 251 | raw_depth_values = torch.sum(weights * z_mid_vals, 1, keepdims=True) / (weights.sum(dim=1, keepdims=True) +1e-8) 252 | depth_values = depth_scale * raw_depth_values 253 | 254 | output = { 255 | 'rgb_values': rgb_values, 256 | 'semantic_values': semantic_values, 257 | 'depth_values': depth_values, 258 | } 259 | 260 | # compute normal map 261 | normals = gradients / (gradients.norm(2, -1, keepdim=True) + 1e-6) 262 | normals = normals.reshape(-1, N_samples, 3) 263 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1) 264 | 265 | # transform to local coordinate system 266 | rot = pose[0, :3, :3].permute(1, 0).contiguous() 267 | normal_map = rot @ normal_map.permute(1, 0) 268 | normal_map = normal_map.permute(1, 0).contiguous() 269 | 270 | output['normal_map'] = normal_map 271 | 272 | return output 273 | 274 | 275 | edit_idx = 1 276 | edit_param = [0., 0., 0.] 277 | edit_type = 'translate' 278 | 279 | exp_name = 'RICO_synthetic_1' 280 | scan_id = int(exp_name[-1]) 281 | 282 | exp_path = os.path.join('../exps/', exp_name) 283 | timestamp = os.listdir(exp_path)[-1] # use the latest if not other need 284 | exp_path = os.path.join(exp_path, timestamp) 285 | 286 | conf = ConfigFactory.parse_file(os.path.join(exp_path, 'runconf.conf')) 287 | dataset_conf = conf.get_config('dataset') 288 | dataset_conf['scan_id'] = scan_id 289 | conf_model = conf.get_config('model') 290 | 291 | train_dataset = utils.get_class(conf.get_string('train.dataset_class'))(**dataset_conf) 292 | plot_dataloader = torch.utils.data.DataLoader( 293 | train_dataset, 294 | batch_size=conf.get_int('plot.plot_nimgs'), 295 | shuffle=False, 296 | collate_fn=train_dataset.collate_fn) 297 | 298 | model = utils.get_class(conf.get_string('train.model_class'))(conf=conf_model) 299 | 300 | if torch.cuda.is_available(): 301 | model.cuda() 302 | 303 | ckpt_path = os.path.join(exp_path, 'checkpoints/ModelParameters', 'latest.pth') 304 | ckpt = torch.load(ckpt_path) 305 | print(ckpt['epoch']) 306 | 307 | # model.load_state_dict(ckpt['model_state_dict']) 308 | # load in a non-DDP fashion 309 | model.load_state_dict({k.replace('module.',''): v for k,v in ckpt['model_state_dict'].items()}) 310 | os.makedirs('./tmp_edit', exist_ok=True) 311 | 312 | model.eval() 313 | 314 | data_idx = 75 315 | vis_data = plot_dataloader.dataset[data_idx] 316 | 317 | indices, model_input, ground_truth = vis_data 318 | indices = torch.tensor([indices]) 319 | print(indices) 320 | for k, v in model_input.items(): 321 | model_input[k] = v.unsqueeze(0) 322 | for k, v in ground_truth.items(): 323 | ground_truth[k] = v.unsqueeze(0) 324 | 325 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 326 | model_input["uv"] = model_input["uv"].cuda() 327 | model_input['pose'] = model_input['pose'].cuda() 328 | 329 | split = utils.split_input(model_input, 384*384, n_pixels=128) 330 | res = [] 331 | 332 | for s in tqdm(split): 333 | # out = model(s, indices) 334 | out = render_edit(model, s, indices, edit_idx, edit_param, edit_type) 335 | d = {'rgb_values': out['rgb_values'].detach(), 336 | 'normal_map': out['normal_map'].detach(), 337 | 'depth_values': out['depth_values'].detach(), 338 | 'semantic_values': out['semantic_values'].detach()} 339 | if 'rgb_un_values' in out: 340 | d['rgb_un_values'] = out['rgb_un_values'].detach() 341 | res.append(d) 342 | 343 | batch_size = ground_truth['rgb'].shape[0] 344 | model_outputs = utils.merge_output(res, 384*384, batch_size) 345 | plot_data = get_plot_data(model_input, model_outputs, model_input['pose'], ground_truth['rgb'], ground_truth['normal'], ground_truth['depth'], ground_truth['instance_mask']) 346 | 347 | plot_conf = conf.get_config('plot') 348 | plot_conf['obj_boxes'] = None 349 | plts.plot_rico( 350 | None, 351 | indices, 352 | plot_data, 353 | './tmp_edit/', 354 | ckpt['epoch'], 355 | [384, 384], 356 | plot_mesh = False, 357 | **plot_conf 358 | ) -------------------------------------------------------------------------------- /scripts/extract_mesh_rico.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import measure 4 | import torchvision 5 | import trimesh 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | import cv2 9 | import os 10 | import json 11 | from pyhocon import ConfigFactory 12 | 13 | import sys 14 | sys.path.append("../code") 15 | import utils.general as utils 16 | 17 | 18 | exp_name = 'RICO_synthetic_1' 19 | scan_id = int(exp_name[-1]) 20 | 21 | avg_pool_3d = torch.nn.AvgPool3d(2, stride=2) 22 | upsample = torch.nn.Upsample(scale_factor=2, mode='nearest') 23 | 24 | @torch.no_grad() 25 | def get_surface_sliding(send_path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0): 26 | if isinstance(send_path, list): 27 | path = send_path[0] 28 | mesh_name = send_path[1] 29 | else: 30 | path = send_path 31 | mesh_name = '' 32 | 33 | # assert resolution % 512 == 0 34 | resN = resolution 35 | cropN = resolution 36 | level = 0 37 | N = resN // cropN 38 | 39 | if len(grid_boundary) == 2: 40 | grid_min = [grid_boundary[0], grid_boundary[0], grid_boundary[0]] 41 | grid_max = [grid_boundary[1], grid_boundary[1], grid_boundary[1]] 42 | elif len(grid_boundary) == 6: # xmin, ymin, zmin, xmax, ymax, zmax 43 | grid_min = [grid_boundary[0], grid_boundary[1], grid_boundary[2]] 44 | grid_max = [grid_boundary[3], grid_boundary[4], grid_boundary[5]] 45 | xs = np.linspace(grid_min[0], grid_max[0], N+1) 46 | ys = np.linspace(grid_min[1], grid_max[1], N+1) 47 | zs = np.linspace(grid_min[2], grid_max[2], N+1) 48 | 49 | print(xs) 50 | print(ys) 51 | print(zs) 52 | meshes = [] 53 | for i in range(N): 54 | for j in range(N): 55 | for k in range(N): 56 | print(i, j, k) 57 | x_min, x_max = xs[i], xs[i+1] 58 | y_min, y_max = ys[j], ys[j+1] 59 | z_min, z_max = zs[k], zs[k+1] 60 | 61 | x = np.linspace(x_min, x_max, cropN) 62 | y = np.linspace(y_min, y_max, cropN) 63 | z = np.linspace(z_min, z_max, cropN) 64 | 65 | xx, yy, zz = np.meshgrid(x, y, z, indexing='ij') 66 | points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda() 67 | 68 | def evaluate(points): 69 | z = [] 70 | for _, pnts in enumerate(torch.split(points, 100000, dim=0)): 71 | z.append(sdf(pnts)) 72 | z = torch.cat(z, axis=0) 73 | return z 74 | 75 | # construct point pyramids 76 | points = points.reshape(cropN, cropN, cropN, 3).permute(3, 0, 1, 2) 77 | 78 | points_pyramid = [points] 79 | for _ in range(3): 80 | points = avg_pool_3d(points[None])[0] 81 | points_pyramid.append(points) 82 | points_pyramid = points_pyramid[::-1] 83 | 84 | # evalute pyramid with mask 85 | mask = None 86 | threshold = 2 * (x_max - x_min)/cropN * 8 87 | for pid, pts in enumerate(points_pyramid): 88 | coarse_N = pts.shape[-1] 89 | pts = pts.reshape(3, -1).permute(1, 0).contiguous() 90 | 91 | if mask is None: 92 | pts_sdf = evaluate(pts) 93 | else: 94 | mask = mask.reshape(-1) 95 | pts_to_eval = pts[mask] 96 | #import pdb; pdb.set_trace() 97 | if pts_to_eval.shape[0] > 0: 98 | pts_sdf_eval = evaluate(pts_to_eval.contiguous()) 99 | pts_sdf[mask] = pts_sdf_eval 100 | print("ratio", pts_to_eval.shape[0] / pts.shape[0]) 101 | 102 | if pid < 3: 103 | # update mask 104 | mask = torch.abs(pts_sdf) < threshold 105 | mask = mask.reshape(coarse_N, coarse_N, coarse_N)[None, None] 106 | mask = upsample(mask.float()).bool() 107 | 108 | pts_sdf = pts_sdf.reshape(coarse_N, coarse_N, coarse_N)[None, None] 109 | pts_sdf = upsample(pts_sdf) 110 | pts_sdf = pts_sdf.reshape(-1) 111 | 112 | threshold /= 2. 113 | 114 | z = pts_sdf.detach().cpu().numpy() 115 | 116 | if (not (np.min(z) > level or np.max(z) < level)): 117 | z = z.astype(np.float32) 118 | verts, faces, normals, values = measure.marching_cubes( 119 | volume=z.reshape(cropN, cropN, cropN), #.transpose([1, 0, 2]), 120 | level=level, 121 | spacing=( 122 | (x_max - x_min)/(cropN-1), 123 | (y_max - y_min)/(cropN-1), 124 | (z_max - z_min)/(cropN-1) )) 125 | print(np.array([x_min, y_min, z_min])) 126 | print(verts.min(), verts.max()) 127 | verts = verts + np.array([x_min, y_min, z_min]) 128 | print(verts.min(), verts.max()) 129 | 130 | meshcrop = trimesh.Trimesh(verts, faces, normals) 131 | #meshcrop.export(f"{i}_{j}_{k}.ply") 132 | meshes.append(meshcrop) 133 | 134 | combined = trimesh.util.concatenate(meshes) 135 | 136 | combined.export('{0}/surface_{1}_{2}.ply'.format(path, epoch, mesh_name), 'ply') 137 | 138 | 139 | exp_path = os.path.join('../exps/', exp_name) 140 | timestamp = os.listdir(exp_path)[-1] # use the latest if not other need 141 | exp_path = os.path.join(exp_path, timestamp) 142 | 143 | conf = ConfigFactory.parse_file(os.path.join(exp_path, 'runconf.conf')) 144 | dataset_conf = conf.get_config('dataset') 145 | conf_model = conf.get_config('model') 146 | 147 | model = utils.get_class(conf.get_string('train.model_class'))(conf=conf_model) 148 | 149 | if torch.cuda.is_available(): 150 | model.cuda() 151 | 152 | ckpt_path = os.path.join(exp_path, 'checkpoints/ModelParameters', 'latest.pth') 153 | ckpt = torch.load(ckpt_path) 154 | print(ckpt['epoch']) 155 | 156 | # model.load_state_dict(ckpt['model_state_dict']) 157 | # load in a non-DDP fashion 158 | model.load_state_dict({k.replace('module.',''): v for k,v in ckpt['model_state_dict'].items()}) 159 | os.makedirs('./tmp', exist_ok=True) 160 | 161 | sem_num = model.implicit_network.d_out 162 | f = torch.nn.MaxPool1d(sem_num) 163 | 164 | for indx in range(sem_num): 165 | obj_grid_boundary = [-1.1, 1.1] 166 | _ = get_surface_sliding( 167 | send_path=['./tmp/', str(indx)], 168 | epoch=ckpt['epoch'], 169 | sdf = lambda x: model.implicit_network(x)[:, indx], 170 | resolution=512, 171 | grid_boundary=obj_grid_boundary, 172 | level=0. 173 | ) 174 | 175 | _ = get_surface_sliding( 176 | send_path=['./tmp/', 'all'], 177 | epoch=ckpt['epoch'], 178 | sdf=lambda x: -f(-model.implicit_network(x)[:, :sem_num].unsqueeze(1)).squeeze(-1).squeeze(-1), 179 | resolution=512, 180 | grid_boundary=[-1.1, 1.1], 181 | level=0. 182 | ) 183 | 184 | print('finish') 185 | 186 | -------------------------------------------------------------------------------- /synthetic_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import open3d as o3d 4 | from sklearn.neighbors import KDTree 5 | import trimesh 6 | import torch 7 | import glob 8 | import os 9 | import pyrender 10 | import os 11 | import cv2 12 | import json 13 | from tqdm import tqdm 14 | from pathlib import Path 15 | 16 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 17 | 18 | def nn_correspondance(verts1, verts2): 19 | indices = [] 20 | distances = [] 21 | if len(verts1) == 0 or len(verts2) == 0: 22 | return indices, distances 23 | 24 | kdtree = KDTree(verts1) 25 | distances, indices = kdtree.query(verts2) 26 | distances = distances.reshape(-1) 27 | 28 | return distances 29 | 30 | 31 | def evaluate(mesh_pred, mesh_trgt, obj_type='bg', threshold=.05, down_sample=.02): 32 | pcd_trgt = o3d.geometry.PointCloud() 33 | pcd_pred = o3d.geometry.PointCloud() 34 | 35 | trgt_pts = mesh_trgt.vertices[:, :3] 36 | pred_pts = mesh_pred.vertices[:, :3] 37 | 38 | if obj_type == 'obj': 39 | pts_mask = pred_pts[:, 2] < -0.9 40 | pred_pts = pred_pts[pts_mask] 41 | 42 | pcd_trgt.points = o3d.utility.Vector3dVector(trgt_pts) 43 | pcd_pred.points = o3d.utility.Vector3dVector(pred_pts) 44 | 45 | if down_sample: 46 | pcd_pred = pcd_pred.voxel_down_sample(down_sample) 47 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample) 48 | 49 | verts_pred = np.asarray(pcd_pred.points) 50 | verts_trgt = np.asarray(pcd_trgt.points) 51 | 52 | dist1 = nn_correspondance(verts_pred, verts_trgt) 53 | dist2 = nn_correspondance(verts_trgt, verts_pred) 54 | 55 | precision = np.mean((dist2 < threshold).astype('float')) 56 | recal = np.mean((dist1 < threshold).astype('float')) 57 | fscore = 2 * precision * recal / (precision + recal) 58 | chamfer = (np.mean(dist2) + np.mean(dist1)) / 2 59 | metrics = { 60 | 'Acc': np.mean(dist2), 61 | 'Comp': np.mean(dist1), 62 | 'Chamfer': chamfer, 63 | 'Prec': precision, 64 | 'Recal': recal, 65 | 'F-score': fscore, 66 | } 67 | return metrics 68 | 69 | # hard-coded image size 70 | H, W = 384, 384 71 | 72 | def average_dicts(dicts): 73 | # input is a list of dict 74 | # all the dict have same keys 75 | dict_num = len(dicts) 76 | keys = dicts[0].keys() 77 | ret = {} 78 | 79 | for k in keys: 80 | values = [x[k] for x in dicts] 81 | value = np.array(values).mean() 82 | ret[k] = value 83 | 84 | return ret 85 | 86 | 87 | root_dir = "../exps/" 88 | exp_name = "RICO_synthetic" 89 | out_dir = "evaluation/" + exp_name 90 | Path(out_dir).mkdir(parents=True, exist_ok=True) 91 | 92 | 93 | scenes = { 94 | 1: 'scene1', 95 | 2: 'scene2', 96 | 3: 'scene3', 97 | 4: 'scene4', 98 | 5: 'scene5', 99 | } 100 | 101 | all_obj_results = [] 102 | all_obj_results_dict = OrderedDict() 103 | 104 | for k, v in scenes.items(): 105 | 106 | cur_exp = f"{exp_name}_{k}" 107 | cur_root = os.path.join(root_dir, cur_exp) 108 | if not os.path.isdir(cur_root): 109 | continue 110 | # use last timestamps 111 | dirs = sorted(os.listdir(cur_root)) 112 | cur_root = os.path.join(cur_root, dirs[-1]) 113 | 114 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*.ply")))) 115 | 116 | # evalute the meshes for obj and bg, the first is bg and last is all 117 | files.sort(key=lambda x:os.path.getmtime(x)) 118 | 119 | cam_file = f"../data/syn_data/scene{k}/cameras.npz" 120 | scale_mat = np.load(cam_file)['scale_mat_0'] 121 | 122 | ply_files = files[1: -1] 123 | # print(ply_files) 124 | 125 | cnt = 1 126 | obj_results = [] 127 | obj_results_dict = OrderedDict() 128 | for ply_file in ply_files: 129 | 130 | mesh = trimesh.load(ply_file) 131 | mesh.vertices = (scale_mat[:3, :3] @ mesh.vertices.T + scale_mat[:3, 3:]).T 132 | 133 | gt_mesh = os.path.join(f"../data/syn_data/scene{k}/GT_mesh", f"object{cnt}.ply") 134 | 135 | gt_mesh = trimesh.load(gt_mesh) 136 | 137 | metrics = evaluate(mesh, gt_mesh, 'obj') 138 | obj_results.append(metrics) 139 | obj_results_dict[cnt] = metrics 140 | 141 | cnt += 1 142 | 143 | obj_results = average_dicts(obj_results) 144 | all_obj_results.append(obj_results) 145 | all_obj_results_dict[k] = obj_results_dict 146 | 147 | # the average result print 148 | all_obj_results = average_dicts(all_obj_results) 149 | print('objects:') 150 | print(all_obj_results) 151 | 152 | # all the result save 153 | obj_json_str = json.dumps(all_obj_results_dict, indent=4) 154 | obj_json_file = os.path.join('evaluation', exp_name + '_obj.json') 155 | 156 | with open(obj_json_file, 'w') as json_file: 157 | json_file.write(obj_json_str) 158 | json_file.close() -------------------------------------------------------------------------------- /synthetic_eval/evaluate_bgdepth.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import numpy as np 3 | import open3d as o3d 4 | from sklearn.neighbors import KDTree 5 | import trimesh 6 | import torch 7 | import glob 8 | import os 9 | import pyrender 10 | import os 11 | import cv2 12 | import json 13 | from tqdm import tqdm 14 | from pathlib import Path 15 | 16 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 17 | 18 | def load_K_Rt_from_P(filename, P=None): 19 | if P is None: 20 | lines = open(filename).read().splitlines() 21 | if len(lines) == 4: 22 | lines = lines[1:] 23 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 24 | P = np.asarray(lines).astype(np.float32).squeeze() 25 | 26 | out = cv2.decomposeProjectionMatrix(P) 27 | K = out[0] 28 | R = out[1] 29 | t = out[2] 30 | 31 | K = K/K[2,2] 32 | intrinsics = np.eye(4) 33 | intrinsics[:3, :3] = K 34 | 35 | pose = np.eye(4, dtype=np.float32) 36 | pose[:3, :3] = R.transpose() 37 | pose[:3,3] = (t[:3] / t[3])[:,0] 38 | 39 | return intrinsics, pose 40 | 41 | # hard-coded image size 42 | H, W = 384, 384 43 | 44 | # load pose 45 | def load_poses(scan_id, object_id): 46 | pose_path = os.path.join(f'../data/syn_data/scene{scan_id}', 'cameras.npz') 47 | 48 | camera_dict = np.load(pose_path) 49 | len_pose = len(camera_dict.files) // 2 50 | 51 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(len_pose)] 52 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(len_pose)] 53 | P = world_mats[0] @ scale_mats[0] 54 | P = P[:3, :4] 55 | intrinsics, pose = load_K_Rt_from_P(None, P) 56 | 57 | poses = [] 58 | cnt = 0 59 | 60 | masks_path = os.path.join(f'../data/syn_data/scene{scan_id}', 'instance_mask') 61 | mask_files = sorted(os.listdir(masks_path)) 62 | 63 | id_json = os.path.join(f'../data/syn_data/scene{scan_id}', 'instance_id.json') 64 | with open(id_json, 'r') as f: 65 | id_data = json.load(f) 66 | f.close() 67 | 68 | if object_id > 0: # the valid object id 69 | obj_idx = id_data[f'obj_{object_id-1}'] 70 | else: 71 | obj_idx = -1 # invalid id, maybe for bg, however we load all the poses in this situation 72 | 73 | for scale_mat, world_mat in zip(scale_mats, world_mats): 74 | # first check if object is in this pose's corresponding image 75 | mask = cv2.imread(os.path.join(masks_path, mask_files[cnt])) 76 | mask = np.array(mask) 77 | mask = np.unique(mask) 78 | 79 | if obj_idx == -1: 80 | orig_pose = world_mat 81 | pose = np.linalg.inv(orig_pose) @ intrinsics 82 | poses.append(np.array(pose)) 83 | 84 | elif obj_idx in mask: 85 | orig_pose = world_mat 86 | pose = np.linalg.inv(orig_pose) @ intrinsics 87 | poses.append(np.array(pose)) 88 | 89 | cnt += 1 90 | 91 | poses = np.array(poses) 92 | print(poses.shape) 93 | return poses, intrinsics 94 | 95 | 96 | class Renderer(): 97 | def __init__(self, height=480, width=640): 98 | self.renderer = pyrender.OffscreenRenderer(width, height) 99 | self.scene = pyrender.Scene() 100 | self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES 101 | 102 | def __call__(self, height, width, intrinsics, pose, mesh, need_flag=True): 103 | self.renderer.viewport_height = height 104 | self.renderer.viewport_width = width 105 | self.scene.clear() 106 | self.scene.add(mesh) 107 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2], 108 | fx=intrinsics[0, 0], fy=intrinsics[1, 1]) 109 | self.scene.add(cam, pose=self.fix_pose(pose)) 110 | if need_flag: 111 | return self.renderer.render(self.scene, self.render_flags) 112 | else: 113 | return self.renderer.render(self.scene) # , self.render_flags) 114 | 115 | def fix_pose(self, pose): 116 | # 3D Rotation about the x-axis. 117 | t = np.pi 118 | c = np.cos(t) 119 | s = np.sin(t) 120 | R = np.array([[1, 0, 0], 121 | [0, c, -s], 122 | [0, s, c]]) 123 | axis_transform = np.eye(4) 124 | axis_transform[:3, :3] = R 125 | return pose @ axis_transform 126 | 127 | def mesh_opengl(self, mesh): 128 | return pyrender.Mesh.from_trimesh(mesh) 129 | 130 | def delete(self): 131 | self.renderer.delete() 132 | 133 | 134 | def refuse_depth(mesh, poses, K, need_flag=False, scan_id=-1): 135 | renderer = Renderer() 136 | mesh_opengl = renderer.mesh_opengl(mesh) 137 | 138 | depths = [] 139 | 140 | for pose in tqdm(poses): 141 | intrinsic = np.eye(4) 142 | intrinsic[:3, :3] = K 143 | 144 | rgb = np.ones((H, W, 3)) 145 | rgb = (rgb * 255).astype(np.uint8) 146 | rgb = o3d.geometry.Image(rgb) 147 | _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl, need_flag=need_flag) 148 | depths.append(depth_pred) 149 | 150 | return depths 151 | 152 | 153 | def average_dicts(dicts): 154 | # input is a list of dict 155 | # all the dict have same keys 156 | dict_num = len(dicts) 157 | keys = dicts[0].keys() 158 | ret = {} 159 | 160 | for k in keys: 161 | values = [x[k] for x in dicts] 162 | value = np.array(values).mean() 163 | ret[k] = value 164 | 165 | return ret 166 | 167 | 168 | root_dir = "../exps/" 169 | exp_name = "RICO_synthetic" 170 | out_dir = "evaluation/" + exp_name 171 | Path(out_dir).mkdir(parents=True, exist_ok=True) 172 | 173 | 174 | scenes = { 175 | 1: 'scene1', 176 | 2: 'scene2', 177 | 3: 'scene3', 178 | 4: 'scene4', 179 | 5: 'scene5', 180 | } 181 | 182 | all_bg_results = [] 183 | all_bg_results_dict = OrderedDict() 184 | 185 | for k, v in scenes.items(): 186 | 187 | cur_exp = f"{exp_name}_{k}" 188 | cur_root = os.path.join(root_dir, cur_exp) 189 | if not os.path.isdir(cur_root): 190 | continue 191 | # use last timestamps 192 | dirs = sorted(os.listdir(cur_root)) 193 | cur_root = os.path.join(cur_root, dirs[-1]) 194 | 195 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*.ply")))) 196 | 197 | # evalute the meshes for obj and bg, the first is bg and last is all 198 | files.sort(key=lambda x:os.path.getmtime(x)) 199 | 200 | bg_file = files[0] 201 | print(bg_file) 202 | bg_mesh = trimesh.load(bg_file) 203 | 204 | cam_file = f"../data/syn_data/scene{k}/cameras.npz" 205 | scale_mat = np.load(cam_file)['scale_mat_0'] 206 | bg_mesh.vertices = (scale_mat[:3, :3] @ bg_mesh.vertices.T + scale_mat[:3, 3:]).T 207 | 208 | poses, K = load_poses(k, -1) 209 | K = K[:3, :3] 210 | bg_mesh_depth = refuse_depth(bg_mesh, poses, K, scan_id=k) 211 | 212 | gt_mesh = os.path.join(f"../data/syn_data/scene{k}/GT_mesh", f"background.ply") 213 | gt_mesh = trimesh.load(gt_mesh) 214 | 215 | gt_mesh.vertex_normals = -gt_mesh.vertex_normals 216 | 217 | gt_mesh_depth = refuse_depth(gt_mesh, poses, K, need_flag=True, scan_id=k) 218 | 219 | masks_path = os.path.join(f'../data/syn_data/scene{k}', 'instance_mask') 220 | mask_files = sorted(os.listdir(masks_path)) 221 | masks = [cv2.imread(os.path.join(masks_path, x)) for x in mask_files] 222 | 223 | depth_errors = [] 224 | for gt_depth, pred_depth, seg_mask in zip(gt_mesh_depth, bg_mesh_depth, masks): 225 | seg = seg_mask 226 | seg = np.array(seg) 227 | seg = seg[:, :, 0] > 0 # obj regions 228 | 229 | gtd = gt_depth[seg] 230 | prd = pred_depth[seg] 231 | 232 | mse = (np.square(gtd - prd)).mean(axis=0) 233 | depth_errors.append(mse) 234 | depth_errors = np.array(depth_errors) 235 | metrics = {'bg_depth_error': depth_errors.mean().astype(float)} 236 | 237 | all_bg_results.append(metrics) 238 | all_bg_results_dict[k] = metrics 239 | 240 | # the average result print 241 | all_bg_results = average_dicts(all_bg_results) 242 | print('background:') 243 | print(all_bg_results) 244 | # all the result save 245 | bg_json_str = json.dumps(all_bg_results_dict, indent=4) 246 | bg_json_file = os.path.join('evaluation', exp_name + '_bg_depth.json') 247 | 248 | with open(bg_json_file, 'w') as json_file: 249 | json_file.write(bg_json_str) 250 | json_file.close() --------------------------------------------------------------------------------