├── LICENSE ├── README.md ├── code ├── confs │ ├── replica_objsdfplus.conf │ ├── scannet_objsdfplus.conf │ └── scannet_objsdfplus_mlp.conf ├── datasets │ └── scene_dataset.py ├── hashencoder │ ├── __init__.py │ ├── backend.py │ ├── hashgrid.py │ └── src │ │ ├── bindings.cpp │ │ ├── hashencoder.cu │ │ └── hashencoder.h ├── model │ ├── density.py │ ├── embedder.py │ ├── loss.py │ ├── network.py │ └── ray_sampler.py ├── training │ ├── exp_runner.py │ └── objectsdfplus_train.py └── utils │ ├── general.py │ ├── plots.py │ ├── rend_util.py │ └── sem_util.py ├── media └── teaser.gif ├── preprocess ├── README.md ├── extract_monocular_cues.py ├── replica_to_objsdfpp.py └── scannet_to_objsdfpp.py ├── replica_eval ├── avg_metric.py ├── cull_mesh.py ├── cull_obj_gt.py ├── eval_3D_obj.py ├── eval_recon.py ├── evaluate.py ├── evaluate_single_scene.py └── metrics.py ├── requirements.txt ├── scannet_eval └── evaluate.py └── scripts ├── download_dataset.sh └── download_meshes.sh /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Qianyi Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

ObjectSDF++: Improved Object-Compositional Neural Implicit Surfaces

4 |

5 | Qianyi Wu 6 | · 7 | Kaisiyuan Wang 8 | · 9 | Kejie Li 10 | · 11 | Jianmin Zheng 12 | · 13 | Jianfei Cai 14 | 15 |

16 |

ICCV 2023

17 |

Paper | Project Page

18 |
19 |

20 | 21 |

22 | 23 | Logo 24 | 25 |

26 | 27 |

28 | TL; DR: We propose an occlusion-aware opacity rendering formulation to better use the instance mask supervision. Together with an object-distinction regularization term, the proposed ObjectSDF++ produces more accurate surface reconstruction at both scene and object levels. 29 |

