├── .gitignore ├── LICENSE ├── README.md ├── configs ├── data │ ├── 7scenes_default.yaml │ ├── 7scenes_dense.yaml │ ├── colmap_dense.yaml │ ├── neucon_arkit_default.yaml │ ├── neucon_arkit_dense.yaml │ ├── scannet_default_test.yaml │ ├── scannet_default_train.yaml │ ├── scannet_default_val.yaml │ ├── scannet_dense_test.yaml │ ├── scannet_dense_val.yaml │ ├── scanniverse_dense.yaml │ ├── vdr_default.yaml │ ├── vdr_dense.yaml │ └── vdr_dense_offline.yaml └── models │ ├── dot_product_model.yaml │ └── hero_model.yaml ├── data_scripts ├── 7scenes_preprocessing.py ├── IOS_LOGGER_ARKIT_README.md ├── generate_test_tuples.py ├── generate_train_tuples.py ├── ios_logger_preprocessing.py ├── precompute_valid_frames.py └── scannet_wrangling_scripts │ ├── LICENSE │ ├── README.md │ ├── SensorData.py │ ├── download_scannet.py │ ├── env.yml │ ├── reader.py │ └── splits │ ├── scannetv2_test.txt │ ├── scannetv2_train.txt │ └── scannetv2_val.txt ├── data_splits ├── 7Scenes │ └── dvmvs_split │ │ ├── dvmvs_test_split.txt │ │ └── test_eight_view_deepvmvs.txt ├── ScanNetv2 │ ├── dvmvs_split │ │ ├── dvmvs_train.txt │ │ ├── dvmvs_val.txt │ │ ├── test_eight_view_deepvmvs.txt │ │ └── test_eight_view_deepvmvs_dense.txt │ └── standard_split │ │ ├── scannetv2_test.txt │ │ ├── scannetv2_train.txt │ │ ├── scannetv2_val.txt │ │ ├── test_eight_view_deepvmvs.txt │ │ ├── test_eight_view_deepvmvs_dense.txt │ │ ├── test_eight_view_deepvmvs_offline.txt │ │ ├── train_eight_view_deepvmvs.txt │ │ └── val_eight_view_deepvmvs.txt ├── arkit │ ├── scans.txt │ ├── test_eight_view_deepvmvs.txt │ └── test_eight_view_deepvmvs_dense.txt └── vdr │ ├── scans.txt │ ├── test_eight_view_deepvmvs.txt │ ├── test_eight_view_deepvmvs_dense.txt │ └── test_eight_view_deepvmvs_dense_offline.txt ├── datasets ├── arkit_dataset.py ├── colmap_dataset.py ├── generic_mvs_dataset.py ├── scannet_dataset.py ├── scanniverse_dataset.py ├── seven_scenes_dataset.py └── vdr_dataset.py ├── experiment_modules └── depth_model.py ├── losses.py ├── media ├── arkit_ioslogger_snapshot.png └── teaser.jpeg ├── modules ├── cost_volume.py ├── layers.py └── networks.py ├── options.py ├── pc_fusion.py ├── simplerecon_env.yml ├── test.py ├── tools ├── fusers_helper.py ├── keyframe_buffer.py ├── mesh_renderer.py ├── torch_point_cloud_fusion.py └── tsdf.py ├── train.py ├── utils ├── dataset_utils.py ├── generic_utils.py ├── geometry_utils.py ├── metrics_utils.py └── visualization_utils.py ├── visualization_scripts ├── generate_gt_min_max_cache.py ├── load_meshes_and_include_normals.py └── visualize_scene_depth_output.py ├── visualize_live_meshing.py └── weights └── strip_checkpoint.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode/ 3 | debug_scripts/ 4 | logs/ 5 | logs 6 | models/*.ckpt 7 | weights/*.ckpt 8 | **__pycache__** 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © Niantic Spatial, Inc. 2022. Patent Pending. 2 | 3 | All rights reserved. 4 | 5 | 6 | 7 | ================================================================================ 8 | 9 | 10 | 11 | This Software is licensed under the terms of the following MVSAnywhere license 12 | which allows for non-commercial use only. For any other use of the software not 13 | covered by the terms of this license, please complete the form at 14 | https://www.nianticspatial.com/partner-with-us. 15 | 16 | 17 | 18 | ================================================================================ 19 | 20 | 21 | 22 | SimpleRecon License 23 | 24 | 25 | This Agreement is made by and between the Licensor and the Licensee as 26 | defined and identified below. 27 | 28 | 29 | 1. Definitions. 30 | 31 | In this Agreement (“the Agreement”) the following words shall have the 32 | following meanings: 33 | 34 | "Authors" shall mean M. Sayed, J. Gibson, J. Watson, V. Prisacariu, 35 | M. Firman, C. Godard 36 | "Licensee" Shall mean the person or organization agreeing to use the 37 | Software in accordance with these terms and conditions. 38 | "Licensor" shall mean Niantic Inc., a company organized and existing under 39 | the laws of Delaware, whose principal place of business is at 1 Ferry Building, 40 | Suite 200, San Francisco, 94111. 41 | "Software" shall mean the SimpleRecon Software uploaded by Licensor to the 42 | GitHub repository at https://github.com/nianticlabs/SimpleRecon 43 | on September 1st 2022 in source code or object code form and any 44 | accompanying documentation as well as any modifications or additions uploaded 45 | to the same GitHub repository by Licensor. 46 | 47 | 48 | 2. License. 49 | 50 | 2.1 The Licensor has all necessary rights to grant a license under: (i) 51 | copyright and rights in the nature of copyright subsisting in the Software; and 52 | (ii) certain patent rights resulting from a patent application(s) filed by the 53 | Licensor in the United States and/or other jurisdictions in connection with the 54 | Software. The Licensor grants the Licensee for the duration of this Agreement, 55 | a free of charge, non-sublicenseable, non-exclusive, non-transferable copyright 56 | and patent license (in consequence of said patent application(s)) to use the 57 | Software for non-commercial purpose only, including teaching and research at 58 | educational institutions and research at not-for-profit research institutions 59 | in accordance with the provisions of this Agreement. Non-commercial use 60 | expressly excludes any profit-making or commercial activities, including without 61 | limitation sale, license, manufacture or development of commercial products, use in 62 | commercially-sponsored research, use at a laboratory or other facility owned or 63 | controlled (whether in whole or in part) by a commercial entity, provision of 64 | consulting service, use for or on behalf of any commercial entity, use in 65 | research where a commercial party obtains rights to research results or any 66 | other benefit, and use of the code in any models, model weights or code 67 | resulting from such procedure in any commercial product. Notwithstanding the 68 | foregoing restrictions, you can use this code for publishing comparison results 69 | for academic papers, including retraining on your own data. Any use of the 70 | Software for any purpose other than pursuant to the license grant set forth 71 | above shall automatically terminate this License. 72 | 73 | 74 | 2.2 The Licensee is permitted to make modifications to the Software 75 | provided that any distribution of such modifications is in accordance with 76 | Clause 3. 77 | 78 | 2.3 Except as expressly permitted by this Agreement and save to the 79 | extent and in the circumstances expressly required to be permitted by law, the 80 | Licensee is not permitted to rent, lease, sell, offer to sell, or loan the 81 | Software or its associated documentation. 82 | 83 | 84 | 3. Redistribution and modifications 85 | 86 | 3.1 The Licensee may reproduce and distribute copies of the Software, with 87 | or without modifications, in source format only and provided that any and every 88 | distribution is accompanied by an unmodified copy of this License and that the 89 | following copyright notice is always displayed in an obvious manner: Copyright 90 | © Niantic Spatial, Inc. 2022. All rights reserved. 91 | 92 | 93 | 3.2 In the case where the Software has been modified, any distribution must 94 | include prominent notices indicating which files have been changed. 95 | 96 | 3.3 The Licensee shall cause any work that it distributes or publishes, 97 | that in whole or in part contains or is derived from the Software or any part 98 | thereof (“Work based on the Software”), to be licensed as a whole at no charge 99 | to all third parties entitled to a license to the Software under the terms of 100 | this License and on the same terms provided in this License. 101 | 102 | 103 | 4. Duration. 104 | 105 | This Agreement is effective until the Licensee terminates it by destroying 106 | the Software, any Work based on the Software, and its documentation together 107 | with all copies. It will also terminate automatically if the Licensee fails to 108 | abide by its terms. Upon automatic termination the Licensee agrees to destroy 109 | all copies of the Software, Work based on the Software, and its documentation. 110 | 111 | 112 | 5. Disclaimer of Warranties. 113 | 114 | The Software is provided as is. To the maximum extent permitted by law, 115 | Licensor provides no warranties or conditions of any kind, either express or 116 | implied, including without limitation, any warranties or condition of title, 117 | non-infringement or fitness for a particular purpose. 118 | 119 | 120 | 6. LIMITATION OF LIABILITY. 121 | 122 | IN NO EVENT SHALL THE LICENSOR AND/OR AUTHORS BE LIABLE FOR ANY DIRECT, 123 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING 124 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 125 | DATA OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 126 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 127 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 128 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 129 | 130 | 131 | 7. Indemnity. 132 | 133 | The Licensee shall indemnify the Licensor and/or Authors against all third 134 | party claims that may be asserted against or suffered by the Licensor and/or 135 | Authors and which relate to use of the Software by the Licensee. 136 | 137 | 138 | 8. Intellectual Property. 139 | 140 | 8.1 As between the Licensee and Licensor, copyright and all other 141 | intellectual property rights subsisting in or in connection with the Software 142 | and supporting information shall remain at all times the property of the 143 | Licensor. The Licensee shall acquire no rights in any such material except as 144 | expressly provided in this Agreement. 145 | 146 | 8.2 No permission is granted to use the trademarks or product names of the 147 | Licensor except as required for reasonable and customary use in describing the 148 | origin of the Software and for the purposes of abiding by the terms of Clause 149 | 3.1. 150 | 151 | 8.3 The Licensee shall promptly notify the Licensor of any improvement or 152 | new use of the Software (“Improvements”) in sufficient detail for Licensor to 153 | evaluate the Improvements. The Licensee hereby grants the Licensor and its 154 | affiliates a non-exclusive, fully paid-up, royalty-free, irrevocable and 155 | perpetual license to all Improvements for non-commercial academic research and 156 | teaching purposes upon creation of such improvements. 157 | 158 | 8.4 The Licensee grants an exclusive first option to the Licensor to be 159 | exercised by the Licensor within three (3) years of the date of notification of 160 | an Improvement under Clause 8.3 to use any the Improvement for commercial 161 | purposes on terms to be negotiated and agreed by Licensee and Licensor in good 162 | faith within a period of six (6) months from the date of exercise of the said 163 | option (including without limitation any royalty share in net income from such 164 | commercialization payable to the Licensee, as the case may be). 165 | 166 | 167 | 9. Acknowledgements. 168 | 169 | The Licensee shall acknowledge the Authors and use of the Software in the 170 | publication of any work that uses, or results that are achieved through, the 171 | use of the Software. The following citation shall be included in the 172 | acknowledgement: “SimpleRecon: 3D Reconstruction Without 3D Convolutions", 173 | by M. Sayed, J. Gibson, J. Watson, V. Prisacariu, M. Firman, C. Godard, 174 | arXiv:2208.14743”. 175 | 176 | 177 | 10. Governing Law. 178 | 179 | This Agreement shall be governed by, construed and interpreted in 180 | accordance with English law and the parties submit to the exclusive 181 | jurisdiction of the English courts. 182 | 183 | 184 | 11. Termination. 185 | 186 | Upon termination of this Agreement, the licenses granted hereunder will 187 | terminate and Sections 5, 6, 7, 8, 9, 10 and 11 shall survive any termination 188 | of this Agreement. 189 | -------------------------------------------------------------------------------- /configs/data/7scenes_default.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/res_nas/shared/datasets/7scenes/ 3 | tuple_info_file_location: data_splits/7Scenes/dvmvs_split/ 4 | dataset_scan_split_file: data_splits/7Scenes/dvmvs_split/dvmvs_test_split.txt 5 | dataset: 7scenes 6 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: default 9 | split: test -------------------------------------------------------------------------------- /configs/data/7scenes_dense.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/res_nas/shared/datasets/7scenes/ 3 | tuple_info_file_location: data_splits/7Scenes/dvmvs_split/ 4 | dataset_scan_split_file: data_splits/7Scenes/dvmvs_split/dvmvs_test_split.txt 5 | dataset: 7scenes 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: test -------------------------------------------------------------------------------- /configs/data/colmap_dense.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/res_nas/mohameds/datasets/colmap/ 3 | tuple_info_file_location: /mnt/res_nas/mohameds/datasets/colmap/tuples 4 | dataset_scan_split_file: /mnt/res_nas/mohameds/datasets/colmap/test.txt 5 | dataset: colmap 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: test -------------------------------------------------------------------------------- /configs/data/neucon_arkit_default.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/res_nas/mohameds/datasets/arkit 3 | tuple_info_file_location: data_splits/arkit/ 4 | dataset_scan_split_file: data_splits/arkit/scans.txt 5 | dataset: arkit 6 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: default 9 | split: test -------------------------------------------------------------------------------- /configs/data/neucon_arkit_dense.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/res_nas/mohameds/datasets/arkit 3 | tuple_info_file_location: data_splits/arkit/ 4 | dataset_scan_split_file: data_splits/arkit/scans.txt 5 | dataset: arkit 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: test -------------------------------------------------------------------------------- /configs/data/scannet_default_test.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/scannet-data-png2/ 3 | tuple_info_file_location: data_splits/ScanNetv2/standard_split 4 | dataset_scan_split_file: data_splits/ScanNetv2/standard_split/scannetv2_test.txt 5 | dataset: scannet 6 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 7 | num_images_in_tuple: 8 8 | # default means every keyframe defined by DVMVS. 9 | frame_tuple_type: default 10 | split: test -------------------------------------------------------------------------------- /configs/data/scannet_default_train.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/scannet-data-png2/ 3 | tuple_info_file_location: data_splits/ScanNetv2/standard_split/ 4 | dataset_scan_split_file: data_splits/ScanNetv2/standard_split/scannetv2_train.txt 5 | dataset: scannet 6 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: default 9 | split: train -------------------------------------------------------------------------------- /configs/data/scannet_default_val.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/scannet-data-png2/ 3 | tuple_info_file_location: data_splits/ScanNetv2/standard_split/ 4 | dataset_scan_split_file: data_splits/ScanNetv2/standard_split/scannetv2_val.txt 5 | dataset: scannet 6 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: default 9 | split: val -------------------------------------------------------------------------------- /configs/data/scannet_dense_test.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/scannet-data-png2/ 3 | tuple_info_file_location: data_splits/ScanNetv2/standard_split/ 4 | dataset_scan_split_file: data_splits/ScanNetv2/standard_split/scannetv2_test.txt 5 | dataset: scannet 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: test -------------------------------------------------------------------------------- /configs/data/scannet_dense_val.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/scannet-data-png2/ 3 | tuple_info_file_location: data_splits/ScanNetv2/standard_split/ 4 | dataset_scan_split_file: data_splits/ScanNetv2/standard_split/scannetv2_val.txt 5 | dataset: scannet 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: val -------------------------------------------------------------------------------- /configs/data/scanniverse_dense.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /mnt/res_nas/mohameds/datasets/scanniverse/ 3 | tuple_info_file_location: /mnt/res_nas/mohameds/datasets/scanniverse/tuples 4 | dataset_scan_split_file: /mnt/res_nas/mohameds/datasets/scanniverse/scans.txt 5 | dataset: scanniverse 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: test -------------------------------------------------------------------------------- /configs/data/vdr_default.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /media/data/mosayed/datasets/vdr 3 | tuple_info_file_location: data_splits/vdr/ 4 | dataset_scan_split_file: data_splits/vdr/scans.txt 5 | dataset: vdr 6 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: default 9 | split: test -------------------------------------------------------------------------------- /configs/data/vdr_dense.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /media/data/mosayed/datasets/vdr 3 | tuple_info_file_location: data_splits/vdr/ 4 | dataset_scan_split_file: data_splits/vdr/scans.txt 5 | dataset: vdr 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense 9 | split: test -------------------------------------------------------------------------------- /configs/data/vdr_dense_offline.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | dataset_path: /media/data/mosayed/datasets/vdr 3 | tuple_info_file_location: data_splits/vdr/ 4 | dataset_scan_split_file: data_splits/vdr/scans.txt 5 | dataset: vdr 6 | mv_tuple_file_suffix: _eight_view_deepvmvs_dense_offline.txt 7 | num_images_in_tuple: 8 8 | frame_tuple_type: dense_offline 9 | split: test -------------------------------------------------------------------------------- /configs/models/dot_product_model.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | feature_volume_type: simple_cost_volume 3 | batch_size: 16 4 | cost_volume_aggregation: dot 5 | cv_encoder_type: multi_scale_encoder 6 | depth_decoder_name: unet_pp 7 | gpus: 2 8 | image_encoder_name: efficientnet 9 | log_interval: 100 10 | loss_type: log_l1 11 | lr: 0.0001 12 | wd: 0.0001 13 | matching_encoder_type: resnet 14 | name: dot_product_model 15 | num_sanity_val_steps: 0 16 | num_workers: 12 17 | precision: 16 18 | random_seed: 0 -------------------------------------------------------------------------------- /configs/models/hero_model.yaml: -------------------------------------------------------------------------------- 1 | !!python/object:options.Options 2 | feature_volume_type: mlp_feature_volume 3 | batch_size: 16 4 | cost_volume_aggregation: dot 5 | cv_encoder_type: multi_scale_encoder 6 | depth_decoder_name: unet_pp 7 | gpus: 2 8 | image_encoder_name: efficientnet 9 | log_interval: 100 10 | loss_type: log_l1 11 | lr: 0.0001 12 | wd: 0.0001 13 | matching_encoder_type: resnet 14 | name: hero_model 15 | num_sanity_val_steps: 0 16 | num_workers: 12 17 | precision: 16 18 | random_seed: 0 -------------------------------------------------------------------------------- /data_scripts/7scenes_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/tsattler/visloc_pseudo_gt_limitations/ 2 | 3 | import os 4 | import warnings 5 | 6 | import numpy as np 7 | from skimage import io 8 | from joblib import Parallel, delayed 9 | 10 | # name of the folder where we download the original 7scenes dataset to 11 | # we restructure the dataset by creating symbolic links to that folder 12 | src_folder = '/mnt/res_nas/shared/datasets/7scenes' 13 | focal_length = 525.0 14 | 15 | # focal length of the depth sensor (source: https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/) 16 | d_focal_length = 585.0 17 | 18 | # RGB image dimensions 19 | img_w = 640 20 | img_h = 480 21 | 22 | # sub sampling factor of eye coordinate tensor 23 | nn_subsampling = 8 24 | 25 | #transformation from depth sensor to RGB sensor 26 | #calibration according to https://projet.liris.cnrs.fr/voir/activities-dataset/kinect-calibration.html 27 | d_to_rgb = np.array([ 28 | [ 9.9996518012567637e-01, 2.6765126468950343e-03, -7.9041012313000904e-03, -2.5558943178152542e-02], 29 | [-2.7409311281316700e-03, 9.9996302803027592e-01, -8.1504520778013286e-03, 1.0109636268061706e-04], 30 | [ 7.8819942130445332e-03, 8.1718328771890631e-03, 9.9993554558014031e-01, 2.0318321729487039e-03], 31 | [0, 0, 0, 1] 32 | ]) 33 | 34 | def mkdir(directory): 35 | """Checks whether the directory exists and creates it if necessacy.""" 36 | if not os.path.exists(directory): 37 | os.makedirs(directory) 38 | 39 | # download the original 7 scenes dataset for poses and images 40 | # mkdir(src_folder) 41 | # os.chdir(src_folder) 42 | 43 | # for ds in ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']: 44 | 45 | # print("=== Downloading 7scenes Data:", ds, "===============================") 46 | 47 | # os.system('wget http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/' + ds + '.zip') 48 | # os.system('unzip ' + ds + '.zip') 49 | # os.system('rm ' + ds + '.zip') 50 | 51 | # sequences = os.listdir(ds) 52 | 53 | # for file in sequences: 54 | # if file.endswith('.zip'): 55 | 56 | # print("Unpacking", file) 57 | # os.system('unzip ' + ds + '/' + file + ' -d ' + ds) 58 | # os.system('rm ' + ds + '/' + file) 59 | 60 | # print("Processing frames...") 61 | 62 | def process_scene(ds): 63 | 64 | def process_frames(split_file): 65 | 66 | # read the split file 67 | with open(ds + '/' + split_file, 'r') as f: 68 | split = f.readlines() 69 | # map sequences to folder names 70 | split = ['seq-' + s.strip()[8:].zfill(2) for s in split] 71 | 72 | for seq in split: 73 | files = os.listdir(ds + '/' + seq) 74 | 75 | # adjust depth files by mapping to RGB sensor 76 | depth_files = [f for f in files if f.endswith('depth.png')] 77 | 78 | for d_index, d_file in enumerate(depth_files): 79 | if d_index % 1000 == 0: 80 | print(d_index, ds, split_file) 81 | 82 | depth = io.imread(ds + '/' + seq + '/' + d_file) 83 | depth = depth.astype(np.float32) 84 | depth /= 1000 # from millimeters to meters 85 | 86 | d_h = depth.shape[0] 87 | d_w = depth.shape[1] 88 | 89 | # reproject depth map to 3D eye coordinates 90 | eye_coords = np.zeros((4, d_h, d_w)) 91 | # set x and y coordinates 92 | eye_coords[0] = 0.5 + np.dstack([np.arange(0, d_w)] * d_h)[0].T 93 | eye_coords[1] = 0.5 + np.dstack([np.arange(0, d_h)] * d_w)[0] 94 | 95 | eye_coords = eye_coords.reshape(4, -1) 96 | depth = depth.reshape(-1) 97 | 98 | # filter pixels with invalid depth 99 | depth_mask = (depth > 0) & (depth < 100) 100 | eye_coords = eye_coords[:, depth_mask] 101 | depth = depth[depth_mask] 102 | 103 | # substract depth principal point (assume image center) 104 | eye_coords[0] -= d_w / 2 105 | eye_coords[1] -= d_h / 2 106 | # reproject 107 | eye_coords[0:2] /= d_focal_length 108 | eye_coords[0] *= depth 109 | eye_coords[1] *= depth 110 | eye_coords[2] = depth 111 | eye_coords[3] = 1 112 | 113 | # transform from depth sensor to RGB sensor 114 | eye_coords = np.matmul(d_to_rgb, eye_coords) 115 | 116 | # project 117 | depth = eye_coords[2] 118 | 119 | eye_coords[0] /= depth 120 | eye_coords[1] /= depth 121 | eye_coords[0:2] *= focal_length 122 | 123 | # add RGB principal point (assume image center) 124 | eye_coords[0] += img_w / 2 125 | eye_coords[1] += img_h / 2 126 | 127 | registered_depth = np.ones((img_h, img_w), dtype=np.float32) * 2e3 128 | 129 | for pt in range(eye_coords.shape[1]): 130 | x = round(eye_coords[0, pt]) 131 | y = round(eye_coords[1, pt]) 132 | d = eye_coords[2, pt] 133 | 134 | if x < 0 or y < 0 or y >= d_h or x >= d_w: 135 | continue 136 | 137 | registered_depth[y, x] = min(registered_depth[y, x], d) 138 | 139 | registered_depth[registered_depth > 1e3] = 0 140 | registered_depth = (1000 * registered_depth).astype(np.uint16) 141 | 142 | # store calibrated depth 143 | with warnings.catch_warnings(): 144 | warnings.simplefilter("ignore") 145 | io.imsave(ds + '/' + seq + '/' + d_file.replace("depth.png", "depth.proj.png"), registered_depth) 146 | 147 | process_frames('TrainSplit.txt') 148 | process_frames('TestSplit.txt') 149 | 150 | Parallel(n_jobs=7, verbose=0)( 151 | map(delayed(process_scene), [os.path.join(src_folder, scan_name) for scan_name in ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']])) -------------------------------------------------------------------------------- /data_scripts/IOS_LOGGER_ARKIT_README.md: -------------------------------------------------------------------------------- 1 | # Running with NeuralRecon's demo data and data from ios-logger. 2 | 3 | Download the demo scene out of ios-logger from here: https://github.com/zju3dv/NeuralRecon/blob/master/DEMO.md 4 | 5 | Follow the instructions in the NeuralRecon repo for how to use ios-logger to make your own captures. 6 | 7 | Unzip the folder into your arkit dataset path so that it looks somthing like this: 8 | 9 | ``` 10 | dataset_path 11 | scans 12 | neucon_demodata_b5f1 13 | ... 14 | .... 15 | ``` 16 | 17 | Run the extraction script that uses modified versions of the functions provided by the NeuralRecon authors: 18 | 19 | ```bash 20 | python ./data_scripts/ios_logger_preprocessing.py --data_config configs/data/neucon_arkit_default.yaml 21 | ``` 22 | 23 | Make sure you set your correct `dataset_path` folder. 24 | 25 | Run tuple file generation (we've already computed one for you in data_splits): 26 | 27 | ```bash 28 | python ./data_scripts/generate_test_tuples.py --num_workers 16 --data_config configs/data/neucon_arkit_default.yaml 29 | ``` 30 | 31 | Then run the model using this config file, see the full readme for more. 32 | 33 | There is unfortunately a break in the pose in the NR demo scene, so you'll to trim the first 350 frames using `--skip_to_frame 350` when running dense frames and `--skip_to_frame 83` when running default. 34 | 35 | Run: 36 | 37 | ```bash 38 | CUDA_VISIBLE_DEVICES=0 python test.py --name HERO_MODEL \ 39 | --output_base_path OUTPUT_PATH \ 40 | --config_file configs/models/hero_model.yaml \ 41 | --load_weights_from_checkpoint weights/hero_model.ckpt \ 42 | --data_config configs/data/neucon_arkit_default.yaml \ 43 | --num_workers 8 \ 44 | --batch_size 2 \ 45 | --fast_cost_volume \ 46 | --run_fusion \ 47 | --depth_fuser open3d \ 48 | --fuse_color \ 49 | --skip_to_frame 83; 50 | ``` 51 | 52 | ```bash 53 | CUDA_VISIBLE_DEVICES=0 python test.py --name HERO_MODEL \ 54 | --output_base_path OUTPUT_PATH \ 55 | --config_file configs/models/hero_model.yaml \ 56 | --load_weights_from_checkpoint weights/hero_model.ckpt \ 57 | --data_config configs/data/neucon_arkit_dense.yaml \ 58 | --num_workers 8 \ 59 | --batch_size 2 \ 60 | --fast_cost_volume \ 61 | --run_fusion \ 62 | --depth_fuser open3d \ 63 | --fuse_color \ 64 | --skip_to_frame 350; 65 | ``` 66 | 67 | Should get an output that looks like this for default frames: 68 | 69 | ![alt text](../media/arkit_ioslogger_snapshot.png) -------------------------------------------------------------------------------- /data_scripts/ios_logger_preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 4 | from datasets.arkit_dataset import process_data 5 | import options 6 | 7 | """ Download scans and extract each to a folder with their name like so: 8 | dataset_path 9 | scans 10 | neucon_demodata_b5f1 11 | ... 12 | .... 13 | """ 14 | 15 | option_handler = options.OptionsHandler() 16 | option_handler.parse_and_merge_options() 17 | opts = option_handler.options 18 | 19 | 20 | if opts.dataset_scan_split_file is not None: 21 | f = open(opts.dataset_scan_split_file, "r") 22 | scans = f.readlines() 23 | scans = [scan.strip() for scan in scans] 24 | f.close() 25 | elif opts.single_debug_scan_id is not None: 26 | scans = [opts.single_debug_scan_id] 27 | else: 28 | print("No valid scans pointers.") 29 | 30 | for scan in scans: 31 | path_dir = os.path.join(opts.dataset_path, "scans", scan) 32 | process_data(path_dir) -------------------------------------------------------------------------------- /data_scripts/precompute_valid_frames.py: -------------------------------------------------------------------------------- 1 | 2 | """Script for precomputing and storing a list of valid frames per scan. A valid 3 | frame is defined as one that has an existing RGB frame, an existing depth 4 | map, and a valid pose. 5 | 6 | Run like so for test (default): 7 | 8 | python ./scripts/precompute_valid_frames.py 9 | --data_config configs/data/scannet_default_test.yaml 10 | --num_workers 16 11 | --split test 12 | 13 | where scannet_default_test.yaml looks like: 14 | !!python/object:options.Options 15 | dataset_path: SCANNET_PATH/ 16 | tuple_info_file_location: SCANNET_PATH/tuples 17 | dataset_scan_split_file: SCANNET_PATH/scannetv2_test.txt 18 | dataset: scannet 19 | mv_tuple_file_suffix: _eight_view_deepvmvs.txt 20 | num_images_in_tuple: 8 21 | frame_tuple_type: default 22 | 23 | For validation, use scannet_default_val.yaml, and for train use 24 | scannet_default_train.yaml 25 | 26 | It will save a valid_frames.txt file where the dataset class defines. 27 | 28 | """ 29 | 30 | import os 31 | import sys 32 | 33 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 34 | import random 35 | import string 36 | from functools import partial 37 | from multiprocessing import Manager 38 | from multiprocessing.pool import Pool 39 | from pathlib import Path 40 | 41 | import numpy as np 42 | import options 43 | from utils.dataset_utils import get_dataset 44 | 45 | def process_scan(opts_temp_filepath, scan, count, progress): 46 | """ 47 | Precomputes a scan's valid frames by calling the dataset's appropriate 48 | function. 49 | 50 | Args: 51 | opts_temp_filepath: filepath for an options config file. 52 | scan: scan to operate on. 53 | count: total count of multi process scans. 54 | progress: a Pool() progress value for tracking progress. For debugging 55 | you can pass 56 | multiprocessing.Manager().Value('i', 0) 57 | for this. 58 | 59 | """ 60 | item_list = [] 61 | 62 | # load options file 63 | option_handler = options.OptionsHandler() 64 | option_handler.parse_and_merge_options(config_filepaths=opts_temp_filepath, 65 | ignore_cl_args=True) 66 | opts = option_handler.options 67 | 68 | # get dataset 69 | dataset_class, _ = get_dataset( 70 | opts.dataset, 71 | opts.dataset_scan_split_file, 72 | opts.single_debug_scan_id, 73 | verbose=False, 74 | ) 75 | 76 | ds = dataset_class( 77 | dataset_path=opts.dataset_path, 78 | mv_tuple_file_suffix=None, 79 | split=opts.split, 80 | tuple_info_file_location=opts.tuple_info_file_location, 81 | pass_frame_id=True, 82 | verbose_init=False, 83 | ) 84 | 85 | _ = ds.get_valid_frame_ids(opts.split, scan) 86 | 87 | progress.value += 1 88 | print(f"Completed scan {scan}, {progress.value} of total {count}.") 89 | 90 | return item_list 91 | 92 | def multi_process_scans(opts_temp_filepath, opts, scans): 93 | """ 94 | Multiprocessing helper for crawl_subprocess_long and crawl_subprocess_long. 95 | 96 | Precomputes a scan's valid frames by calling the dataset's appropriate 97 | function. 98 | 99 | Args: 100 | opts_temp_filepath: filepath for an options config file. 101 | opts: options dataclass. 102 | scans: scans to multiprocess. 103 | """ 104 | pool = Pool(opts.num_workers) 105 | manager = Manager() 106 | 107 | count = len(scans) 108 | progress = manager.Value('i', 0) 109 | 110 | item_list = [] 111 | 112 | for scan_item_list in pool.imap_unordered( 113 | partial( 114 | process_scan, 115 | opts_temp_filepath, 116 | count=count, 117 | progress=progress 118 | ), 119 | scans, 120 | ): 121 | item_list.extend(scan_item_list) 122 | 123 | return item_list 124 | 125 | if __name__ == '__main__': 126 | 127 | 128 | # load options file 129 | option_handler = options.OptionsHandler() 130 | option_handler.parse_and_merge_options(ignore_cl_args=False) 131 | option_handler.pretty_print_options() 132 | opts = option_handler.options 133 | opts_temp_filepath = os.path.join( 134 | os.path.expanduser("~"), 135 | "tmp/", 136 | ''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) 137 | + ".yaml", 138 | ) 139 | option_handler.save_options_as_yaml(opts_temp_filepath, opts) 140 | 141 | np.random.seed(42) 142 | random.seed(42) 143 | 144 | if opts.gpus == 0: 145 | print("Setting precision to 32 bits since --gpus is set to 0.") 146 | opts.precision = 32 147 | 148 | # get dataset 149 | dataset_class, scan_names = get_dataset( 150 | opts.dataset, 151 | opts.dataset_scan_split_file, 152 | opts.single_debug_scan_id, 153 | ) 154 | 155 | item_list = [] 156 | 157 | Path(opts.tuple_info_file_location).mkdir(exist_ok=True, parents=True) 158 | split_filename = f"{opts.split}{opts.mv_tuple_file_suffix}" 159 | split_filepath = os.path.join(opts.tuple_info_file_location, split_filename) 160 | print(f"Processing valid frames.\n") 161 | 162 | if opts.single_debug_scan_id is not None: 163 | item_list = process_scan( 164 | opts_temp_filepath, 165 | opts.single_debug_scan_id, 166 | 0, 167 | Manager().Value('i', 0), 168 | ) 169 | else: 170 | item_list = multi_process_scans(opts_temp_filepath, opts, scan_names) 171 | 172 | with open(split_filepath, "w") as f: 173 | for line in item_list: 174 | f.write(line + "\n") 175 | 176 | print(f"Complete") 177 | -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/LICENSE: -------------------------------------------------------------------------------- 1 | The LICENSE applies only to reader.py and SensorData.py. 2 | 3 | Copyright 2017 4 | Angela Dai, Angel X. Chang, Manolis Savva, Maciej Halber, Thomas Funkhouser, Matthias Niessner 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | 8 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 9 | 10 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/README.md: -------------------------------------------------------------------------------- 1 | # Downloading and Extracting ScanNetv2 2 | 3 | 4 | Developed and tested with python 3.9. 5 | 6 | The included license at LICENSE applies only to `reader.py` and `SensorData.py`. 7 | 8 | 9 | These scripts should help you export ScanNetv2 to the following format: 10 | 11 | SCANNET_ROOT 12 | scans_test (test scans) 13 | scene0707 14 | scene0707_00_vh_clean_2.ply (gt mesh) 15 | sensor_data 16 | frame-000261.pose.txt 17 | frame-000261.color.jpg 18 | frame-000261.color.512.png (optional, image at 512x384) 19 | frame-000261.color.640.png (optional, image at 640x480) 20 | frame-000261.depth.png (full res depth, stored scale *1000) 21 | frame-000261.depth.256.png (optional, depth at 256x192 also 22 | scaled) 23 | scene0707.txt (scan metadata and image sizes) 24 | intrinsic 25 | intrinsic_depth.txt 26 | intrinsic_color.txt 27 | 28 | ... 29 | scans (val and train scans) 30 | scene0000_00 31 | (see above) 32 | scene0000_01 33 | .... 34 | 35 | Make sure all the packages in `env.yml` are installed in your environment. 36 | 37 | ## Downloading ScanNetv2 38 | 39 | The `download_scannet.py` script is from https://kaldir.vc.in.tum.de/scannet/download-scannet.py 40 | 41 | Please make sure you fill in this form before downloading the data: 42 | https://kaldir.vc.in.tum.de/scannet/ScanNet_TOS.pdf 43 | 44 | Download the dataset by running: 45 | ``` 46 | python download_scannet.py -o SCANNET_ROOT 47 | ``` 48 | 49 | For one scan debug use: 50 | ``` 51 | python download_scannet.py -o SCANNET_ROOT --id scene0707_00 52 | ``` 53 | 54 | This will download a `.sens` file, `.txt` file, the high resolution mesh `ply`, and a lower resolution mesh `ply`. 55 | 56 | `.txt` will include meta information for the scan. See the next section for extracting the `.sens` file. 57 | 58 | ## Extracting data from .sens files 59 | 60 | Please use the intrinsics directly from the downloaded `.txt` file from the dataset. 61 | 62 | This is a modified version of the SensReader python script at 63 | https://github.com/ScanNet/ScanNet/tree/master/SensReader/python 64 | 65 | 66 | `reader.py` will extract depth, jpg, and intrinics files from ScanNetv2's downloaded `.sens` files. It will dump the `jpg` data directly to disk without uncompressing/compressing. 67 | 68 | To extract all scans for test: 69 | ``` 70 | python reader.py --scans_folder SCANNET_ROOT/scans_test \ 71 | --output_path OUTPUT_PATH/scans_test \ 72 | --scan_list_file splits/scannetv2_test.txt \ 73 | --num_workers 12 \ 74 | --export_poses \ 75 | --export_depth_images \ 76 | --export_color_images \ 77 | --export_intrinsics; 78 | ``` 79 | 80 | For train and val 81 | ``` 82 | python reader.py --scans_folder SCANNET_ROOT/scans \ 83 | --output_path OUTPUT_PATH/scans \ 84 | --scan_list_file splits/scannetv2_train.txt \ 85 | --num_workers 12 \ 86 | --export_poses \ 87 | --export_depth_images \ 88 | --export_color_images \ 89 | --export_intrinsics; 90 | 91 | python reader.py --scans_folder SCANNET_ROOT/scans \ 92 | --output_path OUTPUT_PATH/scans \ 93 | --scan_list_file splits/scannetv2_val.txt \ 94 | --num_workers 12 \ 95 | --export_poses \ 96 | --export_depth_images \ 97 | --export_color_images \ 98 | --export_intrinsics; 99 | ``` 100 | 101 | `OUTPUT_PATH` can be the same directory as the ScanNet root directory `SCANNET_ROOT`. 102 | 103 | For one scan use `--single_debug_scan_id`. 104 | 105 | For caching resized pngs for depth and color files, run: 106 | 107 | ``` 108 | python reader.py --scans_folder SCANNET_ROOT/scans \ 109 | --output_path OUTPUT_PATH/scans \ 110 | --scan_list_file splits/scannetv2_train.txt \ 111 | --num_workers 12 \ 112 | --export_depth_images \ 113 | --export_color_images \ 114 | --rgb_resize 512 384 \ 115 | --depth_resize 256 192; 116 | ``` 117 | 118 | and for images at `640x480`: 119 | 120 | ``` 121 | python reader.py --scans_folder SCANNET_ROOT/scans \ 122 | --output_path OUTPUT_PATH/scans \ 123 | --scan_list_file splits/scannetv2_train.txt \ 124 | --num_workers 12 \ 125 | --export_color_images \ 126 | --rgb_resize 640 480 \ 127 | ``` 128 | -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/SensorData.py: -------------------------------------------------------------------------------- 1 | 2 | import os, struct 3 | import numpy as np 4 | import zlib 5 | import imageio 6 | import cv2 7 | import png 8 | from PIL import Image 9 | from contextlib import contextmanager 10 | 11 | COMPRESSION_TYPE_COLOR = {-1:'unknown', 0:'raw', 1:'png', 2:'jpeg'} 12 | COMPRESSION_TYPE_DEPTH = {-1:'unknown', 0:'raw_ushort', 1:'zlib_ushort', 2:'occi_ushort'} 13 | 14 | 15 | @contextmanager 16 | def print_array_on_one_line(): 17 | oldoptions = np.get_printoptions() 18 | np.set_printoptions(linewidth=np.inf) 19 | np.set_printoptions(linewidth=np.inf) 20 | yield 21 | np.set_printoptions(**oldoptions) 22 | 23 | class RGBDFrame(): 24 | 25 | def load(self, file_handle): 26 | self.camera_to_world = np.asarray(struct.unpack('f'*16, file_handle.read(16*4)), dtype=np.float32).reshape(4, 4) 27 | self.timestamp_color = struct.unpack('Q', file_handle.read(8))[0] 28 | self.timestamp_depth = struct.unpack('Q', file_handle.read(8))[0] 29 | self.color_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 30 | self.depth_size_bytes = struct.unpack('Q', file_handle.read(8))[0] 31 | self.color_data = b''.join(struct.unpack('c'*self.color_size_bytes, file_handle.read(self.color_size_bytes))) 32 | self.depth_data = b''.join(struct.unpack('c'*self.depth_size_bytes, file_handle.read(self.depth_size_bytes))) 33 | 34 | 35 | def decompress_depth(self, compression_type): 36 | if compression_type == 'zlib_ushort': 37 | return self.decompress_depth_zlib() 38 | else: 39 | raise 40 | 41 | 42 | def decompress_depth_zlib(self): 43 | return zlib.decompress(self.depth_data) 44 | 45 | 46 | def decompress_color(self, compression_type): 47 | if compression_type == 'jpeg': 48 | return self.decompress_color_jpeg() 49 | else: 50 | raise 51 | 52 | def dump_color_to_file(self, compression_type, filepath): 53 | if compression_type == 'jpeg': 54 | filepath += ".jpg" 55 | else: 56 | raise 57 | f = open(filepath, "wb") 58 | f.write(self.color_data) 59 | f.close() 60 | 61 | def decompress_color_jpeg(self): 62 | return imageio.imread(self.color_data) 63 | 64 | 65 | class SensorData: 66 | 67 | def __init__(self, filename): 68 | self.version = 4 69 | self.load(filename) 70 | 71 | 72 | def load(self, filename): 73 | with open(filename, 'rb') as f: 74 | version = struct.unpack('I', f.read(4))[0] 75 | assert self.version == version 76 | strlen = struct.unpack('Q', f.read(8))[0] 77 | self.sensor_name = b''.join(struct.unpack('c'*strlen, f.read(strlen))) 78 | self.intrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 79 | self.extrinsic_color = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 80 | self.intrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 81 | self.extrinsic_depth = np.asarray(struct.unpack('f'*16, f.read(16*4)), dtype=np.float32).reshape(4, 4) 82 | self.color_compression_type = COMPRESSION_TYPE_COLOR[struct.unpack('i', f.read(4))[0]] 83 | self.depth_compression_type = COMPRESSION_TYPE_DEPTH[struct.unpack('i', f.read(4))[0]] 84 | self.color_width = struct.unpack('I', f.read(4))[0] 85 | self.color_height = struct.unpack('I', f.read(4))[0] 86 | self.depth_width = struct.unpack('I', f.read(4))[0] 87 | self.depth_height = struct.unpack('I', f.read(4))[0] 88 | self.depth_shift = struct.unpack('f', f.read(4))[0] 89 | self.num_frames = struct.unpack('Q', f.read(8))[0] 90 | self.frames = [] 91 | for i in range(self.num_frames): 92 | frame = RGBDFrame() 93 | frame.load(f) 94 | self.frames.append(frame) 95 | self.num_IMU_frames = struct.unpack('Q', f.read(8))[0] 96 | 97 | 98 | def export_depth_images(self, output_path, image_size=None, frame_skip=1): 99 | if not os.path.exists(output_path): 100 | os.makedirs(output_path) 101 | print('exporting', len(self.frames), frame_skip, ' depth frames to', output_path) 102 | for f in range(0, len(self.frames), frame_skip): 103 | depth_data = self.frames[f].decompress_depth(self.depth_compression_type) 104 | depth = np.fromstring(depth_data, dtype=np.uint16).reshape(self.depth_height, self.depth_width) 105 | 106 | if image_size is not None: 107 | depth = cv2.resize(depth, (image_size[0], image_size[1]), interpolation=cv2.INTER_NEAREST) 108 | filepath = os.path.join(output_path, f"frame-{f:06d}.depth.{int(image_size[0])}.png") 109 | else: 110 | filepath = os.path.join(output_path, f"frame-{f:06d}.depth.png") 111 | 112 | with open(filepath, 'wb') as f: # write 16-bit 113 | writer = png.Writer(width=depth.shape[1], height=depth.shape[0], bitdepth=16) 114 | depth = depth.reshape(-1, depth.shape[1]).tolist() 115 | writer.write(f, depth) 116 | 117 | def export_color_images(self, output_path, image_size=None, frame_skip=1): 118 | if not os.path.exists(output_path): 119 | os.makedirs(output_path) 120 | print('exporting', len(self.frames), frame_skip, 'color frames to', output_path) 121 | for f in range(0, len(self.frames), frame_skip): 122 | color = self.frames[f].decompress_color(self.color_compression_type) 123 | 124 | if image_size is not None: 125 | resized = Image.fromarray(color).resize((image_size[0], image_size[1]), resample=Image.BILINEAR) 126 | filepath = os.path.join(output_path, f"frame-{f:06d}.color.{int(image_size[0])}.png") 127 | resized.save(filepath) 128 | else: 129 | filepath = os.path.join(output_path, f"frame-{f:06d}.color") 130 | self.frames[f].dump_color_to_file(self.color_compression_type, filepath) 131 | 132 | 133 | def save_mat_to_file(self, matrix, filename): 134 | with open(filename, 'w') as f: 135 | for line in matrix: 136 | np.savetxt(f, line[np.newaxis], fmt='%f') 137 | 138 | 139 | def export_poses(self, output_path, frame_skip=1): 140 | if not os.path.exists(output_path): 141 | os.makedirs(output_path) 142 | print('exporting', len(self.frames), frame_skip, 'camera poses to', output_path) 143 | for f in range(0, len(self.frames), frame_skip): 144 | self.save_mat_to_file(self.frames[f].camera_to_world, os.path.join(output_path, f"frame-{f:06d}.pose.txt")) 145 | 146 | 147 | def export_intrinsics(self, output_path, scan_name): 148 | default_intrinsics_path = os.path.join(output_path, 'intrinsic') 149 | if not os.path.exists(default_intrinsics_path): 150 | os.makedirs(default_intrinsics_path) 151 | print('exporting camera intrinsics to', default_intrinsics_path) 152 | self.save_mat_to_file(self.intrinsic_color, os.path.join(default_intrinsics_path, 'intrinsic_color.txt')) 153 | self.save_mat_to_file(self.extrinsic_color, os.path.join(default_intrinsics_path, 'extrinsic_color.txt')) 154 | self.save_mat_to_file(self.intrinsic_depth, os.path.join(default_intrinsics_path, 'intrinsic_depth.txt')) 155 | self.save_mat_to_file(self.extrinsic_depth, os.path.join(default_intrinsics_path, 'extrinsic_depth.txt')) -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/download_scannet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Downloads ScanNet public data release 3 | # Run with ./download-scannet.py (or python download-scannet.py on Windows) 4 | # -*- coding: utf-8 -*- 5 | import argparse 6 | import os 7 | import urllib.request 8 | import tempfile 9 | 10 | import ssl 11 | ssl._create_default_https_context = ssl._create_unverified_context 12 | 13 | BASE_URL = 'http://kaldir.vc.in.tum.de/scannet/' 14 | TOS_URL = BASE_URL + 'ScanNet_TOS.pdf' 15 | FILETYPES = ['.aggregation.json', '.sens', '.txt', '_vh_clean.ply', '_vh_clean_2.0.010000.segs.json', '_vh_clean_2.ply', '_vh_clean.segs.json', '_vh_clean.aggregation.json', '_vh_clean_2.labels.ply', '_2d-instance.zip', '_2d-instance-filt.zip', '_2d-label.zip', '_2d-label-filt.zip'] 16 | FILETYPES_TEST = ['.sens', '.txt', '_vh_clean.ply', '_vh_clean_2.ply'] 17 | PREPROCESSED_FRAMES_FILE = ['scannet_frames_25k.zip', '5.6GB'] 18 | TEST_FRAMES_FILE = ['scannet_frames_test.zip', '610MB'] 19 | LABEL_MAP_FILES = ['scannetv2-labels.combined.tsv', 'scannet-labels.combined.tsv'] 20 | DATA_EFFICIENT_FILES = ['limited-reconstruction-scenes.zip', 'limited-annotation-points.zip', 'limited-bboxes.zip', '1.7MB'] 21 | GRIT_FILES = ['ScanNet-GRIT.zip'] 22 | RELEASES = ['v2/scans', 'v1/scans'] 23 | RELEASES_TASKS = ['v2/tasks', 'v1/tasks'] 24 | RELEASES_NAMES = ['v2', 'v1'] 25 | RELEASE = RELEASES[0] 26 | RELEASE_TASKS = RELEASES_TASKS[0] 27 | RELEASE_NAME = RELEASES_NAMES[0] 28 | LABEL_MAP_FILE = LABEL_MAP_FILES[0] 29 | RELEASE_SIZE = '1.2TB' 30 | V1_IDX = 1 31 | 32 | 33 | def get_release_scans(release_file): 34 | scan_lines = urllib.request.urlopen(release_file) 35 | scans = [] 36 | for scan_line in scan_lines: 37 | scan_id = scan_line.decode('utf8').rstrip('\n') 38 | scans.append(scan_id) 39 | return scans 40 | 41 | 42 | def download_release(release_scans, out_dir, file_types, use_v1_sens): 43 | if len(release_scans) == 0: 44 | return 45 | print('Downloading ScanNet ' + RELEASE_NAME + ' release to ' + out_dir + '...') 46 | for scan_id in release_scans: 47 | scan_out_dir = os.path.join(out_dir, scan_id) 48 | download_scan(scan_id, scan_out_dir, file_types, use_v1_sens) 49 | print('Downloaded ScanNet ' + RELEASE_NAME + ' release.') 50 | 51 | 52 | def download_file(url, out_file): 53 | out_dir = os.path.dirname(out_file) 54 | if not os.path.isdir(out_dir): 55 | os.makedirs(out_dir) 56 | if not os.path.isfile(out_file): 57 | print('\t' + url + ' > ' + out_file) 58 | fh, out_file_tmp = tempfile.mkstemp(dir=out_dir) 59 | f = os.fdopen(fh, 'w') 60 | f.close() 61 | urllib.request.urlretrieve(url, out_file_tmp) 62 | os.rename(out_file_tmp, out_file) 63 | else: 64 | print('WARNING: skipping download of existing file ' + out_file) 65 | 66 | def download_scan(scan_id, out_dir, file_types, use_v1_sens): 67 | print('Downloading ScanNet ' + RELEASE_NAME + ' scan ' + scan_id + ' ...') 68 | if not os.path.isdir(out_dir): 69 | os.makedirs(out_dir) 70 | for ft in file_types: 71 | v1_sens = use_v1_sens and ft == '.sens' 72 | url = BASE_URL + RELEASE + '/' + scan_id + '/' + scan_id + ft if not v1_sens else BASE_URL + RELEASES[V1_IDX] + '/' + scan_id + '/' + scan_id + ft 73 | out_file = out_dir + '/' + scan_id + ft 74 | download_file(url, out_file) 75 | print('Downloaded scan ' + scan_id) 76 | 77 | 78 | def download_task_data(out_dir): 79 | print('Downloading ScanNet v1 task data...') 80 | files = [ 81 | LABEL_MAP_FILES[V1_IDX], 'obj_classification/data.zip', 82 | 'obj_classification/trained_models.zip', 'voxel_labeling/data.zip', 83 | 'voxel_labeling/trained_models.zip' 84 | ] 85 | for file in files: 86 | url = BASE_URL + RELEASES_TASKS[V1_IDX] + '/' + file 87 | localpath = os.path.join(out_dir, file) 88 | localdir = os.path.dirname(localpath) 89 | if not os.path.isdir(localdir): 90 | os.makedirs(localdir) 91 | download_file(url, localpath) 92 | print('Downloaded task data.') 93 | 94 | def download_tfrecords(in_dir, out_dir): 95 | print('Downloading tf records (302 GB)...') 96 | if not os.path.exists(out_dir): 97 | os.makedirs(out_dir) 98 | split_to_num_shards = {'train': 100, 'val': 25, 'test': 10} 99 | 100 | for folder_name in ['hires_tfrecords', 'lores_tfrecords']: 101 | folder_dir = '%s/%s' % (in_dir, folder_name) 102 | save_dir = '%s/%s' % (out_dir, folder_name) 103 | if not os.path.exists(save_dir): 104 | os.makedirs(save_dir) 105 | for split, num_shards in split_to_num_shards.items(): 106 | for i in range(num_shards): 107 | file_name = '%s-%05d-of-%05d.tfrecords' % (split, i, num_shards) 108 | url = '%s/%s' % (folder_dir, file_name) 109 | localpath = '%s/%s/%s' % (out_dir, folder_name, file_name) 110 | download_file(url, localpath) 111 | 112 | def download_label_map(out_dir): 113 | print('Downloading ScanNet ' + RELEASE_NAME + ' label mapping file...') 114 | files = [ LABEL_MAP_FILE ] 115 | for file in files: 116 | url = BASE_URL + RELEASE_TASKS + '/' + file 117 | localpath = os.path.join(out_dir, file) 118 | localdir = os.path.dirname(localpath) 119 | if not os.path.isdir(localdir): 120 | os.makedirs(localdir) 121 | download_file(url, localpath) 122 | print('Downloaded ScanNet ' + RELEASE_NAME + ' label mapping file.') 123 | 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser(description='Downloads ScanNet public data release.') 127 | parser.add_argument('-o', '--out_dir', required=True, help='directory in which to download') 128 | parser.add_argument('--task_data', action='store_true', help='download task data (v1)') 129 | parser.add_argument('--label_map', action='store_true', help='download label map file') 130 | parser.add_argument('--v1', action='store_true', help='download ScanNet v1 instead of v2') 131 | parser.add_argument('--id', help='specific scan id to download') 132 | parser.add_argument('--preprocessed_frames', action='store_true', help='download preprocessed subset of ScanNet frames (' + PREPROCESSED_FRAMES_FILE[1] + ')') 133 | parser.add_argument('--test_frames_2d', action='store_true', help='download 2D test frames (' + TEST_FRAMES_FILE[1] + '; also included with whole dataset download)') 134 | parser.add_argument('--data_efficient', action='store_true', help='download data efficient task files; also included with whole dataset download)') 135 | parser.add_argument('--tf_semantic', action='store_true', help='download google tensorflow records for 3D segmentation / detection') 136 | parser.add_argument('--grit', action='store_true', help='download ScanNet files for General Robust Image Task') 137 | parser.add_argument('--type', help='specific file type to download (.aggregation.json, .sens, .txt, _vh_clean.ply, _vh_clean_2.0.010000.segs.json, _vh_clean_2.ply, _vh_clean.segs.json, _vh_clean.aggregation.json, _vh_clean_2.labels.ply, _2d-instance.zip, _2d-instance-filt.zip, _2d-label.zip, _2d-label-filt.zip)') 138 | args = parser.parse_args() 139 | 140 | print('By pressing any key to continue you confirm that you have agreed to the ScanNet terms of use as described at:') 141 | print(TOS_URL) 142 | print('***') 143 | print('Press any key to continue, or CTRL-C to exit.') 144 | key = input('') 145 | 146 | if args.v1: 147 | global RELEASE 148 | global RELEASE_TASKS 149 | global RELEASE_NAME 150 | global LABEL_MAP_FILE 151 | RELEASE = RELEASES[V1_IDX] 152 | RELEASE_TASKS = RELEASES_TASKS[V1_IDX] 153 | RELEASE_NAME = RELEASES_NAMES[V1_IDX] 154 | LABEL_MAP_FILE = LABEL_MAP_FILES[V1_IDX] 155 | assert((not args.tf_semantic) and (not args.grit)), "Task files specified invalid for v1" 156 | 157 | release_file = BASE_URL + RELEASE + '.txt' 158 | release_scans = get_release_scans(release_file) 159 | file_types = FILETYPES; 160 | release_test_file = BASE_URL + RELEASE + '_test.txt' 161 | release_test_scans = get_release_scans(release_test_file) 162 | file_types_test = FILETYPES_TEST; 163 | out_dir_scans = os.path.join(args.out_dir, 'scans') 164 | out_dir_test_scans = os.path.join(args.out_dir, 'scans_test') 165 | out_dir_tasks = os.path.join(args.out_dir, 'tasks') 166 | 167 | if args.type: # download file type 168 | file_type = args.type 169 | if file_type not in FILETYPES: 170 | print('ERROR: Invalid file type: ' + file_type) 171 | return 172 | file_types = [file_type] 173 | if file_type in FILETYPES_TEST: 174 | file_types_test = [file_type] 175 | else: 176 | file_types_test = [] 177 | if args.task_data: # download task data 178 | download_task_data(out_dir_tasks) 179 | elif args.label_map: # download label map file 180 | download_label_map(args.out_dir) 181 | elif args.preprocessed_frames: # download preprocessed scannet_frames_25k.zip file 182 | if args.v1: 183 | print('ERROR: Preprocessed frames only available for ScanNet v2') 184 | print('You are downloading the preprocessed subset of frames ' + PREPROCESSED_FRAMES_FILE[0] + ' which requires ' + PREPROCESSED_FRAMES_FILE[1] + ' of space.') 185 | download_file(os.path.join(BASE_URL, RELEASE_TASKS, PREPROCESSED_FRAMES_FILE[0]), os.path.join(out_dir_tasks, PREPROCESSED_FRAMES_FILE[0])) 186 | elif args.test_frames_2d: # download test scannet_frames_test.zip file 187 | if args.v1: 188 | print('ERROR: 2D test frames only available for ScanNet v2') 189 | print('You are downloading the 2D test set ' + TEST_FRAMES_FILE[0] + ' which requires ' + TEST_FRAMES_FILE[1] + ' of space.') 190 | download_file(os.path.join(BASE_URL, RELEASE_TASKS, TEST_FRAMES_FILE[0]), os.path.join(out_dir_tasks, TEST_FRAMES_FILE[0])) 191 | elif args.data_efficient: # download data efficient task files 192 | print('You are downloading the data efficient task files' + ' which requires ' + DATA_EFFICIENT_FILES[-1] + ' of space.') 193 | for k in range(len(DATA_EFFICIENT_FILES)-1): 194 | download_file(os.path.join(BASE_URL, RELEASE_TASKS, DATA_EFFICIENT_FILES[k]), os.path.join(out_dir_tasks, DATA_EFFICIENT_FILES[k])) 195 | elif args.tf_semantic: # download google tf records 196 | download_tfrecords(os.path.join(BASE_URL, RELEASE_TASKS, 'tf3d'), os.path.join(out_dir_tasks, 'tf3d')) 197 | elif args.grit: # download GRIT file 198 | download_file(os.path.join(BASE_URL, RELEASE_TASKS, GRIT_FILES[0]), os.path.join(out_dir_tasks, GRIT_FILES[0])) 199 | elif args.id: # download single scan 200 | scan_id = args.id 201 | is_test_scan = scan_id in release_test_scans 202 | if scan_id not in release_scans and (not is_test_scan or args.v1): 203 | print('ERROR: Invalid scan id: ' + scan_id) 204 | else: 205 | out_dir = os.path.join(out_dir_scans, scan_id) if not is_test_scan else os.path.join(out_dir_test_scans, scan_id) 206 | scan_file_types = file_types if not is_test_scan else file_types_test 207 | use_v1_sens = not is_test_scan 208 | if not is_test_scan and not args.v1 and '.sens' in scan_file_types: 209 | print('Note: ScanNet v2 uses the same .sens files as ScanNet v1: Press \'n\' to exclude downloading .sens files for each scan') 210 | key = input('') 211 | if key.strip().lower() == 'n': 212 | scan_file_types.remove('.sens') 213 | download_scan(scan_id, out_dir, scan_file_types, use_v1_sens) 214 | else: # download entire release 215 | if len(file_types) == len(FILETYPES): 216 | print('WARNING: You are downloading the entire ScanNet ' + RELEASE_NAME + ' release which requires ' + RELEASE_SIZE + ' of space.') 217 | else: 218 | print('WARNING: You are downloading all ScanNet ' + RELEASE_NAME + ' scans of type ' + file_types[0]) 219 | print('Note that existing scan directories will be skipped. Delete partially downloaded directories to re-download.') 220 | print('***') 221 | print('Press any key to continue, or CTRL-C to exit.') 222 | key = input('') 223 | if not args.v1 and '.sens' in file_types: 224 | print('Note: ScanNet v2 uses the same .sens files as ScanNet v1: Press \'n\' to exclude downloading .sens files for each scan') 225 | key = input('') 226 | if key.strip().lower() == 'n': 227 | file_types.remove('.sens') 228 | download_release(release_scans, out_dir_scans, file_types, use_v1_sens=True) 229 | if not args.v1: 230 | download_label_map(args.out_dir) 231 | download_release(release_test_scans, out_dir_test_scans, file_types_test, use_v1_sens=False) 232 | download_file(os.path.join(BASE_URL, RELEASE_TASKS, TEST_FRAMES_FILE[0]), os.path.join(out_dir_tasks, TEST_FRAMES_FILE[0])) 233 | for k in range(len(DATA_EFFICIENT_FILES)-1): 234 | download_file(os.path.join(BASE_URL, RELEASE_TASKS, DATA_EFFICIENT_FILES[k]), os.path.join(out_dir_tasks, DATA_EFFICIENT_FILES[k])) 235 | 236 | 237 | if __name__ == "__main__": main() -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/env.yml: -------------------------------------------------------------------------------- 1 | name: scannet_extraction 2 | dependencies: 3 | - python=3.9.7 4 | - numpy 5 | - imageio 6 | - pillow 7 | - tqdm 8 | - pip 9 | - pip: 10 | - opencv-python 11 | - pypng 12 | -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/reader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from concurrent.futures import process 3 | import os, sys 4 | from tqdm import tqdm 5 | from multiprocessing.pool import Pool 6 | from functools import partial 7 | from multiprocessing import Manager 8 | 9 | from SensorData import SensorData 10 | 11 | # params 12 | parser = argparse.ArgumentParser() 13 | # data paths 14 | parser.add_argument('--scans_folder', required=True, help='dataset root') 15 | parser.add_argument('--scan_list_file', required=False, default=None, help='scan list file') 16 | parser.add_argument('--single_debug_scan_id', required=False, default=None, help='single scan to debug') 17 | parser.add_argument('--output_path', required=True, help='path to output folder') 18 | parser.add_argument('--export_depth_images', dest='export_depth_images', action='store_true') 19 | parser.add_argument('--export_color_images', dest='export_color_images', action='store_true') 20 | parser.add_argument('--export_poses', dest='export_poses', action='store_true') 21 | parser.add_argument('--export_intrinsics', dest='export_intrinsics', action='store_true') 22 | parser.add_argument('--num_workers', type=int, default=1) 23 | parser.add_argument('--rgb_resize', nargs='+', type=int, default=None, help='width height') 24 | parser.add_argument('--depth_resize', nargs='+', type=int, default=None, help='width height') 25 | parser.set_defaults(export_depth_images=False, export_color_images=False, export_poses=False, export_intrinsics=False) 26 | 27 | opt = parser.parse_args() 28 | print(opt) 29 | 30 | def process_scan(opt, scan_job, count=None, progress=None): 31 | filename = scan_job[0] 32 | output_path = scan_job[1] 33 | scan_name = scan_job[2] 34 | 35 | if not os.path.exists(output_path): 36 | os.makedirs(output_path) 37 | # load the data 38 | sys.stdout.write('loading %s...' % opt.scans_folder) 39 | sd = SensorData(filename) 40 | sys.stdout.write('loaded!\n') 41 | if opt.export_depth_images: 42 | sd.export_depth_images(os.path.join(output_path, 'sensor_data'), image_size=opt.depth_resize) 43 | if opt.export_color_images: 44 | sd.export_color_images(os.path.join(output_path, 'sensor_data'), image_size=opt.rgb_resize) 45 | if opt.export_poses: 46 | sd.export_poses(os.path.join(output_path, 'sensor_data')) 47 | if opt.export_intrinsics: 48 | sd.export_intrinsics(output_path, scan_name) 49 | 50 | if progress is not None: 51 | progress.value += 1 52 | print(f"Completed scan {filename}, {progress.value} of total {count}.") 53 | 54 | def main(): 55 | 56 | 57 | if opt.single_debug_scan_id is not None: 58 | scans = [opt.single_debug_scan_id] 59 | else: 60 | f = open(opt.scan_list_file, "r") 61 | scans = f.readlines() 62 | scans = [scan.strip() for scan in scans] 63 | 64 | input_files = [os.path.join(opt.scans_folder, f"{scan}/{scan}.sens") for 65 | scan in scans] 66 | 67 | output_dirs = [os.path.join(opt.output_path, scan) for 68 | scan in scans] 69 | 70 | scan_jobs = list(zip(input_files, output_dirs, scans)) 71 | 72 | if opt.num_workers == 1: 73 | for scan_job in tqdm(scan_jobs): 74 | process_scan(opt, scan_job) 75 | else: 76 | 77 | pool = Pool(opt.num_workers) 78 | manager = Manager() 79 | 80 | count = len(scan_jobs) 81 | progress = manager.Value('i', 0) 82 | 83 | 84 | pool.map( 85 | partial( 86 | process_scan, 87 | opt, 88 | count=count, 89 | progress=progress 90 | ), 91 | scan_jobs, 92 | ) 93 | 94 | if __name__ == '__main__': 95 | main() -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/splits/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /data_scripts/scannet_wrangling_scripts/splits/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 313 | -------------------------------------------------------------------------------- /data_splits/7Scenes/dvmvs_split/dvmvs_test_split.txt: -------------------------------------------------------------------------------- 1 | redkitchen/seq-01 2 | redkitchen/seq-07 3 | chess/seq-01 4 | chess/seq-02 5 | heads/seq-02 6 | fire/seq-01 7 | fire/seq-02 8 | office/seq-01 9 | office/seq-03 10 | pumpkin/seq-03 11 | pumpkin/seq-06 12 | stairs/seq-02 13 | stairs/seq-06 14 | -------------------------------------------------------------------------------- /data_splits/ScanNetv2/dvmvs_split/dvmvs_val.txt: -------------------------------------------------------------------------------- 1 | scene0690_00 2 | scene0690_01 3 | scene0453_00 4 | scene0453_01 5 | scene0603_00 6 | scene0603_01 7 | scene0672_00 8 | scene0672_01 9 | scene0364_00 10 | scene0364_01 11 | scene0687_00 12 | scene0433_00 13 | scene0331_00 14 | scene0331_01 15 | scene0591_00 16 | scene0591_01 17 | scene0591_02 18 | scene0050_00 19 | scene0050_01 20 | scene0050_02 21 | scene0499_00 22 | scene0134_00 23 | scene0134_01 24 | scene0134_02 25 | scene0564_00 26 | scene0297_00 27 | scene0297_01 28 | scene0297_02 29 | scene0459_00 30 | scene0459_01 31 | scene0168_00 32 | scene0168_01 33 | scene0168_02 34 | scene0069_00 35 | scene0341_00 36 | scene0341_01 37 | scene0077_00 38 | scene0077_01 39 | scene0422_00 40 | scene0265_00 41 | scene0265_01 42 | scene0265_02 43 | scene0273_00 44 | scene0273_01 45 | scene0315_00 46 | scene0231_00 47 | scene0231_01 48 | scene0231_02 49 | scene0270_00 50 | scene0270_01 51 | scene0270_02 52 | scene0034_00 53 | scene0034_01 54 | scene0034_02 55 | scene0253_00 56 | scene0365_00 57 | scene0365_01 58 | scene0365_02 59 | scene0300_00 60 | scene0300_01 61 | scene0667_00 62 | scene0667_01 63 | scene0667_02 64 | scene0444_00 65 | scene0444_01 66 | scene0471_00 67 | scene0471_01 68 | scene0471_02 69 | scene0537_00 70 | scene0359_00 71 | scene0359_01 72 | scene0527_00 73 | scene0579_00 74 | scene0579_01 75 | scene0579_02 76 | scene0490_00 77 | scene0099_00 78 | scene0099_01 79 | scene0286_00 80 | scene0286_01 81 | scene0286_02 82 | scene0286_03 83 | scene0028_00 84 | scene0225_00 85 | scene0049_00 86 | scene0398_00 87 | scene0398_01 88 | scene0181_00 89 | scene0181_01 90 | scene0181_02 91 | scene0181_03 92 | scene0611_00 93 | scene0611_01 94 | scene0048_00 95 | scene0048_01 96 | scene0100_00 97 | scene0100_01 98 | scene0100_02 99 | scene0438_00 100 | scene0290_00 101 | scene0454_00 102 | scene0595_00 103 | scene0061_00 104 | scene0061_01 105 | scene0338_00 106 | scene0338_01 107 | scene0338_02 108 | scene0428_00 109 | scene0428_01 110 | scene0345_00 111 | scene0345_01 112 | scene0141_00 113 | scene0141_01 114 | scene0141_02 115 | scene0138_00 116 | scene0425_00 117 | scene0425_01 118 | scene0120_00 119 | scene0120_01 120 | scene0400_00 121 | scene0400_01 122 | scene0594_00 123 | -------------------------------------------------------------------------------- /data_splits/ScanNetv2/standard_split/scannetv2_test.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 101 | -------------------------------------------------------------------------------- /data_splits/ScanNetv2/standard_split/scannetv2_val.txt: -------------------------------------------------------------------------------- 1 | scene0568_00 2 | scene0568_01 3 | scene0568_02 4 | scene0304_00 5 | scene0488_00 6 | scene0488_01 7 | scene0412_00 8 | scene0412_01 9 | scene0217_00 10 | scene0019_00 11 | scene0019_01 12 | scene0414_00 13 | scene0575_00 14 | scene0575_01 15 | scene0575_02 16 | scene0426_00 17 | scene0426_01 18 | scene0426_02 19 | scene0426_03 20 | scene0549_00 21 | scene0549_01 22 | scene0578_00 23 | scene0578_01 24 | scene0578_02 25 | scene0665_00 26 | scene0665_01 27 | scene0050_00 28 | scene0050_01 29 | scene0050_02 30 | scene0257_00 31 | scene0025_00 32 | scene0025_01 33 | scene0025_02 34 | scene0583_00 35 | scene0583_01 36 | scene0583_02 37 | scene0701_00 38 | scene0701_01 39 | scene0701_02 40 | scene0580_00 41 | scene0580_01 42 | scene0565_00 43 | scene0169_00 44 | scene0169_01 45 | scene0655_00 46 | scene0655_01 47 | scene0655_02 48 | scene0063_00 49 | scene0221_00 50 | scene0221_01 51 | scene0591_00 52 | scene0591_01 53 | scene0591_02 54 | scene0678_00 55 | scene0678_01 56 | scene0678_02 57 | scene0462_00 58 | scene0427_00 59 | scene0595_00 60 | scene0193_00 61 | scene0193_01 62 | scene0164_00 63 | scene0164_01 64 | scene0164_02 65 | scene0164_03 66 | scene0598_00 67 | scene0598_01 68 | scene0598_02 69 | scene0599_00 70 | scene0599_01 71 | scene0599_02 72 | scene0328_00 73 | scene0300_00 74 | scene0300_01 75 | scene0354_00 76 | scene0458_00 77 | scene0458_01 78 | scene0423_00 79 | scene0423_01 80 | scene0423_02 81 | scene0307_00 82 | scene0307_01 83 | scene0307_02 84 | scene0606_00 85 | scene0606_01 86 | scene0606_02 87 | scene0432_00 88 | scene0432_01 89 | scene0608_00 90 | scene0608_01 91 | scene0608_02 92 | scene0651_00 93 | scene0651_01 94 | scene0651_02 95 | scene0430_00 96 | scene0430_01 97 | scene0689_00 98 | scene0357_00 99 | scene0357_01 100 | scene0574_00 101 | scene0574_01 102 | scene0574_02 103 | scene0329_00 104 | scene0329_01 105 | scene0329_02 106 | scene0153_00 107 | scene0153_01 108 | scene0616_00 109 | scene0616_01 110 | scene0671_00 111 | scene0671_01 112 | scene0618_00 113 | scene0382_00 114 | scene0382_01 115 | scene0490_00 116 | scene0621_00 117 | scene0607_00 118 | scene0607_01 119 | scene0149_00 120 | scene0695_00 121 | scene0695_01 122 | scene0695_02 123 | scene0695_03 124 | scene0389_00 125 | scene0377_00 126 | scene0377_01 127 | scene0377_02 128 | scene0342_00 129 | scene0139_00 130 | scene0629_00 131 | scene0629_01 132 | scene0629_02 133 | scene0496_00 134 | scene0633_00 135 | scene0633_01 136 | scene0518_00 137 | scene0652_00 138 | scene0406_00 139 | scene0406_01 140 | scene0406_02 141 | scene0144_00 142 | scene0144_01 143 | scene0494_00 144 | scene0278_00 145 | scene0278_01 146 | scene0316_00 147 | scene0609_00 148 | scene0609_01 149 | scene0609_02 150 | scene0609_03 151 | scene0084_00 152 | scene0084_01 153 | scene0084_02 154 | scene0696_00 155 | scene0696_01 156 | scene0696_02 157 | scene0351_00 158 | scene0351_01 159 | scene0643_00 160 | scene0644_00 161 | scene0645_00 162 | scene0645_01 163 | scene0645_02 164 | scene0081_00 165 | scene0081_01 166 | scene0081_02 167 | scene0647_00 168 | scene0647_01 169 | scene0535_00 170 | scene0353_00 171 | scene0353_01 172 | scene0353_02 173 | scene0559_00 174 | scene0559_01 175 | scene0559_02 176 | scene0593_00 177 | scene0593_01 178 | scene0246_00 179 | scene0653_00 180 | scene0653_01 181 | scene0064_00 182 | scene0064_01 183 | scene0356_00 184 | scene0356_01 185 | scene0356_02 186 | scene0030_00 187 | scene0030_01 188 | scene0030_02 189 | scene0222_00 190 | scene0222_01 191 | scene0338_00 192 | scene0338_01 193 | scene0338_02 194 | scene0378_00 195 | scene0378_01 196 | scene0378_02 197 | scene0660_00 198 | scene0553_00 199 | scene0553_01 200 | scene0553_02 201 | scene0527_00 202 | scene0663_00 203 | scene0663_01 204 | scene0663_02 205 | scene0664_00 206 | scene0664_01 207 | scene0664_02 208 | scene0334_00 209 | scene0334_01 210 | scene0334_02 211 | scene0046_00 212 | scene0046_01 213 | scene0046_02 214 | scene0203_00 215 | scene0203_01 216 | scene0203_02 217 | scene0088_00 218 | scene0088_01 219 | scene0088_02 220 | scene0088_03 221 | scene0086_00 222 | scene0086_01 223 | scene0086_02 224 | scene0670_00 225 | scene0670_01 226 | scene0256_00 227 | scene0256_01 228 | scene0256_02 229 | scene0249_00 230 | scene0441_00 231 | scene0658_00 232 | scene0704_00 233 | scene0704_01 234 | scene0187_00 235 | scene0187_01 236 | scene0131_00 237 | scene0131_01 238 | scene0131_02 239 | scene0207_00 240 | scene0207_01 241 | scene0207_02 242 | scene0461_00 243 | scene0011_00 244 | scene0011_01 245 | scene0343_00 246 | scene0251_00 247 | scene0077_00 248 | scene0077_01 249 | scene0684_00 250 | scene0684_01 251 | scene0550_00 252 | scene0686_00 253 | scene0686_01 254 | scene0686_02 255 | scene0208_00 256 | scene0500_00 257 | scene0500_01 258 | scene0552_00 259 | scene0552_01 260 | scene0648_00 261 | scene0648_01 262 | scene0435_00 263 | scene0435_01 264 | scene0435_02 265 | scene0435_03 266 | scene0690_00 267 | scene0690_01 268 | scene0693_00 269 | scene0693_01 270 | scene0693_02 271 | scene0700_00 272 | scene0700_01 273 | scene0700_02 274 | scene0699_00 275 | scene0231_00 276 | scene0231_01 277 | scene0231_02 278 | scene0697_00 279 | scene0697_01 280 | scene0697_02 281 | scene0697_03 282 | scene0474_00 283 | scene0474_01 284 | scene0474_02 285 | scene0474_03 286 | scene0474_04 287 | scene0474_05 288 | scene0355_00 289 | scene0355_01 290 | scene0146_00 291 | scene0146_01 292 | scene0146_02 293 | scene0196_00 294 | scene0702_00 295 | scene0702_01 296 | scene0702_02 297 | scene0314_00 298 | scene0277_00 299 | scene0277_01 300 | scene0277_02 301 | scene0095_00 302 | scene0095_01 303 | scene0015_00 304 | scene0100_00 305 | scene0100_01 306 | scene0100_02 307 | scene0558_00 308 | scene0558_01 309 | scene0558_02 310 | scene0685_00 311 | scene0685_01 312 | scene0685_02 313 | -------------------------------------------------------------------------------- /data_splits/arkit/scans.txt: -------------------------------------------------------------------------------- 1 | neucon_demodata_b5f1 -------------------------------------------------------------------------------- /data_splits/vdr/scans.txt: -------------------------------------------------------------------------------- 1 | living_room 2 | house -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | import torch 3 | import torch.jit as jit 4 | import torch.nn.functional as F 5 | from torch import Tensor, nn 6 | 7 | from utils.geometry_utils import (BackprojectDepth, Project3D) 8 | from utils.generic_utils import pyrdown 9 | 10 | 11 | class MSGradientLoss(nn.Module): 12 | def __init__(self, num_scales: int = 4): 13 | super().__init__() 14 | 15 | self.num_scales = num_scales 16 | 17 | def forward(self, depth_gt: Tensor, depth_pred: Tensor) -> Tensor: 18 | 19 | # Create the gradient pyramids 20 | depth_pred_pyr = pyrdown(depth_pred, self.num_scales) 21 | depth_gtn_pyr = pyrdown(depth_gt, self.num_scales) 22 | 23 | grad_loss = torch.tensor(0, dtype=depth_gt.dtype, device=depth_gt.device) 24 | for depth_pred_down, depth_gtn_down in zip(depth_pred_pyr, depth_gtn_pyr): 25 | 26 | depth_gtn_grad = kornia.filters.spatial_gradient(depth_gtn_down) 27 | 28 | mask_down_b = depth_gtn_grad.isfinite().all(dim=1, keepdim=True) 29 | 30 | depth_pred_grad = kornia.filters.spatial_gradient( 31 | depth_pred_down).masked_select(mask_down_b) 32 | 33 | grad_error = torch.abs(depth_pred_grad - 34 | depth_gtn_grad.masked_select(mask_down_b)) 35 | grad_loss += torch.mean(grad_error) 36 | 37 | return grad_loss 38 | 39 | class ScaleInvariantLoss(jit.ScriptModule): 40 | def __init__(self, si_lambda: float = 0.85): 41 | super().__init__() 42 | 43 | self.si_lambda = si_lambda 44 | 45 | @jit.script_method 46 | def forward(self, log_depth_gt: Tensor, log_depth_pred: Tensor) -> Tensor: 47 | 48 | # Scale invariant loss from Eigen, implementation is from AdaBins 49 | log_diff = log_depth_gt - log_depth_pred 50 | si_loss = torch.sqrt( 51 | (log_diff ** 2).mean() - self.si_lambda * (log_diff.mean() ** 2) 52 | ) 53 | 54 | return si_loss 55 | 56 | 57 | class NormalsLoss(nn.Module): 58 | def forward(self, normals_gt_b3hw: Tensor, normals_pred_b3hw: Tensor) -> Tensor: 59 | 60 | normals_mask_b1hw = torch.logical_and( 61 | normals_gt_b3hw.isfinite().all(dim=1, keepdim=True), 62 | normals_pred_b3hw.isfinite().all(dim=1, keepdim=True)) 63 | 64 | normals_pred_b3hw = normals_pred_b3hw.masked_fill(~normals_mask_b1hw, 1.0) 65 | normals_gt_b3hw = normals_gt_b3hw.masked_fill(~normals_mask_b1hw, 1.0) 66 | 67 | with torch.cuda.amp.autocast(enabled=False): 68 | normals_dot_b1hw = 0.5 * ( 69 | 1.0 - torch.einsum( 70 | "bchw, bchw -> bhw", 71 | normals_pred_b3hw, 72 | normals_gt_b3hw, 73 | ) 74 | ).unsqueeze(1) 75 | normals_loss = normals_dot_b1hw.masked_select(normals_mask_b1hw).mean() 76 | 77 | return normals_loss 78 | 79 | class MVDepthLoss(nn.Module): 80 | def __init__(self, height, width): 81 | super().__init__() 82 | 83 | self.height = height 84 | self.width = width 85 | 86 | self.backproject = BackprojectDepth(self.height, self.width) 87 | self.project = Project3D() 88 | 89 | 90 | def get_valid_mask( 91 | self, 92 | cur_depth_b1hw, 93 | src_depth_b1hw, 94 | cur_invK_b44, 95 | src_K_b44, 96 | cur_world_T_cam_b44, 97 | src_cam_T_world_b44, 98 | ): 99 | 100 | depth_height, depth_width = cur_depth_b1hw.shape[2:] 101 | 102 | cur_cam_points_b4N = self.backproject(cur_depth_b1hw, cur_invK_b44) 103 | world_points_b4N = cur_world_T_cam_b44 @ cur_cam_points_b4N 104 | 105 | # Compute valid mask 106 | src_cam_points_b3N = self.project(world_points_b4N, src_K_b44, src_cam_T_world_b44) 107 | 108 | cam_points_b3hw = src_cam_points_b3N.view(-1, 3, depth_height, depth_width) 109 | pix_coords_b2hw = cam_points_b3hw[:, :2] 110 | proj_src_depths_b1hw = cam_points_b3hw[:, 2:] 111 | 112 | uv_coords = (pix_coords_b2hw.permute(0, 2, 3, 1) 113 | / torch.tensor( 114 | [depth_width, depth_height] 115 | ).view(1, 1, 1, 2).type_as(pix_coords_b2hw) 116 | ) 117 | uv_coords = 2 * uv_coords - 1 118 | 119 | src_depth_sampled_b1hw = F.grid_sample( 120 | input=src_depth_b1hw, 121 | grid=uv_coords, 122 | padding_mode='zeros', 123 | mode='nearest', 124 | align_corners=False, 125 | ) 126 | 127 | valid_mask_b1hw = proj_src_depths_b1hw < 1.05 * src_depth_sampled_b1hw 128 | valid_mask_b1hw = torch.logical_and(valid_mask_b1hw, 129 | proj_src_depths_b1hw > 0) 130 | valid_mask_b1hw = torch.logical_and(valid_mask_b1hw, 131 | src_depth_sampled_b1hw > 0) 132 | 133 | return valid_mask_b1hw, src_depth_sampled_b1hw 134 | 135 | 136 | def get_error_for_pair(self, 137 | depth_pred_b1hw, 138 | cur_depth_b1hw, 139 | src_depth_b1hw, 140 | cur_invK_b44, 141 | src_K_b44, 142 | cur_world_T_cam_b44, 143 | src_cam_T_world_b44): 144 | 145 | depth_height, depth_width = cur_depth_b1hw.shape[2:] 146 | 147 | valid_mask_b1hw, src_depth_sampled_b1hw = self.get_valid_mask( 148 | cur_depth_b1hw, 149 | src_depth_b1hw, 150 | cur_invK_b44, 151 | src_K_b44, 152 | cur_world_T_cam_b44, 153 | src_cam_T_world_b44, 154 | ) 155 | 156 | pred_cam_points_b4N = self.backproject(depth_pred_b1hw, cur_invK_b44) 157 | pred_world_points_b4N = cur_world_T_cam_b44 @ pred_cam_points_b4N 158 | 159 | src_cam_points_b3N = self.project(pred_world_points_b4N, src_K_b44, 160 | src_cam_T_world_b44) 161 | 162 | pred_cam_points_b3hw = src_cam_points_b3N.view(-1, 3, depth_height, 163 | depth_width) 164 | pred_src_depths_b1hw = pred_cam_points_b3hw[:, 2:] 165 | 166 | depth_diff_b1hw = torch.abs( 167 | torch.log(src_depth_sampled_b1hw) - 168 | torch.log(pred_src_depths_b1hw) 169 | ).masked_select(valid_mask_b1hw) 170 | 171 | depth_loss = depth_diff_b1hw.nanmean() 172 | 173 | return depth_loss 174 | 175 | def forward( 176 | self, 177 | depth_pred_b1hw, 178 | cur_depth_b1hw, 179 | src_depth_bk1hw, 180 | cur_invK_b44, 181 | src_K_bk44, 182 | cur_world_T_cam_b44, 183 | src_cam_T_world_bk44, 184 | ): 185 | 186 | src_to_iterate = [ 187 | torch.unbind(src_depth_bk1hw, dim=1), 188 | torch.unbind(src_K_bk44, dim=1), 189 | torch.unbind(src_cam_T_world_bk44, dim=1) 190 | ] 191 | 192 | num_src_frames = src_depth_bk1hw.shape[1] 193 | 194 | loss = 0 195 | for src_depth_b1hw, src_K_b44, src_cam_T_world_b44 in zip(*src_to_iterate): 196 | 197 | error = self.get_error_for_pair( 198 | depth_pred_b1hw, 199 | cur_depth_b1hw, 200 | src_depth_b1hw, 201 | cur_invK_b44, 202 | src_K_b44, 203 | cur_world_T_cam_b44, 204 | src_cam_T_world_b44, 205 | ) 206 | loss += error 207 | 208 | return loss / num_src_frames -------------------------------------------------------------------------------- /media/arkit_ioslogger_snapshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nianticlabs/simplerecon/113b825084cc62166c6f5584d526650928127317/media/arkit_ioslogger_snapshot.png -------------------------------------------------------------------------------- /media/teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nianticlabs/simplerecon/113b825084cc62166c6f5584d526650928127317/media/teaser.jpeg -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | 3 | import torch.nn as nn 4 | from torch import Tensor 5 | 6 | 7 | def conv3x3( 8 | in_planes: int, 9 | out_planes: int, 10 | stride: int = 1, 11 | groups: int = 1, 12 | dilation: int = 1, 13 | bias: bool = False 14 | ) -> nn.Conv2d: 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=dilation, groups=groups, bias=bias, dilation=dilation) 18 | 19 | 20 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d: 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias) 23 | 24 | class BasicBlock(nn.Module): 25 | expansion: int = 1 26 | 27 | def __init__( 28 | self, 29 | inplanes: int, 30 | planes: int, 31 | stride: int = 1, 32 | groups: int = 1, 33 | base_width: int = 64, 34 | dilation: int = 1, 35 | norm_layer: Optional[Callable[..., nn.Module]] = nn.Identity, 36 | ) -> None: 37 | super(BasicBlock, self).__init__() 38 | if norm_layer is None: 39 | norm_layer = nn.BatchNorm2d 40 | if groups != 1 or base_width != 64: 41 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 42 | if dilation > 1: 43 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 44 | if norm_layer == nn.Identity: 45 | bias = True 46 | else: 47 | bias = False 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride, bias=bias) 50 | self.bn1 = norm_layer(planes) 51 | # self.relu = nn.ReLU(inplace=True) 52 | self.relu = nn.LeakyReLU(0.2, inplace=True) 53 | # self.relu = nn.ReLU6(True) 54 | # self.relu = nn.SiLU(inplace=True) 55 | # self.relu = nn.ELU(inplace=True) 56 | self.conv2 = conv3x3(planes, planes, bias=bias) 57 | self.bn2 = norm_layer(planes) 58 | if inplanes == planes * self.expansion and stride == 1: 59 | self.downsample = None 60 | else: 61 | conv = conv1x1 if stride == 1 else conv3x3 62 | self.downsample = nn.Sequential( 63 | conv(inplanes, planes * self.expansion, bias=bias, stride=stride), 64 | norm_layer(planes * self.expansion) 65 | ) 66 | self.stride = stride 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | identity = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | # out = self.se(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | class TensorFormatter(nn.Module): 88 | """Helper to format, apply operation, format back tensor. 89 | 90 | Class to format tensors of shape B x D x C_i x H x W into B*D x C_i x H x W, 91 | apply an operation, and reshape back into B x D x C_o x H x W. 92 | 93 | Used for multidepth - batching feature extraction on source images""" 94 | 95 | def __init__(self): 96 | super().__init__() 97 | 98 | self.batch_size = None 99 | self.depth_chns = None 100 | 101 | def _expand_batch_with_channels(self, x): 102 | if x.dim() != 5: 103 | raise ValueError('TensorFormatter expects tensors with 5 dimensions, ' 104 | 'not {}!'.format(len(x.shape))) 105 | self.batch_size, self.depth_chns, chns, height, width = x.shape 106 | x = x.view(self.batch_size * self.depth_chns, chns, height, width) 107 | return x 108 | 109 | def _reduce_batch_to_channels(self, x): 110 | if self.batch_size is None or self.depth_chns is None: 111 | raise ValueError('Cannot call _reduce_batch_to_channels without first calling' 112 | '_expand_batch_with_channels!') 113 | _, chns, height, width = x.shape 114 | x = x.view(self.batch_size, self.depth_chns, chns, height, width) 115 | return x 116 | 117 | def forward(self, x, apply_func): 118 | x = self._expand_batch_with_channels(x) 119 | x = apply_func(x) 120 | x = self._reduce_batch_to_channels(x) 121 | return x 122 | -------------------------------------------------------------------------------- /modules/networks.py: -------------------------------------------------------------------------------- 1 | import antialiased_cnns 2 | from torchvision import models 3 | import numpy as np 4 | import timm 5 | import torch 6 | from torch import nn 7 | from torchvision.ops import FeaturePyramidNetwork 8 | 9 | from modules.layers import BasicBlock 10 | from utils.generic_utils import upsample 11 | 12 | 13 | def double_basic_block(num_ch_in, num_ch_out, num_repeats=2): 14 | layers = nn.Sequential(BasicBlock(num_ch_in, num_ch_out)) 15 | for i in range(num_repeats - 1): 16 | layers.add_module(f"conv_{i}", BasicBlock(num_ch_out, num_ch_out)) 17 | return layers 18 | 19 | 20 | class DepthDecoderPP(nn.Module): 21 | def __init__( 22 | self, 23 | num_ch_enc, 24 | scales=range(4), 25 | num_output_channels=1, 26 | use_skips=True 27 | ): 28 | super().__init__() 29 | 30 | self.num_output_channels = num_output_channels 31 | self.use_skips = use_skips 32 | self.upsample_mode = 'nearest' 33 | self.scales = scales 34 | 35 | self.num_ch_enc = num_ch_enc 36 | self.num_ch_dec = np.array([64, 64, 128, 256]) 37 | 38 | # decoder 39 | self.convs = nn.ModuleDict() 40 | # i is encoder depth (top to bottom) 41 | # j is decoder depth (left to right) 42 | for j in range(1, 5): 43 | max_i = 4 - j 44 | for i in range(max_i, -1, -1): 45 | 46 | num_ch_out = self.num_ch_dec[i] 47 | total_num_ch_in = 0 48 | 49 | num_ch_in = self.num_ch_enc[i + 1] if j == 1 else self.num_ch_dec[i + 1] 50 | self.convs[f"diag_conv_{i + 1}{j - 1}"] = BasicBlock(num_ch_in, 51 | num_ch_out) 52 | total_num_ch_in += num_ch_out 53 | 54 | num_ch_in = self.num_ch_enc[i] if j == 1 else self.num_ch_dec[i] 55 | self.convs[f"right_conv_{i}{j - 1}"] = BasicBlock(num_ch_in, 56 | num_ch_out) 57 | total_num_ch_in += num_ch_out 58 | 59 | if i + j != 4: 60 | num_ch_in = self.num_ch_dec[i + 1] 61 | self.convs[f"up_conv_{i + 1}{j}"] = BasicBlock(num_ch_in, 62 | num_ch_out) 63 | total_num_ch_in += num_ch_out 64 | 65 | self.convs[f"in_conv_{i}{j}"] = double_basic_block( 66 | total_num_ch_in, 67 | num_ch_out, 68 | ) 69 | 70 | self.convs[f"output_{i}"] = nn.Sequential( 71 | BasicBlock(num_ch_out, num_ch_out) if i != 0 else nn.Identity(), 72 | nn.Conv2d(num_ch_out, self.num_output_channels, 1), 73 | ) 74 | 75 | def forward(self, input_features): 76 | prev_outputs = input_features 77 | outputs = [] 78 | depth_outputs = {} 79 | for j in range(1, 5): 80 | max_i = 4 - j 81 | for i in range(max_i, -1, -1): 82 | 83 | inputs = [self.convs[f"right_conv_{i}{j - 1}"](prev_outputs[i])] 84 | inputs += [upsample(self.convs[f"diag_conv_{i + 1}{j - 1}"](prev_outputs[i + 1]))] 85 | 86 | if i + j != 4: 87 | inputs += [upsample(self.convs[f"up_conv_{i + 1}{j}"](outputs[-1]))] 88 | 89 | output = self.convs[f"in_conv_{i}{j}"](torch.cat(inputs, dim=1)) 90 | outputs += [output] 91 | 92 | depth_outputs[f"log_depth_pred_s{i}_b1hw"] = self.convs[f"output_{i}"](output) 93 | 94 | prev_outputs = outputs[::-1] 95 | 96 | return depth_outputs 97 | 98 | 99 | class CVEncoder(nn.Module): 100 | def __init__(self, num_ch_cv, num_ch_enc, num_ch_outs): 101 | super().__init__() 102 | 103 | self.convs = nn.ModuleDict() 104 | self.num_ch_enc = [] 105 | 106 | self.num_blocks = len(num_ch_outs) 107 | 108 | for i in range(self.num_blocks): 109 | num_ch_in = num_ch_cv if i == 0 else num_ch_outs[i - 1] 110 | num_ch_out = num_ch_outs[i] 111 | self.convs[f"ds_conv_{i}"] = BasicBlock(num_ch_in, num_ch_out, 112 | stride=1 if i == 0 else 2) 113 | 114 | self.convs[f"conv_{i}"] = nn.Sequential( 115 | BasicBlock(num_ch_enc[i] + num_ch_out, num_ch_out, stride=1), 116 | BasicBlock(num_ch_out, num_ch_out, stride=1), 117 | ) 118 | self.num_ch_enc.append(num_ch_out) 119 | 120 | def forward(self, x, img_feats): 121 | outputs = [] 122 | for i in range(self.num_blocks): 123 | x = self.convs[f"ds_conv_{i}"](x) 124 | x = torch.cat([x, img_feats[i]], dim=1) 125 | x = self.convs[f"conv_{i}"](x) 126 | outputs.append(x) 127 | return outputs 128 | 129 | class MLP(nn.Module): 130 | def __init__(self, channel_list, disable_final_activation = False): 131 | super(MLP, self).__init__() 132 | 133 | layer_list = [] 134 | for layer_index in list(range(len(channel_list)))[:-1]: 135 | layer_list.append( 136 | nn.Linear(channel_list[layer_index], 137 | channel_list[layer_index+1]) 138 | ) 139 | layer_list.append(nn.LeakyReLU(inplace=True)) 140 | 141 | if disable_final_activation: 142 | layer_list = layer_list[:-1] 143 | 144 | self.net = nn.Sequential(*layer_list) 145 | 146 | def forward(self, x): 147 | return self.net(x) 148 | 149 | class ResnetMatchingEncoder(nn.Module): 150 | """Pytorch module for a resnet encoder 151 | """ 152 | def __init__( 153 | self, 154 | num_layers, 155 | num_ch_out, 156 | pretrained=True, 157 | antialiased=True, 158 | ): 159 | super().__init__() 160 | 161 | self.num_ch_enc = np.array([64, 64]) 162 | 163 | model_source = antialiased_cnns if antialiased else models 164 | resnets = {18: model_source.resnet18, 165 | 34: model_source.resnet34, 166 | 50: model_source.resnet50, 167 | 101: model_source.resnet101, 168 | 152: model_source.resnet152} 169 | 170 | if num_layers not in resnets: 171 | raise ValueError("{} is not a valid number of resnet layers" 172 | .format(num_layers)) 173 | 174 | encoder = resnets[num_layers](pretrained) 175 | 176 | resnet_backbone = [ 177 | encoder.conv1, 178 | encoder.bn1, 179 | encoder.relu, 180 | encoder.maxpool, 181 | encoder.layer1, 182 | ] 183 | 184 | if num_layers > 34: 185 | self.num_ch_enc[1:] *= 4 186 | 187 | self.num_ch_out = num_ch_out 188 | 189 | self.net = nn.Sequential( 190 | *resnet_backbone, 191 | nn.Conv2d(self.num_ch_enc[-1], 128, (1, 1)), 192 | nn.InstanceNorm2d(128), 193 | nn.LeakyReLU(0.2, True), 194 | nn.Conv2d( 195 | 128, 196 | self.num_ch_out, 197 | (3, 3), 198 | padding=1, 199 | padding_mode="replicate" 200 | ), 201 | nn.InstanceNorm2d(self.num_ch_out) 202 | ) 203 | 204 | def forward(self, input_image): 205 | return self.net(input_image) 206 | 207 | class UNetMatchingEncoder(nn.Module): 208 | def __init__(self): 209 | super().__init__() 210 | self.encoder = timm.create_model( 211 | "mnasnet_100", 212 | pretrained=True, 213 | features_only=True, 214 | ) 215 | 216 | self.decoder = FeaturePyramidNetwork( 217 | self.encoder.feature_info.channels(), 218 | out_channels=32, 219 | ) 220 | self.outconv = nn.Sequential( 221 | nn.LeakyReLU(0.2, True), 222 | nn.Conv2d(32, 16, 1), 223 | nn.InstanceNorm2d(16), 224 | ) 225 | 226 | def forward(self, x): 227 | encoder_feats = {f"feat_{i}": f for i, f in enumerate(self.encoder(x))} 228 | return self.outconv(self.decoder(encoder_feats)["feat_1"]) 229 | -------------------------------------------------------------------------------- /pc_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fuses depth maps into point clouds using the PC fuser from 3 | https://github.com/alexrich021/3dvnet/blob/main/mv3d/eval/pointcloudfusion_custom.py 4 | 5 | This script follows the format in test.py. It expects a model to use for 6 | depth prediction. 7 | 8 | Example command: 9 | python pc_fusion.py --name HERO_MODEL \ 10 | --output_base_path OUTPUT_PATH \ 11 | --config_file models/hero_model.yaml \ 12 | --load_weights_from_checkpoint models/hero_model.ckpt \ 13 | --data_config configs/data/scannet_default_test.yaml \ 14 | --num_workers 8 \ 15 | --batch_size 8; 16 | """ 17 | 18 | import os 19 | from pathlib import Path 20 | 21 | import open3d as o3d 22 | import torch 23 | import torch.nn.functional as F 24 | from tqdm import tqdm 25 | 26 | from experiment_modules.depth_model import DepthModel 27 | import options 28 | import tools.torch_point_cloud_fusion as torch_point_cloud_fusion 29 | from utils.dataset_utils import get_dataset 30 | from utils.generic_utils import (to_gpu, reverse_imagenet_normalize) 31 | 32 | import modules.cost_volume as cost_volume 33 | 34 | def main(opts): 35 | 36 | # get dataset 37 | dataset_class, scans = get_dataset(opts.dataset, 38 | opts.dataset_scan_split_file, opts.single_debug_scan_id) 39 | 40 | # fusion params 41 | N_CONSISTENT_THRESH = opts.n_consistent_thresh 42 | Z_THRESH = opts.pc_fusion_z_thresh 43 | VOXEL_DOWNSAMPLE = opts.voxel_downsample 44 | 45 | # output location 46 | pc_output_folder_name = f"{N_CONSISTENT_THRESH}_{Z_THRESH}_{VOXEL_DOWNSAMPLE}_{opts.fusion_max_depth}" 47 | 48 | # path where results for this model, dataset, and tuple type are. 49 | results_path = os.path.join(opts.output_base_path, opts.name, 50 | opts.dataset, opts.frame_tuple_type) 51 | 52 | # ouput path 53 | pcs_output_dir = os.path.join(results_path, "pcs", pc_output_folder_name) 54 | 55 | Path(os.path.join(pcs_output_dir)).mkdir(parents=True, exist_ok=True) 56 | print(f"".center(80, "#")) 57 | print(f" Running PC Fusion!".center(80, "#")) 58 | print(f"Output directory:\n{pcs_output_dir} ".center(80, "#")) 59 | print(f"".center(80, "#")) 60 | print("") 61 | 62 | # load model 63 | model = DepthModel.load_from_checkpoint( 64 | opts.load_weights_from_checkpoint, 65 | args=None) 66 | if (opts.fast_cost_volume and 67 | isinstance(model.cost_volume, cost_volume.FeatureVolumeManager)): 68 | model.cost_volume = model.cost_volume.to_fast() 69 | model = model.cuda().eval() 70 | 71 | with torch.inference_mode(): 72 | for scan in tqdm(scans): 73 | 74 | # set up dataset with current scan 75 | dataset = dataset_class( 76 | opts.dataset_path, 77 | split=opts.split, 78 | mv_tuple_file_suffix=opts.mv_tuple_file_suffix, 79 | limit_to_scan_id=scan, 80 | include_full_res_depth=True, 81 | tuple_info_file_location=opts.tuple_info_file_location, 82 | num_images_in_tuple=None, 83 | shuffle_tuple=opts.shuffle_tuple, 84 | include_high_res_color=True, 85 | include_full_depth_K=True, 86 | skip_frames=opts.skip_frames, 87 | skip_to_frame=opts.skip_to_frame, 88 | image_width=opts.image_width, 89 | image_height=opts.image_height, 90 | pass_frame_id=True, 91 | ) 92 | 93 | dataloader = torch.utils.data.DataLoader( 94 | dataset, 95 | batch_size=opts.batch_size, 96 | shuffle=False, 97 | num_workers=opts.num_workers, 98 | drop_last=False, 99 | ) 100 | 101 | # loop and collate data 102 | images_list = [] 103 | depths_list = [] 104 | poses_list = [] 105 | K_list = [] 106 | 107 | for _, batch in enumerate(tqdm(dataloader)): 108 | 109 | # get data, move to GPU 110 | cur_data, src_data = batch 111 | 112 | cur_data = to_gpu(cur_data, key_ignores=["frame_id_string"]) 113 | src_data = to_gpu(src_data, key_ignores=["frame_id_string"]) 114 | 115 | outputs = model( 116 | "test", 117 | cur_data, src_data, 118 | unbatched_matching_encoder_forward=True, 119 | return_mask=True, 120 | ) 121 | 122 | depth_pred_s0_b1hw = outputs["depth_pred_s0_b1hw"].cuda() 123 | 124 | depth_pred_s0_b1hw[depth_pred_s0_b1hw > 125 | opts.fusion_max_depth] = 0 126 | 127 | upsampled_depth_pred = F.interpolate( 128 | depth_pred_s0_b1hw, 129 | size=(480, 640), 130 | mode="nearest", 131 | ) 132 | 133 | depths_list.append(upsampled_depth_pred) 134 | 135 | poses_list.append(cur_data["cam_T_world_b44"].clone()) 136 | 137 | K_33 = cur_data["K_s0_b44"].clone() 138 | K_33[:,0] *= (640/depth_pred_s0_b1hw.shape[-1]) 139 | K_33[:,1] *= (480/depth_pred_s0_b1hw.shape[-2]) 140 | 141 | K_list.append(K_33.clone()) 142 | 143 | cur_data["high_res_color_b3hw"] = F.interpolate( 144 | cur_data["high_res_color_b3hw"], 145 | size=(480, 640), 146 | mode="bilinear", 147 | ) 148 | image = cur_data["high_res_color_b3hw"].cuda() 149 | image = reverse_imagenet_normalize(image) 150 | images_list.append(image) 151 | 152 | # pass data to pc fuser 153 | depths_preds_bhw = torch.cat(depths_list, dim=0).squeeze(1) 154 | poses_b44 = torch.cat(poses_list, dim=0) 155 | image_bhw3 = torch.cat(images_list, dim=0).permute(0,2,3,1)*255 156 | K_b33 = torch.cat(K_list, dim=0)[:,:3,:3] 157 | 158 | fused_pts, fused_rgb, _ = torch_point_cloud_fusion.process_scene( 159 | depths_preds_bhw, 160 | image_bhw3.to(torch.uint8), 161 | poses_b44, 162 | K_b33, 163 | Z_THRESH, 164 | N_CONSISTENT_THRESH, 165 | ) 166 | pcd_pred = o3d.geometry.PointCloud() 167 | pcd_pred.points = o3d.utility.Vector3dVector(fused_pts) 168 | pcd_pred.colors = o3d.utility.Vector3dVector(fused_rgb / 255.) 169 | pcd_pred = pcd_pred.voxel_down_sample(VOXEL_DOWNSAMPLE) 170 | 171 | pcd_filepath = os.path.join(pcs_output_dir, f"{scan}.ply") 172 | o3d.io.write_point_cloud(pcd_filepath, pcd_pred) 173 | 174 | if __name__ == '__main__': 175 | # don't need grad for test. 176 | torch.set_grad_enabled(False) 177 | 178 | # get an instance of options and load it with config file(s) and cli args. 179 | option_handler = options.OptionsHandler() 180 | option_handler.parse_and_merge_options() 181 | option_handler.pretty_print_options() 182 | print("\n") 183 | opts = option_handler.options 184 | 185 | # if no GPUs are available for us then, use the 32 bit on CPU 186 | if opts.gpus == 0: 187 | print("Setting precision to 32 bits since --gpus is set to 0.") 188 | opts.precision = 32 189 | 190 | main(opts) 191 | -------------------------------------------------------------------------------- /simplerecon_env.yml: -------------------------------------------------------------------------------- 1 | name: simplerecon 2 | channels: 3 | - default 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - clang=15.0.7 8 | - llvm-openmp=15.0.7 9 | - python=3.9.7 10 | - pytorch=1.10.0 11 | - torchvision=0.11.1 12 | - cudatoolkit=11.3 13 | - pytorch-lightning=1.5.4 # training utils 14 | - pillow # image op 15 | - tensorboard # logging 16 | - matplotlib # plotting 17 | - pip 18 | - pip: 19 | - kornia==0.6.7 # gradients 20 | - antialiased-cnns # anti aliased resnet 21 | - efficientnet_pytorch 22 | - timm # efficent 23 | - trimesh # mesh loading/storage, and mesh generation 24 | - transforms3d # for NeuralRecon's arkit 25 | - einops # batching one liners 26 | - moviepy # storing videos 27 | - pyrender # rendering meshes 28 | - open3d==0.14.1 # mesh fusion 29 | - scipy # transformations and a few others 30 | - protobuf<4.21.0 # lighting/tensorboard fix 31 | - setuptools==59.5.0 # fix for tensorboard 32 | - https://github.com/JamieWatson683/scikit-image/archive/single_mesh.zip # single mesh exporting for measure.marching_cubes 33 | -------------------------------------------------------------------------------- /tools/fusers_helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import torch 4 | import trimesh 5 | from datasets.scannet_dataset import ScannetDataset 6 | from utils.generic_utils import reverse_imagenet_normalize 7 | 8 | from tools.tsdf import TSDF, TSDFFuser 9 | 10 | 11 | class DepthFuser(): 12 | def __init__( 13 | self, 14 | gt_path="", 15 | fusion_resolution=0.04, 16 | max_fusion_depth=3.0, 17 | fuse_color=False 18 | ): 19 | self.fusion_resolution = fusion_resolution 20 | self.max_fusion_depth = max_fusion_depth 21 | 22 | class OurFuser(DepthFuser): 23 | """ 24 | This is the fuser used for scores in the SimpleRecon paper. Note that 25 | unlike open3d's fuser this implementation does not do voxel hashing. If a 26 | path to a known mehs reconstruction is provided, this function will limit 27 | bounds to that mesh's extent, otherwise it'll use a wide volume to prevent 28 | clipping. 29 | 30 | It's best not to use this fuser unless you need to recreate numbers from the 31 | paper. 32 | 33 | """ 34 | def __init__( 35 | self, 36 | gt_path="", 37 | fusion_resolution=0.04, 38 | max_fusion_depth=3, 39 | fuse_color=False, 40 | ): 41 | super().__init__( 42 | gt_path, 43 | fusion_resolution, 44 | max_fusion_depth, 45 | fuse_color, 46 | ) 47 | 48 | if gt_path is not None: 49 | gt_mesh = trimesh.load(gt_path, force='mesh') 50 | tsdf_pred = TSDF.from_mesh(gt_mesh, voxel_size=fusion_resolution) 51 | else: 52 | bounds = {} 53 | bounds["xmin"] = -10.0 54 | bounds["xmax"] = 10.0 55 | bounds["ymin"] = -10.0 56 | bounds["ymax"] = 10.0 57 | bounds["zmin"] = -10.0 58 | bounds["zmax"] = 10.0 59 | 60 | tsdf_pred = TSDF.from_bounds(bounds, voxel_size=fusion_resolution) 61 | 62 | self.tsdf_fuser_pred = TSDFFuser(tsdf_pred, max_depth=max_fusion_depth) 63 | 64 | def fuse_frames(self, depths_b1hw, K_b44, 65 | cam_T_world_b44, 66 | color_b3hw): 67 | self.tsdf_fuser_pred.integrate_depth( 68 | depth_b1hw=depths_b1hw.half(), 69 | cam_T_world_T_b44=cam_T_world_b44.half(), 70 | K_b44=K_b44.half(), 71 | ) 72 | 73 | def export_mesh(self, path, export_single_mesh=True): 74 | _ = trimesh.exchange.export.export_mesh( 75 | self.tsdf_fuser_pred.tsdf.to_mesh( 76 | export_single_mesh=export_single_mesh), 77 | path, 78 | ) 79 | 80 | def get_mesh(self, export_single_mesh=True, convert_to_trimesh=True): 81 | return self.tsdf_fuser_pred.tsdf.to_mesh( 82 | export_single_mesh=export_single_mesh) 83 | 84 | class Open3DFuser(DepthFuser): 85 | """ 86 | Wrapper class for the open3d fuser. 87 | 88 | This wrapper does not support fusion of tensors with higher than batch 1. 89 | """ 90 | def __init__( 91 | self, 92 | gt_path="", 93 | fusion_resolution=0.04, 94 | max_fusion_depth=3, 95 | fuse_color=False, 96 | use_upsample_depth=False, 97 | ): 98 | super().__init__( 99 | gt_path, 100 | fusion_resolution, 101 | max_fusion_depth, 102 | fuse_color, 103 | ) 104 | 105 | self.fuse_color = fuse_color 106 | self.use_upsample_depth = use_upsample_depth 107 | self.fusion_max_depth = max_fusion_depth 108 | 109 | voxel_size = fusion_resolution * 100 110 | self.volume = o3d.pipelines.integration.ScalableTSDFVolume( 111 | voxel_length=float(voxel_size) / 100, 112 | sdf_trunc=3 * float(voxel_size) / 100, 113 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8 114 | ) 115 | 116 | def fuse_frames( 117 | self, 118 | depths_b1hw, 119 | K_b44, 120 | cam_T_world_b44, 121 | color_b3hw, 122 | ): 123 | 124 | width = depths_b1hw.shape[-1] 125 | height = depths_b1hw.shape[-2] 126 | 127 | if self.fuse_color: 128 | color_b3hw = torch.nn.functional.interpolate( 129 | color_b3hw, 130 | size=(height, width), 131 | ) 132 | color_b3hw = reverse_imagenet_normalize(color_b3hw) 133 | 134 | for batch_index in range(depths_b1hw.shape[0]): 135 | if self.fuse_color: 136 | image_i = color_b3hw[batch_index].permute(1,2,0) 137 | 138 | color_im = (image_i * 255).cpu().numpy().astype( 139 | np.uint8 140 | ).copy(order='C') 141 | else: 142 | # mesh will now be grey 143 | color_im = 0.7*torch.ones_like( 144 | depths_b1hw[batch_index] 145 | ).squeeze().cpu().clone().numpy() 146 | color_im = np.repeat( 147 | color_im[:, :, np.newaxis] * 255, 148 | 3, 149 | axis=2 150 | ).astype(np.uint8) 151 | 152 | depth_pred = depths_b1hw[batch_index].squeeze().cpu().clone().numpy() 153 | depth_pred = o3d.geometry.Image(depth_pred) 154 | color_im = o3d.geometry.Image(color_im) 155 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 156 | color_im, 157 | depth_pred, 158 | depth_scale=1.0, 159 | depth_trunc=self.fusion_max_depth, 160 | convert_rgb_to_intensity=False, 161 | ) 162 | cam_intr = K_b44[batch_index].cpu().clone().numpy() 163 | cam_T_world_44 = cam_T_world_b44[batch_index].cpu().clone().numpy() 164 | 165 | self.volume.integrate( 166 | rgbd, 167 | o3d.camera.PinholeCameraIntrinsic( 168 | width=width, 169 | height=height, fx=cam_intr[0, 0], 170 | fy=cam_intr[1, 1], 171 | cx=cam_intr[0, 2], 172 | cy=cam_intr[1, 2] 173 | ), 174 | cam_T_world_44, 175 | ) 176 | 177 | def export_mesh(self, path, use_marching_cubes_mask=None): 178 | o3d.io.write_triangle_mesh(path, self.volume.extract_triangle_mesh()) 179 | 180 | def get_mesh(self, export_single_mesh=None, convert_to_trimesh=False): 181 | mesh = self.volume.extract_triangle_mesh() 182 | 183 | if convert_to_trimesh: 184 | mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.triangles) 185 | 186 | return mesh 187 | 188 | def get_fuser(opts, scan): 189 | """Returns the depth fuser required. Our fuser doesn't allow for """ 190 | 191 | if opts.dataset == "scannet": 192 | gt_path = ScannetDataset.get_gt_mesh_path(opts.dataset_path, 193 | opts.split, scan) 194 | else: 195 | gt_path = None 196 | 197 | if opts.depth_fuser == "ours": 198 | if opts.fuse_color: 199 | print("WARNING: fusing color using 'ours' fuser is not supported, " 200 | "Color will not be fused.") 201 | 202 | return OurFuser( 203 | gt_path=gt_path, 204 | fusion_resolution=opts.fusion_resolution, 205 | max_fusion_depth=opts.fusion_max_depth, 206 | fuse_color=False, 207 | ) 208 | if opts.depth_fuser == "open3d": 209 | return Open3DFuser( 210 | gt_path=gt_path, 211 | fusion_resolution=opts.fusion_resolution, 212 | max_fusion_depth=opts.fusion_max_depth, 213 | fuse_color=opts.fuse_color, 214 | ) 215 | else: 216 | raise ValueError("Unrecognized fuser!") -------------------------------------------------------------------------------- /tools/keyframe_buffer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Most of this is a modified version of code from the DeepVideoMVS repository 3 | at https://github.com/ardaduz/deep-video-mvs/blob/master/dvmvs/keyframe_buffer.py 4 | """ 5 | 6 | from collections import deque 7 | import functools 8 | import numpy as np 9 | 10 | 11 | 12 | class DVMVS_Config: 13 | # train tuple settings 14 | train_minimum_pose_distance = 0.125 15 | train_maximum_pose_distance = 0.325 16 | train_crawl_step = 3 17 | 18 | # test tuple settings 19 | test_keyframe_buffer_size = 30 20 | test_keyframe_pose_distance = 0.1 21 | test_optimal_t_measure = 0.15 22 | test_optimal_R_measure = 0.0 23 | 24 | def is_pose_available(pose): 25 | is_nan = np.isnan(pose).any() 26 | is_inf = np.isinf(pose).any() 27 | is_neg_inf = np.isneginf(pose).any() 28 | if is_nan or is_inf or is_neg_inf: 29 | return False 30 | else: 31 | return True 32 | 33 | def is_valid_pair( 34 | reference_pose, 35 | measurement_pose, 36 | pose_dist_min, 37 | pose_dist_max, 38 | t_norm_threshold=0.05, 39 | return_measure=False, 40 | ): 41 | combined_measure, _, t_measure = pose_distance(reference_pose, measurement_pose) 42 | 43 | if (pose_dist_min <= combined_measure <= pose_dist_max 44 | and t_measure >= t_norm_threshold): 45 | result = True 46 | else: 47 | result = False 48 | 49 | if return_measure: 50 | return result, combined_measure 51 | else: 52 | return result 53 | 54 | def pose_distance(reference_pose, measurement_pose): 55 | """ 56 | :param reference_pose: 4x4 numpy array, reference frame camera-to-world pose 57 | (not extrinsic matrix!) 58 | :param measurement_pose: 4x4 numpy array, measurement frame camera-to-world 59 | pose (not extrinsic matrix!) 60 | :return combined_measure: float, combined pose distance measure 61 | :return R_measure: float, rotation distance measure 62 | :return t_measure: float, translation distance measure 63 | """ 64 | rel_pose = np.dot(np.linalg.inv(reference_pose), measurement_pose) 65 | R = rel_pose[:3, :3] 66 | t = rel_pose[:3, 3] 67 | R_measure = np.sqrt(2 * (1 - min(3.0, np.matrix.trace(R)) / 3)) 68 | t_measure = np.linalg.norm(t) 69 | combined_measure = np.sqrt(t_measure ** 2 + R_measure ** 2) 70 | return combined_measure, R_measure, t_measure 71 | 72 | class KeyframeBuffer: 73 | def __init__( 74 | self, 75 | buffer_size, 76 | keyframe_pose_distance, 77 | optimal_t_score, 78 | optimal_R_score, 79 | store_return_indices, 80 | ): 81 | self.buffer = deque([], maxlen=buffer_size) 82 | self.keyframe_pose_distance = keyframe_pose_distance 83 | self.optimal_t_score = optimal_t_score 84 | self.optimal_R_score = optimal_R_score 85 | self.__tracking_lost_counter = 0 86 | # mostly required for simulation of the frame selection 87 | self.__store_return_indices = store_return_indices 88 | 89 | def calculate_penalty(self, t_score, R_score): 90 | degree = 2.0 91 | R_penalty = np.abs(R_score - self.optimal_R_score) ** degree 92 | t_diff = t_score - self.optimal_t_score 93 | if t_diff < 0.0: 94 | t_penalty = 5.0 * (np.abs(t_diff) ** degree) 95 | else: 96 | t_penalty = np.abs(t_diff) ** degree 97 | return R_penalty + t_penalty 98 | 99 | def try_new_keyframe(self, pose, image, dist_to_last_valid=None, index=None): 100 | if self.__store_return_indices and index is None: 101 | raise ValueError("Storing and returning the frame indices is " 102 | f"requested in the constructor, but index=None is " 103 | f"passed to the function") 104 | 105 | # In case valid frames are used, this helps guess if a gap in tracking 106 | # when using valid frames and the indices are not indicative of time. 107 | if (dist_to_last_valid is not None and dist_to_last_valid > 30): 108 | self.buffer.clear() 109 | self.__tracking_lost_counter = 0 110 | if self.__store_return_indices: 111 | self.buffer.append((pose, image, index)) 112 | else: 113 | self.buffer.append((pose, image)) 114 | 115 | return 3 116 | 117 | if is_pose_available(pose): 118 | self.__tracking_lost_counter = 0 119 | if len(self.buffer) == 0: 120 | if self.__store_return_indices: 121 | self.buffer.append((pose, image, index)) 122 | else: 123 | self.buffer.append((pose, image)) 124 | # pose is available, new frame added but buffer was empty, this 125 | # is the first frame, no depth map prediction will be done 126 | return 0 127 | else: 128 | if self.__store_return_indices: 129 | last_pose, last_image, last_index = self.buffer[-1] 130 | else: 131 | last_pose, last_image = self.buffer[-1] 132 | 133 | combined_measure, R_measure, t_measure = pose_distance(pose, last_pose) 134 | 135 | if combined_measure >= self.keyframe_pose_distance: 136 | if self.__store_return_indices: 137 | self.buffer.append((pose, image, index)) 138 | else: 139 | self.buffer.append((pose, image)) 140 | # pose is available, new frame added, everything is perfect, 141 | # and we will predict a depth map later 142 | return 1 143 | else: 144 | # pose is available but not enough change has happened since 145 | # the last keyframe 146 | return 2 147 | else: 148 | self.__tracking_lost_counter += 1 149 | 150 | if self.__tracking_lost_counter > 30: 151 | if len(self.buffer) > 0: 152 | self.buffer.clear() 153 | # a pose reading has not arrived for over a second, tracking 154 | # is now lost 155 | return 3 156 | else: 157 | return 4 # we are still very lost 158 | else: 159 | # pose is not available right now, but not enough time has 160 | # passed to consider lost, there is still hope :) 161 | return 5 162 | 163 | def get_best_measurement_frames(self, n_requested_measurement_frames): 164 | buffer_array = list(self.buffer) 165 | 166 | if self.__store_return_indices: 167 | reference_pose, reference_image, reference_index = buffer_array[-1] 168 | else: 169 | reference_pose, reference_image = buffer_array[-1] 170 | 171 | n_requested_measurement_frames = min(n_requested_measurement_frames, 172 | len(buffer_array) - 1) 173 | 174 | penalties = [] 175 | for i in range(len(buffer_array) - 1): 176 | measurement_pose = buffer_array[i][0] 177 | 178 | _, R_measure, t_measure = pose_distance(reference_pose, measurement_pose) 179 | penalty = self.calculate_penalty(t_measure, R_measure) 180 | penalties.append(penalty) 181 | indices = np.argpartition(penalties, n_requested_measurement_frames - 1)[:n_requested_measurement_frames] 182 | 183 | measurement_frames = [] 184 | for index in indices: 185 | measurement_frames.append(buffer_array[index]) 186 | return measurement_frames 187 | 188 | 189 | class SimpleBuffer: 190 | def __init__( 191 | self, 192 | buffer_size, 193 | store_return_indices, 194 | ): 195 | self.buffer = deque([], maxlen=buffer_size + 1) 196 | self.__tracking_lost_counter = 0 197 | # mostly required for simulation of the frame selection 198 | self.__store_return_indices = store_return_indices 199 | 200 | def try_new_keyframe(self, pose, image, index=None): 201 | if self.__store_return_indices and index is None: 202 | raise ValueError(f"Storing and returning the frame indices is " 203 | f"requested in the constructor, but index=None is " 204 | f"passed to the function") 205 | 206 | if is_pose_available(pose): 207 | self.__tracking_lost_counter = 0 208 | if len(self.buffer) == 0: 209 | if self.__store_return_indices: 210 | self.buffer.append((pose, image, index)) 211 | else: 212 | self.buffer.append((pose, image)) 213 | # pose is available, new frame added but buffer was empty, this 214 | # is the first frame, no depth map prediction will be done 215 | return 0 216 | else: 217 | if self.__store_return_indices: 218 | self.buffer.append((pose, image, index)) 219 | else: 220 | self.buffer.append((pose, image)) 221 | # pose is available, new frame added, everything is perfect, 222 | # and we will predict a depth map later 223 | return 1 224 | else: 225 | self.__tracking_lost_counter += 1 226 | 227 | if self.__tracking_lost_counter > 30: 228 | if len(self.buffer) > 0: 229 | self.buffer.clear() 230 | # a pose reading has not arrived for over a second, tracking 231 | # is now lost 232 | return 2 233 | else: 234 | # we are still very lost 235 | return 3 236 | else: 237 | # pose is not available right now, but not enough time has 238 | # passed to consider lost, there is still hope :) 239 | return 4 240 | 241 | def get_measurement_frames(self): 242 | measurement_frames = list(self.buffer)[:-1] 243 | return measurement_frames 244 | 245 | class OfflineKeyframeBuffer: 246 | def __init__( 247 | self, 248 | buffer_size, 249 | keyframe_pose_distance, 250 | optimal_t_score, 251 | optimal_R_score, 252 | store_return_indices, 253 | ): 254 | self.buffer = deque([], maxlen=buffer_size) 255 | self.keyframe_pose_distance = keyframe_pose_distance 256 | self.optimal_t_score = optimal_t_score 257 | self.optimal_R_score = optimal_R_score 258 | self.__tracking_lost_counter = 0 259 | # mostly required for simulation of the frame selection 260 | self.__store_return_indices = store_return_indices 261 | 262 | @functools.lru_cache() 263 | def calculate_penalty(self, t_score, R_score): 264 | degree = 2.0 265 | R_penalty = np.abs(R_score - self.optimal_R_score) ** degree 266 | t_diff = t_score - self.optimal_t_score 267 | if t_diff < 0.0: 268 | t_penalty = 5.0 * (np.abs(t_diff) ** degree) 269 | else: 270 | t_penalty = np.abs(t_diff) ** degree 271 | return R_penalty + t_penalty 272 | 273 | def try_new_keyframe(self, pose, image, index=None): 274 | if self.__store_return_indices and index is None: 275 | raise ValueError(f"Storing and returning the frame indices is " 276 | f"requested in the constructor, but index=None is " 277 | f"passed to the function") 278 | 279 | if is_pose_available(pose): 280 | self.__tracking_lost_counter = 0 281 | if len(self.buffer) == 0: 282 | if self.__store_return_indices: 283 | self.buffer.append((pose, image, index)) 284 | else: 285 | self.buffer.append((pose, image)) 286 | # pose is available, new frame added but buffer was empty, this 287 | # is the first frame, no depth map prediction will be done 288 | return 0 289 | else: 290 | if self.__store_return_indices: 291 | last_pose, last_image, last_index = self.buffer[-1] 292 | else: 293 | last_pose, last_image = self.buffer[-1] 294 | 295 | accept_frame = True 296 | for buffer_pose, _, _ in list(self.buffer): 297 | combined_measure, _, _ = pose_distance(pose, buffer_pose) 298 | 299 | if combined_measure < self.keyframe_pose_distance: 300 | accept_frame = False 301 | break 302 | 303 | if accept_frame: 304 | if self.__store_return_indices: 305 | self.buffer.append((pose, image, index)) 306 | else: 307 | self.buffer.append((pose, image)) 308 | # pose is available, new frame added, everything is perfect, 309 | # and we will predict a depth map later 310 | return 1 311 | else: 312 | # pose is available but not enough change has happened since 313 | # the last keyframe 314 | return 2 315 | else: 316 | self.__tracking_lost_counter += 1 317 | 318 | if self.__tracking_lost_counter > 30: 319 | if len(self.buffer) > 0: 320 | self.buffer.clear() 321 | # a pose reading has not arrived for over a second, tracking 322 | # is now lost 323 | return 3 324 | else: 325 | # we are still very lost 326 | return 4 327 | else: 328 | # pose is not available right now, but not enough time has 329 | # passed to consider lost, there is still hope :) 330 | return 5 331 | 332 | def get_best_measurement_frames(self, n_requested_measurement_frames): 333 | buffer_array = list(self.buffer) 334 | 335 | if self.__store_return_indices: 336 | reference_pose, reference_image, reference_index = buffer_array[-1] 337 | else: 338 | reference_pose, reference_image = buffer_array[-1] 339 | 340 | n_requested_measurement_frames = min(n_requested_measurement_frames, 341 | len(buffer_array) - 1) 342 | 343 | penalties = [] 344 | for i in range(len(buffer_array) - 1): 345 | measurement_pose = buffer_array[i][0] 346 | _, R_measure, t_measure = pose_distance(reference_pose, measurement_pose) 347 | 348 | penalty = self.calculate_penalty(t_measure, R_measure) 349 | penalties.append(penalty) 350 | indices = np.argpartition(penalties, n_requested_measurement_frames - 1)[:n_requested_measurement_frames] 351 | 352 | measurement_frames = [] 353 | for index in indices: 354 | measurement_frames.append(buffer_array[index]) 355 | return measurement_frames 356 | 357 | def get_best_measurement_frames_for_0index(self, n_requested_measurement_frames): 358 | buffer_array = list(self.buffer)[1:] 359 | 360 | if len(buffer_array) == 0: 361 | return [] 362 | 363 | if self.__store_return_indices: 364 | reference_pose, _, _ = buffer_array[0] 365 | else: 366 | reference_pose, _ = buffer_array[0] 367 | 368 | n_requested_measurement_frames = min(n_requested_measurement_frames, len(buffer_array) - 1) 369 | 370 | penalties = [] 371 | for i in range(len(buffer_array)): 372 | measurement_pose = buffer_array[i][0] 373 | _, R_measure, t_measure = pose_distance(reference_pose, measurement_pose) 374 | penalty = self.calculate_penalty(t_measure, R_measure) 375 | penalties.append(penalty) 376 | 377 | indices = np.argpartition(penalties, n_requested_measurement_frames - 1)[:n_requested_measurement_frames] 378 | measurement_frames = [] 379 | for index in indices: 380 | measurement_frames.append(buffer_array[index]) 381 | return measurement_frames -------------------------------------------------------------------------------- /tools/torch_point_cloud_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is borrowed from https://github.com/alexrich021/3dvnet/blob/main/mv3d/eval/pointcloudfusion_custom.py 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import tqdm 9 | 10 | IMG_BATCH = 100 # modify to scale pc fusion implementation to your GPU 11 | 12 | def process_depth(ref_depth, ref_image, src_depths, src_images, ref_P, src_Ps, ref_K, src_Ks, z_thresh=0.1, 13 | n_consistent_thresh=3): 14 | n_src_imgs = src_depths.shape[0] 15 | h, w = int(ref_depth.shape[0]), int(ref_depth.shape[1]) 16 | n_pts = h * w 17 | 18 | src_Ks = src_Ks.cuda() 19 | src_Ps = src_Ps.cuda() 20 | ref_K = ref_K.cuda() 21 | ref_P = ref_P.cuda() 22 | ref_depth = ref_depth.cuda() 23 | 24 | ref_K_inv = torch.inverse(ref_K) 25 | src_Ks_inv = torch.inverse(src_Ks) 26 | ref_P_inv = torch.inverse(ref_P) 27 | 28 | pts_x = np.linspace(0, w - 1, w) 29 | pts_y = np.linspace(0, h - 1, h) 30 | pts_xx, pts_yy = np.meshgrid(pts_x, pts_y) 31 | 32 | pts = torch.from_numpy(np.stack((pts_xx, pts_yy, np.ones_like(pts_xx)), axis=0)).float().cuda() 33 | pts = ref_P_inv[:3, :3] @ (ref_K_inv @ (pts * ref_depth.unsqueeze(0)).view(3, n_pts))\ 34 | + ref_P_inv[:3, 3, None] 35 | 36 | n_batches = (n_src_imgs - 1) // IMG_BATCH + 1 37 | n_valid = 0. 38 | pts_sample_all = [] 39 | valid_per_src_all = [] 40 | for b in range(n_batches): 41 | idx_start = b * IMG_BATCH 42 | idx_end = min((b + 1) * IMG_BATCH, n_src_imgs) 43 | src_Ps_batch = src_Ps[idx_start: idx_end] 44 | src_Ks_batch = src_Ks[idx_start: idx_end] 45 | src_Ks_inv_batch = src_Ks_inv[idx_start: idx_end] 46 | src_depths_batch = src_depths[idx_start: idx_end].cuda() 47 | 48 | n_batch_imgs = idx_end - idx_start 49 | pts_reproj = torch.bmm(src_Ps_batch[:, :3, :3], 50 | pts.unsqueeze(0).repeat(n_batch_imgs, 1, 1)) + src_Ps_batch[:, :3, 3, None] 51 | pts_reproj = torch.bmm(src_Ks_batch, pts_reproj) 52 | z_reproj = pts_reproj[:, 2] 53 | pts_reproj = pts_reproj / z_reproj.unsqueeze(1) 54 | 55 | valid_z = (z_reproj > 1e-4) 56 | valid_x = (pts_reproj[:, 0] >= 0.) & (pts_reproj[:, 0] <= float(w - 1)) 57 | valid_y = (pts_reproj[:, 1] >= 0.) & (pts_reproj[:, 1] <= float(h - 1)) 58 | 59 | grid = torch.clone(pts_reproj[:, :2]).transpose(2, 1).view(n_batch_imgs, n_pts, 1, 2) 60 | grid[..., 0] = (grid[..., 0] / float(w - 1)) * 2 - 1.0 # normalize to [-1, 1] 61 | grid[..., 1] = (grid[..., 1] / float(h - 1)) * 2 - 1.0 # normalize to [-1, 1] 62 | z_sample = F.grid_sample(src_depths_batch.unsqueeze(1), grid, mode='nearest', align_corners=True, 63 | padding_mode='zeros') 64 | z_sample = z_sample.squeeze(1).squeeze(-1) 65 | 66 | z_diff = torch.abs(z_reproj - z_sample) 67 | valid_disp = z_diff < z_thresh 68 | 69 | valid_per_src = (valid_disp & valid_x & valid_y & valid_z) 70 | n_valid += torch.sum(valid_per_src.int(), dim=0) 71 | 72 | # back project sampled pts for later averaging 73 | pts_sample = torch.bmm(src_Ks_inv_batch, pts_reproj * z_sample.unsqueeze(1)) 74 | pts_sample = torch.bmm(src_Ps_batch[:, :3, :3].transpose(2, 1), 75 | pts_sample - src_Ps_batch[:, :3, 3, None]) 76 | pts_sample_all.append(pts_sample) 77 | valid_per_src_all.append(valid_per_src) 78 | pts_sample_all = torch.cat(pts_sample_all, dim=0) 79 | valid_per_src_all = torch.cat(valid_per_src_all, dim=0) 80 | 81 | valid = n_valid >= n_consistent_thresh 82 | 83 | # average sampled points amongst consistent views 84 | pts_avg = pts 85 | for i in range(n_src_imgs): 86 | pts_sample_i = pts_sample_all[i] 87 | invalid_idx = torch.isnan(pts_sample_i) # filter out NaNs from div/0 due to grid sample zero padding 88 | pts_sample_i[invalid_idx] = 0. 89 | valid_i = valid_per_src_all[i] & ~torch.any(invalid_idx, dim=0) 90 | pts_avg += pts_sample_i * valid_i.float().unsqueeze(0) 91 | pts_avg = pts_avg / (n_valid + 1).float().unsqueeze(0).expand(3, n_pts) 92 | 93 | pts_filtered = pts_avg.transpose(1, 0)[valid].cpu().numpy() 94 | valid = valid.view(ref_depth.shape[-2:]) 95 | rgb_filtered = ref_image[valid].view(-1, 3).cpu().numpy() 96 | 97 | return pts_filtered, rgb_filtered, valid.cpu().numpy() 98 | 99 | 100 | def process_scene(depth_preds, images, poses, K, z_thresh, n_consistent_thresh): 101 | n_imgs = depth_preds.shape[0] 102 | fused_pts = [] 103 | fused_rgb = [] 104 | all_idx = torch.arange(n_imgs) 105 | all_valid = [] 106 | for ref_idx in tqdm.tqdm(range(n_imgs)): 107 | src_idx = all_idx != ref_idx 108 | pts, rgb, valid = process_depth(depth_preds[ref_idx], images[ref_idx], depth_preds[src_idx], images[src_idx], 109 | poses[ref_idx], poses[src_idx], K[ref_idx], K[src_idx], z_thresh, 110 | n_consistent_thresh) 111 | fused_pts.append(pts) 112 | fused_rgb.append(rgb) 113 | all_valid.append(valid) 114 | fused_pts = np.concatenate(fused_pts, axis=0) 115 | fused_rgb = np.concatenate(fused_rgb, axis=0) 116 | all_valid = np.stack(all_valid, axis=0) 117 | 118 | return fused_pts, fused_rgb, all_valid 119 | -------------------------------------------------------------------------------- /tools/tsdf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as TF 7 | import trimesh 8 | from skimage import measure 9 | 10 | 11 | class TSDF: 12 | 13 | """ 14 | Class for housing and data handling TSDF volumes. 15 | """ 16 | # Ensures the final voxel volume dimensions are multiples of 8 17 | VOX_MOD = 8 18 | 19 | def __init__( 20 | self, 21 | voxel_coords: torch.tensor, 22 | tsdf_values: torch.tensor, 23 | tsdf_weights: torch.tensor, 24 | voxel_size: float, 25 | origin: torch.tensor, 26 | ): 27 | """ 28 | Sets interal class attributes. 29 | """ 30 | self.voxel_coords = voxel_coords.half() 31 | self.tsdf_values = tsdf_values.half() 32 | self.tsdf_weights = tsdf_weights.half() 33 | self.voxel_size = voxel_size 34 | self.origin = origin.half() 35 | 36 | @classmethod 37 | def from_file(cls, tsdf_file): 38 | """ Loads a tsdf from a numpy file. """ 39 | tsdf_data = np.load(tsdf_file) 40 | 41 | tsdf_values = torch.from_numpy(tsdf_data['tsdf_values']) 42 | origin = torch.from_numpy(tsdf_data['origin']) 43 | voxel_size = tsdf_data['voxel_size'].item() 44 | 45 | tsdf_weights = torch.ones_like(tsdf_values) 46 | 47 | voxel_coords = cls.generate_voxel_coords(origin, tsdf_values.shape[1:], voxel_size) 48 | 49 | return TSDF(voxel_coords, tsdf_values, tsdf_weights, voxel_size) 50 | 51 | @classmethod 52 | def from_mesh(cls, mesh: trimesh.Trimesh, voxel_size: float): 53 | """ Gets TSDF bounds from a mesh file. """ 54 | xmax, ymax, zmax = mesh.vertices.max(0) 55 | xmin, ymin, zmin = mesh.vertices.min(0) 56 | 57 | bounds = {'xmin': xmin, 'xmax': xmax, 58 | 'ymin': ymin, 'ymax': ymax, 59 | 'zmin': zmin, 'zmax': zmax} 60 | 61 | # create a buffer around bounds 62 | for key, val in bounds.items(): 63 | if 'min' in key: 64 | bounds[key] = val - 3 * voxel_size 65 | else: 66 | bounds[key] = val + 3 * voxel_size 67 | return cls.from_bounds(bounds, voxel_size) 68 | 69 | @classmethod 70 | def from_bounds(cls, bounds: dict, voxel_size: float): 71 | """ Creates a TSDF volume with bounds at a specific voxel size. """ 72 | 73 | expected_keys = ['xmin', 'xmax', 'ymin', 'ymax', 'zmin', 'zmax'] 74 | for key in expected_keys: 75 | if key not in bounds.keys(): 76 | raise KeyError("Provided bounds dict need to have keys" 77 | "'xmin', 'xmax', 'ymin', 'ymax', 'zmin', 'zmax'!") 78 | 79 | num_voxels_x = int( 80 | np.ceil((bounds['xmax'] - bounds['xmin']) / voxel_size / cls.VOX_MOD)) * cls.VOX_MOD 81 | num_voxels_y = int( 82 | np.ceil((bounds['ymax'] - bounds['ymin']) / voxel_size / cls.VOX_MOD)) * cls.VOX_MOD 83 | num_voxels_z = int( 84 | np.ceil((bounds['zmax'] - bounds['zmin']) / voxel_size / cls.VOX_MOD)) * cls.VOX_MOD 85 | 86 | origin = torch.FloatTensor([bounds['xmin'], bounds['ymin'], bounds['zmin']]) 87 | 88 | voxel_coords = cls.generate_voxel_coords( 89 | origin, (num_voxels_x, num_voxels_y, num_voxels_z), voxel_size).half() 90 | 91 | # init to -1s 92 | tsdf_values = -torch.ones_like(voxel_coords[0]).half() 93 | tsdf_weights = torch.zeros_like(voxel_coords[0]).half() 94 | 95 | return TSDF(voxel_coords, tsdf_values, tsdf_weights, voxel_size, origin) 96 | 97 | @classmethod 98 | def generate_voxel_coords(cls, 99 | origin: torch.tensor, 100 | volume_dims: Tuple[int, int, int], 101 | voxel_size: float): 102 | """ Gets world coordinates for each location in the TSDF. """ 103 | 104 | grid = torch.meshgrid([torch.arange(vd) for vd in volume_dims]) 105 | 106 | voxel_coords = origin.view(3, 1, 1, 1) + torch.stack(grid, 0) * voxel_size 107 | 108 | return voxel_coords 109 | 110 | 111 | def cuda(self): 112 | """ Moves TSDF to gpu memory. """ 113 | self.voxel_coords = self.voxel_coords.cuda() 114 | self.tsdf_values = self.tsdf_values.cuda() 115 | if self.tsdf_weights is not None: 116 | self.tsdf_weights = self.tsdf_weights.cuda() 117 | 118 | def cpu(self): 119 | """ Moves TSDF to cpu memory. """ 120 | self.voxel_coords = self.voxel_coords.cpu() 121 | self.tsdf_values = self.tsdf_values.cpu() 122 | if self.tsdf_weights is not None: 123 | self.tsdf_weights = self.tsdf_weights.cpu() 124 | 125 | def to_mesh(self, scale_to_world=True, export_single_mesh=False): 126 | """ Extracts a mesh from the TSDF volume using marching cubes. 127 | 128 | Args: 129 | scale_to_world: should we scale vertices from TSDF voxel coords 130 | to world coordinates? 131 | export_single_mesh: returns a single walled mesh from marching 132 | cubes. Requires a custom implementation of 133 | measure.marching_cubes that supports single_mesh 134 | 135 | """ 136 | tsdf = self.tsdf_values.detach().cpu().clone().float() 137 | tsdf_np = tsdf.clamp(-1, 1).cpu().numpy() 138 | 139 | if export_single_mesh: 140 | verts, faces, norms, _ = measure.marching_cubes( 141 | tsdf_np, 142 | level=0, 143 | allow_degenerate=False, 144 | single_mesh = True, 145 | ) 146 | else: 147 | verts, faces, norms, _ = measure.marching_cubes( 148 | tsdf_np, 149 | level=0, 150 | allow_degenerate=False, 151 | ) 152 | 153 | if scale_to_world: 154 | verts = self.origin.cpu().view(1, 3) + verts * self.voxel_size 155 | 156 | mesh = trimesh.Trimesh(vertices=verts, faces=faces, normals=norms) 157 | return mesh 158 | 159 | def save(self, savepath, filename, save_mesh=True): 160 | """ Saves a mesh to disk. """ 161 | self.cpu() 162 | os.makedirs(savepath, exist_ok=True) 163 | 164 | if save_mesh: 165 | mesh = self.to_mesh() 166 | trimesh.exchange.export.export_mesh( 167 | mesh, os.path.join(savepath, 168 | filename).replace(".bin", ".ply"), "ply") 169 | 170 | 171 | class TSDFFuser: 172 | """ 173 | Class for fusing depth maps into TSDF volumes. 174 | """ 175 | def __init__(self, tsdf, min_depth=0.5, max_depth=5.0, use_gpu=True): 176 | """ 177 | Inits the fuser with fusing parameters. 178 | 179 | Args: 180 | tsdf: a TSDF volume object. 181 | min_depth: minimum depth to limit inomcing depth maps to. 182 | max_depth: maximum depth to limit inomcing depth maps to. 183 | use_gpu: use cuda? 184 | 185 | """ 186 | self.tsdf = tsdf 187 | self.min_depth = min_depth 188 | self.max_depth = max_depth 189 | self.use_gpu = use_gpu 190 | self.truncation_size = 3.0 191 | self.maxW = 100.0 192 | 193 | # Create homogeneous coords once only 194 | self.hom_voxel_coords_14hwd = torch.cat( 195 | (self.voxel_coords, torch.ones_like(self.voxel_coords[:1])), 0).unsqueeze(0) 196 | 197 | @property 198 | def voxel_coords(self): 199 | return self.tsdf.voxel_coords 200 | 201 | @property 202 | def tsdf_values(self): 203 | return self.tsdf.tsdf_values 204 | 205 | @property 206 | def tsdf_weights(self): 207 | return self.tsdf.tsdf_weights 208 | 209 | @property 210 | def voxel_size(self): 211 | return self.tsdf.voxel_size 212 | 213 | @property 214 | def shape(self): 215 | return self.voxel_coords.shape[1:] 216 | 217 | @property 218 | def truncation(self): 219 | return self.truncation_size * self.voxel_size 220 | 221 | def project_to_camera(self, cam_T_world_T_b44, K_b44): 222 | 223 | if self.use_gpu: 224 | cam_T_world_T_b44 = cam_T_world_T_b44.cuda() 225 | K_b44 = K_b44.cuda() 226 | self.hom_voxel_coords_14hwd = self.hom_voxel_coords_14hwd.cuda() 227 | 228 | world_to_pix_P_b34 = torch.matmul(K_b44, cam_T_world_T_b44)[:, :3] 229 | batch_size = cam_T_world_T_b44.shape[0] 230 | 231 | world_points_b4N = \ 232 | self.hom_voxel_coords_14hwd.expand(batch_size, 4, *self.shape).flatten(start_dim=2) 233 | cam_points_b3N = torch.matmul(world_to_pix_P_b34, world_points_b4N) 234 | cam_points_b3N[:, :2] = cam_points_b3N[:, :2] / cam_points_b3N[:, 2, None] 235 | 236 | return cam_points_b3N 237 | 238 | def integrate_depth( 239 | self, 240 | depth_b1hw, 241 | cam_T_world_T_b44, 242 | K_b44, 243 | depth_mask_b1hw = None, 244 | ): 245 | """ 246 | Integrates depth maps into the volume. Supports batching. 247 | 248 | depth_b1hw: tensor with depth map 249 | cam_T_world_T_b44: camera extrinsics (not pose!). 250 | K_b44: camera intrinsics. 251 | depth_mask_b1hw: an optional boolean mask for valid depth points in 252 | the depth map. 253 | """ 254 | img_h, img_w = depth_b1hw.shape[2:] 255 | img_size = torch.tensor([img_w, img_h], dtype=torch.float16).view(1, 1, 1, 2) 256 | if self.use_gpu: 257 | depth_b1hw = depth_b1hw.cuda() 258 | img_size = img_size.cuda() 259 | self.tsdf.cuda() 260 | 261 | # Project voxel coordinates into images 262 | cam_points_b3N = self.project_to_camera(cam_T_world_T_b44, K_b44) 263 | vox_depth_b1N = cam_points_b3N[:, 2:3] 264 | pixel_coords_b2N = cam_points_b3N[:, :2] 265 | 266 | # Reshape the projected voxel coords to a 2D view of shape Hx(WxD) 267 | pixel_coords_bhw2 = pixel_coords_b2N.view(-1, 2, self.shape[0], 268 | self.shape[1] * self.shape[2] 269 | ).permute(0, 2, 3, 1) 270 | pixel_coords_bhw2 = 2 * pixel_coords_bhw2 / img_size - 1 271 | 272 | if depth_mask_b1hw is not None: 273 | depth_b1hw = depth_b1hw.clone() 274 | depth_b1hw[~depth_mask_b1hw] = -1 275 | 276 | # Sample the depth using grid sample 277 | sampled_depth_b1hw = TF.grid_sample(input=depth_b1hw, 278 | grid=pixel_coords_bhw2, 279 | mode="nearest", 280 | padding_mode="zeros", 281 | align_corners=False) 282 | sampled_depth_b1N = sampled_depth_b1hw.flatten(start_dim=2) 283 | 284 | # Confidence from InfiniTAM 285 | confidence_b1N = torch.clamp( 286 | 1.0 - (sampled_depth_b1N - self.min_depth) / (self.max_depth - self.min_depth), 287 | min=0.0, max=1.0) ** 2 288 | 289 | # Calculate TSDF values from depth difference by normalizing to [-1, 1] 290 | dist_b1N = sampled_depth_b1N - vox_depth_b1N 291 | tsdf_vals_b1N = torch.clamp(dist_b1N / self.truncation, min=-1.0, max=1.0) 292 | 293 | # Get the valid points mask 294 | valid_points_b1N = (vox_depth_b1N > 0) & (dist_b1N > -self.truncation) & \ 295 | (sampled_depth_b1N > 0) & (vox_depth_b1N > 0) & (vox_depth_b1N < self.max_depth) & \ 296 | (confidence_b1N > 0) 297 | 298 | # Updating the TSDF has to be sequential so we break out the batch here 299 | for tsdf_val_1N, valid_points_1N, confidence_1N in zip(tsdf_vals_b1N, 300 | valid_points_b1N, 301 | confidence_b1N): 302 | # Reshape the valid mask to the TSDF's shape and read the old values 303 | valid_points_hwd = valid_points_1N.view(self.shape) 304 | old_tsdf_vals = self.tsdf_values[valid_points_hwd] 305 | old_weights = self.tsdf_weights[valid_points_hwd] 306 | 307 | # Fetch the new tsdf values and the confidence 308 | new_tsdf_vals = tsdf_val_1N[valid_points_1N] 309 | confidence = confidence_1N[valid_points_1N] 310 | 311 | # More infiniTAM magic: update faster when the new samples are more confident 312 | update_rate = torch.where(confidence < old_weights, 2.0, 5.0).half() 313 | 314 | # Compute the new weight and the normalization factor 315 | new_weights = confidence * update_rate / self.maxW 316 | total_weights = old_weights + new_weights 317 | 318 | # Update the tsdf and the weights 319 | self.tsdf_values[valid_points_hwd] = (old_tsdf_vals * old_weights + new_tsdf_vals * new_weights) / total_weights 320 | self.tsdf_weights[valid_points_hwd] = torch.clamp(total_weights, max=1.0) 321 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a DepthModel model. Uses an MVS dataset from datasets. 3 | 4 | - Outputs logs and checkpoints to opts.log_dir/opts.name 5 | - Supports mixed precision training by setting '--precision 16' 6 | 7 | We train with a batch_size of 16 with 16-bit precision on two A100s. 8 | 9 | Example command to train with two GPUs 10 | python train.py --name HERO_MODEL \ 11 | --log_dir logs \ 12 | --config_file configs/models/hero_model.yaml \ 13 | --data_config configs/data/scannet_default_train.yaml \ 14 | --gpus 2 \ 15 | --batch_size 16; 16 | 17 | """ 18 | 19 | 20 | import os 21 | 22 | import pytorch_lightning as pl 23 | from pytorch_lightning.callbacks import LearningRateMonitor 24 | from pytorch_lightning.loggers import TensorBoardLogger 25 | from pytorch_lightning.plugins import DDPPlugin 26 | from torch.utils.data import DataLoader 27 | 28 | import options 29 | from experiment_modules.depth_model import DepthModel 30 | from utils.generic_utils import copy_code_state 31 | from utils.dataset_utils import get_dataset 32 | 33 | 34 | def main(opts): 35 | 36 | # set seed 37 | pl.seed_everything(opts.random_seed) 38 | 39 | 40 | if opts.load_weights_from_checkpoint is not None: 41 | model = DepthModel.load_from_checkpoint( 42 | opts.load_weights_from_checkpoint, 43 | opts=opts, 44 | args=None 45 | ) 46 | else: 47 | # load model using read options 48 | model = DepthModel(opts) 49 | 50 | # load dataset and dataloaders 51 | dataset_class, _ = get_dataset(opts.dataset, 52 | opts.dataset_scan_split_file, opts.single_debug_scan_id) 53 | 54 | train_dataset = dataset_class( 55 | opts.dataset_path, 56 | split="train", 57 | mv_tuple_file_suffix=opts.mv_tuple_file_suffix, 58 | num_images_in_tuple=opts.num_images_in_tuple, 59 | tuple_info_file_location=opts.tuple_info_file_location, 60 | image_width=opts.image_width, 61 | image_height=opts.image_height, 62 | shuffle_tuple=opts.shuffle_tuple, 63 | ) 64 | 65 | train_dataloader = DataLoader( 66 | train_dataset, 67 | batch_size=opts.batch_size, 68 | shuffle=True, 69 | num_workers=opts.num_workers, 70 | pin_memory=True, 71 | drop_last=True, 72 | persistent_workers=True, 73 | ) 74 | 75 | val_dataset = dataset_class( 76 | opts.dataset_path, 77 | split="val", 78 | mv_tuple_file_suffix=opts.mv_tuple_file_suffix, 79 | num_images_in_tuple=opts.num_images_in_tuple, 80 | tuple_info_file_location=opts.tuple_info_file_location, 81 | image_width=opts.image_width, 82 | image_height=opts.image_height, 83 | include_full_res_depth=opts.high_res_validation, 84 | ) 85 | 86 | val_dataloader = DataLoader( 87 | val_dataset, 88 | batch_size=opts.val_batch_size, 89 | shuffle=False, 90 | num_workers=opts.num_workers, 91 | pin_memory=True, 92 | drop_last=True, 93 | persistent_workers=True, 94 | ) 95 | 96 | # set up a tensorboard logger through lightning 97 | logger = TensorBoardLogger(save_dir=opts.log_dir, name=opts.name) 98 | 99 | # This will copy a snapshot of the code (minus whatever is in .gitignore) 100 | # into a folder inside the main log directory. 101 | copy_code_state(path=os.path.join(logger.log_dir, "code")) 102 | 103 | # dumping a copy of the config to the directory for easy(ier) 104 | # reproducibility. 105 | options.OptionsHandler.save_options_as_yaml( 106 | os.path.join(logger.log_dir, "config.yaml"), 107 | opts, 108 | ) 109 | 110 | # set a checkpoint callback for lignting to save model checkpoints 111 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 112 | save_last=True, 113 | save_top_k=1, 114 | verbose=True, 115 | monitor='val/loss', 116 | mode='min', 117 | ) 118 | 119 | 120 | # keep track of changes in learning rate 121 | lr_monitor = LearningRateMonitor(logging_interval='step') 122 | 123 | # allowing the lightning DDPPlugin to ignore unused params. 124 | find_unused_parameters = (opts.matching_encoder_type == "unet_encoder") 125 | 126 | trainer = pl.Trainer( 127 | gpus=opts.gpus, 128 | log_every_n_steps=opts.log_interval, 129 | val_check_interval=opts.val_interval, 130 | limit_val_batches=opts.val_batches, 131 | max_steps=opts.max_steps, 132 | precision=opts.precision, 133 | benchmark=True, 134 | logger=logger, 135 | sync_batchnorm=False, 136 | callbacks=[checkpoint_callback, lr_monitor], 137 | num_sanity_val_steps=opts.num_sanity_val_steps, 138 | strategy=DDPPlugin( 139 | find_unused_parameters=find_unused_parameters 140 | ), 141 | resume_from_checkpoint=opts.resume, 142 | ) 143 | 144 | # start training 145 | trainer.fit(model, train_dataloader, val_dataloader) 146 | 147 | 148 | if __name__ == '__main__': 149 | # get an instance of options and load it with config file(s) and cli args. 150 | option_handler = options.OptionsHandler() 151 | option_handler.parse_and_merge_options() 152 | option_handler.pretty_print_options() 153 | print("\n") 154 | opts = option_handler.options 155 | 156 | # if no GPUs are available for us then, use the 32 bit on CPU 157 | if opts.gpus == 0: 158 | print("Setting precision to 32 bits since --gpus is set to 0.") 159 | opts.precision = 32 160 | 161 | main(opts) 162 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from datasets.colmap_dataset import ColmapDataset 2 | from datasets.arkit_dataset import ARKitDataset 3 | from datasets.scannet_dataset import ScannetDataset 4 | from datasets.seven_scenes_dataset import SevenScenesDataset 5 | from datasets.vdr_dataset import VDRDataset 6 | from datasets.scanniverse_dataset import ScanniverseDataset 7 | 8 | def get_dataset(dataset_name, 9 | split_filepath, 10 | single_debug_scan_id=None, 11 | verbose=True): 12 | """ Helper function for passing back the right dataset class, and helps with 13 | itentifying the scans in a split file. 14 | 15 | dataset_name: a string pointing to the right dataset name, allowed names 16 | are: 17 | - scannet 18 | - arkit: arkit format as obtained and processed by NeuralRecon 19 | - vdr 20 | - scanniverse 21 | - colmap: colmap text format. 22 | - 7scenes: processed and undistorted seven scenes. 23 | split_filepath: a path to a text file that contains a list of scans that 24 | will be passed back as a list called scans. 25 | single_debug_scan_id: if not None will override the split file and will 26 | be passed back in scans as the only item. 27 | verbose: if True will print the dataset name and number of scans. 28 | 29 | Returns: 30 | dataset_class: A handle to the right dataset class for use in 31 | creating objects of that class. 32 | scans: a lit of scans in the split file. 33 | """ 34 | if dataset_name == "scannet": 35 | 36 | with open(split_filepath) as file: 37 | scans = file.readlines() 38 | scans = [scan.strip() for scan in scans] 39 | 40 | if single_debug_scan_id is not None: 41 | scans = [single_debug_scan_id] 42 | 43 | dataset_class = ScannetDataset 44 | if verbose: 45 | print(f"".center(80, "#")) 46 | print(f" ScanNet Dataset, number of scans: {len(scans)} ".center(80, "#")) 47 | print(f"".center(80, "#")) 48 | print("") 49 | 50 | 51 | elif dataset_name == "arkit": 52 | 53 | with open(split_filepath) as file: 54 | scans = file.readlines() 55 | scans = [scan.strip() for scan in scans] 56 | 57 | if single_debug_scan_id is not None: 58 | scans = [single_debug_scan_id] 59 | 60 | dataset_class = ARKitDataset 61 | if verbose: 62 | print(f"".center(80, "#")) 63 | print(f" ARKit Dataset, number of scans: {len(scans)} ".center(80, "#")) 64 | print(f"".center(80, "#")) 65 | print("") 66 | 67 | elif dataset_name == "vdr": 68 | 69 | with open(split_filepath) as file: 70 | scans = file.readlines() 71 | scans = [scan.strip() for scan in scans] 72 | 73 | if single_debug_scan_id is not None: 74 | scans = [single_debug_scan_id] 75 | 76 | 77 | if single_debug_scan_id is not None: 78 | scans = [single_debug_scan_id] 79 | 80 | dataset_class = VDRDataset 81 | 82 | if verbose: 83 | print(f"".center(80, "#")) 84 | print(f" VDR Dataset, number of scans: {len(scans)} ".center(80, "#")) 85 | print(f"".center(80, "#")) 86 | print("") 87 | 88 | elif dataset_name == "scanniverse": 89 | 90 | with open(split_filepath) as file: 91 | scans = file.readlines() 92 | scans = [scan.strip() for scan in scans] 93 | 94 | if single_debug_scan_id is not None: 95 | scans = [single_debug_scan_id] 96 | 97 | dataset_class = ScanniverseDataset 98 | if verbose: 99 | print(f"".center(80, "#")) 100 | print(f" Scanniverse Dataset, number of scans: {len(scans)} ".center(80, "#")) 101 | print(f"".center(80, "#")) 102 | print("") 103 | 104 | elif dataset_name == "colmap": 105 | 106 | with open(split_filepath) as file: 107 | scans = file.readlines() 108 | scans = [scan.strip() for scan in scans] 109 | 110 | if single_debug_scan_id is not None: 111 | scans = [single_debug_scan_id] 112 | 113 | dataset_class = ColmapDataset 114 | if verbose: 115 | print(f"".center(80, "#")) 116 | print(f" Colmap Dataset, number of scans: {len(scans)} ".center(80, "#")) 117 | print(f"".center(80, "#")) 118 | print("") 119 | 120 | elif dataset_name == "7scenes": 121 | 122 | with open(split_filepath) as file: 123 | scans = file.readlines() 124 | scans = [scan.strip() for scan in scans] 125 | 126 | if single_debug_scan_id is not None: 127 | scans = [single_debug_scan_id] 128 | 129 | dataset_class = SevenScenesDataset 130 | 131 | if verbose: 132 | print(f"".center(80, "#")) 133 | print(f" 7Scenes Dataset, number of scans: {len(scans)} ".center(80, "#")) 134 | print(f"".center(80, "#")) 135 | print("") 136 | 137 | else: 138 | raise ValueError(f"Not a recognized dataset: {dataset_name}") 139 | 140 | return dataset_class, scans 141 | -------------------------------------------------------------------------------- /utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | from pathlib import Path 5 | 6 | import kornia 7 | import torch 8 | import torchvision.transforms.functional as TF 9 | from PIL import Image 10 | from torch import nn 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def copy_code_state(path): 16 | """Copies the code directory into the path specified using rsync. It will 17 | use a .gitignore file to exclude files in rsync. We preserve modification 18 | times in rsync.""" 19 | 20 | # create dir 21 | Path(os.path.join(path)).mkdir(parents=True, exist_ok=True) 22 | 23 | if os.path.exists("./.gitignore"): 24 | # use .gitignore to remove junk 25 | rsync_command = ( 26 | f"rsync -art --exclude-from='./.gitignore' --exclude '.git' . {path}" 27 | ) 28 | else: 29 | print("WARNING: no .gitignore found so can't use that to exlcude large " 30 | "files when making a back up of files in copy_code_state.") 31 | rsync_command = ( 32 | f"rsync -art --exclude '.git' . {path}" 33 | ) 34 | os.system(rsync_command) 35 | 36 | def readlines(filepath): 37 | """ Reads in a text file and returns lines in a list. """ 38 | with open(filepath, 'r') as f: 39 | lines = f.read().splitlines() 40 | return lines 41 | 42 | def normalize_depth_single(depth_11hw, mask_11hw, robust=False): 43 | 44 | if mask_11hw is not None: 45 | valid_depth_vals_N = depth_11hw.masked_select(mask_11hw) 46 | else: 47 | valid_depth_vals_N = torch.flatten(depth_11hw) 48 | 49 | num_valid_pix = valid_depth_vals_N.nelement() 50 | num_percentile_pix = num_valid_pix // 10 51 | 52 | if num_valid_pix == 0: 53 | return depth_11hw 54 | 55 | sorted_depth_vals_N = torch.sort(valid_depth_vals_N)[0] 56 | depth_flat_N = sorted_depth_vals_N[num_percentile_pix:-num_percentile_pix] 57 | 58 | if depth_flat_N.nelement() == 0: 59 | depth_flat_N = valid_depth_vals_N 60 | 61 | if robust: 62 | depth_shift = depth_flat_N.median() 63 | depth_scale = torch.mean(torch.abs(depth_flat_N - depth_shift)) 64 | else: 65 | depth_shift = depth_flat_N.mean() 66 | depth_scale = depth_flat_N.std() 67 | 68 | depth_norm = (depth_11hw - depth_shift) / depth_scale 69 | 70 | return depth_norm 71 | 72 | 73 | def normalize_depth(depth_b1hw: torch.Tensor, 74 | mask_b1hw: torch.Tensor = None, 75 | robust: bool = False): 76 | 77 | depths_11hw = torch.split(depth_b1hw, 1, 0) 78 | masks_11hw = ([None] * len(depths_11hw) if mask_b1hw is None 79 | else torch.split(mask_b1hw, 1, 0)) 80 | 81 | depths_norm_11hw = [normalize_depth_single(d, m, robust) 82 | for d, m in zip(depths_11hw, masks_11hw)] 83 | 84 | return torch.cat(depths_norm_11hw, dim=0) 85 | 86 | 87 | @torch.jit.script 88 | def pyrdown(input_tensor: torch.Tensor, num_scales: int = 4): 89 | """ Creates a downscale pyramid for the input tensor. """ 90 | output = [input_tensor] 91 | for _ in range(num_scales - 1): 92 | down = kornia.filters.blur_pool2d(output[-1], 3) 93 | output.append(down) 94 | return output 95 | 96 | def upsample(x): 97 | """ 98 | Upsample input tensor by a factor of 2 99 | """ 100 | return nn.functional.interpolate( 101 | x, 102 | scale_factor=2, 103 | mode="bilinear", 104 | align_corners=False, 105 | ) 106 | 107 | def batched_trace(mat_bNN): 108 | return mat_bNN.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) 109 | 110 | def tensor_B_to_bM(tensor_BS, batch_size, num_views): 111 | """Unpacks a flattened tensor of tupled elements (BS) into bMS. Tuple size 112 | is M.""" 113 | # S for wild card number of dims in the middle 114 | # tensor_bSM = tensor_BS.unfold(0, step=num_views, size=num_views) 115 | # tensor_bMS = tensor_bSM.movedim(-1, 1) 116 | tensor_bMS = tensor_BS.view([batch_size, num_views] + list(tensor_BS.shape[1:])) 117 | 118 | return tensor_bMS 119 | 120 | 121 | def tensor_bM_to_B(tensor_bMS): 122 | """Packs an inflated tensor of tupled elements (bMS) into BS. Tuple size 123 | is M.""" 124 | # S for wild card number of dims in the middle 125 | num_views = tensor_bMS.shape[1] 126 | num_batches = tensor_bMS.shape[0] 127 | 128 | tensor_BS = tensor_bMS.view([num_views * num_batches] + list(tensor_bMS.shape[2:])) 129 | 130 | return tensor_BS 131 | 132 | def combine_dims(x, dim_begin, dim_end): 133 | """Views x with the dimensions from dim_begin to dim_end folded.""" 134 | combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:]) 135 | return x.view(combined_shape) 136 | 137 | 138 | def to_gpu(input_dict, key_ignores=[]): 139 | """" Moves tensors in the input dict to the gpu and ignores tensors/elements 140 | as with keys in key_ignores. 141 | """ 142 | for k, v in input_dict.items(): 143 | if k not in key_ignores: 144 | input_dict[k] = v.cuda().float() 145 | return input_dict 146 | 147 | def imagenet_normalize(image): 148 | """ Normalizes an image with ImageNet statistics. """ 149 | image = TF.normalize(tensor=image, 150 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 151 | return image 152 | 153 | def reverse_imagenet_normalize(image): 154 | """ Reverses ImageNet normalization in an input image. """ 155 | 156 | image = TF.normalize(tensor=image, 157 | mean=(-2.11790393, -2.03571429, -1.80444444), 158 | std=(4.36681223, 4.46428571, 4.44444444)) 159 | return image 160 | 161 | 162 | def read_image_file(filepath, 163 | height=None, 164 | width=None, 165 | value_scale_factor=1.0, 166 | resampling_mode=Image.BILINEAR, 167 | disable_warning=False, 168 | target_aspect_ratio=None): 169 | """" Reads an image file using PIL, then optionally resizes the image, 170 | with selective resampling, scales values, and returns the image as a 171 | tensor 172 | 173 | Args: 174 | filepath: path to the image. 175 | height, width: resolution to resize the image to. Both must not be 176 | None for scaling to take place. 177 | value_scale_factor: value to scale image values with, default is 1.0 178 | resampling_mode: resampling method when resizing using PIL. Default 179 | is PIL.Image.BILINEAR 180 | target_aspect_ratio: if not None, will crop the image to match this 181 | aspect ratio. Default is None 182 | 183 | Returns: 184 | img: tensor with (optionally) scaled and resized image data. 185 | 186 | """ 187 | img = Image.open(filepath) 188 | 189 | if target_aspect_ratio: 190 | crop_image_to_target_ratio(img, target_aspect_ratio) 191 | 192 | # resize if both width and height are not none. 193 | if height is not None and width is not None: 194 | img_width, img_height = img.size 195 | # do we really need to resize? If not, skip. 196 | if (img_width, img_height) != (width, height): 197 | # warn if it doesn't make sense. 198 | if ((width > img_width or height > img_height) and 199 | not disable_warning): 200 | logger.warning( 201 | f"WARNING: target size ({width}, {height}) has a " 202 | f"dimension larger than input size ({img_width}, " 203 | f"{img_height}).") 204 | img = img.resize((width, height), resample=resampling_mode) 205 | 206 | img = TF.to_tensor(img).float() * value_scale_factor 207 | 208 | return img 209 | 210 | def crop_image_to_target_ratio(image, target_aspect_ratio=4.0/3.0): 211 | """ Crops an image to satisfy a target aspect ratio. """ 212 | 213 | actual_aspect_ratio = image.width/image.height 214 | 215 | if actual_aspect_ratio > target_aspect_ratio: 216 | # we should crop width 217 | new_width = image.height * target_aspect_ratio 218 | 219 | left = (image.width - new_width)/2 220 | top = 0 221 | right = (image.width + new_width)/2 222 | bottom = image.height 223 | 224 | # Crop the center of the image 225 | image = image.crop((left, top, right, bottom)) 226 | 227 | elif actual_aspect_ratio < target_aspect_ratio: 228 | # we should crop height 229 | new_height = image.width/target_aspect_ratio 230 | 231 | left = 0 232 | top = (image.height - new_height)/2 233 | right = image.width 234 | bottom = (image.height + new_height)/2 235 | 236 | # Crop the center of the image 237 | image = image.crop((left, top, right, bottom)) 238 | 239 | return image 240 | 241 | def cache_model_outputs( 242 | output_path, 243 | outputs, 244 | cur_data, 245 | src_data, 246 | batch_ind, 247 | batch_size, 248 | ): 249 | """ Helper function for model output during inference. """ 250 | 251 | for elem_ind in range(outputs["depth_pred_s0_b1hw"].shape[0]): 252 | if "frame_id_string" in cur_data: 253 | frame_id = cur_data["frame_id_string"][elem_ind] 254 | else: 255 | frame_id = (batch_ind * batch_size) + elem_ind 256 | frame_id = f"{str(frame_id):6d}" 257 | 258 | elem_filepath = os.path.join(output_path, f"{frame_id}.pickle") 259 | 260 | elem_output_dict = {} 261 | 262 | for key in outputs: 263 | if outputs[key] is not None: 264 | elem_output_dict[key] = outputs[key][elem_ind].unsqueeze(0) 265 | else: 266 | elem_output_dict[key] = None 267 | 268 | # include some auxiliary information 269 | elem_output_dict["K_full_depth_b44"] = cur_data[ 270 | "K_full_depth_b44" 271 | ][elem_ind].unsqueeze(0) 272 | elem_output_dict["K_s0_b44"] = cur_data[ 273 | "K_s0_b44" 274 | ][elem_ind].unsqueeze(0) 275 | 276 | elem_output_dict["frame_id"] = cur_data["frame_id_string"][elem_ind] 277 | elem_output_dict["src_ids"] = [] 278 | for src_id_list in src_data["frame_id_string"]: 279 | elem_output_dict["src_ids"].append(src_id_list[elem_ind]) 280 | 281 | with open(elem_filepath, 'wb') as handle: 282 | pickle.dump(elem_output_dict, handle) 283 | -------------------------------------------------------------------------------- /utils/geometry_utils.py: -------------------------------------------------------------------------------- 1 | import kornia 2 | import numpy as np 3 | import torch 4 | import torch.jit as jit 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | from utils.generic_utils import batched_trace 9 | 10 | 11 | @torch.jit.script 12 | def to_homogeneous(input_tensor: Tensor, dim: int = 0) -> Tensor: 13 | """ 14 | Converts tensor to homogeneous coordinates by adding ones to the specified 15 | dimension 16 | """ 17 | ones = torch.ones_like(input_tensor.select(dim, 0).unsqueeze(dim)) 18 | output_bkN = torch.cat([input_tensor, ones], dim=dim) 19 | return output_bkN 20 | 21 | 22 | class BackprojectDepth(jit.ScriptModule): 23 | """ 24 | Layer that projects points from 2D camera to 3D space. The 3D points are 25 | represented in homogeneous coordinates. 26 | """ 27 | 28 | def __init__(self, height: int, width: int): 29 | super().__init__() 30 | 31 | self.height = height 32 | self.width = width 33 | 34 | xx, yy = torch.meshgrid( 35 | torch.arange(self.width), 36 | torch.arange(self.height), 37 | indexing='xy', 38 | ) 39 | pix_coords_2hw = torch.stack((xx, yy), axis=0) + 0.5 40 | 41 | pix_coords_13N = to_homogeneous( 42 | pix_coords_2hw, 43 | dim=0, 44 | ).flatten(1).unsqueeze(0) 45 | 46 | # make these tensors into buffers so they are put on the correct GPU 47 | # automatically 48 | self.register_buffer("pix_coords_13N", pix_coords_13N) 49 | 50 | @jit.script_method 51 | def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor: 52 | """ 53 | Backprojects spatial points in 2D image space to world space using 54 | invK_b44 at the depths defined in depth_b1hw. 55 | """ 56 | cam_points_b3N = torch.matmul(invK_b44[:, :3, :3], self.pix_coords_13N) 57 | cam_points_b3N = depth_b1hw.flatten(start_dim=2) * cam_points_b3N 58 | cam_points_b4N = to_homogeneous(cam_points_b3N, dim=1) 59 | return cam_points_b4N 60 | 61 | 62 | class Project3D(jit.ScriptModule): 63 | """ 64 | Layer that projects 3D points into the 2D camera 65 | """ 66 | def __init__(self, eps: float = 1e-8): 67 | super().__init__() 68 | 69 | self.register_buffer("eps", torch.tensor(eps).view(1, 1, 1)) 70 | 71 | @jit.script_method 72 | def forward(self, points_b4N: Tensor, 73 | K_b44: Tensor, cam_T_world_b44: Tensor) -> Tensor: 74 | """ 75 | Projects spatial points in 3D world space to camera image space using 76 | the extrinsics matrix cam_T_world_b44 and intrinsics K_b44. 77 | """ 78 | P_b44 = K_b44 @ cam_T_world_b44 79 | 80 | cam_points_b3N = P_b44[:, :3] @ points_b4N 81 | 82 | # from Kornia and OpenCV, https://kornia.readthedocs.io/en/latest/_modules/kornia/geometry/conversions.html#convert_points_from_homogeneous 83 | mask = torch.abs(cam_points_b3N[:, 2:]) > self.eps 84 | depth_b1N = (cam_points_b3N[:, 2:] + self.eps) 85 | scale = torch.where(mask, 1.0 / depth_b1N, torch.tensor(1.0, device=depth_b1N.device)) 86 | 87 | pix_coords_b2N = cam_points_b3N[:, :2] * scale 88 | 89 | return torch.cat([pix_coords_b2N, depth_b1N], dim=1) 90 | 91 | 92 | class NormalGenerator(jit.ScriptModule): 93 | def __init__(self, height: int, width: int, 94 | smoothing_kernel_size: int=5, smoothing_kernel_std: float=2.0): 95 | """ 96 | Estimates normals from depth maps. 97 | """ 98 | super().__init__() 99 | self.height = height 100 | self.width = width 101 | 102 | self.backproject = BackprojectDepth(self.height, self.width) 103 | 104 | self.kernel_size = smoothing_kernel_size 105 | self.std = smoothing_kernel_std 106 | 107 | @jit.script_method 108 | def forward(self, depth_b1hw: Tensor, invK_b44: Tensor) -> Tensor: 109 | """ 110 | First smoothes incoming depth maps with a gaussian blur, backprojects 111 | those depth points into world space (see BackprojectDepth), estimates 112 | the spatial gradient at those points, and finally uses normalized cross 113 | correlation to estimate a normal vector at each location. 114 | 115 | """ 116 | depth_smooth_b1hw = kornia.filters.gaussian_blur2d( 117 | depth_b1hw, 118 | (self.kernel_size, self.kernel_size), 119 | (self.std, self.std), 120 | ) 121 | cam_points_b4N = self.backproject(depth_smooth_b1hw, invK_b44) 122 | cam_points_b3hw = cam_points_b4N[:, :3].view(-1, 3, self.height, self.width) 123 | 124 | gradients_b32hw = kornia.filters.spatial_gradient(cam_points_b3hw) 125 | 126 | return F.normalize( 127 | torch.cross( 128 | gradients_b32hw[:, :, 0], 129 | gradients_b32hw[:, :, 1], 130 | dim=1, 131 | ), 132 | dim=1, 133 | ) 134 | 135 | def get_angle_dif(matA_b33, matB_b33): 136 | """Computes the angle difference between two rotation matrices.""" 137 | trace = batched_trace(torch.matmul(matA_b33, 138 | matB_b33.transpose(dim0=1, dim1=2))) 139 | angle_diff_b = torch.arccos((trace - 1) / 2) 140 | 141 | return angle_diff_b 142 | 143 | def get_camera_rays( 144 | world_T_cam_b44, 145 | world_points_b3N, 146 | in_camera_frame, 147 | cam_T_world_b44=None, 148 | eps=1e-4, 149 | ): 150 | """ 151 | Computes camera rays for given camera data and points, optionally shifts 152 | rays to camera frame. 153 | """ 154 | 155 | if in_camera_frame: 156 | batch_size = world_points_b3N.shape[0] 157 | num_points = world_points_b3N.shape[2] 158 | world_points_b4N = torch.cat( 159 | [ 160 | world_points_b3N, 161 | torch.ones(batch_size, 1, num_points).to(world_points_b3N.device), 162 | ], 163 | 1, 164 | ) 165 | camera_points_b3N = torch.matmul(cam_T_world_b44[:, :3, :4], 166 | world_points_b4N) 167 | rays_b3N = camera_points_b3N 168 | else: 169 | rays_b3N = world_points_b3N - world_T_cam_b44[:, 0:3, 3][:, :, None].expand( 170 | world_points_b3N.shape 171 | ) 172 | 173 | rays_b3N = torch.nn.functional.normalize(rays_b3N, dim=1) 174 | 175 | return rays_b3N 176 | 177 | 178 | def pose_distance(pose_b44): 179 | """ 180 | DVMVS frame pose distance. 181 | """ 182 | 183 | R = pose_b44[:, :3, :3] 184 | t = pose_b44[:, :3, 3] 185 | R_trace = R.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) 186 | R_measure = torch.sqrt(2 * 187 | (1 - torch.minimum(torch.ones_like(R_trace)*3.0, R_trace) / 3)) 188 | t_measure = torch.norm(t, dim=1) 189 | combined_measure = torch.sqrt(t_measure ** 2 + R_measure ** 2) 190 | 191 | return combined_measure, R_measure, t_measure 192 | 193 | def qvec2rotmat(qvec): 194 | """ 195 | Quaternion to 3x3 rotation matrix. 196 | """ 197 | return np.array([ 198 | [ 199 | 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 200 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 201 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2] 202 | ], [ 203 | 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 204 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 205 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1] 206 | ], [ 207 | 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 208 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 209 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2 210 | ] 211 | ]) 212 | 213 | def rotx(t): 214 | """ 215 | 3D Rotation about the x-axis. 216 | """ 217 | c = np.cos(t) 218 | s = np.sin(t) 219 | return np.array([[1, 0, 0], 220 | [0, c, -s], 221 | [0, s, c]]) 222 | 223 | def roty(t): 224 | """ 225 | 3D Rotation about the y-axis. 226 | """ 227 | c = np.cos(t) 228 | s = np.sin(t) 229 | return np.array([[c, 0, s], 230 | [0, 1, 0], 231 | [-s, 0, c]]) 232 | 233 | def rotz(t): 234 | """ 235 | 3D Rotation about the z-axis. 236 | """ 237 | c = np.cos(t) 238 | s = np.sin(t) 239 | return np.array([[c, -s, 0], 240 | [s, c, 0], 241 | [0, 0, 1]]) 242 | -------------------------------------------------------------------------------- /utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def compute_depth_metrics(gt, pred, mult_a=False): 8 | """ 9 | Computes error metrics between predicted and ground truth depths 10 | """ 11 | 12 | thresh = torch.max((gt / pred), (pred / gt)) 13 | a_dict = {} 14 | a_dict["a5"] = (thresh < 1.05 ).float().mean() 15 | a_dict["a10"] = (thresh < 1.10 ).float().mean() 16 | a_dict["a25"] = (thresh < 1.25 ).float().mean() 17 | 18 | a_dict["a0"] = (thresh < 1.10 ).float().mean() 19 | a_dict["a1"] = (thresh < 1.25 ).float().mean() 20 | a_dict["a2"] = (thresh < 1.25 ** 2).float().mean() 21 | a_dict["a3"] = (thresh < 1.25 ** 3).float().mean() 22 | 23 | 24 | if mult_a: 25 | for key in a_dict: 26 | a_dict[key] = a_dict[key]*100 27 | 28 | rmse = (gt - pred) ** 2 29 | rmse = torch.sqrt(rmse.mean()) 30 | 31 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 32 | rmse_log = torch.sqrt(rmse_log.mean()) 33 | 34 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 35 | 36 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 37 | 38 | abs_diff = torch.mean(torch.abs(gt - pred)) 39 | 40 | metrics_dict = { 41 | "abs_diff": abs_diff, 42 | "abs_rel": abs_rel, 43 | "sq_rel": sq_rel, 44 | "rmse": rmse, 45 | "rmse_log": rmse_log, 46 | } 47 | metrics_dict.update(a_dict) 48 | 49 | return metrics_dict 50 | 51 | def compute_depth_metrics_batched(gt_bN, pred_bN, valid_masks_bN, mult_a=False): 52 | """ 53 | Computes error metrics between predicted and ground truth depths, 54 | batched. Abuses nan behavior in torch. 55 | """ 56 | 57 | gt_bN = gt_bN.clone() 58 | pred_bN = pred_bN.clone() 59 | 60 | gt_bN[~valid_masks_bN] = torch.nan 61 | pred_bN[~valid_masks_bN] = torch.nan 62 | 63 | thresh_bN = torch.max(torch.stack([(gt_bN / pred_bN), (pred_bN / gt_bN)], 64 | dim=2), dim=2)[0] 65 | a_dict = {} 66 | 67 | a_val = (thresh_bN < (1.0+0.05) ).float() 68 | a_val[~valid_masks_bN] = torch.nan 69 | a_dict[f"a5"] = torch.nanmean(a_val, dim=1) 70 | 71 | a_val = (thresh_bN < (1.0+0.10) ).float() 72 | a_val[~valid_masks_bN] = torch.nan 73 | a_dict[f"a10"] = torch.nanmean(a_val, dim=1) 74 | 75 | a_val = (thresh_bN < (1.0+0.25) ).float() 76 | a_val[~valid_masks_bN] = torch.nan 77 | a_dict[f"a25"] = torch.nanmean(a_val, dim=1) 78 | 79 | a_val = (thresh_bN < (1.0+0.10) ).float() 80 | a_val[~valid_masks_bN] = torch.nan 81 | a_dict[f"a0"] = torch.nanmean(a_val, dim=1) 82 | 83 | a_val = (thresh_bN < (1.0+0.25) ).float() 84 | a_val[~valid_masks_bN] = torch.nan 85 | a_dict[f"a1"] = torch.nanmean(a_val, dim=1) 86 | 87 | a_val = (thresh_bN < (1.0+0.25) ** 2).float() 88 | a_val[~valid_masks_bN] = torch.nan 89 | a_dict[f"a2"] = torch.nanmean(a_val, dim=1) 90 | 91 | a_val = (thresh_bN < (1.0+0.25) ** 3).float() 92 | a_val[~valid_masks_bN] = torch.nan 93 | a_dict[f"a3"] = torch.nanmean(a_val, dim=1) 94 | 95 | if mult_a: 96 | for key in a_dict: 97 | a_dict[key] = a_dict[key]*100 98 | 99 | rmse_bN = (gt_bN - pred_bN) ** 2 100 | rmse_b = torch.sqrt(torch.nanmean(rmse_bN, dim=1)) 101 | 102 | rmse_log_bN = (torch.log(gt_bN) - torch.log(pred_bN)) ** 2 103 | rmse_log_b = torch.sqrt(torch.nanmean(rmse_log_bN, dim=1)) 104 | 105 | abs_rel_b = torch.nanmean(torch.abs(gt_bN - pred_bN) / gt_bN, dim=1) 106 | 107 | sq_rel_b = torch.nanmean((gt_bN - pred_bN) ** 2 / gt_bN, dim=1) 108 | 109 | abs_diff_b = torch.nanmean(torch.abs(gt_bN - pred_bN), dim=1) 110 | 111 | metrics_dict = { 112 | "abs_diff": abs_diff_b, 113 | "abs_rel": abs_rel_b, 114 | "sq_rel": sq_rel_b, 115 | "rmse": rmse_b, 116 | "rmse_log": rmse_log_b, 117 | } 118 | metrics_dict.update(a_dict) 119 | 120 | return metrics_dict 121 | 122 | class ResultsAverager(): 123 | """ 124 | Helper class for stable averaging of metrics across frames and scenes. 125 | """ 126 | def __init__(self, exp_name, metrics_name): 127 | """ 128 | Args: 129 | exp_name: name of the specific experiment. 130 | metrics_name: type of metrics. 131 | """ 132 | self.exp_name = exp_name 133 | self.metrics_name = metrics_name 134 | 135 | self.elem_metrics_list = [] 136 | self.running_metrics = None 137 | self.running_count = 0 138 | 139 | self.final_computed_average = None 140 | 141 | def update_results(self, elem_metrics): 142 | """ 143 | Adds elem_matrix to elem_metrics_list. Updates running_metrics with 144 | incomming metrics to keep a running average. 145 | 146 | running_metrics are cheap to compute but not totally stable. 147 | """ 148 | 149 | self.elem_metrics_list.append(elem_metrics.copy()) 150 | 151 | if self.running_metrics is None: 152 | self.running_metrics = elem_metrics.copy() 153 | else: 154 | for key in list(elem_metrics.keys()): 155 | self.running_metrics[key] = ( 156 | self.running_metrics[key] * 157 | self.running_count 158 | + elem_metrics[key] 159 | ) / (self.running_count + 1) 160 | 161 | self.running_count += 1 162 | 163 | def print_sheets_friendly( 164 | self, print_exp_name=True, 165 | include_metrics_names=False, 166 | print_running_metrics=True, 167 | ): 168 | """ 169 | Print for easy sheets copy/paste. 170 | Args: 171 | print_exp_name: should we print the experiment name? 172 | include_metrics_names: should we print a row for metric names? 173 | print_running_metrics: should we print running metrics or the 174 | final average? 175 | """ 176 | 177 | if print_exp_name: 178 | print(f"{self.exp_name}, {self.metrics_name}") 179 | 180 | if print_running_metrics: 181 | metrics_to_print = self.running_metrics 182 | else: 183 | metrics_to_print = self.final_metrics 184 | 185 | if len(self.elem_metrics_list) == 0: 186 | print("WARNING: No valid metrics to print.") 187 | return 188 | 189 | metric_names_row = "" 190 | metrics_row = "" 191 | for k, v in metrics_to_print.items(): 192 | metric_names_row += f"{k:8} " 193 | metric_string = f"{v:.4f}," 194 | metrics_row += f"{metric_string:8} " 195 | 196 | if include_metrics_names: 197 | print(metric_names_row) 198 | print(metrics_row) 199 | 200 | def output_json(self, filepath, print_running_metrics=False): 201 | """ 202 | Outputs metrics to a json file. 203 | Args: 204 | filepath: file path where we should save the file. 205 | print_running_metrics: should we print running metrics or the 206 | final average? 207 | """ 208 | scores_dict = {} 209 | scores_dict["exp_name"] = self.exp_name 210 | scores_dict["metrics_type"] = self.metrics_name 211 | 212 | scores_dict["scores"] = {} 213 | 214 | if print_running_metrics: 215 | metrics_to_use = self.running_metrics 216 | else: 217 | metrics_to_use = self.final_metrics 218 | 219 | if len(self.elem_metrics_list) == 0: 220 | print("WARNING: No valid metrics will be output.") 221 | 222 | metric_names_row = "" 223 | metrics_row = "" 224 | for k, v in metrics_to_use.items(): 225 | metric_names_row += f"{k:8} " 226 | metric_string = f"{v:.4f}," 227 | metrics_row += f"{metric_string:8} " 228 | scores_dict["scores"][k] = float(v) 229 | 230 | scores_dict["metrics_string"] = metric_names_row 231 | scores_dict["scores_string"] = metrics_row 232 | 233 | with open(filepath, "w") as file: 234 | json.dump(scores_dict, file, indent=4) 235 | 236 | def pretty_print_results( 237 | self, 238 | print_exp_name=True, 239 | print_running_metrics=True 240 | ): 241 | """ 242 | Pretty print for easy(ier) reading 243 | Args: 244 | print_exp_name: should we print the experiment name? 245 | include_metrics_names: should we print a row for metric names? 246 | print_running_metrics: should we print running metrics or the 247 | final average? 248 | """ 249 | if print_running_metrics: 250 | metrics_to_print = self.running_metrics 251 | else: 252 | metrics_to_print = self.final_metrics 253 | 254 | if len(self.elem_metrics_list) == 0: 255 | print("WARNING: No valid metrics to print.") 256 | return 257 | 258 | if print_exp_name: 259 | print(f"{self.exp_name}, {self.metrics_name}") 260 | for k, v in metrics_to_print.items(): 261 | print(f"{k:8}: {v:.4f}") 262 | 263 | def compute_final_average(self, ignore_nans=False): 264 | """ 265 | Computes final a final average on the metrics element list using 266 | numpy. 267 | 268 | This should be more accurate than running metrics as it's a single 269 | average vs multiple high level multiplications and divisions. 270 | 271 | Args: 272 | ignore_nans: ignore nans in the results and run using nanmean. 273 | """ 274 | 275 | self.final_metrics = {} 276 | 277 | if len(self.elem_metrics_list) == 0: 278 | print("WARNING: no valid entry to average!") 279 | return 280 | 281 | for key in list(self.running_metrics.keys()): 282 | values = [] 283 | for element in self.elem_metrics_list: 284 | if torch.is_tensor(element[key]): 285 | values.append(element[key].cpu().numpy()) 286 | else: 287 | values.append(element[key]) 288 | 289 | if ignore_nans: 290 | mean_value = np.nanmean(np.array(values)) 291 | else: 292 | mean_value = np.array(values).mean() 293 | self.final_metrics[key] = mean_value 294 | -------------------------------------------------------------------------------- /utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import moviepy.editor as mpy 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | 9 | from utils.generic_utils import reverse_imagenet_normalize 10 | 11 | 12 | def colormap_image( 13 | image_1hw, 14 | mask_1hw=None, 15 | invalid_color=(0.0, 0, 0.0), 16 | flip=True, 17 | vmin=None, 18 | vmax=None, 19 | return_vminvmax=False, 20 | colormap="turbo", 21 | ): 22 | """ 23 | Colormaps a one channel tensor using a matplotlib colormap. 24 | 25 | Args: 26 | image_1hw: the tensor to colomap. 27 | mask_1hw: an optional float mask where 1.0 donates valid pixels. 28 | colormap: the colormap to use. Default is turbo. 29 | invalid_color: the color to use for invalid pixels. 30 | flip: should we flip the colormap? True by default. 31 | vmin: if provided uses this as the minimum when normalizing the tensor. 32 | vmax: if provided uses this as the maximum when normalizing the tensor. 33 | When either of vmin or vmax are None, they are computed from the 34 | tensor. 35 | return_vminvmax: when true, returns vmin and vmax. 36 | 37 | Returns: 38 | image_cm_3hw: image of the colormapped tensor. 39 | vmin, vmax: returned when return_vminvmax is true. 40 | 41 | 42 | """ 43 | valid_vals = image_1hw if mask_1hw is None else image_1hw[mask_1hw.bool()] 44 | if vmin is None: 45 | vmin = valid_vals.min() 46 | if vmax is None: 47 | vmax = valid_vals.max() 48 | 49 | cmap = torch.Tensor( 50 | plt.cm.get_cmap(colormap)( 51 | torch.linspace(0, 1, 256) 52 | )[:, :3] 53 | ).to(image_1hw.device) 54 | if flip: 55 | cmap = torch.flip(cmap, (0,)) 56 | 57 | h, w = image_1hw.shape[1:] 58 | 59 | image_norm_1hw = (image_1hw - vmin) / (vmax - vmin) 60 | image_int_1hw = (torch.clamp(image_norm_1hw * 255, 0, 255)).byte().long() 61 | 62 | image_cm_3hw = cmap[image_int_1hw.flatten(start_dim=1) 63 | ].permute([0, 2, 1]).view([-1, h, w]) 64 | 65 | if mask_1hw is not None: 66 | invalid_color = torch.Tensor(invalid_color).view(3, 1, 1).to(image_1hw.device) 67 | image_cm_3hw = image_cm_3hw * mask_1hw + invalid_color * (1 - mask_1hw) 68 | 69 | if return_vminvmax: 70 | return image_cm_3hw, vmin, vmax 71 | else: 72 | return image_cm_3hw 73 | 74 | def save_viz_video_frames(frame_list, path, fps=30): 75 | """ 76 | Saves a video file of numpy RGB frames in frame_list. 77 | """ 78 | clip = mpy.ImageSequenceClip(frame_list, fps=fps) 79 | clip.write_videofile(path, verbose=False, logger=None) 80 | 81 | return 82 | 83 | 84 | def quick_viz_export( 85 | output_path, 86 | outputs, 87 | cur_data, 88 | batch_ind, 89 | valid_mask_b, 90 | batch_size): 91 | """ Helper function for quickly exporting depth maps during inference. """ 92 | 93 | if valid_mask_b.sum() == 0: 94 | batch_vmin = 0.0 95 | batch_vmax = 5.0 96 | else: 97 | batch_vmin = cur_data["full_res_depth_b1hw"][valid_mask_b].min() 98 | batch_vmax = cur_data["full_res_depth_b1hw"][valid_mask_b].max() 99 | 100 | if batch_vmax == batch_vmin: 101 | batch_vmin = 0.0 102 | batch_vmax = 5.0 103 | 104 | for elem_ind in range(outputs["depth_pred_s0_b1hw"].shape[0]): 105 | if "frame_id_string" in cur_data: 106 | frame_id = cur_data["frame_id_string"][elem_ind] 107 | else: 108 | frame_id = (batch_ind * batch_size) + elem_ind 109 | frame_id = f"{str(frame_id):6d}" 110 | 111 | # check for valid depths from dataloader 112 | if valid_mask_b[elem_ind].sum() == 0: 113 | sample_vmin = 0.0 114 | sample_vmax = 0.0 115 | else: 116 | # these will be the same when the depth map is all ones. 117 | sample_vmin = cur_data["full_res_depth_b1hw"][elem_ind][valid_mask_b[elem_ind]].min() 118 | sample_vmax = cur_data["full_res_depth_b1hw"][elem_ind][valid_mask_b[elem_ind]].max() 119 | 120 | # if no meaningful gt depth in dataloader, don't viz gt and 121 | # set vmin/max to default 122 | if sample_vmax != sample_vmin: 123 | full_res_depth_1hw = cur_data["full_res_depth_b1hw"][elem_ind] 124 | 125 | full_res_depth_3hw = colormap_image( 126 | full_res_depth_1hw, 127 | vmin=batch_vmin, vmax=batch_vmax 128 | ) 129 | 130 | full_res_depth_hw3 = np.uint8( 131 | full_res_depth_3hw.permute(1,2,0 132 | ).cpu().detach().numpy() * 255 133 | ) 134 | Image.fromarray(full_res_depth_hw3).save( 135 | os.path.join(output_path, 136 | f"{frame_id}_gt_depth.png") 137 | ) 138 | 139 | lowest_cost_3hw = colormap_image( 140 | outputs["lowest_cost_bhw"][elem_ind].unsqueeze(0), 141 | vmin=batch_vmin, vmax=batch_vmax 142 | ) 143 | pil_image = Image.fromarray( 144 | np.uint8( 145 | lowest_cost_3hw.permute(1,2,0 146 | ).cpu().detach().numpy() * 255) 147 | ) 148 | pil_image.save(os.path.join(output_path, 149 | f"{frame_id}_lowest_cost_pred.png")) 150 | 151 | depth_3hw = colormap_image( 152 | outputs["depth_pred_s0_b1hw"][elem_ind], 153 | vmin=batch_vmin, vmax=batch_vmax) 154 | pil_image = Image.fromarray( 155 | np.uint8(depth_3hw.permute(1,2,0 156 | ).cpu().detach().numpy() * 255) 157 | ) 158 | 159 | pil_image.save(os.path.join(output_path, f"{frame_id}_pred_depth.png")) 160 | 161 | main_color_3hw = cur_data["high_res_color_b3hw"][elem_ind] 162 | main_color_3hw = reverse_imagenet_normalize(main_color_3hw) 163 | pil_image = Image.fromarray( 164 | np.uint8(main_color_3hw.permute(1,2,0 165 | ).cpu().detach().numpy() * 255) 166 | ) 167 | pil_image.save(os.path.join(output_path, f"{frame_id}_color.png")) -------------------------------------------------------------------------------- /visualization_scripts/generate_gt_min_max_cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loops through an MVS dataset, and extracts smoothed max and min values 3 | across a scene. Saves those values to disk. These limits can then be used 4 | for visualization. 5 | 6 | Example command: 7 | python visualization_scripts/generate_gt_min_max_cache.py \ 8 | --output_base_path OUTPUT_PATH \ 9 | --data_config configs/data/scannet_dense.yaml; 10 | 11 | """ 12 | 13 | import sys 14 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 15 | import os 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | import options 20 | import scipy 21 | import torch 22 | from tqdm import tqdm 23 | from utils.dataset_utils import get_dataset 24 | 25 | 26 | def main(opts): 27 | 28 | print("Setting batch size to 1.") 29 | opts.batch_size = 1 30 | 31 | # by default, we skip every 12 frames. We're looking for a rough average, 32 | # so this saves a lot of time. 33 | if opts.skip_frames is None: 34 | print("Setting skip_frames size to 12.") 35 | opts.skip_frames = 12 36 | 37 | # get dataset 38 | dataset_class, scans = get_dataset(opts.dataset, 39 | opts.dataset_scan_split_file, opts.single_debug_scan_id) 40 | 41 | # will save limits per scene here 42 | gt_viz_path = os.path.join(opts.output_base_path, "gt_min_max", opts.dataset) 43 | Path(gt_viz_path).mkdir(parents=True, exist_ok=True) 44 | 45 | print(f"".center(80, "#")) 46 | print(f" Computing GT min/max.") 47 | print(f" Output directory: {gt_viz_path} ".center(80, "#")) 48 | print(f"".center(80, "#")) 49 | print("") 50 | 51 | with torch.inference_mode(): 52 | for scan in tqdm(scans): 53 | 54 | # set up dataset with current scan 55 | dataset = dataset_class( 56 | opts.dataset_path, 57 | split=opts.split, 58 | mv_tuple_file_suffix=opts.mv_tuple_file_suffix, 59 | limit_to_scan_id=scan, 60 | include_full_res_depth=False, 61 | tuple_info_file_location=opts.tuple_info_file_location, 62 | num_images_in_tuple=None, 63 | shuffle_tuple=opts.shuffle_tuple, 64 | include_high_res_color=False, 65 | pass_frame_id=True, 66 | include_full_depth_K=False, 67 | skip_frames=opts.skip_frames, 68 | skip_to_frame=opts.skip_to_frame, 69 | ) 70 | 71 | dataloader = torch.utils.data.DataLoader( 72 | dataset, 73 | batch_size=opts.batch_size, 74 | shuffle=False, 75 | num_workers=opts.num_workers, 76 | drop_last=False, 77 | ) 78 | 79 | # set inits 80 | vmin = torch.inf 81 | vmax = 0 82 | mins = [] 83 | maxs = [] 84 | for _, batch in enumerate(tqdm(dataloader)): 85 | cur_data, _ = batch 86 | 87 | depth = cur_data["depth_b1hw"].cuda()[ 88 | cur_data["mask_b_b1hw"].cuda()] 89 | # get values at 98% and 2% 90 | maxs.append(torch.quantile(depth, 0.98).squeeze().cpu()) 91 | mins.append(torch.quantile(depth, 0.02).squeeze().cpu()) 92 | 93 | # gaussian filter all values to remove any outliers, then take 94 | # min/max 95 | maxs = scipy.ndimage.gaussian_filter1d(np.array(maxs), sigma=1) 96 | vmax = np.max(maxs) 97 | 98 | mins = scipy.ndimage.gaussian_filter1d(np.array(mins), sigma=1) 99 | vmin = np.min(mins) 100 | 101 | # print and save limits to file. 102 | limits = [vmin, vmax] 103 | print(scan, limits) 104 | 105 | gt_min_max_path = os.path.join(gt_viz_path, f"{scan}.txt") 106 | with open(gt_min_max_path, 'w') as handle: 107 | handle.write(f"{vmin} {vmax}") 108 | 109 | if __name__ == '__main__': 110 | # don't need grad for test. 111 | torch.set_grad_enabled(False) 112 | 113 | # get an instance of options and load it with config file(s) and cli args. 114 | option_handler = options.OptionsHandler() 115 | option_handler.parse_and_merge_options() 116 | option_handler.pretty_print_options() 117 | print("\n") 118 | opts = option_handler.options 119 | 120 | # if no GPUs are available for us then, use the 32 bit on CPU 121 | if opts.gpus == 0: 122 | print("Setting precision to 32 bits since --gpus is set to 0.") 123 | opts.precision = 32 124 | 125 | main(opts) 126 | -------------------------------------------------------------------------------- /visualization_scripts/load_meshes_and_include_normals.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reads plys defined with pattern in scans_path_pattern. First computes 3 | normals for each scan using open3d, then outputs each scan with normal 4 | information visualized as vertex colors. 5 | 6 | Example command: 7 | python ./visualization_scripts/load_meshes_and_include_normals.py \ 8 | --input_path simple_recon_output/HERO_MODEL/scannet/default/meshes/0.04_3.0_open3d_color/ \ 9 | --output_path simple_recon_output/HERO_MODEL/scannet/default/meshes/0.04_3.0_open3d_color_normals/; 10 | """ 11 | 12 | import argparse 13 | import glob 14 | import os 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | import open3d as o3d 19 | from tqdm import tqdm 20 | 21 | parser = argparse.ArgumentParser(description='mesh normal visualizer.') 22 | 23 | parser.add_argument('--scannet_scans_path_pattern', required=False, default=None, 24 | help="Input example string pattern for one scan file. " 25 | "For ScanNet it should look something like path_to_scans/SCAN_NAME.ply." 26 | "SCAN_NAME will be replaced with each scan's name.") 27 | parser.add_argument('--input_path', required=False, default=None) 28 | parser.add_argument('--output_path', required=True) 29 | 30 | args = parser.parse_args() 31 | 32 | Path(args.output_path).mkdir(exist_ok=True, parents=True) 33 | 34 | 35 | if args.scannet_scans_path_pattern: 36 | scans = ['scene0707_00', 'scene0708_00', 'scene0709_00', 'scene0710_00', 37 | 'scene0711_00', 'scene0712_00', 'scene0713_00', 'scene0714_00', 38 | 'scene0715_00', 'scene0716_00', 'scene0717_00', 'scene0718_00', 39 | 'scene0719_00', 'scene0720_00', 'scene0721_00', 'scene0722_00', 40 | 'scene0723_00', 'scene0724_00', 'scene0725_00', 'scene0726_00', 41 | 'scene0727_00', 'scene0728_00', 'scene0729_00', 'scene0730_00', 42 | 'scene0731_00', 'scene0732_00', 'scene0733_00', 'scene0734_00', 43 | 'scene0735_00', 'scene0736_00', 'scene0737_00', 'scene0738_00', 44 | 'scene0739_00', 'scene0740_00', 'scene0741_00', 'scene0742_00', 45 | 'scene0743_00', 'scene0744_00', 'scene0745_00', 'scene0746_00', 46 | 'scene0747_00', 'scene0748_00', 'scene0749_00', 'scene0750_00', 47 | 'scene0751_00', 'scene0752_00', 'scene0753_00', 'scene0754_00', 48 | 'scene0755_00', 'scene0756_00', 'scene0757_00', 'scene0758_00', 49 | 'scene0759_00', 'scene0760_00', 'scene0761_00', 'scene0762_00', 50 | 'scene0763_00', 'scene0764_00', 'scene0765_00', 'scene0766_00', 51 | 'scene0767_00', 'scene0768_00', 'scene0769_00', 'scene0770_00', 52 | 'scene0771_00', 'scene0772_00', 'scene0773_00', 'scene0774_00', 53 | 'scene0775_00', 'scene0776_00', 'scene0777_00', 'scene0778_00', 54 | 'scene0779_00', 'scene0780_00', 'scene0781_00', 'scene0782_00', 55 | 'scene0783_00', 'scene0784_00', 'scene0785_00', 'scene0786_00', 56 | 'scene0787_00', 'scene0788_00', 'scene0789_00', 'scene0790_00', 57 | 'scene0791_00', 'scene0792_00', 'scene0793_00', 'scene0794_00', 58 | 'scene0795_00', 'scene0796_00', 'scene0797_00', 'scene0798_00', 59 | 'scene0799_00', 'scene0800_00', 'scene0801_00', 'scene0802_00', 60 | 'scene0803_00', 'scene0804_00', 'scene0805_00', 'scene0806_00'] 61 | 62 | mesh_paths = [args.scannet_scans_path_pattern.replace("SCAN_NAME", scan) 63 | for scan in scans] 64 | 65 | elif args.input_path: 66 | os.chdir(args.input_path) 67 | mesh_paths = glob.glob("*.ply") 68 | 69 | else: 70 | raise Exception("No valid input path found.") 71 | 72 | 73 | for path_to_mesh in tqdm(mesh_paths): 74 | # read mesh 75 | mesh = o3d.io.read_triangle_mesh(path_to_mesh) 76 | 77 | # compute normals and include them as RGB information 78 | mesh.compute_vertex_normals() 79 | mesh.vertex_colors = o3d.cuda.pybind.utility.Vector3dVector( 80 | ((1+np.asarray(mesh.vertex_normals))*0.5)) 81 | 82 | # write to disk 83 | mesh_name = path_to_mesh.split("/")[-1].split(".")[0] 84 | output_path = os.path.join(args.output_path, f"{mesh_name}.ply") 85 | 86 | o3d.io.write_triangle_mesh(output_path, mesh) 87 | -------------------------------------------------------------------------------- /weights/strip_checkpoint.py: -------------------------------------------------------------------------------- 1 | # for importing options on checkpoint load. 2 | import sys 3 | sys.path.append("/".join(sys.path[0].split("/")[:-1])) 4 | 5 | import torch 6 | import argparse 7 | 8 | 9 | parser = argparse.ArgumentParser(description="Script for " 10 | "removing training state weighs " 11 | "from a checkpoint.") 12 | 13 | parser.add_argument('--heavy_checkpoint_path') 14 | parser.add_argument('--output_checkpoint_path') 15 | 16 | args = parser.parse_args() 17 | 18 | checkpoint = torch.load(args.heavy_checkpoint_path) 19 | 20 | keys_to_store = ["state_dict", 'hparams_name', 'hyper_parameters'] 21 | 22 | new_checkpoint = {} 23 | for key in keys_to_store: 24 | new_checkpoint[key] = checkpoint[key] 25 | 26 | torch.save(new_checkpoint, args.output_checkpoint_path) --------------------------------------------------------------------------------