30 |
31 | 32 | # Setup 33 | 34 | ## Installation 35 | This code has been tested on Ubuntu 22.02 with torch 2.0 & CUDA 11.7 on a RTX 3090. 36 | Clone the repository and create an anaconda environment named objsdf 37 | ``` 38 | git clone https://github.com/QianyiWu/objectsdf_plus.git 39 | cd objectsdf_plus 40 | 41 | conda create -y -n objsdf python=3.9 42 | conda activate object 43 | 44 | pip install -r requirements.txt 45 | ``` 46 | The hash encoder will be compiled on the fly when running the code. 47 | 48 | ## Dataset 49 | For downloading the preprocessed data, run the following script. The data for the Replica and ScanNet is adapted from [MonoSDF](https://github.com/autonomousvision/monosdf), [vMAP](https://github.com/kxhit/vMAP). 50 | ``` 51 | bash scripts/download_dataset.sh 52 | ``` 53 | # Training 54 | 55 | Run the following command to train ObjectSDF++: 56 | ``` 57 | cd ./code 58 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 training/exp_runner.py --conf CONFIG --scan_id SCAN_ID 59 | ``` 60 | where CONFIG is the config file in `code/confs`, and SCAN_ID is the id of the scene to reconstruct. 61 | 62 | We provide example commands for training Replica dataset as follows: 63 | ``` 64 | # Replica scan 1 (room0) 65 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 training/exp_runner.py --conf confs/replica_objsdfplus.conf --scan_id 1 66 | 67 | # ScanNet scan 1 (scene_0050_00) 68 | CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 training/exp_runner.py --conf confs/scannet_objsdfplus.conf --scan_id 1 69 | 70 | ``` 71 | The intermediate results and checkpoints will be saved in ``exps`` folder. 72 | 73 | # Evaluations 74 | 75 | ## Replica 76 | Evaluate one scene (take scan 1 room0 for example) 77 | ``` 78 | cd replica_eval 79 | python evaluate_single_scene.py --input_mesh replica_scan1_mesh.ply --scan_id 1 --output_dir replica_scan1 80 | ``` 81 | 82 | We also provided scripts for evaluating all Replica scenes and objects: 83 | ``` 84 | cd replica_eval 85 | python evaluate.py # scene-level evaluation 86 | python evaluate_3D_obj.py # object-level evaluation 87 | ``` 88 | please check the script for more details. For obtaining the object groundtruth, you can refer to [here](https://github.com/kxhit/vMAP#dataset) for more details. 89 | 90 | ## ScanNet 91 | ``` 92 | cd scannet_eval 93 | python evaluate.py 94 | ``` 95 | please check the script for more details. 96 | 97 | # Acknowledgements 98 | This project is built upon [MonoSDF](https://github.com/autonomousvision/monosdf). The monocular depth and normal images are obtained by [Omnidata](https://omnidata.vision). The evaluation of object reconstruction is inspired by [vMAP](https://github.com/kxhit/vMAP). Cuda implementation of Multi-Resolution hash encoding is heavily based on [torch-ngp](https://github.com/ashawkey/torch-ngp). Kudos to these researchers. 99 | 100 | 101 | # Citation 102 | If you find our code or paper useful, please cite the series of ObjectSDF works. 103 | ```BibTeX 104 | @inproceedings{wu2022object, 105 | title = {Object-compositional neural implicit surfaces}, 106 | author = {Wu, Qianyi and Liu, Xian and Chen, Yuedong and Li, Kejie and Zheng, Chuanxia and Cai, Jianfei and Zheng, Jianmin}, 107 | booktitle = {European Conference on Computer Vision}, 108 | year = {2022}, 109 | } 110 | 111 | @inproceedings{wu2023objsdfplus, 112 | author = {Wu, Qianyi and Wang, Kaisiyuan and Li, Kejie and Zheng, Jianmin and Cai, Jianfei}, 113 | title = {ObjectSDF++: Improved Object-Compositional Neural Implicit Surfaces}, 114 | booktitle = {ICCV}, 115 | year = {2023}, 116 | } 117 | ``` 118 | 119 | -------------------------------------------------------------------------------- /code/confs/replica_objsdfplus.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = objectsdfplus_replica 3 | dataset_class = datasets.scene_dataset.SceneDatasetDN_segs 4 | model_class = model.network.ObjectSDFPlusNetwork 5 | loss_class = model.loss.ObjectSDFPlusLoss 6 | learning_rate = 5.0e-4 7 | lr_factor_for_grid = 20.0 8 | num_pixels = 1024 9 | checkpoint_freq = 100 10 | plot_freq = 50 11 | split_n_pixels = 1024 12 | add_objectvio_iter = 100000 13 | } 14 | plot{ 15 | plot_nimgs = 1 16 | resolution = 256 17 | grid_boundary = [-1.0, 1.0] 18 | } 19 | loss{ 20 | rgb_loss = torch.nn.L1Loss 21 | eikonal_weight = 0.1 22 | smooth_weight = 0.005 23 | depth_weight = 0.1 24 | normal_l1_weight = 0.05 25 | normal_cos_weight = 0.05 26 | semantic_loss = torch.nn.MSELoss 27 | use_obj_opacity = True 28 | semantic_weight = 1 29 | reg_vio_weight = 0.5 30 | } 31 | dataset{ 32 | data_dir = replica 33 | img_res = [384, 384] 34 | use_mask = True 35 | center_crop_type = center_crop_for_replica 36 | } 37 | model{ 38 | feature_vector_size = 256 39 | scene_bounding_sphere = 1.0 40 | 41 | Grid_MLP = True 42 | 43 | implicit_network 44 | { 45 | d_in = 3 46 | d_out = 46 47 | dims = [256, 256] 48 | geometric_init = True 49 | bias = 0.9 50 | skip_in = [4] 51 | weight_norm = True 52 | multires = 6 53 | inside_outside = True 54 | use_grid_feature = True 55 | divide_factor = 1.0 56 | sigmoid = 10 57 | } 58 | 59 | rendering_network 60 | { 61 | mode = idr 62 | d_in = 9 63 | d_out = 3 64 | dims = [256, 256] 65 | weight_norm = True 66 | multires_view = 4 67 | per_image_code = True 68 | } 69 | density 70 | { 71 | params_init{ 72 | beta = 0.1 73 | } 74 | beta_min = 0.0001 75 | } 76 | ray_sampler 77 | { 78 | near = 0.0 79 | N_samples = 64 80 | N_samples_eval = 128 81 | N_samples_extra = 32 82 | eps = 0.1 83 | beta_iters = 10 84 | max_total_iters = 5 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /code/confs/scannet_objsdfplus.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = objectsdfplus_grid_scannet 3 | dataset_class = datasets.scene_dataset.SceneDatasetDN_segs 4 | model_class = model.network.ObjectSDFPlusNetwork 5 | loss_class = model.loss.ObjectSDFPlusLoss 6 | learning_rate = 5.0e-4 7 | lr_factor_for_grid = 20.0 8 | num_pixels = 1024 9 | checkpoint_freq = 100 10 | plot_freq = 50 11 | split_n_pixels = 1024 12 | } 13 | plot{ 14 | plot_nimgs = 1 15 | resolution = 256 16 | grid_boundary = [-1.1, 1.1] 17 | } 18 | loss{ 19 | rgb_loss = torch.nn.L1Loss 20 | eikonal_weight = 0.05 21 | smooth_weight = 0.005 22 | depth_weight = 0.1 23 | normal_l1_weight = 0.05 24 | normal_cos_weight = 0.05 25 | semantic_loss = torch.nn.MSELoss 26 | use_obj_opacity = True 27 | semantic_weight = 0.5 28 | reg_vio_weight = 0.1 29 | } 30 | dataset{ 31 | data_dir = scannet 32 | img_res = [384, 384] 33 | center_crop_type = no_crop 34 | } 35 | model{ 36 | feature_vector_size = 256 37 | scene_bounding_sphere = 1.0 38 | 39 | Grid_MLP = True 40 | 41 | implicit_network 42 | { 43 | d_in = 3 44 | d_out = 32 45 | dims = [256, 256] 46 | geometric_init = True 47 | bias = 0.9 48 | skip_in = [4] 49 | weight_norm = True 50 | multires = 6 51 | inside_outside = True 52 | use_grid_feature = True 53 | divide_factor = 1.1 54 | sigmoid = 10 55 | } 56 | 57 | rendering_network 58 | { 59 | mode = idr 60 | d_in = 9 61 | d_out = 3 62 | dims = [256, 256] 63 | weight_norm = True 64 | multires_view = 4 65 | per_image_code = True 66 | } 67 | density 68 | { 69 | params_init{ 70 | beta = 0.1 71 | } 72 | beta_min = 0.0001 73 | } 74 | ray_sampler 75 | { 76 | near = 0.0 77 | N_samples = 64 78 | N_samples_eval = 128 79 | N_samples_extra = 32 80 | eps = 0.1 81 | beta_iters = 10 82 | max_total_iters = 5 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /code/confs/scannet_objsdfplus_mlp.conf: -------------------------------------------------------------------------------- 1 | train{ 2 | expname = objectsdfplus_mlp_scannet 3 | dataset_class = datasets.scene_dataset.SceneDatasetDN_segs 4 | model_class = model.network.ObjectSDFPlusNetwork 5 | loss_class = model.loss.ObjectSDFPlusLoss 6 | learning_rate = 5.0e-4 7 | lr_factor_for_grid = 20.0 8 | num_pixels = 1024 9 | checkpoint_freq = 100 10 | plot_freq = 50 11 | split_n_pixels = 1024 12 | } 13 | plot{ 14 | plot_nimgs = 1 15 | resolution = 256 16 | grid_boundary = [-1.1, 1.1] 17 | } 18 | loss{ 19 | rgb_loss = torch.nn.L1Loss 20 | eikonal_weight = 0.05 21 | smooth_weight = 0.005 22 | depth_weight = 0.1 23 | normal_l1_weight = 0.05 24 | normal_cos_weight = 0.05 25 | semantic_loss = torch.nn.MSELoss 26 | use_obj_opacity = True 27 | semantic_weight = 0.5 28 | reg_vio_weight = 0.1 29 | } 30 | dataset{ 31 | data_dir = scannet 32 | img_res = [384, 384] 33 | center_crop_type = no_crop 34 | } 35 | model{ 36 | feature_vector_size = 256 37 | scene_bounding_sphere = 1.0 38 | 39 | Grid_MLP = True 40 | 41 | implicit_network 42 | { 43 | d_in = 3 44 | d_out = 32 45 | dims = [256, 256, 256, 256, 256, 256, 256, 256] 46 | geometric_init = True 47 | bias = 0.9 48 | skip_in = [4] 49 | weight_norm = True 50 | multires = 6 51 | inside_outside = True 52 | use_grid_feature = False 53 | divide_factor = 1.1 54 | sigmoid = 10 55 | } 56 | 57 | rendering_network 58 | { 59 | mode = idr 60 | d_in = 9 61 | d_out = 3 62 | dims = [256, 256] 63 | weight_norm = True 64 | multires_view = 4 65 | per_image_code = True 66 | } 67 | density 68 | { 69 | params_init{ 70 | beta = 0.1 71 | } 72 | beta_min = 0.0001 73 | } 74 | ray_sampler 75 | { 76 | near = 0.0 77 | N_samples = 64 78 | N_samples_eval = 128 79 | N_samples_extra = 32 80 | eps = 0.1 81 | beta_iters = 10 82 | max_total_iters = 5 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /code/datasets/scene_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import utils.general as utils 7 | from utils import rend_util 8 | from glob import glob 9 | import cv2 10 | import random 11 | 12 | # Dataset with monocular depth, normal and segmentation mask 13 | class SceneDatasetDN_segs(torch.utils.data.Dataset): 14 | 15 | def __init__(self, 16 | data_dir, 17 | img_res, 18 | scan_id=0, 19 | center_crop_type='xxxx', 20 | use_mask=False, 21 | num_views=-1 22 | ): 23 | 24 | self.instance_dir = os.path.join('../data', data_dir, 'scan{0}'.format(scan_id)) 25 | print(self.instance_dir) 26 | 27 | self.total_pixels = img_res[0] * img_res[1] 28 | self.img_res = img_res 29 | self.num_views = num_views 30 | assert num_views in [-1, 3, 6, 9] 31 | 32 | assert os.path.exists(self.instance_dir), "Data directory is empty" 33 | 34 | self.sampling_idx = None 35 | 36 | def glob_data(data_dir): 37 | data_paths = [] 38 | data_paths.extend(glob(data_dir)) 39 | data_paths = sorted(data_paths) 40 | return data_paths 41 | 42 | image_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_rgb.png")) 43 | depth_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_depth.npy")) 44 | normal_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_normal.npy")) 45 | # semantic_paths 46 | semantic_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "segs", "*_segs.png")) 47 | 48 | # mask is only used in the replica dataset as some monocular depth predictions have very large error and we ignore it 49 | if use_mask: 50 | mask_paths = glob_data(os.path.join('{0}'.format(self.instance_dir), "*_mask.npy")) 51 | else: 52 | mask_paths = None 53 | 54 | self.n_images = len(image_paths) 55 | print('[INFO]: Dataset Size ', self.n_images) 56 | 57 | # load instance_mapping_dict 58 | self.label_mapping = None 59 | self.instance_mapping_dict= {} 60 | # with open(os.path.join(data_dir, 'label_mapping_instance.txt'), 'r') as f: 61 | # content = f.readlines() 62 | # self.label_mapping = [int(a) for a in content[0].split(',')] 63 | 64 | # using the remapped instance label for training 65 | with open(os.path.join(self.instance_dir, 'instance_mapping.txt'), 'r') as f: 66 | for l in f: 67 | (k, v_sem, v_ins) = l.split(',') 68 | self.instance_mapping_dict[int(k)] = int(v_ins) 69 | # self.label_mapping = [] # get the sorted label mapping list 70 | # for k, v in self.instance_mapping_dict.items(): 71 | # if v not in self.label_mapping: # not a duplicate instance 72 | # self.label_mapping.append(v) 73 | self.label_mapping = sorted(set(self.instance_mapping_dict.values())) # get sorted label mapping. The first one is the background 74 | # sorted 75 | # print('Instance Label Mapping: ', self.label_mapping) 76 | self.obj_id = torch.from_numpy(np.array(range(len(self.label_mapping)))) 77 | print(self.obj_id) 78 | 79 | self.cam_file = '{0}/cameras.npz'.format(self.instance_dir) 80 | camera_dict = np.load(self.cam_file) 81 | scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 82 | world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)] 83 | 84 | self.intrinsics_all = [] 85 | self.pose_all = [] 86 | for scale_mat, world_mat in zip(scale_mats, world_mats): 87 | P = world_mat @ scale_mat 88 | P = P[:3, :4] 89 | intrinsics, pose = rend_util.load_K_Rt_from_P(None, P) 90 | 91 | # because we do resize and center crop 384x384 when using omnidata model, we need to adjust the camera intrinsic accordingly 92 | if center_crop_type == 'center_crop_for_replica': 93 | scale = 384 / 680 94 | offset = (1200 - 680 ) * 0.5 95 | intrinsics[0, 2] -= offset 96 | intrinsics[:2, :] *= scale 97 | elif center_crop_type == 'center_crop_for_tnt': 98 | scale = 384 / 540 99 | offset = (960 - 540) * 0.5 100 | intrinsics[0, 2] -= offset 101 | intrinsics[:2, :] *= scale 102 | elif center_crop_type == 'center_crop_for_dtu': 103 | scale = 384 / 1200 104 | offset = (1600 - 1200) * 0.5 105 | intrinsics[0, 2] -= offset 106 | intrinsics[:2, :] *= scale 107 | elif center_crop_type == 'padded_for_dtu': 108 | scale = 384 / 1200 109 | offset = 0 110 | intrinsics[0, 2] -= offset 111 | intrinsics[:2, :] *= scale 112 | elif center_crop_type == 'no_crop': # for scannet dataset, we already adjust the camera intrinsic duing preprocessing so nothing to be done here 113 | pass 114 | else: 115 | raise NotImplementedError 116 | 117 | self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) 118 | self.pose_all.append(torch.from_numpy(pose).float()) 119 | 120 | self.rgb_images = [] 121 | for path in image_paths: 122 | rgb = rend_util.load_rgb(path) 123 | rgb = rgb.reshape(3, -1).transpose(1, 0) 124 | self.rgb_images.append(torch.from_numpy(rgb).float()) 125 | 126 | self.depth_images = [] 127 | self.normal_images = [] 128 | 129 | for dpath, npath in zip(depth_paths, normal_paths): 130 | depth = np.load(dpath) 131 | self.depth_images.append(torch.from_numpy(depth.reshape(-1, 1)).float()) 132 | 133 | normal = np.load(npath) 134 | normal = normal.reshape(3, -1).transpose(1, 0) 135 | # important as the output of omnidata is normalized 136 | normal = normal * 2. - 1. 137 | self.normal_images.append(torch.from_numpy(normal).float()) 138 | 139 | # load semantic 140 | self.semantic_images = [] 141 | for spath in semantic_paths: 142 | semantic_ori = cv2.imread(spath, cv2.IMREAD_UNCHANGED).astype(np.int32) 143 | semantic = np.copy(semantic_ori) 144 | ins_list = np.unique(semantic_ori) 145 | if self.label_mapping is not None: 146 | for i in ins_list: 147 | semantic[semantic_ori ==i] = self.label_mapping.index(self.instance_mapping_dict[i]) 148 | self.semantic_images.append(torch.from_numpy(semantic.reshape(-1, 1)).float()) 149 | 150 | # load mask 151 | self.mask_images = [] 152 | if mask_paths is None: 153 | for depth in self.depth_images: 154 | mask = torch.ones_like(depth) 155 | self.mask_images.append(mask) 156 | else: 157 | for path in mask_paths: 158 | mask = np.load(path) 159 | self.mask_images.append(torch.from_numpy(mask.reshape(-1, 1)).float()) 160 | 161 | def __len__(self): 162 | return self.n_images 163 | 164 | def __getitem__(self, idx): 165 | if self.num_views >= 0: 166 | image_ids = [25, 22, 28, 40, 44, 48, 0, 8, 13][:self.num_views] 167 | idx = image_ids[random.randint(0, self.num_views - 1)] 168 | 169 | uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32) 170 | uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float() 171 | uv = uv.reshape(2, -1).transpose(1, 0) 172 | 173 | sample = { 174 | "uv": uv, 175 | "intrinsics": self.intrinsics_all[idx], 176 | "pose": self.pose_all[idx] 177 | } 178 | 179 | ground_truth = { 180 | "rgb": self.rgb_images[idx], 181 | "depth": self.depth_images[idx], 182 | "mask": self.mask_images[idx], 183 | "normal": self.normal_images[idx], 184 | "segs": self.semantic_images[idx] 185 | } 186 | 187 | if self.sampling_idx is not None: 188 | if (self.random_image_for_path is None) or (idx not in self.random_image_for_path): 189 | # print('sampling_idx:', self.sampling_idx) 190 | ground_truth["rgb"] = self.rgb_images[idx][self.sampling_idx, :] 191 | ground_truth["full_rgb"] = self.rgb_images[idx] 192 | ground_truth["normal"] = self.normal_images[idx][self.sampling_idx, :] 193 | ground_truth["depth"] = self.depth_images[idx][self.sampling_idx, :] 194 | ground_truth["full_depth"] = self.depth_images[idx] 195 | ground_truth["mask"] = self.mask_images[idx][self.sampling_idx, :] 196 | ground_truth["full_mask"] = self.mask_images[idx] 197 | ground_truth["segs"] = self.semantic_images[idx][self.sampling_idx, :] 198 | 199 | sample["uv"] = uv[self.sampling_idx, :] 200 | sample["is_patch"] = torch.tensor([False]) 201 | else: 202 | # sampling a patch from the image, this could be used for training with depth total variational loss 203 | # a fix patch sampling, which require the sampling_size should be a H*H continuous path 204 | patch_size = np.floor(np.sqrt(len(self.sampling_idx))).astype(np.int32) 205 | start = np.random.randint(self.img_res[1]-patch_size +1)*self.img_res[0] + np.random.randint(self.img_res[1]-patch_size +1) # the start coordinate 206 | idx_row = torch.arange(start, start + patch_size) 207 | patch_sampling_idx = torch.cat([idx_row + self.img_res[1]*m for m in range(patch_size)]) 208 | ground_truth["rgb"] = self.rgb_images[idx][patch_sampling_idx, :] 209 | ground_truth["full_rgb"] = self.rgb_images[idx] 210 | ground_truth["normal"] = self.normal_images[idx][patch_sampling_idx, :] 211 | ground_truth["depth"] = self.depth_images[idx][patch_sampling_idx, :] 212 | ground_truth["full_depth"] = self.depth_images[idx] 213 | ground_truth["mask"] = self.mask_images[idx][patch_sampling_idx, :] 214 | ground_truth["full_mask"] = self.mask_images[idx] 215 | ground_truth["segs"] = self.semantic_images[idx][patch_sampling_idx, :] 216 | 217 | sample["uv"] = uv[patch_sampling_idx, :] 218 | sample["is_patch"] = torch.tensor([True]) 219 | 220 | return idx, sample, ground_truth 221 | 222 | 223 | def collate_fn(self, batch_list): 224 | # get list of dictionaries and returns input, ground_true as dictionary for all batch instances 225 | batch_list = zip(*batch_list) 226 | 227 | all_parsed = [] 228 | for entry in batch_list: 229 | if type(entry[0]) is dict: 230 | # make them all into a new dict 231 | ret = {} 232 | for k in entry[0].keys(): 233 | ret[k] = torch.stack([obj[k] for obj in entry]) 234 | all_parsed.append(ret) 235 | else: 236 | all_parsed.append(torch.LongTensor(entry)) 237 | 238 | return tuple(all_parsed) 239 | 240 | def change_sampling_idx(self, sampling_size, sampling_pattern='random'): 241 | if sampling_size == -1: 242 | self.sampling_idx = None 243 | self.random_image_for_path = None 244 | else: 245 | if sampling_pattern == 'random': 246 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] 247 | self.random_image_for_path = None 248 | elif sampling_pattern == 'patch': 249 | self.sampling_idx = torch.randperm(self.total_pixels)[:sampling_size] 250 | self.random_image_for_path = torch.randperm(self.n_images, )[:int(self.n_images/10)] 251 | else: 252 | raise NotImplementedError('the sampling pattern is not implemented.') 253 | 254 | def get_scale_mat(self): 255 | return np.load(self.cam_file)['scale_mat_0'] -------------------------------------------------------------------------------- /code/hashencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .hashgrid import HashEncoder -------------------------------------------------------------------------------- /code/hashencoder/backend.py: -------------------------------------------------------------------------------- 1 | from distutils.command.build import build 2 | import os 3 | from torch.utils.cpp_extension import load 4 | from pathlib import Path 5 | 6 | Path('./tmp_build/').mkdir(parents=True, exist_ok=True) 7 | 8 | _src_path = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | _backend = load(name='_hash_encoder', 11 | extra_cflags=['-O3', '-std=c++14'], 12 | extra_cuda_cflags=[ 13 | '-O3', '-std=c++14', '-allow-unsupported-compiler', 14 | '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__', 15 | ], 16 | sources=[os.path.join(_src_path, 'src', f) for f in [ 17 | 'hashencoder.cu', 18 | 'bindings.cpp', 19 | ]], 20 | build_directory='./tmp_build/', 21 | verbose=True, 22 | ) 23 | 24 | __all__ = ['_backend'] -------------------------------------------------------------------------------- /code/hashencoder/hashgrid.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from math import ceil 3 | from cachetools import cached 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Function 9 | from torch.autograd.function import once_differentiable 10 | from torch.cuda.amp import custom_bwd, custom_fwd 11 | 12 | from .backend import _backend 13 | 14 | class _hash_encode(Function): 15 | @staticmethod 16 | @custom_fwd(cast_inputs=torch.half) 17 | def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, calc_grad_inputs=False): 18 | # inputs: [B, D], float in [0, 1] 19 | # embeddings: [sO, C], float 20 | # offsets: [L + 1], int 21 | # RETURN: [B, F], float 22 | 23 | inputs = inputs.contiguous() 24 | embeddings = embeddings.contiguous() 25 | offsets = offsets.contiguous() 26 | 27 | B, D = inputs.shape # batch size, coord dim 28 | L = offsets.shape[0] - 1 # level 29 | C = embeddings.shape[1] # embedding dim for each level 30 | S = np.log2(per_level_scale) # resolution multiplier at each level, apply log2 for later CUDA exp2f 31 | H = base_resolution # base resolution 32 | 33 | # L first, optimize cache for cuda kernel, but needs an extra permute later 34 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=inputs.dtype) 35 | 36 | if calc_grad_inputs: 37 | dy_dx = torch.empty(B, L * D * C, device=inputs.device, dtype=inputs.dtype) 38 | else: 39 | dy_dx = torch.empty(1, device=inputs.device, dtype=inputs.dtype) 40 | 41 | _backend.hash_encode_forward(inputs, embeddings, offsets, outputs, B, D, C, L, S, H, calc_grad_inputs, dy_dx) 42 | 43 | # permute back to [B, L * C] 44 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 45 | 46 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 47 | ctx.dims = [B, D, C, L, S, H] 48 | ctx.calc_grad_inputs = calc_grad_inputs 49 | 50 | return outputs 51 | 52 | @staticmethod 53 | @custom_bwd 54 | def backward(ctx, grad): 55 | 56 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 57 | B, D, C, L, S, H = ctx.dims 58 | calc_grad_inputs = ctx.calc_grad_inputs 59 | 60 | # grad: [B, L * C] --> [L, B, C] 61 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 62 | 63 | grad_inputs, grad_embeddings = _hash_encode_second_backward.apply(grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx) 64 | 65 | if calc_grad_inputs: 66 | return grad_inputs, grad_embeddings, None, None, None, None 67 | else: 68 | return None, grad_embeddings, None, None, None, None 69 | 70 | 71 | class _hash_encode_second_backward(Function): 72 | @staticmethod 73 | def forward(ctx, grad, inputs, embeddings, offsets, B, D, C, L, S, H, calc_grad_inputs, dy_dx): 74 | 75 | grad_inputs = torch.zeros_like(inputs) 76 | grad_embeddings = torch.zeros_like(embeddings) 77 | 78 | ctx.save_for_backward(grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings) 79 | ctx.dims = [B, D, C, L, S, H] 80 | ctx.calc_grad_inputs = calc_grad_inputs 81 | 82 | _backend.hash_encode_backward(grad, inputs, embeddings, offsets, grad_embeddings, B, D, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs) 83 | 84 | return grad_inputs, grad_embeddings 85 | 86 | @staticmethod 87 | def backward(ctx, grad_grad_inputs, grad_grad_embeddings): 88 | 89 | grad, inputs, embeddings, offsets, dy_dx, grad_inputs, grad_embeddings = ctx.saved_tensors 90 | B, D, C, L, S, H = ctx.dims 91 | calc_grad_inputs = ctx.calc_grad_inputs 92 | 93 | grad_grad = torch.zeros_like(grad) 94 | grad2_embeddings = torch.zeros_like(embeddings) 95 | 96 | _backend.hash_encode_second_backward(grad, inputs, embeddings, offsets, 97 | B, D, C, L, S, H, calc_grad_inputs, dy_dx, 98 | grad_grad_inputs, 99 | grad_grad, grad2_embeddings) 100 | 101 | return grad_grad, None, grad2_embeddings, None, None, None, None, None, None, None, None, None 102 | 103 | 104 | hash_encode = _hash_encode.apply 105 | 106 | 107 | class HashEncoder(nn.Module): 108 | def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, log2_hashmap_size=19, desired_resolution=None): 109 | super().__init__() 110 | 111 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 112 | if desired_resolution is not None: 113 | per_level_scale = np.exp2(np.log2(desired_resolution / base_resolution) / (num_levels - 1)) 114 | 115 | self.input_dim = input_dim # coord dims, 2 or 3 116 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 117 | self.level_dim = level_dim # encode channels per level 118 | self.per_level_scale = per_level_scale # multiply resolution by this scale at each level. 119 | self.log2_hashmap_size = log2_hashmap_size 120 | self.base_resolution = base_resolution 121 | self.output_dim = num_levels * level_dim 122 | 123 | if level_dim % 2 != 0: 124 | print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)') 125 | 126 | # allocate parameters 127 | offsets = [] 128 | offset = 0 129 | self.max_params = 2 ** log2_hashmap_size 130 | for i in range(num_levels): 131 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 132 | params_in_level = min(self.max_params, (resolution) ** input_dim) # limit max number 133 | #params_in_level = np.ceil(params_in_level / 8) * 8 # make divisible 134 | offsets.append(offset) 135 | offset += params_in_level 136 | offsets.append(offset) 137 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 138 | self.register_buffer('offsets', offsets) 139 | 140 | self.n_params = offsets[-1] * level_dim 141 | 142 | # parameters 143 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 144 | 145 | self.reset_parameters() 146 | 147 | def reset_parameters(self): 148 | std = 1e-4 149 | self.embeddings.data.uniform_(-std, std) 150 | 151 | def __repr__(self): 152 | return f"HashEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} base_resolution={self.base_resolution} per_level_scale={self.per_level_scale} params={tuple(self.embeddings.shape)}" 153 | 154 | def forward(self, inputs, size=1): 155 | # inputs: [..., input_dim], normalized real world positions in [-size, size] 156 | # return: [..., num_levels * level_dim] 157 | 158 | inputs = (inputs + size) / (2 * size) # map to [0, 1] 159 | 160 | prefix_shape = list(inputs.shape[:-1]) 161 | inputs = inputs.view(-1, self.input_dim) 162 | 163 | outputs = hash_encode(inputs, self.embeddings, self.offsets, self.per_level_scale, self.base_resolution, inputs.requires_grad) 164 | outputs = outputs.view(prefix_shape + [self.output_dim]) 165 | 166 | return outputs -------------------------------------------------------------------------------- /code/hashencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "hashencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("hash_encode_forward", &hash_encode_forward, "hash encode forward (CUDA)"); 7 | m.def("hash_encode_backward", &hash_encode_backward, "hash encode backward (CUDA)"); 8 | m.def("hash_encode_second_backward", &hash_encode_second_backward, "hash encode second backward (CUDA)"); 9 | } -------------------------------------------------------------------------------- /code/hashencoder/src/hashencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | // inputs: [B, D], float, in [0, 1] 9 | // embeddings: [sO, C], float 10 | // offsets: [L + 1], uint32_t 11 | // outputs: [B, L * C], float 12 | // H: base resolution 13 | void hash_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx); 14 | void hash_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs); 15 | void hash_encode_second_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, const at::Tensor grad_grad_inputs, at::Tensor grad_grad, at::Tensor grad2_embeddings); 16 | 17 | #endif -------------------------------------------------------------------------------- /code/model/density.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Density(nn.Module): 6 | def __init__(self, params_init={}): 7 | super().__init__() 8 | for p in params_init: 9 | param = nn.Parameter(torch.tensor(params_init[p])) 10 | setattr(self, p, param) 11 | 12 | def forward(self, sdf, beta=None): 13 | return self.density_func(sdf, beta=beta) 14 | 15 | 16 | class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf) 17 | def __init__(self, params_init={}, beta_min=0.0001): 18 | super().__init__(params_init=params_init) 19 | self.beta_min = torch.tensor(beta_min).cuda() 20 | 21 | def density_func(self, sdf, beta=None): 22 | if beta is None: 23 | beta = self.get_beta() 24 | 25 | alpha = 1 / beta 26 | return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) 27 | 28 | def get_beta(self): 29 | beta = self.beta.abs() + self.beta_min 30 | return beta 31 | 32 | 33 | class AbsDensity(Density): # like NeRF++ 34 | def density_func(self, sdf, beta=None): 35 | return torch.abs(sdf) 36 | 37 | 38 | class SimpleDensity(Density): # like NeRF 39 | def __init__(self, params_init={}, noise_std=1.0): 40 | super().__init__(params_init=params_init) 41 | self.noise_std = noise_std 42 | 43 | def density_func(self, sdf, beta=None): 44 | if self.training and self.noise_std > 0.0: 45 | noise = torch.randn(sdf.shape).cuda() * self.noise_std 46 | sdf = sdf + noise 47 | return torch.relu(sdf) 48 | -------------------------------------------------------------------------------- /code/model/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ 4 | 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | embed_fns = [] 12 | d = self.kwargs['input_dims'] 13 | out_dim = 0 14 | if self.kwargs['include_input']: 15 | embed_fns.append(lambda x: x) 16 | out_dim += d 17 | 18 | max_freq = self.kwargs['max_freq_log2'] 19 | N_freqs = self.kwargs['num_freqs'] 20 | 21 | if self.kwargs['log_sampling']: 22 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 25 | 26 | for freq in freq_bands: 27 | for p_fn in self.kwargs['periodic_fns']: 28 | embed_fns.append(lambda x, p_fn=p_fn, 29 | freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | def get_embedder(multires, input_dims=3): 39 | embed_kwargs = { 40 | 'include_input': True, 41 | 'input_dims': input_dims, 42 | 'max_freq_log2': multires-1, 43 | 'num_freqs': multires, 44 | 'log_sampling': True, 45 | 'periodic_fns': [torch.sin, torch.cos], 46 | } 47 | 48 | embedder_obj = Embedder(**embed_kwargs) 49 | def embed(x, eo=embedder_obj): return eo.embed(x) 50 | return embed, embedder_obj.out_dim 51 | -------------------------------------------------------------------------------- /code/model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import utils.general as utils 4 | import math 5 | import torch.nn.functional as F 6 | 7 | # copy from MiDaS 8 | def compute_scale_and_shift(prediction, target, mask): 9 | # system matrix: A = [[a_00, a_01], [a_10, a_11]] 10 | a_00 = torch.sum(mask * prediction * prediction, (1, 2)) 11 | a_01 = torch.sum(mask * prediction, (1, 2)) 12 | a_11 = torch.sum(mask, (1, 2)) 13 | 14 | # right hand side: b = [b_0, b_1] 15 | b_0 = torch.sum(mask * prediction * target, (1, 2)) 16 | b_1 = torch.sum(mask * target, (1, 2)) 17 | 18 | # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b 19 | x_0 = torch.zeros_like(b_0) 20 | x_1 = torch.zeros_like(b_1) 21 | 22 | det = a_00 * a_11 - a_01 * a_01 23 | valid = det.nonzero() 24 | 25 | x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] 26 | x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] 27 | 28 | return x_0, x_1 29 | 30 | 31 | def reduction_batch_based(image_loss, M): 32 | # average of all valid pixels of the batch 33 | 34 | # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0) 35 | divisor = torch.sum(M) 36 | 37 | if divisor == 0: 38 | return 0 39 | else: 40 | return torch.sum(image_loss) / divisor 41 | 42 | 43 | def reduction_image_based(image_loss, M): 44 | # mean of average of valid pixels of an image 45 | 46 | # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0) 47 | valid = M.nonzero() 48 | 49 | image_loss[valid] = image_loss[valid] / M[valid] 50 | 51 | return torch.mean(image_loss) 52 | 53 | 54 | def mse_loss(prediction, target, mask, reduction=reduction_batch_based): 55 | 56 | M = torch.sum(mask, (1, 2)) 57 | res = prediction - target 58 | image_loss = torch.sum(mask * res * res, (1, 2)) 59 | 60 | return reduction(image_loss, 2 * M) 61 | 62 | 63 | def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): 64 | 65 | M = torch.sum(mask, (1, 2)) 66 | 67 | diff = prediction - target 68 | diff = torch.mul(mask, diff) 69 | 70 | grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) 71 | mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) 72 | grad_x = torch.mul(mask_x, grad_x) 73 | 74 | grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) 75 | mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) 76 | grad_y = torch.mul(mask_y, grad_y) 77 | 78 | image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) 79 | 80 | return reduction(image_loss, M) 81 | 82 | 83 | def convert_number_to_digits(x): 84 | assert isinstance(x, int), 'the input value {} should be int'.format(x) 85 | v = 2**x 86 | # convert to 0-1 digits 87 | 88 | 89 | 90 | 91 | class MSELoss(nn.Module): 92 | def __init__(self, reduction='batch-based'): 93 | super().__init__() 94 | 95 | if reduction == 'batch-based': 96 | self.__reduction = reduction_batch_based 97 | else: 98 | self.__reduction = reduction_image_based 99 | 100 | def forward(self, prediction, target, mask): 101 | return mse_loss(prediction, target, mask, reduction=self.__reduction) 102 | 103 | 104 | class GradientLoss(nn.Module): 105 | def __init__(self, scales=4, reduction='batch-based'): 106 | super().__init__() 107 | 108 | if reduction == 'batch-based': 109 | self.__reduction = reduction_batch_based 110 | else: 111 | self.__reduction = reduction_image_based 112 | 113 | self.__scales = scales 114 | 115 | def forward(self, prediction, target, mask): 116 | total = 0 117 | 118 | for scale in range(self.__scales): 119 | step = pow(2, scale) 120 | 121 | total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step], 122 | mask[:, ::step, ::step], reduction=self.__reduction) 123 | 124 | return total 125 | 126 | 127 | class ScaleAndShiftInvariantLoss(nn.Module): 128 | def __init__(self, alpha=0.5, scales=4, reduction='batch-based'): 129 | super().__init__() 130 | 131 | self.__data_loss = MSELoss(reduction=reduction) 132 | self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) 133 | self.__alpha = alpha 134 | 135 | self.__prediction_ssi = None 136 | 137 | def forward(self, prediction, target, mask): 138 | 139 | scale, shift = compute_scale_and_shift(prediction, target, mask) 140 | self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) 141 | 142 | total = self.__data_loss(self.__prediction_ssi, target, mask) 143 | if self.__alpha > 0: 144 | total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) 145 | 146 | return total 147 | 148 | def __get_prediction_ssi(self): 149 | return self.__prediction_ssi 150 | 151 | prediction_ssi = property(__get_prediction_ssi) 152 | # end copy 153 | 154 | 155 | class MonoSDFLoss(nn.Module): 156 | def __init__(self, rgb_loss, 157 | eikonal_weight, 158 | smooth_weight = 0.005, 159 | depth_weight = 0.1, 160 | normal_l1_weight = 0.05, 161 | normal_cos_weight = 0.05, 162 | end_step = -1): 163 | super().__init__() 164 | self.eikonal_weight = eikonal_weight 165 | self.smooth_weight = smooth_weight 166 | self.depth_weight = depth_weight 167 | self.normal_l1_weight = normal_l1_weight 168 | self.normal_cos_weight = normal_cos_weight 169 | self.rgb_loss = utils.get_class(rgb_loss)(reduction='mean') 170 | 171 | self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1) 172 | 173 | # print(f"using weight for loss RGB_1.0 EK_{self.eikonal_weight} SM_{self.smooth_weight} Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}") 174 | 175 | self.step = 0 176 | self.end_step = end_step 177 | 178 | def get_rgb_loss(self,rgb_values, rgb_gt): 179 | rgb_gt = rgb_gt.reshape(-1, 3) 180 | rgb_loss = self.rgb_loss(rgb_values, rgb_gt) 181 | return rgb_loss 182 | 183 | def get_eikonal_loss(self, grad_theta): 184 | eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() 185 | return eikonal_loss 186 | 187 | def get_smooth_loss(self,model_outputs): 188 | # smoothness loss as unisurf 189 | g1 = model_outputs['grad_theta'] 190 | g2 = model_outputs['grad_theta_nei'] 191 | 192 | normals_1 = g1 / (g1.norm(2, dim=1).unsqueeze(-1) + 1e-5) 193 | normals_2 = g2 / (g2.norm(2, dim=1).unsqueeze(-1) + 1e-5) 194 | smooth_loss = torch.norm(normals_1 - normals_2, dim=-1).mean() 195 | return smooth_loss 196 | 197 | def get_depth_loss(self, depth_pred, depth_gt, mask): 198 | # TODO remove hard-coded scaling for depth 199 | return self.depth_loss(depth_pred.reshape(1, 32, 32), (depth_gt * 50 + 0.5).reshape(1, 32, 32), mask.reshape(1, 32, 32)) 200 | 201 | def get_normal_loss(self, normal_pred, normal_gt): 202 | normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) 203 | normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) 204 | l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean() 205 | cos = (1. - torch.sum(normal_pred * normal_gt, dim = -1)).mean() 206 | return l1, cos 207 | 208 | def forward(self, model_outputs, ground_truth): 209 | # import pdb; pdb.set_trace() 210 | rgb_gt = ground_truth['rgb'].cuda() 211 | # monocular depth and normal 212 | depth_gt = ground_truth['depth'].cuda() 213 | normal_gt = ground_truth['normal'].cuda() 214 | 215 | depth_pred = model_outputs['depth_values'] 216 | normal_pred = model_outputs['normal_map'][None] 217 | 218 | rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt) 219 | 220 | if 'grad_theta' in model_outputs: 221 | eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta']) 222 | else: 223 | eikonal_loss = torch.tensor(0.0).cuda().float() 224 | 225 | # only supervised the foreground normal 226 | mask = ((model_outputs['sdf'] > 0.).any(dim=-1) & (model_outputs['sdf'] < 0.).any(dim=-1))[None, :, None] 227 | # combine with GT 228 | mask = (ground_truth['mask'] > 0.5).cuda() & mask 229 | 230 | depth_loss = self.get_depth_loss(depth_pred, depth_gt, mask) if self.depth_weight > 0 else torch.tensor(0.0).cuda().float() 231 | if isinstance(depth_loss, float): 232 | depth_loss = torch.tensor(0.0).cuda().float() 233 | 234 | normal_l1, normal_cos = self.get_normal_loss(normal_pred * mask, normal_gt) 235 | 236 | smooth_loss = self.get_smooth_loss(model_outputs) 237 | 238 | # compute decay weights 239 | if self.end_step > 0: 240 | decay = math.exp(-self.step / self.end_step * 10.) 241 | else: 242 | decay = 1.0 243 | 244 | self.step += 1 245 | 246 | loss = rgb_loss + \ 247 | self.eikonal_weight * eikonal_loss +\ 248 | self.smooth_weight * smooth_loss +\ 249 | decay * self.depth_weight * depth_loss +\ 250 | decay * self.normal_l1_weight * normal_l1 +\ 251 | decay * self.normal_cos_weight * normal_cos 252 | 253 | output = { 254 | 'loss': loss, 255 | 'rgb_loss': rgb_loss, 256 | 'eikonal_loss': eikonal_loss, 257 | 'smooth_loss': smooth_loss, 258 | 'depth_loss': depth_loss, 259 | 'normal_l1': normal_l1, 260 | 'normal_cos': normal_cos 261 | } 262 | 263 | return output 264 | 265 | 266 | class ObjectSDFPlusLoss(MonoSDFLoss): 267 | def __init__(self, rgb_loss, 268 | eikonal_weight, 269 | semantic_weight = 0.04, 270 | smooth_weight = 0.005, 271 | semantic_loss = torch.nn.CrossEntropyLoss(ignore_index = -1), 272 | depth_weight = 0.1, 273 | normal_l1_weight = 0.05, 274 | normal_cos_weight = 0.05, 275 | reg_vio_weight = 0.1, 276 | use_obj_opacity = True, 277 | bg_reg_weight = 0.1, 278 | end_step = -1): 279 | super().__init__( 280 | rgb_loss = rgb_loss, 281 | eikonal_weight = eikonal_weight, 282 | smooth_weight = smooth_weight, 283 | depth_weight = depth_weight, 284 | normal_l1_weight = normal_l1_weight, 285 | normal_cos_weight = normal_cos_weight, 286 | end_step = end_step) 287 | self.semantic_weight = semantic_weight 288 | self.bg_reg_weight = bg_reg_weight 289 | self.semantic_loss = utils.get_class(semantic_loss)(reduction='mean') if semantic_loss is not torch.nn.CrossEntropyLoss else torch.nn.CrossEntropyLoss(ignore_index = -1) 290 | self.reg_vio_weight = reg_vio_weight 291 | self.use_obj_opacity = use_obj_opacity 292 | 293 | print(f"[INFO]: using weight for loss RGB_1.0 EK_{self.eikonal_weight} SM_{self.smooth_weight} Depth_{self.depth_weight} NormalL1_{self.normal_l1_weight} NormalCos_{self.normal_cos_weight}\ 294 | Semantic_{self.semantic_weight}, semantic_loss_type_{self.semantic_loss} Use_object_opacity_{self.use_obj_opacity}") 295 | 296 | def get_semantic_loss(self, semantic_value, semantic_gt): 297 | semantic_gt = semantic_gt.squeeze() 298 | semantic_loss = self.semantic_loss(semantic_value, semantic_gt) 299 | return semantic_loss 300 | 301 | # violiation loss 302 | def get_violation_reg_loss(self, sdf_value): 303 | # turn to vector, sdf_value: [#rays, #objects] 304 | min_value, min_indice = torch.min(sdf_value, dim=1, keepdims=True) 305 | input = -sdf_value-min_value.detach() # add the min value for all tensor 306 | res = torch.relu(input).sum(dim=1, keepdims=True) - torch.relu(torch.gather(input, 1, min_indice)) 307 | loss = res.sum() 308 | return loss 309 | 310 | def object_distinct_loss(self, sdf_value, min_sdf): 311 | _, min_indice = torch.min(sdf_value.squeeze(), dim=1, keepdims=True) 312 | input = -sdf_value.squeeze() - min_sdf.detach() 313 | res = torch.relu(input).sum(dim=1, keepdims=True) - torch.relu(torch.gather(input, 1, min_indice)) 314 | loss = res.mean() 315 | return loss 316 | 317 | def object_opacity_loss(self, predict_opacity, gt_opacity, weight=None): 318 | # normalize predict_opacity 319 | # predict_opacity = torch.nn.functional.normalize(predict_opacity, p=1, dim=-1) 320 | target = torch.nn.functional.one_hot(gt_opacity.squeeze(), num_classes=predict_opacity.shape[1]).float() 321 | if weight is None: 322 | loss = F.binary_cross_entropy(predict_opacity.clamp(1e-4, 1-1e-4), target) 323 | return loss 324 | 325 | # background regularization loss following the desing in Sec 3.2 of RICO (https://arxiv.org/pdf/2303.08605.pdf) 326 | def bg_tv_loss(self, depth_pred, normal_pred, gt_mask): 327 | # the depth_pred and normal_pred should form a patch in image space, depth_pred: [ray, 1], normal_pred: [ray, 3], gt_mask: [1, ray, 1] 328 | size = int(math.sqrt(gt_mask.shape[1])) 329 | mask = gt_mask.reshape(size, size, -1) 330 | depth = depth_pred.reshape(size, size, -1) 331 | normal = torch.nn.functional.normalize(normal_pred, p=2, dim=-1).reshape(size, size, -1) 332 | loss = 0 333 | for stride in [1, 2, 4]: 334 | hd_d = torch.abs(depth[:, :-stride, :] - depth[:, stride:, :]) 335 | wd_d = torch.abs(depth[:-stride, :, :] - depth[stride:, :, :]) 336 | hd_n = torch.abs(normal[:, :-stride, :] - normal[:, stride:, :]) 337 | wd_n = torch.abs(normal[:-stride, :, :] - normal[stride:, :, :]) 338 | loss+= torch.mean(hd_d*mask[:, :-stride, :]) + torch.mean(wd_d*mask[:-stride, :, :]) 339 | loss+= torch.mean(hd_n*mask[:, :-stride, :]) + torch.mean(wd_n*mask[:-stride, :, :]) 340 | return loss 341 | 342 | 343 | def forward(self, model_outputs, ground_truth, call_reg=False, call_bg_reg=False): 344 | output = super().forward(model_outputs, ground_truth) 345 | if 'semantic_values' in model_outputs and not self.use_obj_opacity: # ObjectSDF loss: semantic field + cross entropy 346 | semantic_gt = ground_truth['segs'].cuda().long() 347 | semantic_loss = self.get_semantic_loss(model_outputs['semantic_values'], semantic_gt) 348 | elif "object_opacity" in model_outputs and self.use_obj_opacity: # ObjectSDF++ loss: occlusion-awared object opacity + MSE 349 | semantic_gt = ground_truth['segs'].cuda().long() 350 | semantic_loss = self.object_opacity_loss(model_outputs['object_opacity'], semantic_gt) 351 | else: 352 | semantic_loss = torch.tensor(0.0).cuda().float() 353 | 354 | if "sample_sdf" in model_outputs and call_reg: 355 | sample_sdf_loss = self.object_distinct_loss(model_outputs["sample_sdf"], model_outputs["sample_minsdf"]) 356 | else: 357 | sample_sdf_loss = torch.tensor(0.0).cuda().float() 358 | 359 | background_reg_loss = torch.tensor(0.0).cuda().float() 360 | # if call_bg_reg: 361 | # mask = (ground_truth['segs'] != 0).cuda() 362 | # background_reg_loss = self.bg_tv_loss(model_outputs['bg_depth_values'], model_outputs['bg_normal_map'], mask) 363 | # else: 364 | # background_reg_loss = torch.tensor(0.0).cuda().float() 365 | output['semantic_loss'] = semantic_loss 366 | output['collision_reg_loss'] = sample_sdf_loss 367 | output['background_reg_loss'] = background_reg_loss 368 | output['loss'] = output['loss'] + self.semantic_weight * semantic_loss + self.reg_vio_weight* sample_sdf_loss 369 | # + self.bg_reg_weight * background_reg_loss # <- this one is not used in the paper, but is helpful for background regularization 370 | return output -------------------------------------------------------------------------------- /code/model/network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | from utils import rend_util 6 | from model.embedder import * 7 | from model.density import LaplaceDensity 8 | from model.ray_sampler import ErrorBoundSampler 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | from torch import vmap 13 | 14 | class ImplicitNetwork(nn.Module): 15 | def __init__( 16 | self, 17 | feature_vector_size, 18 | sdf_bounding_sphere, 19 | d_in, 20 | d_out, 21 | dims, 22 | geometric_init=True, 23 | bias=1.0, 24 | skip_in=(), 25 | weight_norm=True, 26 | multires=0, 27 | sphere_scale=1.0, 28 | inside_outside=True, 29 | sigmoid = 10 30 | ): 31 | super().__init__() 32 | 33 | self.sdf_bounding_sphere = sdf_bounding_sphere 34 | self.sphere_scale = sphere_scale 35 | dims = [d_in] + dims + [d_out + feature_vector_size] 36 | 37 | self.embed_fn = None 38 | if multires > 0: 39 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 40 | self.embed_fn = embed_fn 41 | dims[0] = input_ch 42 | print(multires, dims) 43 | self.num_layers = len(dims) 44 | self.skip_in = skip_in 45 | self.d_out = d_out 46 | self.sigmoid = sigmoid 47 | 48 | for l in range(0, self.num_layers - 1): 49 | if l + 1 in self.skip_in: 50 | out_dim = dims[l + 1] - dims[0] 51 | else: 52 | out_dim = dims[l + 1] 53 | 54 | lin = nn.Linear(dims[l], out_dim) 55 | 56 | if geometric_init: 57 | if l == self.num_layers - 2: 58 | # Geometry initalization for compositional scene, bg SDF sign: inside + outside -, fg SDF sign: outside + inside - 59 | # The 0 index is the background SDF, the rest are the object SDFs 60 | # background SDF with postive value inside and nagative value outside 61 | torch.nn.init.normal_(lin.weight[:1, :], mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 62 | torch.nn.init.constant_(lin.bias[:1], bias) 63 | # inner objects with SDF initial with negative value inside and positive value outside, ~0.6 radius of background 64 | torch.nn.init.normal_(lin.weight[1:,:], mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 65 | torch.nn.init.constant_(lin.bias[1:], -0.6*bias) 66 | 67 | elif multires > 0 and l == 0: 68 | torch.nn.init.constant_(lin.bias, 0.0) 69 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 70 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 71 | elif multires > 0 and l in self.skip_in: 72 | torch.nn.init.constant_(lin.bias, 0.0) 73 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 74 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 75 | else: 76 | torch.nn.init.constant_(lin.bias, 0.0) 77 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 78 | 79 | if weight_norm: 80 | lin = nn.utils.weight_norm(lin) 81 | 82 | setattr(self, "lin" + str(l), lin) 83 | 84 | self.softplus = nn.Softplus(beta=100) 85 | self.pool = nn.MaxPool1d(self.d_out) 86 | 87 | def forward(self, input): 88 | if self.embed_fn is not None: 89 | input = self.embed_fn(input) 90 | 91 | x = input 92 | 93 | for l in range(0, self.num_layers - 1): 94 | lin = getattr(self, "lin" + str(l)) 95 | 96 | if l in self.skip_in: 97 | x = torch.cat([x, input], 1) / np.sqrt(2) 98 | 99 | x = lin(x) 100 | 101 | if l < self.num_layers - 2: 102 | x = self.softplus(x) 103 | 104 | return x 105 | 106 | def gradient(self, x): 107 | x.requires_grad_(True) 108 | y = self.forward(x)[:,:self.d_out] 109 | d_output = torch.ones_like(y[:, :1], requires_grad=False, device=y.device) 110 | g = [] 111 | for idx in range(y.shape[1]): 112 | gradients = torch.autograd.grad( 113 | outputs=y[:, idx:idx+1], 114 | inputs=x, 115 | grad_outputs=d_output, 116 | create_graph=True, 117 | retain_graph=True, 118 | only_inputs=True)[0] 119 | g.append(gradients) 120 | g = torch.cat(g) 121 | # add the gradient of minimum sdf 122 | sdf = -self.pool(-y.unsqueeze(1)).squeeze(-1) 123 | g_min_sdf = torch.autograd.grad( 124 | outputs=sdf, 125 | inputs=x, 126 | grad_outputs=d_output, 127 | create_graph=True, 128 | retain_graph=True, 129 | only_inputs=True)[0] 130 | g = torch.cat([g, g_min_sdf]) 131 | return g 132 | 133 | def get_outputs(self, x, beta=None): 134 | x.requires_grad_(True) 135 | output = self.forward(x) 136 | sdf_raw = output[:,:self.d_out] 137 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' 138 | if self.sdf_bounding_sphere > 0.0: 139 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 140 | sdf_raw = torch.minimum(sdf_raw, sphere_sdf.expand(sdf_raw.shape)) 141 | if beta == None: 142 | semantic = self.sigmoid * torch.sigmoid(-self.sigmoid * sdf_raw) 143 | else: 144 | semantic = 0.5/beta *torch.exp(-sdf_raw.abs()/beta) 145 | sdf = -self.pool(-sdf_raw.unsqueeze(1)).squeeze(-1) # get the minium value of sdf 146 | feature_vectors = output[:, self.d_out:] 147 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 148 | gradients = torch.autograd.grad( 149 | outputs=sdf, 150 | inputs=x, 151 | grad_outputs=d_output, 152 | create_graph=True, 153 | retain_graph=True, 154 | only_inputs=True)[0] 155 | 156 | return sdf, feature_vectors, gradients, semantic, sdf_raw 157 | 158 | def get_sdf_vals(self, x): 159 | sdf = self.forward(x)[:,:self.d_out] 160 | ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded ''' 161 | # sdf = -self.pool(-sdf) # get the minium value of sdf if bound apply in the final 162 | if self.sdf_bounding_sphere > 0.0: 163 | sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 164 | sdf = torch.minimum(sdf, sphere_sdf.expand(sdf.shape)) 165 | sdf = -self.pool(-sdf.unsqueeze(1)).squeeze(-1) # get the minium value of sdf if bound apply before min 166 | return sdf 167 | 168 | def get_sdf_raw(self, x): 169 | return self.forward(x)[:, :self.d_out] 170 | 171 | 172 | from hashencoder.hashgrid import HashEncoder 173 | class ObjectImplicitNetworkGrid(nn.Module): 174 | def __init__( 175 | self, 176 | feature_vector_size, 177 | sdf_bounding_sphere, 178 | d_in, 179 | d_out, 180 | dims, 181 | geometric_init=True, 182 | bias=1.0, # radius of the sphere in geometric initialization 183 | skip_in=(), 184 | weight_norm=True, 185 | multires=0, 186 | sphere_scale=1.0, 187 | inside_outside=False, 188 | base_size = 16, 189 | end_size = 2048, 190 | logmap = 19, 191 | num_levels=16, 192 | level_dim=2, 193 | divide_factor = 1.5, # used to normalize the points range for multi-res grid 194 | use_grid_feature = True, # use hash grid embedding or not, if not, it is a pure MLP with sin/cos embedding 195 | sigmoid = 20 196 | ): 197 | super().__init__() 198 | 199 | self.d_out = d_out 200 | self.sigmoid = sigmoid 201 | self.sdf_bounding_sphere = sdf_bounding_sphere 202 | self.sphere_scale = sphere_scale 203 | dims = [d_in] + dims + [d_out + feature_vector_size] 204 | self.embed_fn = None 205 | self.divide_factor = divide_factor 206 | self.grid_feature_dim = num_levels * level_dim 207 | self.use_grid_feature = use_grid_feature 208 | dims[0] += self.grid_feature_dim 209 | 210 | print(f"[INFO]: using hash encoder with {num_levels} levels, each level with feature dim {level_dim}") 211 | print(f"[INFO]: resolution:{base_size} -> {end_size} with hash map size {logmap}") 212 | self.encoding = HashEncoder(input_dim=3, num_levels=num_levels, level_dim=level_dim, 213 | per_level_scale=2, base_resolution=base_size, 214 | log2_hashmap_size=logmap, desired_resolution=end_size) 215 | 216 | ''' 217 | # can also use tcnn for multi-res grid as it now supports eikonal loss 218 | base_size = 16 219 | hash = True 220 | smoothstep = True 221 | self.encoding = tcnn.Encoding(3, { 222 | "otype": "HashGrid" if hash else "DenseGrid", 223 | "n_levels": 16, 224 | "n_features_per_level": 2, 225 | "log2_hashmap_size": 19, 226 | "base_resolution": base_size, 227 | "per_level_scale": 1.34, 228 | "interpolation": "Smoothstep" if smoothstep else "Linear" 229 | }) 230 | ''' 231 | 232 | if multires > 0: 233 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 234 | self.embed_fn = embed_fn 235 | dims[0] += input_ch - 3 236 | # print("network architecture") 237 | # print(dims) 238 | 239 | self.num_layers = len(dims) 240 | self.skip_in = skip_in 241 | for l in range(0, self.num_layers - 1): 242 | if l + 1 in self.skip_in: 243 | out_dim = dims[l + 1] - dims[0] 244 | else: 245 | out_dim = dims[l + 1] 246 | 247 | lin = nn.Linear(dims[l], out_dim) 248 | 249 | if geometric_init: 250 | if l == self.num_layers - 2: 251 | # Geometry initalization for compositional scene, bg SDF sign: inside + outside -, fg SDF sign: outside + inside - 252 | # The 0 index is the background SDF, the rest are the object SDFs 253 | # background SDF with postive value inside and nagative value outside 254 | torch.nn.init.normal_(lin.weight[:1, :], mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 255 | torch.nn.init.constant_(lin.bias[:1], bias) 256 | # inner objects with SDF initial with negative value inside and positive value outside, ~0.5 radius of background 257 | torch.nn.init.normal_(lin.weight[1:,:], mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 258 | torch.nn.init.constant_(lin.bias[1:], -0.5*bias) 259 | 260 | elif multires > 0 and l == 0: 261 | torch.nn.init.constant_(lin.bias, 0.0) 262 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 263 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 264 | elif multires > 0 and l in self.skip_in: 265 | torch.nn.init.constant_(lin.bias, 0.0) 266 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 267 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 268 | else: 269 | torch.nn.init.constant_(lin.bias, 0.0) 270 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 271 | 272 | if weight_norm: 273 | lin = nn.utils.weight_norm(lin) 274 | 275 | setattr(self, "lin" + str(l), lin) 276 | 277 | self.softplus = nn.Softplus(beta=100) 278 | self.cache_sdf = None 279 | 280 | self.pool = nn.MaxPool1d(d_out) 281 | self.relu = nn.ReLU() 282 | 283 | def forward(self, input): 284 | if self.use_grid_feature: 285 | # normalize point range as encoding assume points are in [-1, 1] 286 | # assert torch.max(input / self.divide_factor)<1 and torch.min(input / self.divide_factor)>-1, 'range out of [-1, 1], max: {}, min: {}'.format(torch.max(input / self.divide_factor), torch.min(input / self.divide_factor)) 287 | feature = self.encoding(input / self.divide_factor) 288 | else: 289 | feature = torch.zeros_like(input[:, :1].repeat(1, self.grid_feature_dim)) 290 | 291 | if self.embed_fn is not None: 292 | embed = self.embed_fn(input) 293 | input = torch.cat((embed, feature), dim=-1) 294 | else: 295 | input = torch.cat((input, feature), dim=-1) 296 | 297 | x = input 298 | 299 | for l in range(0, self.num_layers - 1): 300 | lin = getattr(self, "lin" + str(l)) 301 | 302 | if l in self.skip_in: 303 | x = torch.cat([x, input], 1) / np.sqrt(2) 304 | 305 | x = lin(x) 306 | 307 | if l < self.num_layers - 2: 308 | x = self.softplus(x) 309 | 310 | return x 311 | 312 | def gradient(self, x): 313 | x.requires_grad_(True) 314 | y = self.forward(x)[:,:self.d_out] 315 | d_output = torch.ones_like(y[:, :1], requires_grad=False, device=y.device) 316 | f = lambda v: torch.autograd.grad(outputs=y, 317 | inputs=x, 318 | grad_outputs=v.repeat(y.shape[0], 1), 319 | create_graph=True, 320 | retain_graph=True, 321 | only_inputs=True)[0] 322 | 323 | N = torch.eye(y.shape[1], requires_grad=False).to(y.device) 324 | 325 | # start_time = time.time() 326 | if self.use_grid_feature: # using hashing grid feature, cannot support vmap now 327 | g = torch.cat([torch.autograd.grad( 328 | outputs=y, 329 | inputs=x, 330 | grad_outputs=idx.repeat(y.shape[0], 1), 331 | create_graph=True, 332 | retain_graph=True, 333 | only_inputs=True)[0] for idx in N.unbind()]) 334 | # torch.cuda.synchronize() 335 | # print("time for computing gradient by for loop: ", time.time() - start_time, "s") 336 | 337 | # using vmap for batched gradient computation, if not using grid feature (pure MLP) 338 | else: 339 | g = vmap(f, in_dims=1)(N).reshape(-1, 3) 340 | 341 | # add the gradient of scene sdf 342 | sdf = -self.pool(-y.unsqueeze(1)).squeeze(-1) 343 | g_min_sdf = torch.autograd.grad( 344 | outputs=sdf, 345 | inputs=x, 346 | grad_outputs=d_output, 347 | create_graph=True, 348 | retain_graph=True, 349 | only_inputs=True)[0] 350 | g = torch.cat([g, g_min_sdf]) 351 | return g 352 | 353 | def get_outputs(self, x, beta=None): 354 | x.requires_grad_(True) 355 | output = self.forward(x) 356 | sdf_raw = output[:,:self.d_out] 357 | # if self.sdf_bounding_sphere > 0.0: 358 | # sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True)) 359 | # sdf_raw = torch.minimum(sdf_raw, sphere_sdf.expand(sdf_raw.shape)) 360 | 361 | if beta == None: 362 | semantic = self.sigmoid * torch.sigmoid(-self.sigmoid * sdf_raw) 363 | else: 364 | # change semantic to the gradianct of density 365 | semantic = 1/beta * (0.5 + 0.5 * sdf_raw.sign() * torch.expm1(-sdf_raw.abs() / beta)) 366 | sdf = -self.pool(-sdf_raw.unsqueeze(1)).squeeze(-1) # get the minium value of all objects sdf 367 | feature_vectors = output[:, self.d_out:] 368 | 369 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 370 | gradients = torch.autograd.grad( 371 | outputs=sdf, 372 | inputs=x, 373 | grad_outputs=d_output, 374 | create_graph=True, 375 | retain_graph=True, 376 | only_inputs=True)[0] 377 | 378 | return sdf, feature_vectors, gradients, semantic, sdf_raw 379 | 380 | def get_specific_outputs(self, x, idx): 381 | x.requires_grad_(True) 382 | output = self.forward(x) 383 | sdf_raw = output[:,:self.d_out] 384 | semantic = self.sigmoid * torch.sigmoid(-self.sigmoid * sdf_raw) 385 | sdf = -self.pool(-sdf_raw.unsqueeze(1)).squeeze(-1) 386 | 387 | feature_vectors = output[:, self.d_out:] 388 | d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) 389 | gradients = torch.autograd.grad( 390 | outputs=sdf, 391 | inputs=x, 392 | grad_outputs=d_output, 393 | create_graph=True, 394 | retain_graph=True, 395 | only_inputs=True)[0] 396 | 397 | return sdf, feature_vectors, gradients, semantic, output[:,:self.d_out] 398 | 399 | 400 | def get_sdf_vals(self, x): 401 | sdf = -self.pool(-self.forward(x)[:,:self.d_out].unsqueeze(1)).squeeze(-1) 402 | return sdf 403 | 404 | def get_sdf_raw(self, x): 405 | return self.forward(x)[:, :self.d_out] 406 | 407 | 408 | def mlp_parameters(self): 409 | parameters = [] 410 | for l in range(0, self.num_layers - 1): 411 | lin = getattr(self, "lin" + str(l)) 412 | parameters += list(lin.parameters()) 413 | return parameters 414 | 415 | def grid_parameters(self, verbose=False): 416 | if verbose: 417 | print("[INFO]: grid parameters", len(list(self.encoding.parameters()))) 418 | for p in self.encoding.parameters(): 419 | print(p.shape) 420 | return self.encoding.parameters() 421 | 422 | 423 | class RenderingNetwork(nn.Module): 424 | def __init__( 425 | self, 426 | feature_vector_size, 427 | mode, 428 | d_in, 429 | d_out, 430 | dims, 431 | weight_norm=True, 432 | multires_view=0, 433 | per_image_code = False 434 | ): 435 | super().__init__() 436 | 437 | self.mode = mode 438 | dims = [d_in + feature_vector_size] + dims + [d_out] 439 | 440 | self.embedview_fn = None 441 | if multires_view > 0: 442 | embedview_fn, input_ch = get_embedder(multires_view) 443 | self.embedview_fn = embedview_fn 444 | dims[0] += (input_ch - 3) 445 | 446 | self.per_image_code = per_image_code 447 | if self.per_image_code: 448 | # nerf in the wild parameter 449 | # parameters 450 | # maximum 1024 images 451 | self.embeddings = nn.Parameter(torch.empty(1024, 32)) 452 | std = 1e-4 453 | self.embeddings.data.uniform_(-std, std) 454 | dims[0] += 32 455 | 456 | # print("rendering network architecture:") 457 | # print(dims) 458 | 459 | self.num_layers = len(dims) 460 | 461 | for l in range(0, self.num_layers - 1): 462 | out_dim = dims[l + 1] 463 | lin = nn.Linear(dims[l], out_dim) 464 | 465 | if weight_norm: 466 | lin = nn.utils.weight_norm(lin) 467 | 468 | setattr(self, "lin" + str(l), lin) 469 | 470 | self.relu = nn.ReLU() 471 | self.sigmoid = torch.nn.Sigmoid() 472 | 473 | def forward(self, points, normals, view_dirs, feature_vectors, indices): 474 | if self.embedview_fn is not None: 475 | view_dirs = self.embedview_fn(view_dirs) 476 | 477 | if self.mode == 'idr': 478 | rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) 479 | elif self.mode == 'nerf': 480 | rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1) 481 | else: 482 | raise NotImplementedError 483 | 484 | if self.per_image_code: 485 | image_code = self.embeddings[indices].expand(rendering_input.shape[0], -1) 486 | rendering_input = torch.cat([rendering_input, image_code], dim=-1) 487 | 488 | x = rendering_input 489 | 490 | for l in range(0, self.num_layers - 1): 491 | lin = getattr(self, "lin" + str(l)) 492 | 493 | x = lin(x) 494 | 495 | if l < self.num_layers - 2: 496 | x = self.relu(x) 497 | 498 | x = self.sigmoid(x) 499 | return x 500 | 501 | 502 | class ObjectSDFPlusNetwork(nn.Module): 503 | def __init__(self, conf): 504 | super().__init__() 505 | self.feature_vector_size = conf.get_int('feature_vector_size') 506 | self.scene_bounding_sphere = conf.get_float('scene_bounding_sphere', default=1.0) 507 | self.white_bkgd = conf.get_bool('white_bkgd', default=False) 508 | self.bg_color = torch.tensor(conf.get_list("bg_color", default=[1.0, 1.0, 1.0])).float().cuda() 509 | 510 | Grid_MLP = conf.get_bool('Grid_MLP', default=False) 511 | self.Grid_MLP = Grid_MLP 512 | if Grid_MLP: 513 | self.implicit_network = ObjectImplicitNetworkGrid(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network')) 514 | else: 515 | self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0 if self.white_bkgd else self.scene_bounding_sphere, **conf.get_config('implicit_network')) 516 | 517 | self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.get_config('rendering_network')) 518 | 519 | self.density = LaplaceDensity(**conf.get_config('density')) 520 | self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, **conf.get_config('ray_sampler')) 521 | 522 | self.num_semantic = conf.get_int('implicit_network.d_out') 523 | 524 | def forward(self, input, indices): 525 | # Parse model input 526 | intrinsics = input["intrinsics"] 527 | uv = input["uv"] 528 | pose = input["pose"] 529 | 530 | ray_dirs, cam_loc = rend_util.get_camera_params(uv, pose, intrinsics) 531 | 532 | # we should use unnormalized ray direction for depth 533 | ray_dirs_tmp, _ = rend_util.get_camera_params(uv, torch.eye(4).to(pose.device)[None], intrinsics) 534 | depth_scale = ray_dirs_tmp[0, :, 2:] 535 | 536 | batch_size, num_pixels, _ = ray_dirs.shape 537 | 538 | cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) 539 | ray_dirs = ray_dirs.reshape(-1, 3) 540 | 541 | z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self) 542 | N_samples = z_vals.shape[1] 543 | 544 | points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) 545 | points_flat = points.reshape(-1, 3) 546 | 547 | dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) 548 | dirs_flat = dirs.reshape(-1, 3) 549 | 550 | beta_cur = self.density.get_beta() 551 | sdf, feature_vectors, gradients, semantic, sdf_raw = self.implicit_network.get_outputs(points_flat, beta=None) 552 | 553 | rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors, indices) 554 | rgb = rgb_flat.reshape(-1, N_samples, 3) 555 | 556 | semantic = semantic.reshape(-1, N_samples, self.num_semantic) 557 | weights, transmittance, dists = self.volume_rendering(z_vals, sdf) 558 | 559 | # rendering the occlusion-awared object opacity 560 | object_opacity = self.occlusion_opacity(z_vals, transmittance, dists, sdf_raw).sum(-1).transpose(0, 1) 561 | 562 | 563 | rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1) 564 | semantic_values = torch.sum(weights.unsqueeze(-1)*semantic, 1) 565 | depth_values = torch.sum(weights * z_vals, 1, keepdims=True) / (weights.sum(dim=1, keepdims=True) +1e-8) 566 | # we should scale rendered distance to depth along z direction 567 | depth_values = depth_scale * depth_values 568 | 569 | # white background assumption 570 | if self.white_bkgd: 571 | acc_map = torch.sum(weights, -1) 572 | rgb_values = rgb_values + (1. - acc_map[..., None]) * self.bg_color.unsqueeze(0) 573 | 574 | output = { 575 | 'rgb':rgb, 576 | 'semantic_values': semantic_values, # here semantic value calculated as in ObjectSDF 577 | 'object_opacity': object_opacity, 578 | 'rgb_values': rgb_values, 579 | 'depth_values': depth_values, 580 | 'z_vals': z_vals, 581 | 'depth_vals': z_vals * depth_scale, 582 | 'sdf': sdf.reshape(z_vals.shape), 583 | 'weights': weights, 584 | } 585 | 586 | if self.training: 587 | # Sample points for the eikonal loss 588 | n_eik_points = batch_size * num_pixels 589 | 590 | eikonal_points = torch.empty(n_eik_points, 3).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere).cuda() 591 | 592 | # add some of the near surface points 593 | eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3) 594 | eikonal_points = torch.cat([eikonal_points, eik_near_points], 0) 595 | # add some neighbour points as unisurf 596 | neighbour_points = eikonal_points + (torch.rand_like(eikonal_points) - 0.5) * 0.01 597 | eikonal_points = torch.cat([eikonal_points, neighbour_points], 0) 598 | 599 | grad_theta = self.implicit_network.gradient(eikonal_points) 600 | 601 | sample_sdf = self.implicit_network.get_sdf_raw(eikonal_points) 602 | sdf_value = self.implicit_network.get_sdf_vals(eikonal_points) 603 | output['sample_sdf'] = sample_sdf 604 | output['sample_minsdf'] = sdf_value 605 | 606 | # split gradient to eikonal points and heighbour ponits 607 | output['grad_theta'] = grad_theta[:grad_theta.shape[0]//2] 608 | output['grad_theta_nei'] = grad_theta[grad_theta.shape[0]//2:] 609 | 610 | # compute normal map 611 | normals = gradients / (gradients.norm(2, -1, keepdim=True) + 1e-6) 612 | normals = normals.reshape(-1, N_samples, 3) 613 | normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1) 614 | 615 | # transform to local coordinate system 616 | rot = pose[0, :3, :3].permute(1, 0).contiguous() 617 | normal_map = rot @ normal_map.permute(1, 0) 618 | normal_map = normal_map.permute(1, 0).contiguous() 619 | 620 | output['normal_map'] = normal_map 621 | 622 | return output 623 | 624 | def volume_rendering(self, z_vals, sdf): 625 | density_flat = self.density(sdf) 626 | density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples 627 | 628 | dists = z_vals[:, 1:] - z_vals[:, :-1] 629 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) 630 | 631 | # LOG SPACE 632 | free_energy = dists * density 633 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) # shift one step 634 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here 635 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now 636 | weights = alpha * transmittance # probability of the ray hits something here 637 | 638 | return weights, transmittance, dists 639 | 640 | def occlusion_opacity(self, z_vals, transmittance, dists, sdf_raw): 641 | obj_density = self.density(sdf_raw).transpose(0, 1).reshape(-1, dists.shape[0], dists.shape[1]) # [#object, #ray, #sample points] 642 | free_energy = dists * obj_density 643 | alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here 644 | object_weight = alpha * transmittance 645 | return object_weight -------------------------------------------------------------------------------- /code/model/ray_sampler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from tkinter.messagebox import NO 3 | import torch 4 | 5 | from utils import rend_util 6 | 7 | class RaySampler(metaclass=abc.ABCMeta): 8 | def __init__(self,near, far): 9 | self.near = near 10 | self.far = far 11 | 12 | @abc.abstractmethod 13 | def get_z_vals(self, ray_dirs, cam_loc, model): 14 | pass 15 | 16 | class UniformSampler(RaySampler): 17 | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1): 18 | #super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R 19 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75 if far == -1 else far) # default far is 2*R 20 | self.N_samples = N_samples 21 | self.scene_bounding_sphere = scene_bounding_sphere 22 | self.take_sphere_intersection = take_sphere_intersection 23 | 24 | # dtu and bmvs 25 | def get_z_vals_dtu_bmvs(self, ray_dirs, cam_loc, model): 26 | if not self.take_sphere_intersection: 27 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() 28 | else: 29 | sphere_intersections = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere) 30 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() 31 | far = sphere_intersections[:,1:] 32 | 33 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() 34 | z_vals = near * (1. - t_vals) + far * (t_vals) 35 | 36 | if model.training: 37 | # get intervals between samples 38 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 39 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 40 | lower = torch.cat([z_vals[..., :1], mids], -1) 41 | # stratified samples in those intervals 42 | t_rand = torch.rand(z_vals.shape).cuda() 43 | 44 | z_vals = lower + (upper - lower) * t_rand 45 | 46 | return z_vals, near, far 47 | 48 | def near_far_from_cube(self, rays_o, rays_d, bound): 49 | tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] 50 | tmax = (bound - rays_o) / (rays_d + 1e-15) 51 | near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] 52 | far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] 53 | # if far < near, means no intersection, set both near and far to inf (1e9 here) 54 | mask = far < near 55 | near[mask] = 1e9 56 | far[mask] = 1e9 57 | # restrict near to a minimal value 58 | near = torch.clamp(near, min=self.near) 59 | far = torch.clamp(far, max=self.far) 60 | return near, far 61 | 62 | # currently this is used for replica scannet and T&T 63 | def get_z_vals(self, ray_dirs, cam_loc, model): 64 | if not self.take_sphere_intersection: 65 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() 66 | else: 67 | _, far = self.near_far_from_cube(cam_loc, ray_dirs, bound=self.scene_bounding_sphere) 68 | near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() 69 | 70 | t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() 71 | z_vals = near * (1. - t_vals) + far * (t_vals) 72 | 73 | if model.training: 74 | # get intervals between samples 75 | mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) 76 | upper = torch.cat([mids, z_vals[..., -1:]], -1) 77 | lower = torch.cat([z_vals[..., :1], mids], -1) 78 | # stratified samples in those intervals 79 | t_rand = torch.rand(z_vals.shape).cuda() 80 | 81 | z_vals = lower + (upper - lower) * t_rand 82 | 83 | return z_vals, near, far 84 | 85 | 86 | class ErrorBoundSampler(RaySampler): 87 | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra, 88 | eps, beta_iters, max_total_iters, 89 | inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=1.0e-6): 90 | #super().__init__(near, 2.0 * scene_bounding_sphere) 91 | super().__init__(near, 2.0 * scene_bounding_sphere * 1.75) 92 | 93 | self.N_samples = N_samples 94 | self.N_samples_eval = N_samples_eval 95 | self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=True) # replica scannet and T&T courtroom 96 | #self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) # dtu and bmvs 97 | 98 | self.N_samples_extra = N_samples_extra 99 | 100 | self.eps = eps 101 | self.beta_iters = beta_iters 102 | self.max_total_iters = max_total_iters 103 | self.scene_bounding_sphere = scene_bounding_sphere 104 | self.add_tiny = add_tiny 105 | 106 | self.inverse_sphere_bg = inverse_sphere_bg 107 | if inverse_sphere_bg: 108 | self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0) 109 | 110 | def get_z_vals(self, ray_dirs, cam_loc, model): 111 | beta0 = model.density.get_beta().detach() 112 | 113 | # Start with uniform sampling 114 | z_vals, near, far = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model) 115 | samples, samples_idx = z_vals, None 116 | 117 | # Get maximum beta from the upper bound (Lemma 2) 118 | dists = z_vals[:, 1:] - z_vals[:, :-1] 119 | bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1) 120 | beta = torch.sqrt(bound) 121 | 122 | total_iters, not_converge = 0, True 123 | 124 | # Algorithm 1 125 | while not_converge and total_iters < self.max_total_iters: 126 | points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1) 127 | points_flat = points.reshape(-1, 3) 128 | 129 | # Calculating the SDF only for the new sampled points 130 | with torch.no_grad(): 131 | samples_sdf = model.implicit_network.get_sdf_vals(points_flat) 132 | if samples_idx is not None: 133 | sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]), 134 | samples_sdf.reshape(-1, samples.shape[1])], -1) 135 | sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1) 136 | else: 137 | sdf = samples_sdf 138 | 139 | 140 | # Calculating the bound d* (Theorem 1) 141 | d = sdf.reshape(z_vals.shape) 142 | dists = z_vals[:, 1:] - z_vals[:, :-1] 143 | a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs() 144 | first_cond = a.pow(2) + b.pow(2) <= c.pow(2) 145 | second_cond = a.pow(2) + c.pow(2) <= b.pow(2) 146 | d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda() 147 | d_star[first_cond] = b[first_cond] 148 | d_star[second_cond] = c[second_cond] 149 | s = (a + b + c) / 2.0 150 | area_before_sqrt = s * (s - a) * (s - b) * (s - c) 151 | mask = ~first_cond & ~second_cond & (b + c - a > 0) 152 | d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask]) 153 | d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign 154 | 155 | 156 | # Updating beta using line search 157 | curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star) 158 | beta[curr_error <= self.eps] = beta0 159 | beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta 160 | for j in range(self.beta_iters): 161 | beta_mid = (beta_min + beta_max) / 2. 162 | curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star) 163 | beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] 164 | beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] 165 | beta = beta_max 166 | 167 | # Upsample more points 168 | density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1)) 169 | 170 | dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) 171 | free_energy = dists * density 172 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) 173 | alpha = 1 - torch.exp(-free_energy) 174 | transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) 175 | weights = alpha * transmittance # probability of the ray hits something here 176 | 177 | # Check if we are done and this is the last sampling 178 | total_iters += 1 179 | not_converge = beta.max() > beta0 180 | 181 | if not_converge and total_iters < self.max_total_iters: 182 | ''' Sample more points proportional to the current error bound''' 183 | 184 | N = self.N_samples_eval 185 | 186 | bins = z_vals 187 | error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2) 188 | error_integral = torch.cumsum(error_per_section, dim=-1) 189 | bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1] 190 | 191 | pdf = bound_opacity + self.add_tiny 192 | pdf = pdf / torch.sum(pdf, -1, keepdim=True) 193 | cdf = torch.cumsum(pdf, -1) 194 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) 195 | 196 | else: 197 | ''' Sample the final sample set to be used in the volume rendering integral ''' 198 | 199 | N = self.N_samples 200 | 201 | bins = z_vals 202 | pdf = weights[..., :-1] 203 | pdf = pdf + 1e-5 # prevent nans 204 | pdf = pdf / torch.sum(pdf, -1, keepdim=True) 205 | cdf = torch.cumsum(pdf, -1) 206 | cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) 207 | 208 | 209 | # Invert CDF 210 | if (not_converge and total_iters < self.max_total_iters) or (not model.training): 211 | u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1) 212 | else: 213 | u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda() 214 | u = u.contiguous() 215 | 216 | inds = torch.searchsorted(cdf, u, right=True) 217 | below = torch.max(torch.zeros_like(inds - 1), inds - 1) 218 | above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) 219 | inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) 220 | 221 | matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] 222 | cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) 223 | bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) 224 | 225 | denom = (cdf_g[..., 1] - cdf_g[..., 0]) 226 | denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) 227 | t = (u - cdf_g[..., 0]) / denom 228 | samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) 229 | 230 | 231 | # Adding samples if we not converged 232 | if not_converge and total_iters < self.max_total_iters: 233 | z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1) 234 | 235 | 236 | z_samples = samples 237 | #TODO Use near and far from intersection 238 | near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda() 239 | if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection 240 | far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:] 241 | 242 | if self.N_samples_extra > 0: 243 | if model.training: 244 | sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra] 245 | else: 246 | sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long() 247 | z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1) 248 | else: 249 | z_vals_extra = torch.cat([near, far], -1) 250 | 251 | z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1) 252 | 253 | # add some of the near surface points 254 | idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda() 255 | z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1)) 256 | 257 | if self.inverse_sphere_bg: 258 | z_vals_inverse_sphere, _, _ = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model) 259 | z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere) 260 | z_vals = (z_vals, z_vals_inverse_sphere) 261 | 262 | return z_vals, z_samples_eik 263 | 264 | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star): 265 | density = model.density(sdf.reshape(z_vals.shape), beta=beta) 266 | shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1) 267 | integral_estimation = torch.cumsum(shifted_free_energy, dim=-1) 268 | error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2) 269 | error_integral = torch.cumsum(error_per_section, dim=-1) 270 | bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1]) 271 | 272 | return bound_opacity.max(-1)[0] 273 | 274 | -------------------------------------------------------------------------------- /code/training/exp_runner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../code') 4 | import argparse 5 | import torch 6 | 7 | import os 8 | from training.objectsdfplus_train import ObjectSDFPlusTrainRunner 9 | import datetime 10 | 11 | if __name__ == '__main__': 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 15 | parser.add_argument('--nepoch', type=int, default=2000, help='number of epochs to train for') 16 | parser.add_argument('--conf', type=str, default='./confs/dtu.conf') 17 | parser.add_argument('--expname', type=str, default='') 18 | parser.add_argument("--exps_folder", type=str, default="exps") 19 | #parser.add_argument('--gpu', type=str, default='auto', help='GPU to use [default: GPU auto]') 20 | parser.add_argument('--is_continue', default=False, action="store_true", 21 | help='If set, indicates continuing from a previous run.') 22 | parser.add_argument('--timestamp', default='latest', type=str, 23 | help='The timestamp of the run to be used in case of continuing from a previous run.') 24 | parser.add_argument('--checkpoint', default='latest', type=str, 25 | help='The checkpoint epoch of the run to be used in case of continuing from a previous run.') 26 | parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.') 27 | parser.add_argument('--cancel_vis', default=False, action="store_true", 28 | help='If set, cancel visualization in intermediate epochs.') 29 | # parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') # this is not required in torch 2.0 30 | parser.add_argument("--ft_folder", type=str, default=None, help='If set, finetune model from the given folder path') 31 | 32 | opt = parser.parse_args() 33 | 34 | ''' 35 | # if using GPUtil 36 | if opt.gpu == "auto": 37 | deviceIDs = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False, 38 | excludeID=[], excludeUUID=[]) 39 | gpu = deviceIDs[0] 40 | else: 41 | gpu = opt.gpu 42 | ''' 43 | # gpu = opt.local_rank 44 | 45 | # set distributed training 46 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 47 | rank = int(os.environ["RANK"]) 48 | world_size = int(os.environ['WORLD_SIZE']) 49 | local_rank = int(os.environ['LOCAL_RANK']) 50 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 51 | else: 52 | rank = -1 53 | world_size = -1 54 | local_rank = -1 55 | 56 | # print(opt.local_rank) 57 | torch.cuda.set_device(local_rank) 58 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank, timeout=datetime.timedelta(1, 1800)) 59 | torch.distributed.barrier() 60 | 61 | 62 | trainrunner = ObjectSDFPlusTrainRunner(conf=opt.conf, 63 | batch_size=opt.batch_size, 64 | nepochs=opt.nepoch, 65 | expname=opt.expname, 66 | gpu_index=local_rank, 67 | exps_folder_name=opt.exps_folder, 68 | is_continue=opt.is_continue, 69 | timestamp=opt.timestamp, 70 | checkpoint=opt.checkpoint, 71 | scan_id=opt.scan_id, 72 | do_vis=not opt.cancel_vis, 73 | ft_folder = opt.ft_folder 74 | ) 75 | 76 | trainrunner.run() 77 | -------------------------------------------------------------------------------- /code/training/objectsdfplus_train.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | from datetime import datetime 4 | from pyhocon import ConfigFactory 5 | import sys 6 | import torch 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | import utils.general as utils 11 | import utils.plots as plt 12 | from utils import rend_util 13 | from utils.general import get_time 14 | from torch.utils.tensorboard import SummaryWriter 15 | from model.loss import compute_scale_and_shift 16 | from utils.general import BackprojectDepth 17 | 18 | class ObjectSDFPlusTrainRunner(): 19 | def __init__(self,**kwargs): 20 | torch.set_default_dtype(torch.float32) 21 | torch.set_num_threads(1) 22 | 23 | self.conf = ConfigFactory.parse_file(kwargs['conf']) 24 | self.batch_size = kwargs['batch_size'] 25 | self.nepochs = kwargs['nepochs'] 26 | self.exps_folder_name = kwargs['exps_folder_name'] 27 | self.GPU_INDEX = kwargs['gpu_index'] 28 | 29 | self.expname = self.conf.get_string('train.expname') + kwargs['expname'] 30 | scan_id = kwargs['scan_id'] if kwargs['scan_id'] != -1 else self.conf.get_int('dataset.scan_id', default=-1) 31 | if scan_id != -1: 32 | self.expname = self.expname + '_{0}'.format(scan_id) 33 | 34 | self.finetune_folder = kwargs['ft_folder'] if kwargs['ft_folder'] is not None else None 35 | if kwargs['is_continue'] and kwargs['timestamp'] == 'latest': 36 | if os.path.exists(os.path.join('../',kwargs['exps_folder_name'],self.expname)): 37 | timestamps = os.listdir(os.path.join('../',kwargs['exps_folder_name'],self.expname)) 38 | if (len(timestamps)) == 0: 39 | is_continue = False 40 | timestamp = None 41 | else: 42 | timestamp = sorted(timestamps)[-1] 43 | is_continue = True 44 | else: 45 | is_continue = False 46 | timestamp = None 47 | else: 48 | timestamp = kwargs['timestamp'] 49 | is_continue = kwargs['is_continue'] 50 | 51 | if self.GPU_INDEX == 0: 52 | utils.mkdir_ifnotexists(os.path.join('../',self.exps_folder_name)) 53 | self.expdir = os.path.join('../', self.exps_folder_name, self.expname) 54 | utils.mkdir_ifnotexists(self.expdir) 55 | self.timestamp = '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) 56 | utils.mkdir_ifnotexists(os.path.join(self.expdir, self.timestamp)) 57 | 58 | self.plots_dir = os.path.join(self.expdir, self.timestamp, 'plots') 59 | utils.mkdir_ifnotexists(self.plots_dir) 60 | 61 | # create checkpoints dirs 62 | self.checkpoints_path = os.path.join(self.expdir, self.timestamp, 'checkpoints') 63 | utils.mkdir_ifnotexists(self.checkpoints_path) 64 | self.model_params_subdir = "ModelParameters" 65 | self.optimizer_params_subdir = "OptimizerParameters" 66 | self.scheduler_params_subdir = "SchedulerParameters" 67 | 68 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.model_params_subdir)) 69 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.optimizer_params_subdir)) 70 | utils.mkdir_ifnotexists(os.path.join(self.checkpoints_path, self.scheduler_params_subdir)) 71 | 72 | os.system("""cp -r {0} "{1}" """.format(kwargs['conf'], os.path.join(self.expdir, self.timestamp, 'runconf.conf'))) 73 | 74 | # if (not self.GPU_INDEX == 'ignore'): 75 | # os.environ["CUDA_VISIBLE_DEVICES"] = '{0}'.format(self.GPU_INDEX) 76 | 77 | print('[INFO]: shell command : {0}'.format(' '.join(sys.argv))) 78 | 79 | print('[INFO]: Loading data ...') 80 | 81 | dataset_conf = self.conf.get_config('dataset') 82 | if kwargs['scan_id'] != -1: 83 | dataset_conf['scan_id'] = kwargs['scan_id'] 84 | 85 | self.all_dataset = utils.get_class(self.conf.get_string('train.dataset_class'))(**dataset_conf) 86 | if hasattr(self.all_dataset, 'i_split'): 87 | # if you would like to split the dataset into train and test, assign 'i_split' attribute to the all_dataset 88 | self.train_dataset = torch.utils.data.Subset(self.all_dataset, self.all_dataset.i_split[0]) 89 | self.test_dataset = torch.utils.data.Subset(self.all_dataset, self.all_dataset.i_split[1]) 90 | else: 91 | self.train_dataset = torch.utils.data.Subset(self.all_dataset, range(self.all_dataset.n_images)) 92 | self.test_dataset = torch.utils.data.Subset(self.all_dataset, range(self.all_dataset.n_images)) 93 | 94 | self.max_total_iters = self.conf.get_int('train.max_total_iters', default=200000) 95 | self.ds_len = len(self.train_dataset) 96 | self.nepochs = int(self.max_total_iters / self.ds_len) # update nepochs as iters/len(dataset) 97 | print('[INFO]: Finish loading data. Data-set size: {0}'.format(self.ds_len)) 98 | 99 | self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, 100 | batch_size=self.batch_size, 101 | shuffle=True, 102 | collate_fn=self.all_dataset.collate_fn, 103 | num_workers=8, 104 | pin_memory=True) 105 | self.plot_dataloader = torch.utils.data.DataLoader(self.train_dataset, 106 | batch_size=self.conf.get_int('plot.plot_nimgs'), 107 | shuffle=True, 108 | collate_fn=self.all_dataset.collate_fn 109 | ) 110 | 111 | conf_model = self.conf.get_config('model') 112 | self.model = utils.get_class(self.conf.get_string('train.model_class'))(conf=conf_model) 113 | 114 | self.Grid_MLP = self.model.Grid_MLP 115 | if torch.cuda.is_available(): 116 | self.model.cuda() 117 | 118 | self.loss = utils.get_class(self.conf.get_string('train.loss_class'))(**self.conf.get_config('loss')) 119 | 120 | # The MLP and hash grid should have different learning rates 121 | self.lr = self.conf.get_float('train.learning_rate') 122 | self.lr_factor_for_grid = self.conf.get_float('train.lr_factor_for_grid', default=1.0) 123 | 124 | if self.Grid_MLP: 125 | self.optimizer = torch.optim.Adam([ 126 | {'name': 'encoding', 'params': list(self.model.implicit_network.grid_parameters()), 127 | 'lr': self.lr * self.lr_factor_for_grid}, 128 | {'name': 'net', 'params': list(self.model.implicit_network.mlp_parameters()) +\ 129 | list(self.model.rendering_network.parameters()), 130 | 'lr': self.lr}, 131 | {'name': 'density', 'params': list(self.model.density.parameters()), 132 | 'lr': self.lr}, 133 | ], betas=(0.9, 0.99), eps=1e-15) 134 | else: 135 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 136 | 137 | # Exponential learning rate scheduler 138 | decay_rate = self.conf.get_float('train.sched_decay_rate', default=0.1) 139 | decay_steps = self.nepochs * len(self.train_dataset) 140 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, decay_rate ** (1./decay_steps)) 141 | 142 | self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[self.GPU_INDEX], broadcast_buffers=False, find_unused_parameters=True) 143 | 144 | self.do_vis = kwargs['do_vis'] 145 | 146 | self.start_epoch = 0 147 | # Loading a pretrained model for finetuning, the model path can be provided by self.finetune_folder 148 | if is_continue or self.finetune_folder is not None: 149 | old_checkpnts_dir = os.path.join(self.expdir, timestamp, 'checkpoints') if self.finetune_folder is None\ 150 | else os.path.join(self.finetune_folder, 'checkpoints') 151 | 152 | print('[INFO]: Loading pretrained model from {}'.format(old_checkpnts_dir)) 153 | saved_model_state = torch.load( 154 | os.path.join(old_checkpnts_dir, 'ModelParameters', str(kwargs['checkpoint']) + ".pth")) 155 | self.model.load_state_dict(saved_model_state["model_state_dict"]) 156 | self.start_epoch = saved_model_state['epoch'] 157 | 158 | data = torch.load( 159 | os.path.join(old_checkpnts_dir, 'OptimizerParameters', str(kwargs['checkpoint']) + ".pth")) 160 | self.optimizer.load_state_dict(data["optimizer_state_dict"]) 161 | 162 | data = torch.load( 163 | os.path.join(old_checkpnts_dir, self.scheduler_params_subdir, str(kwargs['checkpoint']) + ".pth")) 164 | self.scheduler.load_state_dict(data["scheduler_state_dict"]) 165 | 166 | self.num_pixels = self.conf.get_int('train.num_pixels') 167 | self.total_pixels = self.all_dataset.total_pixels 168 | self.img_res = self.all_dataset.img_res 169 | self.n_batches = len(self.train_dataloader) 170 | self.plot_freq = self.conf.get_int('train.plot_freq') 171 | self.checkpoint_freq = self.conf.get_int('train.checkpoint_freq', default=100) 172 | self.split_n_pixels = self.conf.get_int('train.split_n_pixels', default=10000) 173 | self.plot_conf = self.conf.get_config('plot') 174 | self.backproject = BackprojectDepth(1, self.img_res[0], self.img_res[1]).cuda() 175 | 176 | self.add_objectvio_iter = self.conf.get_int('train.add_objectvio_iter', default=0) 177 | 178 | def save_checkpoints(self, epoch): 179 | torch.save( 180 | {"epoch": epoch, "model_state_dict": self.model.state_dict()}, 181 | os.path.join(self.checkpoints_path, self.model_params_subdir, str(epoch) + ".pth")) 182 | torch.save( 183 | {"epoch": epoch, "model_state_dict": self.model.state_dict()}, 184 | os.path.join(self.checkpoints_path, self.model_params_subdir, "latest.pth")) 185 | 186 | torch.save( 187 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, 188 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, str(epoch) + ".pth")) 189 | torch.save( 190 | {"epoch": epoch, "optimizer_state_dict": self.optimizer.state_dict()}, 191 | os.path.join(self.checkpoints_path, self.optimizer_params_subdir, "latest.pth")) 192 | 193 | torch.save( 194 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, 195 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, str(epoch) + ".pth")) 196 | torch.save( 197 | {"epoch": epoch, "scheduler_state_dict": self.scheduler.state_dict()}, 198 | os.path.join(self.checkpoints_path, self.scheduler_params_subdir, "latest.pth")) 199 | 200 | def run(self): 201 | print("training...") 202 | if self.GPU_INDEX == 0 : 203 | self.writer = SummaryWriter(log_dir=os.path.join(self.plots_dir, 'logs')) 204 | 205 | self.iter_step = 0 206 | for epoch in range(self.start_epoch, self.nepochs + 1): 207 | 208 | if self.GPU_INDEX == 0 and epoch % self.checkpoint_freq == 0: 209 | self.save_checkpoints(epoch) 210 | 211 | if self.GPU_INDEX == 0 and self.do_vis and epoch % self.plot_freq == 0: 212 | self.model.eval() 213 | 214 | self.all_dataset.change_sampling_idx(-1) 215 | 216 | indices, model_input, ground_truth = next(iter(self.plot_dataloader)) 217 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 218 | model_input["uv"] = model_input["uv"].cuda() 219 | model_input['pose'] = model_input['pose'].cuda() 220 | 221 | split = utils.split_input(model_input, self.total_pixels, n_pixels=self.split_n_pixels) 222 | res = [] 223 | for s in tqdm(split): 224 | out = self.model(s, indices) 225 | d = {'rgb_values': out['rgb_values'].detach(), 226 | 'normal_map': out['normal_map'].detach(), 227 | 'depth_values': out['depth_values'].detach()} 228 | if 'rgb_un_values' in out: 229 | d['rgb_un_values'] = out['rgb_un_values'].detach() 230 | if 'semantic_values' in out: 231 | d['semantic_values'] = torch.argmax(out['semantic_values'].detach(),dim=1) 232 | res.append(d) 233 | 234 | batch_size = ground_truth['rgb'].shape[0] 235 | model_outputs = utils.merge_output(res, self.total_pixels, batch_size) 236 | plot_data = self.get_plot_data(model_input, model_outputs, model_input['pose'], ground_truth['rgb'], ground_truth['normal'], ground_truth['depth'], ground_truth['segs']) 237 | 238 | plt.plot(self.model.module.implicit_network, 239 | indices, 240 | plot_data, 241 | self.plots_dir, 242 | epoch, 243 | self.img_res, 244 | **self.plot_conf 245 | ) 246 | 247 | self.model.train() 248 | self.all_dataset.change_sampling_idx(self.num_pixels) 249 | 250 | for data_index, (indices, model_input, ground_truth) in enumerate(self.train_dataloader): 251 | model_input["intrinsics"] = model_input["intrinsics"].cuda() 252 | model_input["uv"] = model_input["uv"].cuda() 253 | model_input['pose'] = model_input['pose'].cuda() 254 | 255 | self.optimizer.zero_grad() 256 | 257 | model_outputs = self.model(model_input, indices) 258 | 259 | loss_output = self.loss(model_outputs, ground_truth, call_reg=True) if\ 260 | self.iter_step >= self.add_objectvio_iter else self.loss(model_outputs, ground_truth, call_reg=False) 261 | # if change the pixel sampling pattern to patch, then you can add a TV loss to enforce some smoothness constraint 262 | loss = loss_output['loss'] 263 | loss.backward() 264 | self.optimizer.step() 265 | 266 | psnr = rend_util.get_psnr(model_outputs['rgb_values'], 267 | ground_truth['rgb'].cuda().reshape(-1,3)) 268 | 269 | self.iter_step += 1 270 | 271 | if self.GPU_INDEX == 0 and data_index %20 == 0: 272 | print( 273 | '{0}_{1} [{2}] ({3}/{4}): loss = {5}, rgb_loss = {6}, eikonal_loss = {7}, psnr = {8}, bete={9}, alpha={10}, semantic_loss = {11}, reg_loss = {12}' 274 | .format(self.expname, self.timestamp, epoch, data_index, self.n_batches, loss.item(), 275 | loss_output['rgb_loss'].item(), 276 | loss_output['eikonal_loss'].item(), 277 | psnr.item(), 278 | self.model.module.density.get_beta().item(), 279 | 1. / self.model.module.density.get_beta().item(), 280 | loss_output['semantic_loss'].item(), 281 | loss_output['collision_reg_loss'].item())) 282 | 283 | self.writer.add_scalar('Loss/loss', loss.item(), self.iter_step) 284 | self.writer.add_scalar('Loss/color_loss', loss_output['rgb_loss'].item(), self.iter_step) 285 | self.writer.add_scalar('Loss/eikonal_loss', loss_output['eikonal_loss'].item(), self.iter_step) 286 | self.writer.add_scalar('Loss/smooth_loss', loss_output['smooth_loss'].item(), self.iter_step) 287 | self.writer.add_scalar('Loss/depth_loss', loss_output['depth_loss'].item(), self.iter_step) 288 | self.writer.add_scalar('Loss/normal_l1_loss', loss_output['normal_l1'].item(), self.iter_step) 289 | self.writer.add_scalar('Loss/normal_cos_loss', loss_output['normal_cos'].item(), self.iter_step) 290 | if 'semantic_loss' in loss_output: 291 | self.writer.add_scalar('Loss/semantic_loss', loss_output['semantic_loss'].item(), self.iter_step) 292 | if 'collision_reg_loss' in loss_output: 293 | self.writer.add_scalar('Loss/collision_reg_loss', loss_output['collision_reg_loss'].item(), self.iter_step) 294 | 295 | self.writer.add_scalar('Statistics/beta', self.model.module.density.get_beta().item(), self.iter_step) 296 | self.writer.add_scalar('Statistics/alpha', 1. / self.model.module.density.get_beta().item(), self.iter_step) 297 | self.writer.add_scalar('Statistics/psnr', psnr.item(), self.iter_step) 298 | 299 | if self.Grid_MLP: 300 | self.writer.add_scalar('Statistics/lr0', self.optimizer.param_groups[0]['lr'], self.iter_step) 301 | self.writer.add_scalar('Statistics/lr1', self.optimizer.param_groups[1]['lr'], self.iter_step) 302 | self.writer.add_scalar('Statistics/lr2', self.optimizer.param_groups[2]['lr'], self.iter_step) 303 | 304 | self.all_dataset.change_sampling_idx(self.num_pixels) 305 | self.scheduler.step() 306 | 307 | if self.GPU_INDEX == 0: 308 | self.save_checkpoints(epoch) 309 | 310 | 311 | def get_plot_data(self, model_input, model_outputs, pose, rgb_gt, normal_gt, depth_gt, seg_gt): 312 | batch_size, num_samples, _ = rgb_gt.shape 313 | 314 | rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3) 315 | normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3) 316 | normal_map = (normal_map + 1.) / 2. 317 | 318 | depth_map = model_outputs['depth_values'].reshape(batch_size, num_samples) 319 | depth_gt = depth_gt.to(depth_map.device) 320 | scale, shift = compute_scale_and_shift(depth_map[..., None], depth_gt, depth_gt > 0.) 321 | depth_map = depth_map * scale + shift 322 | 323 | seg_map = model_outputs['semantic_values'].reshape(batch_size, num_samples) 324 | seg_gt = seg_gt.to(seg_map.device) 325 | 326 | # save point cloud 327 | depth = depth_map.reshape(1, 1, self.img_res[0], self.img_res[1]) 328 | pred_points = self.get_point_cloud(depth, model_input, model_outputs) 329 | 330 | gt_depth = depth_gt.reshape(1, 1, self.img_res[0], self.img_res[1]) 331 | gt_points = self.get_point_cloud(gt_depth, model_input, model_outputs) 332 | 333 | plot_data = { 334 | 'rgb_gt': rgb_gt, 335 | 'normal_gt': (normal_gt + 1.)/ 2., 336 | 'depth_gt': depth_gt, 337 | 'seg_gt': seg_gt, 338 | 'pose': pose, 339 | 'rgb_eval': rgb_eval, 340 | 'normal_map': normal_map, 341 | 'depth_map': depth_map, 342 | 'seg_map': seg_map, 343 | "pred_points": pred_points, 344 | "gt_points": gt_points, 345 | } 346 | 347 | return plot_data 348 | 349 | def get_point_cloud(self, depth, model_input, model_outputs): 350 | color = model_outputs["rgb_values"].reshape(-1, 3) 351 | 352 | K_inv = torch.inverse(model_input["intrinsics"][0])[None] 353 | points = self.backproject(depth, K_inv)[0, :3, :].permute(1, 0) 354 | points = torch.cat([points, color], dim=-1) 355 | return points.detach().cpu().numpy() 356 | -------------------------------------------------------------------------------- /code/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import time 7 | from torchvision import transforms 8 | import numpy as np 9 | 10 | def mkdir_ifnotexists(directory): 11 | if not os.path.exists(directory): 12 | os.mkdir(directory) 13 | 14 | def get_class(kls): 15 | parts = kls.split('.') 16 | module = ".".join(parts[:-1]) 17 | m = __import__(module) 18 | for comp in parts[1:]: 19 | m = getattr(m, comp) 20 | return m 21 | 22 | def glob_imgs(path): 23 | imgs = [] 24 | for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG']: 25 | imgs.extend(glob(os.path.join(path, ext))) 26 | return imgs 27 | 28 | def split_input(model_input, total_pixels, n_pixels=10000): 29 | ''' 30 | Split the input to fit Cuda memory for large resolution. 31 | Can decrease the value of n_pixels in case of cuda out of memory error. 32 | ''' 33 | split = [] 34 | for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)): 35 | data = model_input.copy() 36 | data['uv'] = torch.index_select(model_input['uv'], 1, indx) 37 | if 'object_mask' in data: 38 | data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx) 39 | if 'depth' in data: 40 | data['depth'] = torch.index_select(model_input['depth'], 1, indx) 41 | split.append(data) 42 | return split 43 | 44 | def merge_output(res, total_pixels, batch_size): 45 | ''' Merge the split output. ''' 46 | 47 | model_outputs = {} 48 | for entry in res[0]: 49 | if res[0][entry] is None: 50 | continue 51 | if len(res[0][entry].shape) == 1: 52 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res], 53 | 1).reshape(batch_size * total_pixels) 54 | else: 55 | model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res], 56 | 1).reshape(batch_size * total_pixels, -1) 57 | 58 | return model_outputs 59 | 60 | def concat_home_dir(path): 61 | return os.path.join(os.environ['HOME'],'data',path) 62 | 63 | def get_time(): 64 | torch.cuda.synchronize() 65 | return time.time() 66 | 67 | trans_topil = transforms.ToPILImage() 68 | 69 | 70 | class BackprojectDepth(nn.Module): 71 | """Layer to transform a depth image into a point cloud 72 | """ 73 | def __init__(self, batch_size, height, width): 74 | super(BackprojectDepth, self).__init__() 75 | 76 | self.batch_size = batch_size 77 | self.height = height 78 | self.width = width 79 | 80 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 81 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 82 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 83 | requires_grad=False) 84 | 85 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 86 | requires_grad=False) 87 | 88 | self.pix_coords = torch.unsqueeze(torch.stack( 89 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 90 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 91 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 92 | requires_grad=False) 93 | 94 | def forward(self, depth, inv_K): 95 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 96 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 97 | cam_points = torch.cat([cam_points, self.ones], 1) 98 | return cam_points 99 | -------------------------------------------------------------------------------- /code/utils/rend_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import skimage 4 | import cv2 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def get_psnr(img1, img2, normalize_rgb=False): 10 | if normalize_rgb: # [-1,1] --> [0,1] 11 | img1 = (img1 + 1.) / 2. 12 | img2 = (img2 + 1. ) / 2. 13 | 14 | mse = torch.mean((img1 - img2) ** 2) 15 | psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda()) 16 | 17 | return psnr 18 | 19 | 20 | def load_rgb(path, normalize_rgb = False): 21 | img = imageio.imread(path) 22 | img = skimage.img_as_float32(img) 23 | 24 | if normalize_rgb: # [-1,1] --> [0,1] 25 | img -= 0.5 26 | img *= 2. 27 | img = img.transpose(2, 0, 1) 28 | return img 29 | 30 | 31 | def load_K_Rt_from_P(filename, P=None): 32 | if P is None: 33 | lines = open(filename).read().splitlines() 34 | if len(lines) == 4: 35 | lines = lines[1:] 36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 37 | P = np.asarray(lines).astype(np.float32).squeeze() 38 | 39 | out = cv2.decomposeProjectionMatrix(P) 40 | K = out[0] 41 | R = out[1] 42 | t = out[2] 43 | 44 | # import pdb; pdb.set_trace() 45 | 46 | K = K/K[2,2] 47 | intrinsics = np.eye(4) 48 | intrinsics[:3, :3] = K 49 | 50 | pose = np.eye(4, dtype=np.float32) 51 | pose[:3, :3] = R.transpose() 52 | pose[:3,3] = (t[:3] / t[3])[:,0] 53 | 54 | return intrinsics, pose 55 | 56 | 57 | def get_camera_params(uv, pose, intrinsics): 58 | if pose.shape[1] == 7: #In case of quaternion vector representation 59 | cam_loc = pose[:, 4:] 60 | R = quat_to_rot(pose[:,:4]) 61 | p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float() 62 | p[:, :3, :3] = R 63 | p[:, :3, 3] = cam_loc 64 | else: # In case of pose matrix representation 65 | cam_loc = pose[:, :3, 3] 66 | p = pose 67 | 68 | batch_size, num_samples, _ = uv.shape 69 | 70 | depth = torch.ones((batch_size, num_samples)).cuda() 71 | x_cam = uv[:, :, 0].view(batch_size, -1) 72 | y_cam = uv[:, :, 1].view(batch_size, -1) 73 | z_cam = depth.view(batch_size, -1) 74 | 75 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) 76 | 77 | # permute for batch matrix product 78 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1) 79 | 80 | # import pdb; pdb.set_trace(); 81 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] 82 | ray_dirs = world_coords - cam_loc[:, None, :] 83 | ray_dirs = F.normalize(ray_dirs, dim=2) 84 | 85 | return ray_dirs, cam_loc 86 | 87 | 88 | def get_camera_for_plot(pose): 89 | if pose.shape[1] == 7: #In case of quaternion vector representation 90 | cam_loc = pose[:, 4:].detach() 91 | R = quat_to_rot(pose[:,:4].detach()) 92 | else: # In case of pose matrix representation 93 | cam_loc = pose[:, :3, 3] 94 | R = pose[:, :3, :3] 95 | cam_dir = R[:, :3, 2] 96 | return cam_loc, cam_dir 97 | 98 | 99 | def lift(x, y, z, intrinsics): 100 | # parse intrinsics 101 | intrinsics = intrinsics.cuda() 102 | fx = intrinsics[:, 0, 0] 103 | fy = intrinsics[:, 1, 1] 104 | cx = intrinsics[:, 0, 2] 105 | cy = intrinsics[:, 1, 2] 106 | sk = intrinsics[:, 0, 1] 107 | 108 | x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z 109 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 110 | 111 | # homogeneous 112 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) 113 | 114 | 115 | def quat_to_rot(q): 116 | batch_size, _ = q.shape 117 | q = F.normalize(q, dim=1) 118 | R = torch.ones((batch_size, 3,3)).cuda() 119 | qr=q[:,0] 120 | qi = q[:, 1] 121 | qj = q[:, 2] 122 | qk = q[:, 3] 123 | R[:, 0, 0]=1-2 * (qj**2 + qk**2) 124 | R[:, 0, 1] = 2 * (qj *qi -qk*qr) 125 | R[:, 0, 2] = 2 * (qi * qk + qr * qj) 126 | R[:, 1, 0] = 2 * (qj * qi + qk * qr) 127 | R[:, 1, 1] = 1-2 * (qi**2 + qk**2) 128 | R[:, 1, 2] = 2*(qj*qk - qi*qr) 129 | R[:, 2, 0] = 2 * (qk * qi-qj * qr) 130 | R[:, 2, 1] = 2 * (qj*qk + qi*qr) 131 | R[:, 2, 2] = 1-2 * (qi**2 + qj**2) 132 | return R 133 | 134 | 135 | def rot_to_quat(R): 136 | batch_size, _,_ = R.shape 137 | q = torch.ones((batch_size, 4)).cuda() 138 | 139 | R00 = R[:, 0,0] 140 | R01 = R[:, 0, 1] 141 | R02 = R[:, 0, 2] 142 | R10 = R[:, 1, 0] 143 | R11 = R[:, 1, 1] 144 | R12 = R[:, 1, 2] 145 | R20 = R[:, 2, 0] 146 | R21 = R[:, 2, 1] 147 | R22 = R[:, 2, 2] 148 | 149 | q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2 150 | q[:, 1]=(R21-R12)/(4*q[:,0]) 151 | q[:, 2] = (R02 - R20) / (4 * q[:, 0]) 152 | q[:, 3] = (R10 - R01) / (4 * q[:, 0]) 153 | return q 154 | 155 | 156 | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0): 157 | # Input: n_rays x 3 ; n_rays x 3 158 | # Output: n_rays x 1, n_rays x 1 (close and far) 159 | 160 | ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3), 161 | cam_loc.view(-1, 3, 1)).squeeze(-1) 162 | under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2) 163 | 164 | # sanity check 165 | if (under_sqrt <= 0).sum() > 0: 166 | print('BOUNDING SPHERE PROBLEM!') 167 | exit() 168 | 169 | sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot 170 | sphere_intersections = sphere_intersections.clamp_min(0.0) 171 | 172 | return sphere_intersections 173 | -------------------------------------------------------------------------------- /code/utils/sem_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | COLOR_MAP = { 5 | 0: [0, 0, 0], 6 | 1: [204, 0, 0], 7 | 2: [76, 153, 0], 8 | 3: [204, 204, 0], 9 | 4: [51, 51, 255], 10 | 5: [204, 0, 204], 11 | 6: [0, 255, 255], 12 | 7: [255, 204, 204], 13 | 8: [102, 51, 0], 14 | 9: [255, 0, 0], 15 | 10: [102, 204, 0], 16 | 11: [255, 255, 0], 17 | 12: [0, 0, 153], 18 | 13: [0, 0, 204], 19 | 14: [255, 51, 153], 20 | 15: [0, 204, 204], 21 | 16: [0, 51, 0], 22 | 17: [255, 153, 51], 23 | 18: [0, 204, 0]} 24 | 25 | 26 | COLOR_MAP_COMPLETE = { 27 | 0: [0, 0, 0], 28 | 1: [204, 0, 0], 29 | 2: [76, 153, 0], 30 | 3: [204, 204, 0], 31 | 4: [51, 51, 255], 32 | 5: [204, 0, 204], 33 | 6: [0, 255, 255], 34 | 7: [255, 204, 204], 35 | 8: [102, 51, 0], 36 | 9: [255, 0, 0], 37 | 10: [102, 204, 0], 38 | 11: [255, 255, 0], 39 | 12: [0, 0, 153], 40 | 13: [0, 0, 204], 41 | 14: [255, 51, 153], 42 | 15: [0, 204, 204], 43 | 16: [0, 51, 0], 44 | 17: [255, 153, 51], 45 | 18: [0, 204, 0]} 46 | 47 | 48 | 49 | def mask2color(masks, is_argmax=True): 50 | if is_argmax is True: 51 | masks = torch.argmax(masks, dim=1).float() 52 | # import pdb; pdb.set_trace() 53 | masks = masks.squeeze(1) 54 | sample_mask = torch.zeros((masks.shape[0], masks.shape[1], masks.shape[2], 3), dtype=torch.float) 55 | for key in COLOR_MAP: 56 | sample_mask[masks==key] = torch.tensor(COLOR_MAP[key], dtype=torch.float) 57 | sample_mask = sample_mask.permute(0,3,1,2) 58 | return sample_mask -------------------------------------------------------------------------------- /media/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QianyiWu/objectsdf_plus/4c93f016d553ed3251a527ef1a32821cf53af90c/media/teaser.gif -------------------------------------------------------------------------------- /preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Data Preprocess 2 | Here we privode script to preprocess the data if you start from scratch. Take the ScanNet dataset as a example. You can download one scene from ScanNet dataset, including RGB images, campose poses/intrinsics and semantic/instance segmentations. Please refer to [ScanNet dataformat](https://github.com/ScanNet/ScanNet#data-organization) and [here](https://github.com/ScanNet/ScanNet/tree/master/SensReader/python) for more details. 3 | 4 | 5 | First, you need to prerun the Omnidata model (please install [omnidata model](https://github.com/EPFL-VILAB/omnidata) before running the command) to predict monocular cues for image. Note that the Omnidata is trained in 384*384, we follow MonoSDF to apply center-crop on original image to extract these information. 6 | ``` 7 | cd preprocess 8 | python extract_monocular_cues.py --task depth --img_path PATH_to_your_image --output_path PATH_to_SAVE --omnidata_path YOUR_OMNIDATA_PATH --pretrained_models PRETRAINED_MODELS 9 | python extract_monocular_cues.py --task normal --img_path PATH_to_your_image --output_path PATH_to_SAVE --omnidata_path YOUR_OMNIDATA_PATH --pretrained_models PRETRAINED_MODELS 10 | ``` 11 | 12 | Now you will get the monocular supervision. Then you need to organize the dataset to make it ready for training. 13 | 14 | ``` 15 | python scannet_to_objsdfpp.py 16 | ``` 17 | 18 | You can perform similar opeartion on Replica dataset to get the instance label mapping file. 19 | ``` 20 | python replica_to_objsdfpp.py 21 | ``` 22 | 23 | Here are some notes: 24 | 1. We merge some instance classes (such as ceiling, wall...) into the background. You can edit the `instance_mapping.txt` to define the objects you want. 25 | 2. The center and scale paramters mainly used for normalize the entire scene into a cube box. It is widely-used in many NeRF projects to obtain the camera poses for training. 26 | 3. We assume that the mask is view-consistent (i.e, the index of instance will not change). This can be done by a front-end segmentation algorithm (e.g. a video segmentation model). 27 | 28 | Please check these scripts for more details. -------------------------------------------------------------------------------- /preprocess/extract_monocular_cues.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/EPFL-VILAB/omnidata 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | 6 | import PIL 7 | from PIL import Image 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | import argparse 12 | import os.path 13 | from pathlib import Path 14 | import glob 15 | import sys 16 | 17 | 18 | parser = argparse.ArgumentParser(description='Visualize output for depth or surface normals') 19 | 20 | parser.add_argument('--omnidata_path', dest='omnidata_path', help="path to omnidata model") 21 | parser.set_defaults(omnidata_path='/media/hdd/omnidata/omnidata_tools/torch/') 22 | 23 | parser.add_argument('--pretrained_models', dest='pretrained_models', help="path to pretrained models") 24 | parser.set_defaults(pretrained_models='/media/hdd/omnidata/omnidata_tools/torch/pretrained_models/') 25 | 26 | parser.add_argument('--task', dest='task', help="normal or depth") 27 | parser.set_defaults(task='NONE') 28 | 29 | parser.add_argument('--img_path', dest='img_path', help="path to rgb image") 30 | parser.set_defaults(im_name='NONE') 31 | 32 | parser.add_argument('--output_path', dest='output_path', help="path to where output image should be stored") 33 | parser.set_defaults(store_name='NONE') 34 | 35 | args = parser.parse_args() 36 | 37 | root_dir = args.pretrained_models 38 | omnidata_path = args.omnidata_path 39 | 40 | sys.path.append(args.omnidata_path) 41 | print(sys.path) 42 | from modules.unet import UNet 43 | from modules.midas.dpt_depth import DPTDepthModel 44 | from data.transforms import get_transform 45 | 46 | trans_topil = transforms.ToPILImage() 47 | os.system(f"mkdir -p {args.output_path}") 48 | map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu') 49 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 50 | 51 | 52 | # get target task and model 53 | if args.task == 'normal': 54 | image_size = 384 55 | 56 | ## Version 1 model 57 | # pretrained_weights_path = root_dir + 'omnidata_unet_normal_v1.pth' 58 | # model = UNet(in_channels=3, out_channels=3) 59 | # checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 60 | 61 | # if 'state_dict' in checkpoint: 62 | # state_dict = {} 63 | # for k, v in checkpoint['state_dict'].items(): 64 | # state_dict[k.replace('model.', '')] = v 65 | # else: 66 | # state_dict = checkpoint 67 | 68 | 69 | pretrained_weights_path = root_dir + '/omnidata_dpt_normal_v2.ckpt' 70 | model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) # DPT Hybrid 71 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 72 | if 'state_dict' in checkpoint: 73 | state_dict = {} 74 | for k, v in checkpoint['state_dict'].items(): 75 | state_dict[k[6:]] = v 76 | else: 77 | state_dict = checkpoint 78 | 79 | model.load_state_dict(state_dict) 80 | model.to(device) 81 | trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 82 | transforms.CenterCrop(image_size), 83 | get_transform('rgb', image_size=None)]) 84 | 85 | elif args.task == 'depth': 86 | image_size = 384 87 | pretrained_weights_path = root_dir + '/omnidata_dpt_depth_v2.ckpt' # 'omnidata_dpt_depth_v1.ckpt' 88 | # model = DPTDepthModel(backbone='vitl16_384') # DPT Large 89 | model = DPTDepthModel(backbone='vitb_rn50_384') # DPT Hybrid 90 | checkpoint = torch.load(pretrained_weights_path, map_location=map_location) 91 | if 'state_dict' in checkpoint: 92 | state_dict = {} 93 | for k, v in checkpoint['state_dict'].items(): 94 | state_dict[k[6:]] = v 95 | else: 96 | state_dict = checkpoint 97 | model.load_state_dict(state_dict) 98 | model.to(device) 99 | trans_totensor = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 100 | transforms.CenterCrop(image_size), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=0.5, std=0.5)]) 103 | 104 | else: 105 | print("task should be one of the following: normal, depth") 106 | sys.exit() 107 | 108 | trans_rgb = transforms.Compose([transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 109 | transforms.CenterCrop(image_size), 110 | ]) 111 | 112 | 113 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1): 114 | if mask_valid is not None: 115 | img[~mask_valid] = torch.nan 116 | sorted_img = torch.sort(torch.flatten(img))[0] 117 | # Remove nan, nan at the end of sort 118 | num_nan = sorted_img.isnan().sum() 119 | if num_nan > 0: 120 | sorted_img = sorted_img[:-num_nan] 121 | # Remove outliers 122 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))] 123 | trunc_mean = trunc_img.mean() 124 | trunc_var = trunc_img.var() 125 | eps = 1e-6 126 | # Replace nan by mean 127 | img = torch.nan_to_num(img, nan=trunc_mean) 128 | # Standardize 129 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps) 130 | return img 131 | 132 | 133 | def save_outputs(img_path, output_file_name): 134 | with torch.no_grad(): 135 | save_path = os.path.join(args.output_path, f'{output_file_name}_{args.task}.png') 136 | 137 | print(f'Reading input {img_path} ...') 138 | img = Image.open(img_path) 139 | # import pdb; pdb.set_trace(); 140 | img_tensor = trans_totensor(img)[:3].unsqueeze(0).to(device) 141 | 142 | rgb_path = os.path.join(args.output_path, f'{output_file_name}_rgb.png') 143 | trans_rgb(img).save(rgb_path) 144 | 145 | if img_tensor.shape[1] == 1: 146 | img_tensor = img_tensor.repeat_interleave(3,1) 147 | 148 | output = model(img_tensor).clamp(min=0, max=1) 149 | 150 | if args.task == 'depth': 151 | #output = F.interpolate(output.unsqueeze(0), (512, 512), mode='bicubic').squeeze(0) 152 | output = output.clamp(0,1) 153 | 154 | np.save(save_path.replace('.png', '.npy'), output.detach().cpu().numpy()[0]) 155 | 156 | #output = 1 - output 157 | # output = standardize_depth_map(output) 158 | plt.imsave(save_path, output.detach().cpu().squeeze(),cmap='viridis') 159 | 160 | else: 161 | #import pdb; pdb.set_trace() 162 | np.save(save_path.replace('.png', '.npy'), output.detach().cpu().numpy()[0]) 163 | trans_topil(output[0]).save(save_path) 164 | 165 | print(f'Writing output {save_path} ...') 166 | 167 | 168 | img_path = Path(args.img_path) 169 | if img_path.is_file(): 170 | save_outputs(args.img_path, os.path.splitext(os.path.basename(args.img_path))[0]) 171 | elif img_path.is_dir(): 172 | for f in glob.glob(args.img_path+'/*'): 173 | save_outputs(f, os.path.splitext(os.path.basename(f))[0]) 174 | else: 175 | print("invalid file path!") 176 | sys.exit() 177 | -------------------------------------------------------------------------------- /preprocess/replica_to_objsdfpp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import os 5 | from scipy.spatial.transform import Slerp 6 | from scipy.interpolate import interp1d 7 | from scipy.spatial.transform import Rotation as R 8 | import json 9 | import trimesh 10 | import glob 11 | import PIL 12 | from PIL import Image 13 | from torchvision import transforms 14 | import matplotlib.pyplot as plt 15 | import imageio 16 | 17 | # For Replica dataset, we adopt the camera intrinsic/extrinsic/rgb/depth/normal from MonoSDF dataset 18 | # For instance label, we use the instance label from vMAP processed dataset. 19 | 20 | # map the instance segmentation result to semantic segmentation result 21 | 22 | image_size = 384 23 | # trans_totensor = transforms.Compose([ 24 | # transforms.CenterCrop(image_size), 25 | # transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 26 | # ]) 27 | # depth_trans_totensor = transforms.Compose([ 28 | # # transforms.Resize([680, 1200], interpolation=PIL.Image.NEAREST), 29 | # transforms.CenterCrop(image_size*2), 30 | # transforms.Resize(image_size, interpolation=PIL.Image.NEAREST), 31 | # ]) 32 | 33 | seg_trans_totensor = transforms.Compose([ 34 | transforms.CenterCrop(680), 35 | transforms.Resize(image_size, interpolation=PIL.Image.NEAREST), 36 | ]) 37 | 38 | out_path_prefix = '../data/replica/Replica/' 39 | data_root = lambda x: '/media/hdd/Replica-Dataset/vmap/{}/imap/00'.format(x) 40 | # scenes = ['scene0050_00', 'scene0084_00', 'scene0580_00', 'scene0616_00'] 41 | scenes = ['room_0', 'room_1', 'room_2', 'office_0', 'office_1', 'office_2', 'office_3', 'office_4'] 42 | out_names = ['scan1', 'scan2', 'scan3', 'scan4', 'scan5', 'scan6', 'scan7', 'scan8'] 43 | 44 | background_cls_list = [5,12,30,31,40,60,92,93,95,97,98,79] 45 | # merge more class into background 46 | background_cls_list.append(37) # door 47 | # background_cls_list.append(0) # undefined: 0 for this class, we mannully organize the data and the result can be found in each instance-mapping.txt 48 | background_cls_list.append(56) # panel 49 | background_cls_list.append(62) # pipe 50 | 51 | for scene, out_name in zip(scenes, out_names): 52 | # for scene, out_name in zip() 53 | out_path = os.path.join(out_path_prefix, out_name) 54 | os.makedirs(out_path, exist_ok=True) 55 | print(out_path) 56 | 57 | # folders = ["image", "mask", "depth", "segs"] 58 | folders = ['segs'] 59 | for folder in folders: 60 | out_folder = os.path.join(out_path, folder) 61 | os.makedirs(out_folder, exist_ok=True) 62 | 63 | # process segmentation 64 | segs_path = os.path.join(data_root(scene), 'semantic_instance') 65 | segs_paths = sorted(glob.glob(os.path.join(segs_path, 'semantic_instance_*.png')), 66 | key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1])) 67 | print(segs_paths) 68 | 69 | labels_path = os.path.join(data_root(scene), 'semantic_class') 70 | labels_paths = sorted(glob.glob(os.path.join(labels_path, 'semantic_class_*.png')), 71 | key = lambda x: int(os.path.basename(x).split('.')[0].split('_')[-1])) 72 | print(labels_paths) 73 | 74 | instance_semantic_mapping = {} 75 | instance_label_mapping = {} 76 | label_instance_counts = {} 77 | 78 | out_index = 0 79 | for idx, (seg_path, label_path) in enumerate(zip(segs_paths, labels_paths)): 80 | if idx % 20 !=0: continue 81 | print(idx, seg_path, label_path) 82 | # save segs 83 | target_image = os.path.join(out_path, 'segs', '{:06d}_segs.png'.format(out_index)) 84 | print(target_image) 85 | seg = Image.open(seg_path) 86 | seg_tensor = seg_trans_totensor(seg) 87 | seg_tensor.save(target_image) 88 | 89 | # label_mapping 90 | label_np = cv2.imread(label_path, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0) 91 | # if 30 in np.unique(_np): 92 | # import pdb; pdb.set_trace() 93 | # label_np = label_np. 94 | 95 | segs_np = cv2.imread(seg_path, cv2.IMREAD_UNCHANGED).astype(np.int32).transpose(1,0) 96 | insts = np.unique(segs_np) 97 | for inst_id in insts: 98 | inst_mask = segs_np == inst_id 99 | sem_cls = int(np.unique(label_np[inst_mask])) 100 | # import pdb; pdb.set_trace() 101 | instance_semantic_mapping[inst_id] = sem_cls 102 | if sem_cls in background_cls_list: 103 | instance_semantic_mapping[inst_id] = 0 104 | instance_label_mapping[inst_id] = 0 105 | # assert sem_cls.shape[0]!=0 106 | elif sem_cls in label_instance_counts: 107 | if inst_id not in instance_label_mapping: 108 | inst_count = label_instance_counts[sem_cls] + 1 109 | label_instance_counts[sem_cls] = inst_count 110 | # chaneg the instance label mapping index to 100*label + inst_count 111 | instance_label_mapping[inst_id] = sem_cls * 100 + inst_count 112 | else: 113 | continue # already saved 114 | else: 115 | inst_count = 1 116 | label_instance_counts[sem_cls] = inst_count 117 | instance_label_mapping[inst_id] = sem_cls * 100 + inst_count 118 | 119 | out_index += 1 120 | 121 | # save the instance mapping file to output path 122 | print({k: v for k, v in sorted(instance_label_mapping.items())}) 123 | with open(os.path.join(out_path, 'instance_mapping_new.txt'), 'w') as f: 124 | # f.write(str(sorted(label_set)).strip('[').strip(']')) 125 | for k, v in instance_label_mapping.items(): 126 | # instance_id, semantic_label, updated_instance_label (according to the number in this semantic class) 127 | f.write(str(k)+','+str(instance_semantic_mapping[k])+','+str(v)+'\n') 128 | f.close() 129 | 130 | #np.savez(os.path.join(out_path, "cameras_sphere.npz"), **cameras) 131 | # np.savez(os.path.join(out_path, "cameras.npz"), **cameras) 132 | -------------------------------------------------------------------------------- /preprocess/scannet_to_objsdfpp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | import os 5 | from scipy.spatial.transform import Slerp 6 | from scipy.interpolate import interp1d 7 | from scipy.spatial.transform import Rotation as R 8 | import json 9 | import trimesh 10 | import glob 11 | import PIL 12 | from PIL import Image 13 | from torchvision import transforms 14 | import matplotlib.pyplot as plt 15 | import imageio 16 | 17 | import csv 18 | # Code from ScanNet script to convert instance images from the *_2d-instance.zip or *_2d-instance-filt.zip data for each scan. 19 | def read_label_mapping(filename, label_from='id', label_to='nyu40id'): 20 | assert os.path.isfile(filename) 21 | mapping = dict() 22 | with open(filename) as csvfile: 23 | reader = csv.DictReader(csvfile, delimiter='\t') 24 | for row in reader: 25 | mapping[row[label_from]] = int(row[label_to]) 26 | # if ints convert 27 | # if represents_int(mapping.keys()[0]): 28 | mapping = {int(k):v for k,v in mapping.items()} 29 | return mapping 30 | 31 | def map_label_image(image, label_mapping): 32 | mapped = np.copy(image) 33 | for k,v in label_mapping.items(): 34 | mapped[image==k] = v 35 | # merge some label like bg, wall, ceiling, floor 36 | # bg: 0, wall: 1, floor: 2, ceiling: 22, door: 8 37 | mapped[mapped==1] = 0 38 | mapped[mapped==2] = 0 39 | mapped[mapped==22] = 0 40 | # add windows 41 | mapped[mapped==9] = 0 42 | # add door 43 | mapped[mapped==8] = 0 44 | # add mirror; curtain to 0 45 | mapped[mapped==19] = 0 # mirror 46 | mapped[mapped==16] = 0 # curtain 47 | return mapped.astype(np.uint8) 48 | 49 | def make_instance_image(label_image, instance_image): 50 | output = np.zeros_like(instance_image, dtype=np.uint16) 51 | # oldinst2labelinst = {} 52 | label_instance_counts = {} 53 | old_insts = np.unique(instance_image) 54 | for inst in old_insts: 55 | label = label_image[instance_image==inst][0] 56 | if label in label_instance_counts and label !=0: 57 | inst_count = label_instance_counts[label] + 1 58 | label_instance_counts[label] = inst_count 59 | else: 60 | inst_count = 1 61 | label_instance_counts[label] = inst_count 62 | # oldinst2labelinst[inst] = (label, inst_count) 63 | output[instance_image==inst] = label * 1000 + inst_count 64 | return output 65 | 66 | image_size = 384 67 | trans_totensor = transforms.Compose([ 68 | transforms.CenterCrop(image_size*2), 69 | transforms.Resize(image_size, interpolation=PIL.Image.BILINEAR), 70 | ]) 71 | depth_trans_totensor = transforms.Compose([ 72 | transforms.Resize([968, 1296], interpolation=PIL.Image.NEAREST), 73 | transforms.CenterCrop(image_size*2), 74 | transforms.Resize(image_size, interpolation=PIL.Image.NEAREST), 75 | ]) 76 | 77 | seg_trans_totensor = transforms.Compose([ 78 | transforms.CenterCrop(image_size*2), 79 | transforms.Resize(image_size, interpolation=PIL.Image.NEAREST), 80 | ]) 81 | 82 | 83 | 84 | out_path_prefix = '../data/custom' 85 | data_root = '/media/hdd/monodata_ours/scans' 86 | scenes = ['scene0050_00', 'scene0084_00', 'scene0580_00', 'scene0616_00'] 87 | # scenes = ['scene0050_00'] 88 | # out_names = ['scan1'] 89 | out_names = ['scan1', 'scan2', 'scan3', 'scan4'] 90 | 91 | label_map = read_label_mapping('/media/hdd/scannet/tasks/scannetv2-labels.combined.tsv') # path to scannet_labels.csv 92 | 93 | for scene, out_name in zip(scenes, out_names): 94 | out_path = os.path.join(out_path_prefix, out_name) 95 | os.makedirs(out_path, exist_ok=True) 96 | print(out_path) 97 | 98 | folders = ["image", "mask", "depth", "segs"] 99 | for folder in folders: 100 | out_folder = os.path.join(out_path, folder) 101 | os.makedirs(out_folder, exist_ok=True) 102 | 103 | # load color 104 | color_path = os.path.join(data_root, scene, 'color') 105 | color_paths = sorted(glob.glob(os.path.join(color_path, '*.jpg')), 106 | key=lambda x: int(os.path.basename(x)[:-4])) 107 | print(color_paths) 108 | 109 | # load depth 110 | depth_path = os.path.join(data_root, scene, 'depth') 111 | depth_paths = sorted(glob.glob(os.path.join(depth_path, '*.png')), 112 | key=lambda x: int(os.path.basename(x)[:-4])) 113 | print(depth_paths) 114 | 115 | segs_path = os.path.join(data_root, scene, 'segs', 'instance-filt') 116 | segs_paths = sorted(glob.glob(os.path.join(segs_path, '*.png')), 117 | key=lambda x: int(os.path.basename(x)[:-4])) 118 | print(segs_paths) 119 | 120 | labels_path = os.path.join(data_root, scene, 'segs', 'label-filt') 121 | labels_paths = sorted(glob.glob(os.path.join(labels_path, '*.png')), 122 | key = lambda x: int(os.path.basename(x)[:-4])) 123 | print(labels_paths) 124 | 125 | # load intrinsic 126 | intrinsic_path = os.path.join(data_root, scene, 'intrinsic', 'intrinsic_color.txt') 127 | camera_intrinsic = np.loadtxt(intrinsic_path) 128 | print(camera_intrinsic) 129 | 130 | # load pose 131 | pose_path = os.path.join(data_root, scene, 'pose') 132 | poses = [] 133 | pose_paths = sorted(glob.glob(os.path.join(pose_path, '*.txt')), 134 | key=lambda x: int(os.path.basename(x)[:-4])) 135 | for pose_path in pose_paths: 136 | c2w = np.loadtxt(pose_path) 137 | poses.append(c2w) 138 | poses = np.array(poses) 139 | 140 | # deal with invalid poses 141 | valid_poses = np.isfinite(poses).all(axis=2).all(axis=1) 142 | min_vertices = poses[:, :3, 3][valid_poses].min(axis=0) 143 | max_vertices = poses[:, :3, 3][valid_poses].max(axis=0) 144 | 145 | center = (min_vertices + max_vertices) / 2. 146 | scale = 2. / (np.max(max_vertices - min_vertices) + 3.) 147 | print(center, scale) 148 | 149 | # we should normalized to unit cube 150 | scale_mat = np.eye(4).astype(np.float32) 151 | scale_mat[:3, 3] = -center 152 | scale_mat[:3 ] *= scale 153 | scale_mat = np.linalg.inv(scale_mat) 154 | 155 | # copy image 156 | out_index = 0 157 | cameras = {} 158 | pcds = [] 159 | H, W = 968, 1296 160 | 161 | # center crop by 2 * image_size 162 | offset_x = (W - image_size * 2) * 0.5 163 | offset_y = (H - image_size * 2) * 0.5 164 | camera_intrinsic[0, 2] -= offset_x 165 | camera_intrinsic[1, 2] -= offset_y 166 | # resize from 384*2 to 384 167 | resize_factor = 0.5 168 | camera_intrinsic[:2, :] *= resize_factor 169 | 170 | K = camera_intrinsic 171 | print(K) 172 | 173 | instance_semantic_mapping = {} 174 | instance_label_mapping = {} 175 | label_instance_counts = {} 176 | 177 | for idx, (valid, pose, depth_path, image_path, seg_path, label_path) in enumerate(zip(valid_poses, poses, depth_paths, color_paths, segs_paths, labels_paths)): 178 | print(idx, valid) 179 | if idx % 10 != 0: continue 180 | if not valid : continue 181 | 182 | target_image = os.path.join(out_path, "image/%06d.png"%(out_index)) 183 | print(target_image) 184 | img = Image.open(image_path) 185 | img_tensor = trans_totensor(img) 186 | img_tensor.save(target_image) 187 | 188 | mask = (np.ones((image_size, image_size, 3)) * 255.).astype(np.uint8) 189 | 190 | target_image = os.path.join(out_path, "mask/%06d_mask.png"%(out_index)) 191 | cv2.imwrite(target_image, mask) 192 | 193 | # load depth 194 | target_image = os.path.join(out_path, "depth/%06d_depth.png"%(out_index)) 195 | depth = cv2.imread(depth_path, -1).astype(np.float32) / 1000. 196 | #import pdb; pdb.set_trace() 197 | depth_PIL = Image.fromarray(depth) 198 | new_depth = depth_trans_totensor(depth_PIL) 199 | new_depth = np.asarray(new_depth) 200 | plt.imsave(target_image, new_depth, cmap='viridis') 201 | np.save(target_image.replace(".png", ".npy"), new_depth) 202 | 203 | # segs 204 | target_image = os.path.join(out_path, "segs/%06d_segs.png"%(out_index)) 205 | print(target_image) 206 | seg = Image.open(seg_path) 207 | # seg_tensor = trans_totensor(seg) 208 | seg_tensor = seg_trans_totensor(seg) 209 | seg_tensor.save(target_image) 210 | # np.save(target_image) 211 | 212 | # label_mapping 213 | # label = Image.open(label_path) 214 | # label_tensor = trans_totensor(label) 215 | label_np = imageio.imread(label_path) 216 | segs_np = imageio.imread(seg_path) 217 | # import pdb; pdb.set_trace() 218 | mapped_labels = map_label_image(label_np, label_map) 219 | old_insts = np.unique(segs_np) 220 | for inst in old_insts: 221 | label = mapped_labels[segs_np==inst][0] 222 | # import pdb; pdb.set_trace() 223 | instance_semantic_mapping[inst] = label 224 | if label == 0: 225 | instance_label_mapping[inst] = 0 226 | elif label in label_instance_counts: 227 | if inst not in instance_label_mapping: 228 | inst_count = label_instance_counts[label] + 1 # add the number of counting for one label 229 | label_instance_counts[label] = inst_count 230 | # change the instance label mapping index of this instance to 1000*label + counted number 231 | instance_label_mapping[inst] = label * 1000 + inst_count 232 | else: 233 | continue 234 | else: 235 | inst_count = 1 236 | label_instance_counts[label] = inst_count # this label never exist before, so add the inst_count as 1 and put it in label_instance_conuts 237 | instance_label_mapping[inst] = label * 1000 + inst_count 238 | 239 | 240 | 241 | 242 | # save pose 243 | pcds.append(pose[:3, 3]) 244 | pose = K @ np.linalg.inv(pose) 245 | 246 | #cameras["scale_mat_%d"%(out_index)] = np.eye(4).astype(np.float32) 247 | cameras["scale_mat_%d"%(out_index)] = scale_mat 248 | cameras["world_mat_%d"%(out_index)] = pose 249 | 250 | out_index += 1 251 | 252 | # save the instance mapping file to output path 253 | print({k: v for k, v in sorted(instance_label_mapping.items())}) 254 | with open(os.path.join(out_path, 'instance_mapping.txt'), 'w') as f: 255 | # f.write(str(sorted(label_set)).strip('[').strip(']')) 256 | for k, v in instance_label_mapping.items(): 257 | f.write(str(k)+','+str(instance_semantic_mapping[k])+','+str(v)+'\n') 258 | f.close() 259 | 260 | #np.savez(os.path.join(out_path, "cameras_sphere.npz"), **cameras) 261 | np.savez(os.path.join(out_path, "cameras.npz"), **cameras) 262 | -------------------------------------------------------------------------------- /replica_eval/avg_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | root_dir = './evaluation_results' # path to the evaluation results folder 6 | 7 | all_metric = [] 8 | for fname in os.listdir(root_dir): 9 | metric_result = np.load(os.path.join(root_dir, fname, 'metrics_3D_obj.npy')) 10 | if np.isnan(metric_result.any()): 11 | print(fname) 12 | # print(metric_result.shape) 13 | all_metric.append(metric_result) 14 | print(fname, 0.5*(metric_result.mean(1)[0] + metric_result.mean(1)[1])) 15 | result = np.concatenate(all_metric, 1) 16 | # # recompute the f-score from precision and recall 17 | precision_1, completion_1 = result[4], result[2] 18 | _sum = precision_1 + completion_1 19 | _prod = 2*precision_1 * completion_1 20 | 21 | 22 | print('var F_score 5cm: {}'.format(result.var(1)[7])) 23 | print('var Chamer: {}'.format((0.5*(result[0]+result[1])).var(0))) 24 | # print('chamfer mean each scene: {}'.format(0.5*(result.mean(1)+result.mean(2)))) 25 | print('Acc mean: {}, Comp: {}, chamfer: {}, Ratio 1cm: {}, Ratio 5cm: {}, F_score 1cm: {}, F_score 5cm: {}'.format(result.mean(1)[0], result.mean(1)[1], 0.5*(result.mean(1)[0]+result.mean(1)[1]), result.mean(1)[2], result.mean(1)[3], result.mean(1)[6], result.mean(1)[7])) 26 | -------------------------------------------------------------------------------- /replica_eval/cull_mesh.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/cvg/nice-slam 2 | import os 3 | import numpy as np 4 | import argparse 5 | import pickle 6 | import os 7 | import glob 8 | import open3d as o3d 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import trimesh 12 | 13 | 14 | def load_poses(path): 15 | poses = [] 16 | with open(path, "r") as f: 17 | lines = f.readlines() 18 | for line in lines: 19 | c2w = np.array(list(map(float, line.split()))).reshape(4, 4) 20 | c2w[:3, 1] *= -1 21 | c2w[:3, 2] *= -1 22 | c2w = torch.from_numpy(c2w).float() 23 | poses.append(c2w) 24 | return poses 25 | 26 | 27 | parser = argparse.ArgumentParser( 28 | description='Arguments to cull the mesh.' 29 | ) 30 | 31 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be culled') 32 | parser.add_argument('--input_scalemat', type=str, help='path to the scale mat') 33 | parser.add_argument('--traj', type=str, help='path to the trajectory') 34 | parser.add_argument('--output_mesh', type=str, help='path to the output mesh') 35 | args = parser.parse_args() 36 | 37 | H = 680 38 | W = 1200 39 | fx = 600.0 40 | fy = 600.0 41 | fx = 600.0 42 | cx = 599.5 43 | cy = 339.5 44 | # scale = 6553.5 45 | 46 | poses = load_poses(args.traj) 47 | n_imgs = len(poses) 48 | mesh = trimesh.load(args.input_mesh, process=False) 49 | # mesh.export(args.output_mesh+'_raw.ply') 50 | 51 | # transform to original coordinate system with scale mat 52 | if args.input_scalemat is not None: 53 | scalemat = np.load(args.input_scalemat)['scale_mat_0'] 54 | mesh.vertices = mesh.vertices @ scalemat[:3, :3].T + scalemat[:3, 3] 55 | else: 56 | # print('not input scalemat') 57 | scalemat = np.eye(4) 58 | mesh.vertices = mesh.vertices @ scalemat[:3, :3].T + scalemat[:3, 3] 59 | 60 | pc = mesh.vertices 61 | faces = mesh.faces 62 | 63 | 64 | # delete mesh vertices that are not inside any camera's viewing frustum 65 | whole_mask = np.ones(pc.shape[0]).astype(bool) 66 | for i in range(0, n_imgs, 1): 67 | c2w = poses[i] 68 | points = pc.copy() 69 | points = torch.from_numpy(points).cuda() 70 | w2c = np.linalg.inv(c2w) 71 | w2c = torch.from_numpy(w2c).cuda().float() 72 | K = torch.from_numpy( 73 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda() 74 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda() 75 | homo_points = torch.cat( 76 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float() 77 | cam_cord_homo = w2c@homo_points 78 | cam_cord = cam_cord_homo[:, :3] 79 | 80 | cam_cord[:, 0] *= -1 81 | uv = K.float()@cam_cord.float() 82 | z = uv[:, -1:]+1e-5 83 | uv = uv[:, :2]/z 84 | uv = uv.float().squeeze(-1).cpu().numpy() 85 | edge = 0 86 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W - 87 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge) 88 | whole_mask &= ~mask 89 | pc = mesh.vertices 90 | faces = mesh.faces 91 | face_mask = whole_mask[mesh.faces].all(axis=1) 92 | mesh.update_faces(~face_mask) 93 | mesh.export(args.output_mesh) 94 | -------------------------------------------------------------------------------- /replica_eval/cull_obj_gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | scans = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 6 | 7 | gt_data_dir = '/media/hdd/Replica-Dataset/vmap/' 8 | 9 | for idx, exp in enumerate(scans): 10 | idx = idx + 1 11 | folder_name = os.path.join(gt_data_dir, exp[:-1]+'_'+exp[-1], 'habitat') 12 | files = list(filter(os.path.isfile, glob.glob(os.path.join(folder_name, 'mesh_semantic.*.ply')))) 13 | # print(files) 14 | # exit() 15 | print(files[0].split('/')[-1]) 16 | os.makedirs(os.path.join(folder_name, 'cull_object_mesh'), exist_ok=True) 17 | for name in files: 18 | cull_mesh_out = os.path.join(folder_name, 'cull_object_mesh', 'cull_'+name.split('/')[-1]) 19 | # print(cull_mesh_out) 20 | cmd = f"python cull_mesh.py --input_mesh {name} --traj ../data/replica/Replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}" 21 | print(cmd) 22 | os.system(cmd) 23 | # exit() 24 | -------------------------------------------------------------------------------- /replica_eval/eval_3D_obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import trimesh 4 | from metrics import accuracy, completion, completion_ratio 5 | import os 6 | import json 7 | import glob 8 | 9 | def calc_3d_metric(mesh_rec, mesh_gt, N=200000): 10 | """ 11 | 3D reconstruction metric. 12 | """ 13 | metrics = [[] for _ in range(8)] 14 | transform, extents = trimesh.bounds.oriented_bounds(mesh_gt) 15 | extents = extents / 0.9 if N != 200000 else extents # enlarge 0.9 for objects 16 | # extents = extents *1.2 if N != 200000 else extents # enlarge 0.9 for objects 17 | box = trimesh.creation.box(extents=extents, transform=np.linalg.inv(transform)) 18 | mesh_rec = mesh_rec.slice_plane(box.facets_origin, -box.facets_normal) 19 | if mesh_rec.vertices.shape[0] == 0: 20 | print("no mesh found") 21 | return 22 | rec_pc = trimesh.sample.sample_surface(mesh_rec, N) 23 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0]) 24 | 25 | gt_pc = trimesh.sample.sample_surface(mesh_gt, N) 26 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) 27 | accuracy_rec = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices) 28 | completion_rec = completion(gt_pc_tri.vertices, rec_pc_tri.vertices) 29 | completion_ratio_rec = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.05) 30 | completion_ratio_rec_1 = completion_ratio(gt_pc_tri.vertices, rec_pc_tri.vertices, 0.01) 31 | 32 | precision_ratio_rec = completion_ratio(rec_pc_tri.vertices, gt_pc_tri.vertices, 0.05) 33 | precision_ratio_rec_1 = completion_ratio(rec_pc_tri.vertices, gt_pc_tri.vertices, 0.01) 34 | 35 | f_score = 2*precision_ratio_rec*completion_ratio_rec / (completion_ratio_rec + precision_ratio_rec) 36 | f_score_1 = 2 * precision_ratio_rec_1*completion_ratio_rec_1 / (completion_ratio_rec_1 + precision_ratio_rec_1) 37 | 38 | 39 | # accuracy_rec *= 100 # convert to cm 40 | # completion_rec *= 100 # convert to cm 41 | # completion_ratio_rec *= 100 # convert to % 42 | # print('accuracy: ', accuracy_rec) 43 | # print('completion: ', completion_rec) 44 | # print('completion ratio: ', completion_ratio_rec) 45 | # print("completion_ratio_rec_1cm ", completion_ratio_rec_1) 46 | metrics[0].append(accuracy_rec) 47 | metrics[1].append(completion_rec) 48 | metrics[2].append(completion_ratio_rec_1) 49 | metrics[3].append(completion_ratio_rec) 50 | metrics[4].append(precision_ratio_rec_1) 51 | metrics[5].append(precision_ratio_rec) 52 | metrics[6].append(np.nan_to_num(f_score_1)) 53 | metrics[7].append(np.nan_to_num(f_score)) 54 | 55 | return metrics 56 | 57 | def get_gt_bg_mesh(gt_dir, background_cls_list): 58 | with open(os.path.join(gt_dir, "info_semantic.json")) as f: 59 | label_obj_list = json.load(f)["objects"] 60 | 61 | bg_meshes = [] 62 | for obj in label_obj_list: 63 | if int(obj["class_id"]) in background_cls_list: 64 | obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(int(obj["id"])) + ".ply") 65 | obj_mesh = trimesh.load(obj_file) 66 | bg_meshes.append(obj_mesh) 67 | 68 | bg_mesh = trimesh.util.concatenate(bg_meshes) 69 | return bg_mesh 70 | 71 | def get_obj_ids(obj_dir): 72 | files = os.listdir(obj_dir) 73 | obj_ids = [] 74 | for f in files: 75 | obj_id = f.split("obj")[1][:-1] 76 | if obj_id == '': 77 | continue 78 | obj_ids.append(int(obj_id)) 79 | return obj_ids 80 | 81 | def get_obj_ids_ours(obj_dir): 82 | files = list(filter(os.path.isfile, glob.glob(os.path.join(obj_dir, '*[0-9].ply')))) 83 | epoch_count = set([int(os.path.basename(f).split('_')[1]) for f in files]) 84 | max_epoch = max(epoch_count) 85 | # print(max_epoch) 86 | obj_file = [int(os.path.basename(f).split('.')[0].split('_')[2]) for f in files if int(os.path.basename(f).split('_')[1])==max_epoch] 87 | return sorted(obj_file), max_epoch 88 | 89 | 90 | 91 | def get_gt_mesh_from_objid(gt_dir, mapping_list): 92 | combined_meshes = [] 93 | if len(mapping_list) == 1: 94 | return trimesh.load(os.path.join(gt_dir, 'cull_mesh_semantic.ply_'+str(mapping_list[0])+'.ply')) 95 | else: 96 | for idx in mapping_list: 97 | if os.path.isfile(os.path.join(gt_dir, 'cull_mesh_semantic.ply_'+str(int(idx))+'.ply')): 98 | obj_file = os.path.join(gt_dir, 'cull_mesh_semantic.ply_'+str(int(idx))+'.ply') 99 | obj_mesh = trimesh.load(obj_file) 100 | combined_meshes.append(obj_mesh) 101 | else: 102 | continue 103 | combine_mesh = trimesh.util.concatenate(combined_meshes) 104 | return combine_mesh 105 | 106 | 107 | 108 | if __name__ == "__main__": 109 | exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 110 | data_dir = "/media/hdd/Replica-Dataset/vmap" # where to store the data 111 | log_dir = './evaluation_results' # where to store the evaluation results 112 | info_dir = '../data/replica' # path to dataset information 113 | mesh_rec_root = "../exps/objectsdfplus_replica" # path to the reconstruction results 114 | os.makedirs(log_dir, exist_ok=True) 115 | for idx, exp in enumerate(exp_name): 116 | if exp is not "room2": continue 117 | idx = idx + 1 118 | gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat") 119 | info_text = os.path.join(info_dir, 'scan'+str(idx), 'instance_mapping.txt') 120 | # mesh_dir = os.path.join() 121 | # mesh_rec_dir = os.path.join('') # path to reconstruction results 122 | # get the lastest folder for evaluation 123 | dirs = sorted(os.listdir(mesh_rec_root+f'_{idx}')) 124 | mesh_rec_dir = os.path.join(mesh_rec_root+f'_{idx}', dirs[-1], "plots") 125 | print(mesh_rec_dir) 126 | 127 | output_dir = os.path.join(log_dir, exp+'_{}'.format(mesh_rec_root.split('/')[-1])) 128 | os.makedirs(output_dir, exist_ok=True) 129 | metrics_3D = [[] for _ in range(8)] 130 | 131 | # only calculate the valid mesh in experiment 132 | instance_mapping = {} 133 | with open(info_text, 'r') as f: 134 | for l in f: 135 | (k, v_sem, v_ins) = l.split(',') 136 | instance_mapping[int(k)] = int(v_ins) 137 | label_mapping = sorted(set(instance_mapping.values())) 138 | # print(label_mapping) 139 | # get all valid obj index 140 | obj_ids, max_epoch = get_obj_ids_ours(mesh_rec_dir) 141 | # print(obj_ids) 142 | for obj_id in tqdm(obj_ids): 143 | inst_id = label_mapping[obj_id] 144 | # merge the gt mesh with the same instance_id 145 | gt_inst_list = [] # a list used to store the index in gt_object mesh that are the same object defined by instance_mapping 146 | for k, v in instance_mapping.items(): 147 | if v == inst_id: 148 | gt_inst_list.append(k) 149 | mesh_gt = get_gt_mesh_from_objid(os.path.join(gt_dir, 'cull_object_mesh'), gt_inst_list) 150 | if obj_id == 0: 151 | N=200000 152 | else: 153 | N=10000 154 | 155 | # in order to evaluate the result, we need to cull the object mesh first and then evaluate the metric 156 | # mesh_rec = trimesh.load(os.path.join(mesh_dir, 'surface_{max_epoch}_{obj_id}.ply')) 157 | cull_rec_mesh_path = os.path.join(output_dir, f"{exp}_cull_surface_{max_epoch}_{obj_id}.ply") 158 | 159 | rec_mesh = os.path.join(mesh_rec_dir, f'surface_{max_epoch}_{obj_id}.ply') 160 | # print(rec_mesh) 161 | cmd = f"python cull_mesh.py --input_mesh {rec_mesh} --input_scalemat ../data/replica/scan{idx}/cameras.npz --traj ../data/replica/scan{idx}/traj.txt --output_mesh {cull_rec_mesh_path}" 162 | os.system(cmd) 163 | # evaluate the metric 164 | mesh_rec = trimesh.load(cull_rec_mesh_path) 165 | # use the biggest connected component for evaluation. 166 | 167 | metrics = calc_3d_metric(mesh_rec, mesh_gt, N=N) 168 | if metrics is None: 169 | continue 170 | np.save(output_dir + '/metric_obj{}.npy'.format(obj_id), np.array(metrics)) 171 | metrics_3D[0].append(metrics[0]) # acc 172 | metrics_3D[1].append(metrics[1]) # comp 173 | metrics_3D[2].append(metrics[2]) # comp ratio 1cm 174 | metrics_3D[3].append(metrics[3]) # comp ratio 5cm 175 | metrics_3D[4].append(metrics[4]) # precision ratio 1cm 176 | metrics_3D[5].append(metrics[5]) # precision ration 5cm 177 | metrics_3D[6].append(metrics[6]) # f_score 1 178 | metrics_3D[7].append(metrics[7]) # f_score 5cm 179 | metrics_3D = np.array(metrics_3D) 180 | np.save(output_dir + '/metrics_3D_obj.npy', metrics_3D) 181 | print("metrics 3D obj \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n", metrics_3D.mean(axis=1)) 182 | print("-----------------------------------------") 183 | print("finish exp ", exp) 184 | # exit() 185 | 186 | 187 | 188 | 189 | # use culled object mesh for evaluation 190 | 191 | 192 | 193 | 194 | 195 | 196 | # background_cls_list = [5, 12, 30, 31, 40, 60, 92, 93, 95, 97, 98, 79] 197 | # exp_name = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 198 | # # data_dir = "/home/xin/data/vmap/" 199 | # data_dir = "/media/hdd/Replica-Dataset/vmap/" 200 | # # log_dir = "../logs/iMAP/" 201 | # # log_dir = "../logs/vMAP/" 202 | # log_dir = "/media/hdd/Replica-Dataset/vmap/vMAP_Replica_Results/" 203 | 204 | # for exp in tqdm(exp_name): 205 | # gt_dir = os.path.join(data_dir, exp[:-1]+"_"+exp[-1]+"/habitat") 206 | # exp_dir = os.path.join(log_dir, exp+'_vmap') 207 | # mesh_dir = os.path.join(exp_dir, "scene_mesh") 208 | # output_path = os.path.join(exp_dir, "eval_mesh") 209 | # os.makedirs(output_path, exist_ok=True) 210 | # metrics_3D = [[] for _ in range(4)] 211 | 212 | # # get obj ids 213 | # # obj_ids = np.loadtxt() # todo use a pre-defined obj list or use vMAP results 214 | # obj_ids = get_obj_ids(mesh_dir.replace("imap", "vmap")) 215 | # for obj_id in tqdm(obj_ids): 216 | # if obj_id == 0: # for bg 217 | # N = 200000 218 | # mesh_gt = get_gt_bg_mesh(gt_dir, background_cls_list) 219 | # else: # for obj 220 | # N = 10000 221 | # obj_file = os.path.join(gt_dir, "mesh_semantic.ply_" + str(obj_id) + ".ply") 222 | # mesh_gt = trimesh.load(obj_file) 223 | 224 | # if "vMAP" in exp_dir: 225 | # rec_meshfile = os.path.join(mesh_dir, "imap_frame2000_obj"+str(obj_id)+".obj") 226 | # # rec_meshfile = os.path.join(mesh_dir, ) 227 | # elif "iMAP" in exp_dir: 228 | # rec_meshfile = os.path.join(mesh_dir, "frame_1999_obj0.obj") 229 | # else: 230 | # print("Not Implement") 231 | # exit(-1) 232 | 233 | # mesh_rec = trimesh.load(rec_meshfile) 234 | # # mesh_rec.invert() # niceslam mesh face needs invert 235 | # metrics = calc_3d_metric(mesh_rec, mesh_gt, N=N) # for objs use 10k, for scene use 200k points 236 | # if metrics is None: 237 | # continue 238 | # np.save(output_path + '/metric_obj{}.npy'.format(obj_id), np.array(metrics)) 239 | # metrics_3D[0].append(metrics[0]) # acc 240 | # metrics_3D[1].append(metrics[1]) # comp 241 | # metrics_3D[2].append(metrics[2]) # comp ratio 1cm 242 | # metrics_3D[3].append(metrics[3]) # comp ratio 5cm 243 | # metrics_3D = np.array(metrics_3D) 244 | # np.save(output_path + '/metrics_3D_obj.npy', metrics_3D) 245 | # print("metrics 3D obj \n Acc | Comp | Comp Ratio 1cm | Comp Ratio 5cm \n", metrics_3D.mean(axis=1)) 246 | # print("-----------------------------------------") 247 | # print("finish exp ", exp) 248 | 249 | # calculate the avaerage result over all 8 scenes 250 | -------------------------------------------------------------------------------- /replica_eval/eval_recon.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/cvg/nice-slam 2 | import argparse 3 | import random 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | import trimesh 9 | from scipy.spatial import cKDTree as KDTree 10 | import cv2 11 | 12 | def normalize(x): 13 | return x / np.linalg.norm(x) 14 | 15 | 16 | def viewmatrix(z, up, pos): 17 | vec2 = normalize(z) 18 | vec1_avg = up 19 | vec0 = normalize(np.cross(vec1_avg, vec2)) 20 | vec1 = normalize(np.cross(vec2, vec0)) 21 | m = np.stack([vec0, vec1, vec2, pos], 1) 22 | return m 23 | 24 | 25 | def completion_ratio(gt_points, rec_points, dist_th=0.05): 26 | gen_points_kd_tree = KDTree(rec_points) 27 | distances, _ = gen_points_kd_tree.query(gt_points) 28 | comp_ratio = np.mean((distances < dist_th).astype(np.float)) 29 | return comp_ratio 30 | 31 | 32 | def accuracy(gt_points, rec_points): 33 | gt_points_kd_tree = KDTree(gt_points) 34 | distances, _ = gt_points_kd_tree.query(rec_points) 35 | acc = np.mean(distances) 36 | return acc, distances 37 | 38 | 39 | def completion(gt_points, rec_points): 40 | gt_points_kd_tree = KDTree(rec_points) 41 | distances, _ = gt_points_kd_tree.query(gt_points) 42 | comp = np.mean(distances) 43 | return comp, distances 44 | 45 | def write_vis_pcd(file, points, colors): 46 | pcd = o3d.geometry.PointCloud() 47 | pcd.points = o3d.utility.Vector3dVector(points) 48 | pcd.colors = o3d.utility.Vector3dVector(colors) 49 | o3d.io.write_point_cloud(file, pcd) 50 | 51 | def get_align_transformation(rec_meshfile, gt_meshfile): 52 | """ 53 | Get the transformation matrix to align the reconstructed mesh to the ground truth mesh. 54 | """ 55 | o3d_rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 56 | o3d_gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 57 | o3d_rec_pc = o3d.geometry.PointCloud(points=o3d_rec_mesh.vertices) 58 | o3d_gt_pc = o3d.geometry.PointCloud(points=o3d_gt_mesh.vertices) 59 | trans_init = np.eye(4) 60 | threshold = 0.1 61 | reg_p2p = o3d.pipelines.registration.registration_icp( 62 | o3d_rec_pc, o3d_gt_pc, threshold, trans_init, 63 | o3d.pipelines.registration.TransformationEstimationPointToPoint()) 64 | transformation = reg_p2p.transformation 65 | return transformation 66 | 67 | 68 | def check_proj(points, W, H, fx, fy, cx, cy, c2w): 69 | """ 70 | Check if points can be projected into the camera view. 71 | 72 | """ 73 | c2w = c2w.copy() 74 | c2w[:3, 1] *= -1.0 75 | c2w[:3, 2] *= -1.0 76 | points = torch.from_numpy(points).cuda().clone() 77 | w2c = np.linalg.inv(c2w) 78 | w2c = torch.from_numpy(w2c).cuda().float() 79 | K = torch.from_numpy( 80 | np.array([[fx, .0, cx], [.0, fy, cy], [.0, .0, 1.0]]).reshape(3, 3)).cuda() 81 | ones = torch.ones_like(points[:, 0]).reshape(-1, 1).cuda() 82 | homo_points = torch.cat( 83 | [points, ones], dim=1).reshape(-1, 4, 1).cuda().float() # (N, 4) 84 | cam_cord_homo = w2c@homo_points # (N, 4, 1)=(4,4)*(N, 4, 1) 85 | cam_cord = cam_cord_homo[:, :3] # (N, 3, 1) 86 | cam_cord[:, 0] *= -1 87 | uv = K.float()@cam_cord.float() 88 | z = uv[:, -1:]+1e-5 89 | uv = uv[:, :2]/z 90 | uv = uv.float().squeeze(-1).cpu().numpy() 91 | edge = 0 92 | mask = (0 <= -z[:, 0, 0].cpu().numpy()) & (uv[:, 0] < W - 93 | edge) & (uv[:, 0] > edge) & (uv[:, 1] < H-edge) & (uv[:, 1] > edge) 94 | return mask.sum() > 0 95 | 96 | def nn_correspondance(verts1, verts2): 97 | indices = [] 98 | distances = [] 99 | if len(verts1) == 0 or len(verts2) == 0: 100 | return indices, distances 101 | 102 | kdtree = KDTree(verts1) 103 | distances, indices = kdtree.query(verts2) 104 | distances = distances.reshape(-1) 105 | 106 | return distances, indices 107 | 108 | 109 | def calc_3d_metric(rec_meshfile, gt_meshfile, align=False): 110 | """ 111 | 3D reconstruction metric. 112 | 113 | """ 114 | mesh_rec = trimesh.load(rec_meshfile, process=False) 115 | mesh_gt = trimesh.load(gt_meshfile, process=False) 116 | 117 | if align: 118 | transformation = get_align_transformation(rec_meshfile, gt_meshfile) 119 | mesh_rec = mesh_rec.apply_transform(transformation) 120 | 121 | # found the aligned bbox for the mesh 122 | to_align, _ = trimesh.bounds.oriented_bounds(mesh_gt) 123 | mesh_gt.vertices = (to_align[:3, :3] @ mesh_gt.vertices.T + to_align[:3, 3:]).T 124 | mesh_rec.vertices = (to_align[:3, :3] @ mesh_rec.vertices.T + to_align[:3, 3:]).T 125 | 126 | min_points = mesh_gt.vertices.min(axis=0) * 1.005 127 | max_points = mesh_gt.vertices.max(axis=0) * 1.005 128 | 129 | mask_min = (mesh_rec.vertices - min_points[None]) > 0 130 | mask_max = (mesh_rec.vertices - max_points[None]) < 0 131 | 132 | mask = np.concatenate((mask_min, mask_max), axis=1).all(axis=1) 133 | face_mask = mask[mesh_rec.faces].all(axis=1) 134 | 135 | mesh_rec.update_vertices(mask) 136 | mesh_rec.update_faces(face_mask) 137 | 138 | rec_pc = trimesh.sample.sample_surface(mesh_rec, 200000) 139 | rec_pc_tri = trimesh.PointCloud(vertices=rec_pc[0]) 140 | 141 | gt_pc = trimesh.sample.sample_surface(mesh_gt, 200000) 142 | gt_pc_tri = trimesh.PointCloud(vertices=gt_pc[0]) 143 | accuracy_rec, dist_d2s = accuracy(gt_pc_tri.vertices, rec_pc_tri.vertices) 144 | completion_rec, dist_s2d = completion(gt_pc_tri.vertices, rec_pc_tri.vertices) 145 | completion_ratio_rec = completion_ratio( 146 | gt_pc_tri.vertices, rec_pc_tri.vertices) 147 | 148 | precision_ratio_rec = completion_ratio( 149 | rec_pc_tri.vertices, gt_pc_tri.vertices) 150 | 151 | fscore = 2 * precision_ratio_rec * completion_ratio_rec / (completion_ratio_rec + precision_ratio_rec) 152 | 153 | # normal consistency 154 | N = 200000 155 | pointcloud_pred, idx = mesh_rec.sample(N, return_index=True) 156 | pointcloud_pred = pointcloud_pred.astype(np.float32) 157 | normal_pred = mesh_rec.face_normals[idx] 158 | 159 | pointcloud_trgt, idx = mesh_gt.sample(N, return_index=True) 160 | pointcloud_trgt = pointcloud_trgt.astype(np.float32) 161 | normal_trgt = mesh_gt.face_normals[idx] 162 | 163 | _, index1 = nn_correspondance(pointcloud_pred, pointcloud_trgt) 164 | _, index2 = nn_correspondance(pointcloud_trgt, pointcloud_pred) 165 | 166 | normal_acc = np.abs((normal_pred * normal_trgt[index2.reshape(-1)]).sum(axis=-1)).mean() 167 | normal_comp = np.abs((normal_trgt * normal_pred[index1.reshape(-1)]).sum(axis=-1)).mean() 168 | normal_avg = (normal_acc + normal_comp) * 0.5 169 | 170 | accuracy_rec *= 100 # convert to cm 171 | completion_rec *= 100 # convert to cm 172 | completion_ratio_rec *= 100 # convert to % 173 | precision_ratio_rec *= 100 # convert to % 174 | fscore *= 100 175 | normal_acc *= 100 176 | normal_comp *= 100 177 | normal_avg *= 100 178 | 179 | print('Acc: ', accuracy_rec, 'Comp:', completion_rec, 'Rrecision: ', precision_ratio_rec, 'Comp_ratio: ', completion_ratio_rec, 'F_score: ', fscore, 'Normal_acc: ', normal_acc, 'Normal_comp: ', normal_comp, 'Normal_avg: ',normal_avg) 180 | 181 | # add visualization and save the mesh output 182 | vis_dist = 0.05 # dist_th = 5 cm 183 | data_alpha = (dist_d2s.clip(max=vis_dist) / vis_dist).reshape(-1, 1) 184 | #data_color = R * data_alpha + W * (1-data_alpha) 185 | im_gray = (data_alpha * 255).astype(np.uint8) 186 | data_color = cv2.applyColorMap(im_gray, cv2.COLORMAP_JET)[:,0,[2, 0, 1]] / 255. 187 | write_vis_pcd(f'{rec_meshfile}_d2s.ply', rec_pc_tri.vertices, data_color) 188 | 189 | stl_alpha = (dist_s2d.clip(max=vis_dist) / vis_dist).reshape(-1, 1) 190 | #stl_color = R * stl_alpha + W * (1-stl_alpha) 191 | im_gray = (stl_alpha * 255).astype(np.uint8) 192 | stl_color = cv2.applyColorMap(im_gray, cv2.COLORMAP_JET)[:,0,[2, 0, 1]] / 255. 193 | write_vis_pcd(f'{rec_meshfile}_s2d.ply', gt_pc_tri.vertices, stl_color) 194 | 195 | 196 | def get_cam_position(gt_meshfile): 197 | mesh_gt = trimesh.load(gt_meshfile) 198 | to_origin, extents = trimesh.bounds.oriented_bounds(mesh_gt) 199 | extents[2] *= 0.7 200 | extents[1] *= 0.7 201 | extents[0] *= 0.3 202 | transform = np.linalg.inv(to_origin) 203 | transform[2, 3] += 0.4 204 | return extents, transform 205 | 206 | 207 | def calc_2d_metric(rec_meshfile, gt_meshfile, align=False, n_imgs=1000): 208 | """ 209 | 2D reconstruction metric, depth L1 loss. 210 | 211 | """ 212 | H = 500 213 | W = 500 214 | focal = 300 215 | fx = focal 216 | fy = focal 217 | cx = H/2.0-0.5 218 | cy = W/2.0-0.5 219 | 220 | gt_mesh = o3d.io.read_triangle_mesh(gt_meshfile) 221 | rec_mesh = o3d.io.read_triangle_mesh(rec_meshfile) 222 | unseen_gt_pointcloud_file = gt_meshfile.replace('.ply', '_pc_unseen.npy') 223 | pc_unseen = np.load(unseen_gt_pointcloud_file) 224 | if align: 225 | transformation = get_align_transformation(rec_meshfile, gt_meshfile) 226 | rec_mesh = rec_mesh.transform(transformation) 227 | 228 | # get vacant area inside the room 229 | extents, transform = get_cam_position(gt_meshfile) 230 | 231 | vis = o3d.visualization.Visualizer() 232 | vis.create_window(width=W, height=H) 233 | vis.get_render_option().mesh_show_back_face = True 234 | errors = [] 235 | for i in range(n_imgs): 236 | while True: 237 | # sample view, and check if unseen region is not inside the camera view, if inside, then needs to resample 238 | up = [0, 0, -1] 239 | origin = trimesh.sample.volume_rectangular( 240 | extents, 1, transform=transform) 241 | origin = origin.reshape(-1) 242 | tx = round(random.uniform(-10000, +10000), 2) 243 | ty = round(random.uniform(-10000, +10000), 2) 244 | tz = round(random.uniform(-10000, +10000), 2) 245 | target = [tx, ty, tz] 246 | target = np.array(target)-np.array(origin) 247 | c2w = viewmatrix(target, up, origin) 248 | tmp = np.eye(4) 249 | tmp[:3, :] = c2w 250 | c2w = tmp 251 | seen = check_proj(pc_unseen, W, H, fx, fy, cx, cy, c2w) 252 | if (~seen): 253 | break 254 | 255 | param = o3d.camera.PinholeCameraParameters() 256 | param.extrinsic = np.linalg.inv(c2w) # 4x4 numpy array 257 | 258 | param.intrinsic = o3d.camera.PinholeCameraIntrinsic( 259 | W, H, fx, fy, cx, cy) 260 | 261 | ctr = vis.get_view_control() 262 | ctr.set_constant_z_far(20) 263 | ctr.convert_from_pinhole_camera_parameters(param) 264 | 265 | vis.add_geometry(gt_mesh, reset_bounding_box=True,) 266 | ctr.convert_from_pinhole_camera_parameters(param) 267 | vis.poll_events() 268 | vis.update_renderer() 269 | gt_depth = vis.capture_depth_float_buffer(True) 270 | gt_depth = np.asarray(gt_depth) 271 | vis.remove_geometry(gt_mesh, reset_bounding_box=True,) 272 | 273 | vis.add_geometry(rec_mesh, reset_bounding_box=True,) 274 | ctr.convert_from_pinhole_camera_parameters(param) 275 | vis.poll_events() 276 | vis.update_renderer() 277 | ours_depth = vis.capture_depth_float_buffer(True) 278 | ours_depth = np.asarray(ours_depth) 279 | vis.remove_geometry(rec_mesh, reset_bounding_box=True,) 280 | 281 | errors += [np.abs(gt_depth-ours_depth).mean()] 282 | 283 | errors = np.array(errors) 284 | # from m to cm 285 | print('Depth L1: ', errors.mean()*100) 286 | 287 | 288 | if __name__ == '__main__': 289 | 290 | parser = argparse.ArgumentParser( 291 | description='Arguments to evaluate the reconstruction.' 292 | ) 293 | parser.add_argument('--rec_mesh', type=str, 294 | help='reconstructed mesh file path') 295 | parser.add_argument('--gt_mesh', type=str, 296 | help='ground truth mesh file path') 297 | args = parser.parse_args() 298 | calc_3d_metric(args.rec_mesh, args.gt_mesh) 299 | #calc_2d_metric(args.rec_mesh, args.gt_mesh, n_imgs=1000) 300 | -------------------------------------------------------------------------------- /replica_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import cv2 as cv 5 | import numpy as np 6 | import os 7 | import glob 8 | 9 | import trimesh 10 | from pathlib import Path 11 | import subprocess 12 | 13 | 14 | scans = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 15 | 16 | 17 | root_dir = "../exps" # path to the experiment results 18 | exp_name = 'objectsdfplus_replica' # experiment name 19 | out_dir = "evaluation/scene_results" # path to save the scene evaluation results 20 | Path(out_dir).mkdir(parents=True, exist_ok=True) 21 | 22 | # evaluation_txt_file = "evaluation/replica_objsdf_star.csv" 23 | evaluation_txt_file = "evaluation/rebuttal_oneobject.csv" 24 | evaluation_txt_file = open(evaluation_txt_file, 'w') 25 | 26 | 27 | for idx, scan in enumerate(scans): 28 | idx = idx + 1 29 | if scan != "room2": 30 | continue 31 | # test set 32 | #if not (idx in [4, 6, 7]): 33 | # continue 34 | 35 | cur_exp = f"{exp_name}_{idx}" 36 | cur_root = os.path.join(root_dir, cur_exp) 37 | # use first timestamps 38 | dirs = sorted(os.listdir(cur_root)) 39 | cur_root = os.path.join(cur_root, dirs[-1]) 40 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*_whole.ply")))) 41 | 42 | files.sort(key=lambda x:os.path.getmtime(x)) 43 | ply_file = files[-1] 44 | print(ply_file) 45 | 46 | # curmesh 47 | cull_mesh_out = os.path.join(out_dir, f"{scan}.ply") 48 | cmd = f"python cull_mesh.py --input_mesh {ply_file} --input_scalemat ../data/replica/Replica/scan{idx}/cameras.npz --traj ../data/replica/Replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}" 49 | print(cmd) 50 | os.system(cmd) 51 | 52 | cmd = f"python eval_recon.py --rec_mesh {cull_mesh_out} --gt_mesh ../data/replica/Replica/cull_GTmesh/{scan}.ply" 53 | print(cmd) 54 | # accuracy_rec, completion_rec, precision_ratio_rec, completion_ratio_rec, fscore, normal_acc, normal_comp, normal_avg 55 | output = subprocess.check_output(cmd, shell=True).decode("utf-8") 56 | output = output.replace(" ", ",") 57 | print(output) 58 | 59 | evaluation_txt_file.write(f"{scan},{Path(ply_file).name},{output}") 60 | evaluation_txt_file.flush() -------------------------------------------------------------------------------- /replica_eval/evaluate_single_scene.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import subprocess 4 | import argparse 5 | 6 | 7 | scans = ["room0", "room1", "room2", "office0", "office1", "office2", "office3", "office4"] 8 | 9 | if __name__ == "__main__": 10 | 11 | parser = argparse.ArgumentParser( 12 | description='Arguments to evaluate the mesh.' 13 | ) 14 | 15 | parser.add_argument('--input_mesh', type=str, help='path to the mesh to be evaluated') 16 | parser.add_argument('--scan_id', type=str, help='scan id of the input mesh') 17 | parser.add_argument('--output_dir', type=str, default='evaluation_results_single', help='path to the output folder') 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | out_dir = args.output_dir 23 | Path(out_dir).mkdir(parents=True, exist_ok=True) 24 | 25 | idx = args.scan_id 26 | scan = scans[int(idx) - 1] 27 | 28 | ply_file = args.input_mesh 29 | 30 | result_mesh_file = os.path.join(out_dir, "culled_mesh.ply") 31 | 32 | # cumesh 33 | cull_mesh_out = os.path.join(out_dir, f"cull_{scan}.ply") 34 | # cmd = f"python cull_mesh.py --input_mesh {ply_file} --input_scalemat ../data/Replica/scan{idx}/cameras.npz --traj ../data/Replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}" 35 | cmd = f"python cull_mesh.py --input_mesh {ply_file} --input_scalemat ../data/replica/scan{idx}/cameras.npz --traj ../data/replica/scan{idx}/traj.txt --output_mesh {cull_mesh_out}" 36 | print(cmd) 37 | os.system(cmd) 38 | 39 | cmd = f"python eval_recon.py --rec_mesh {cull_mesh_out} --gt_mesh ../data/replica/cull_GTmesh/{scan}.ply" 40 | print(cmd) 41 | # accuracy_rec, completion_rec, precision_ratio_rec, completion_ratio_rec, fscore, normal_acc, normal_comp, normal_avg 42 | output = subprocess.check_output(cmd, shell=True).decode("utf-8") 43 | output = output.replace(" ", ",") 44 | print(output) -------------------------------------------------------------------------------- /replica_eval/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import cKDTree as KDTree 3 | 4 | def completion_ratio(gt_points, rec_points, dist_th=0.01): 5 | gen_points_kd_tree = KDTree(rec_points) 6 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points) 7 | completion = np.mean((one_distances < dist_th).astype(np.float)) 8 | return completion 9 | 10 | 11 | def accuracy(gt_points, rec_points): 12 | gt_points_kd_tree = KDTree(gt_points) 13 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points) 14 | gen_to_gt_chamfer = np.mean(two_distances) 15 | return gen_to_gt_chamfer 16 | 17 | 18 | def completion(gt_points, rec_points): 19 | gt_points_kd_tree = KDTree(rec_points) 20 | one_distances, two_vertex_ids = gt_points_kd_tree.query(gt_points) 21 | gt_to_gen_chamfer = np.mean(one_distances) 22 | return gt_to_gen_chamfer 23 | 24 | 25 | def chamfer(gt_points, rec_points): 26 | # one direction 27 | gen_points_kd_tree = KDTree(rec_points) 28 | one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points) 29 | gt_to_gen_chamfer = np.mean(one_distances) 30 | 31 | # other direction 32 | gt_points_kd_tree = KDTree(gt_points) 33 | two_distances, two_vertex_ids = gt_points_kd_tree.query(rec_points) 34 | gen_to_gt_chamfer = np.mean(two_distances) 35 | 36 | return (gt_to_gen_chamfer + gen_to_gt_chamfer) / 2. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision 3 | torchaudio 4 | torch 5 | trimesh 6 | pyhocon==0.3.59 7 | tqdm 8 | scikit-image 9 | matplotlib 10 | opencv-python 11 | tensorboard 12 | scipy 13 | numpy 14 | plotly 15 | ninja 16 | gdown 17 | -------------------------------------------------------------------------------- /scannet_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/zju3dv/manhattan_sdf 2 | import numpy as np 3 | import open3d as o3d 4 | from sklearn.neighbors import KDTree 5 | import trimesh 6 | import torch 7 | import glob 8 | import os 9 | import pyrender 10 | import os 11 | from tqdm import tqdm 12 | from pathlib import Path 13 | 14 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 15 | 16 | def nn_correspondance(verts1, verts2): 17 | indices = [] 18 | distances = [] 19 | if len(verts1) == 0 or len(verts2) == 0: 20 | return indices, distances 21 | 22 | kdtree = KDTree(verts1) 23 | distances, indices = kdtree.query(verts2) 24 | distances = distances.reshape(-1) 25 | 26 | return distances 27 | 28 | 29 | def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02): 30 | pcd_trgt = o3d.geometry.PointCloud() 31 | pcd_pred = o3d.geometry.PointCloud() 32 | 33 | pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3]) 34 | pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3]) 35 | 36 | if down_sample: 37 | pcd_pred = pcd_pred.voxel_down_sample(down_sample) 38 | pcd_trgt = pcd_trgt.voxel_down_sample(down_sample) 39 | 40 | verts_pred = np.asarray(pcd_pred.points) 41 | verts_trgt = np.asarray(pcd_trgt.points) 42 | 43 | dist1 = nn_correspondance(verts_pred, verts_trgt) 44 | dist2 = nn_correspondance(verts_trgt, verts_pred) 45 | 46 | precision = np.mean((dist2 < threshold).astype('float')) 47 | recal = np.mean((dist1 < threshold).astype('float')) 48 | fscore = 2 * precision * recal / (precision + recal) 49 | metrics = { 50 | 'Acc': np.mean(dist2), 51 | 'Comp': np.mean(dist1), 52 | 'Prec': precision, 53 | 'Recal': recal, 54 | 'F-score': fscore, 55 | } 56 | return metrics 57 | 58 | # hard-coded image size 59 | H, W = 968, 1296 60 | 61 | # load pose 62 | def load_poses(scan_id): 63 | pose_path = os.path.join(f'../data/scannet/scan{scan_id}', 'pose') 64 | # pose_path = os.path.join(f'../data/custom/scan{scan_id}', 'pose') 65 | poses = [] 66 | pose_paths = sorted(glob.glob(os.path.join(pose_path, '*.txt')), 67 | key=lambda x: int(os.path.basename(x)[:-4])) 68 | for pose_path in pose_paths[::10]: 69 | c2w = np.loadtxt(pose_path) 70 | if np.isfinite(c2w).any(): 71 | poses.append(c2w) 72 | poses = np.array(poses) 73 | 74 | return poses 75 | 76 | 77 | class Renderer(): 78 | def __init__(self, height=480, width=640): 79 | self.renderer = pyrender.OffscreenRenderer(width, height) 80 | self.scene = pyrender.Scene() 81 | # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES 82 | 83 | def __call__(self, height, width, intrinsics, pose, mesh): 84 | self.renderer.viewport_height = height 85 | self.renderer.viewport_width = width 86 | self.scene.clear() 87 | self.scene.add(mesh) 88 | cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2], 89 | fx=intrinsics[0, 0], fy=intrinsics[1, 1]) 90 | self.scene.add(cam, pose=self.fix_pose(pose)) 91 | return self.renderer.render(self.scene) # , self.render_flags) 92 | 93 | def fix_pose(self, pose): 94 | # 3D Rotation about the x-axis. 95 | t = np.pi 96 | c = np.cos(t) 97 | s = np.sin(t) 98 | R = np.array([[1, 0, 0], 99 | [0, c, -s], 100 | [0, s, c]]) 101 | axis_transform = np.eye(4) 102 | axis_transform[:3, :3] = R 103 | return pose @ axis_transform 104 | 105 | def mesh_opengl(self, mesh): 106 | return pyrender.Mesh.from_trimesh(mesh) 107 | 108 | def delete(self): 109 | self.renderer.delete() 110 | 111 | 112 | def refuse(mesh, poses, K): 113 | renderer = Renderer() 114 | mesh_opengl = renderer.mesh_opengl(mesh) 115 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 116 | voxel_length=0.01, 117 | sdf_trunc=3 * 0.01, 118 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 119 | ) 120 | 121 | for pose in tqdm(poses): 122 | intrinsic = np.eye(4) 123 | intrinsic[:3, :3] = K 124 | 125 | rgb = np.ones((H, W, 3)) 126 | rgb = (rgb * 255).astype(np.uint8) 127 | rgb = o3d.geometry.Image(rgb) 128 | _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl) 129 | depth_pred = o3d.geometry.Image(depth_pred) 130 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 131 | rgb, depth_pred, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False 132 | ) 133 | fx, fy, cx, cy = intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2] 134 | intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx, fy=fy, cx=cx, cy=cy) 135 | extrinsic = np.linalg.inv(pose) 136 | volume.integrate(rgbd, intrinsic, extrinsic) 137 | 138 | return volume.extract_triangle_mesh() 139 | 140 | 141 | root_dir = "../exps/objectsdfplus_scannet" # path to your experiments 142 | exp_name = "scan" # experiment name 143 | out_dir ="evaluation/eval_mesh" # output directory for evaluation results 144 | Path(out_dir).mkdir(parents=True, exist_ok=True) 145 | 146 | 147 | scenes = ["scene0050_00", "scene0084_00", "scene0580_00", "scene0616_00"] 148 | all_results = [] 149 | for idx, scan in enumerate(scenes): 150 | idx = idx + 1 151 | cur_exp = f"{exp_name}_{idx}" 152 | cur_root = os.path.join(root_dir, cur_exp) 153 | print(cur_root) 154 | if not os.path.exists(cur_root): 155 | print('Current experment {} folder is not exist, skiped'.format(cur_root)) 156 | continue 157 | # use lastest timestamps 158 | dirs = sorted(os.listdir(cur_root)) 159 | cur_root = os.path.join(cur_root, dirs[-1]) 160 | files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*_whole.ply")))) # the whole mesh path 161 | # files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "*_whole.ply")))) 162 | # files = list(filter(os.path.isfile, glob.glob(os.path.join(cur_root, "plots/*.ply")))) 163 | print(files) 164 | 165 | # evalute the latest mesh 166 | files.sort(key=lambda x:os.path.getmtime(x)) 167 | ply_file = files[-1] 168 | print(ply_file) 169 | 170 | mesh = trimesh.load(ply_file) 171 | 172 | # transform to world coordinate 173 | cam_file = f"../data/scannet/scan{idx}/cameras.npz" 174 | # cam_file = f"../data/custom/scan{idx}/cameras.npz" 175 | scale_mat = np.load(cam_file)['scale_mat_0'] 176 | mesh.vertices = (scale_mat[:3, :3] @ mesh.vertices.T + scale_mat[:3, 3:]).T 177 | 178 | # load pose and intrinsic for render depth 179 | poses = load_poses(idx) 180 | 181 | intrinsic_path = os.path.join(f'../data/scannet/scan{idx}/intrinsic/intrinsic_color.txt') 182 | # intrinsic_path = os.path.join(f'../data/custom/scan{idx}/intrinsic/intrinsic_color.txt') 183 | K = np.loadtxt(intrinsic_path)[:3, :3] 184 | 185 | mesh = refuse(mesh, poses, K) 186 | 187 | # save mesh 188 | out_mesh_path = os.path.join(out_dir, f"{exp_name}_scan_{idx}_{scan}.ply") 189 | o3d.io.write_triangle_mesh(out_mesh_path, mesh) 190 | mesh = trimesh.load(out_mesh_path) 191 | 192 | 193 | gt_mesh = os.path.join("../data/scannet/GTmesh", f"{scan}_vh_clean_2.ply") 194 | # gt_mesh = os.path.join("../data/scannet/GTmesh_lowres", f"{scan[5:]}.obj") # a low-res version of mesh for evaluation 195 | 196 | gt_mesh = trimesh.load(gt_mesh) 197 | 198 | metrics = evaluate(mesh, gt_mesh) 199 | print(metrics) 200 | all_results.append(metrics) 201 | 202 | # print all results 203 | for scan, metric in zip(scenes, all_results): 204 | values = [scan] + [str(metric[k]) for k in metric.keys()] 205 | out = ",".join(values) 206 | print(out) 207 | 208 | 209 | # average the all_results 210 | mean_dict = {} 211 | for key in all_results[0].keys(): 212 | mean_dict[key] = sum(d[key] for d in all_results) / len(all_results) 213 | print(mean_dict) 214 | -------------------------------------------------------------------------------- /scripts/download_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | echo "0 - Replica sample (one scene)" 4 | echo "1 - Replca" 5 | echo "2 - ScanNet" 6 | read -p "Enter the dataset ID you want to download: " ds_id 7 | 8 | mkdir -p data 9 | cd data 10 | 11 | if [ $ds_id == 0 ] 12 | then 13 | echo "Download Replica sample dataset, room00" 14 | # Download Replica scenes used in objectsdf++, includine rgb, depth, normal, semantic, instance, pose 15 | gdown --no-cookies 17U8RzDWCtUCNPTDF16pEhFqznbz5u8bc -O replica_sample.zip 16 | echo "done,start unzipping" 17 | unzip -o replica_sample.zip -d replica 18 | rm -rf replica_sample.zip 19 | 20 | elif [ $ds_id == 1 ] 21 | then 22 | echo "Download All Replica dataset, 8 scenes" 23 | # Download Replica scenes used in objectsdf++, includine rgb, depth, normal, semantic, instance, pose 24 | gdown --no-cookies 1IAFNQE3TNyE_ZNdJhDCcPPebWqbTuzYl -O replica.zip 25 | echo "done,start unzipping" 26 | unzip -o replica.zip 27 | rm -rf replica.zip 28 | 29 | elif [ $ds_id == 2 ] 30 | then 31 | # Download scannet scenes follow monosdf 32 | echo "Download ScanNet dataset, 4 scenes" 33 | gdown --no-cookies 1w-HZHhhvc71xOYhFBdZrLYu8FBsNWBhU -O scannet.zip 34 | echo "done,start unzipping" 35 | unzip -o scannet.zip 36 | rm -rf scannet.zip 37 | 38 | else 39 | echo "Invalid dataset ID" 40 | fi 41 | 42 | cd .. -------------------------------------------------------------------------------- /scripts/download_meshes.sh: -------------------------------------------------------------------------------- 1 | gdown --no-cookies 1vDb3JN9lHTvjkBkmj8CQtQxng8fxPl39 -O results.zip 2 | unzip results.zip 3 | rm -rf results.zip --------------------------------------------------------------------------------