├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── New.png └── demo.gif ├── cfgs ├── 7scenes.yaml ├── Cambridge.yaml ├── default.yaml └── indoor6.yaml ├── datasets ├── _base.py ├── augmentation.py ├── data_collection.py ├── dataloader.py └── test.py ├── detectors ├── line2d │ ├── DeepLSD │ │ └── deeplsd.py │ ├── LSD │ │ └── lsd.py │ ├── linebase_detector.py │ └── register_linedetector.py └── point2d │ ├── SuperPoint │ └── superpoint.py │ └── register_pointdetector.py ├── models ├── base_model.py ├── pipeline.py ├── pl2map.py ├── util.py └── util_learner.py ├── prepare_scripts ├── cambridge.sh ├── download_pre_trained_models.sh ├── indoor6.sh └── seven_scenes.sh ├── requirements.txt ├── runners ├── eval.py ├── evaluator.py ├── train.py └── trainer.py └── util ├── config.py ├── help_evaluation.py ├── io.py ├── logger.py ├── pose_estimator.py ├── read_write_model.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/gt_3Dmodels/ 2 | datasets/imgs_datasets/ 3 | detectors/point2d/SuperPoint/weights/ 4 | visualization/ 5 | visualization_all/ 6 | logs/ 7 | pre_train_logs/ 8 | experiments/ 9 | __pycache__/ 10 | train_test_datasets/ 11 | train_test_datasets_origin/ 12 | 13 | *.npy 14 | *.png 15 | *.jpg 16 | *.pyc 17 | *.tar 18 | *.obj 19 | *.pth 20 | *.gif 21 | *.mp4 22 | *.h5 23 | *.swp 24 | *.zip 25 | *.tar.gz 26 | *.th 27 | *.so 28 | 29 | .vscode -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/DeepLSD"] 2 | path = third_party/DeepLSD 3 | url = https://github.com/cvg/DeepLSD.git 4 | [submodule "third_party/pytlsd"] 5 | path = third_party/pytlsd 6 | url = https://github.com/iago-suarez/pytlsd.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 Thuan 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point-Line to Map Regresssion for Camera Relocalization 2 | #### [Project Page](https://thpjp.github.io/pl2map/) | [PL2Map](https://arxiv.org/abs/2402.18011) | [PL2Map++](https://arxiv.org/pdf/2502.20814)(Code for PL2Map++ is coming soon) 3 | ## Introduction 4 | 5 | ![demo_vid](assets/demo.gif) 6 | 7 | We introduce a lightweight neural network for visual localization that efficiently represents both 3D points and lines. Specifically, we use a single transformer block to convert line features into distinctive point-like descriptors. These features are then refined through self- and cross-attention in a graph-based framework before 3D map regression using simple MLPs. Our method outperforms [Hloc](https://github.com/cvg/Hierarchical-Localization) and [Limap](https://github.com/cvg/limap) in small-scale indoor localization and achieves the best results in outdoor settings, setting a new benchmark for learning-based approaches. It also operates in real-time at ~16 FPS, compared to [Limap](https://github.com/cvg/limap)’s ~0.03 FPS, while requiring only lightweight network weights of 33MB instead of [Limap](https://github.com/cvg/limap)’s multi-GB memory footprint. 8 | 9 | --- 10 | ## Papers 11 | **Improved 3D Point-Line Mapping Regression for Camera Relocalization**![new](assets/New.png) 12 | Bach-Thuan Bui, Huy-Hoang Bui, Yasuyuki Fujii, Dinh-Tuan Tran, and Joo-Ho Lee. 13 | arXiv preprint arXiv:2502.20814, 2025. 14 | [pdf](https://arxiv.org/pdf/2502.20814) 15 | 16 | **Representing 3D sparse map points and lines for camera relocalization** 17 | Bach-Thuan Bui, Huy-Hoang Bui, Dinh-Tuan Tran, and Joo-Ho Lee. 18 | IEEE/RSJ International Conference on Intelligent Robots and Systems (**IROS**), 2024. 19 | [pdf](https://arxiv.org/abs/2402.18011) 20 | 21 | 22 | ## Installation 23 | Python 3.9 + required packages 24 | ``` 25 | git clone https://github.com/ais-lab/pl2map.git 26 | cd pl2map 27 | git submodule update --init --recursive 28 | conda create --name pl2map python=3.9 29 | conda activate pl2map 30 | # Refer to https://pytorch.org/get-started/previous-versions/ to install pytorch compatible with your CUDA 31 | python -m pip install torch==1.12.0 torchvision==0.13.0 32 | python -m pip install -r requirements.txt 33 | ``` 34 | ## Supported datasets 35 | - [Microsoft 7scenes](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/) 36 | - [Cambridge Landmarks](https://www.repository.cam.ac.uk/handle/1810/251342/) 37 | - [Indoor-6](https://github.com/microsoft/SceneLandmarkLocalization) 38 | 39 | Please run the provided scripts to prepare and download the data which has been preprocessed by running: 40 | 41 | 7scenes 42 | ``` 43 | ./prepare_scripts/seven_scenes.sh 44 | ``` 45 | Cambridge Landmarks 46 | ``` 47 | ./prepare_scripts/cambridge.sh 48 | ``` 49 | Indoor-6 50 | ``` 51 | ./prepare_scripts/indoor6.sh 52 | ``` 53 | 54 | ## Evaluation with pre-trained models 55 | Please download the pre-trained models by running: 56 | ``` 57 | ./prepare_scripts/download_pre_trained_models.sh 58 | ``` 59 | For example, to evaluate KingsCollege scene: 60 | ``` 61 | python runners/eval.py --dataset Cambridge --scene KingsCollege -expv pl2map 62 | ``` 63 | 64 | ## Training 65 | ``` 66 | python runners/train.py --dataset Cambridge --scene KingsCollege -expv pl2map_test 67 | ``` 68 | 69 | ## Supported detectors 70 | ### Lines 71 | - [LSD](https://github.com/iago-suarez/pytlsd) 72 | - [DeepLSD](https://github.com/cvg/DeepLSD) 73 | ### Points 74 | - [Superpoint](https://github.com/rpautrat/SuperPoint) 75 | 76 | 77 | ## Citation 78 | If you use this code in your project, please consider citing the following paper: 79 | ```bibtex 80 | @article{bui2024pl2map, 81 | title={Representing 3D sparse map points and lines for camera relocalization}, 82 | author={Bui, Bach-Thuan and Bui, Huy-Hoang and Tran, Dinh-Tuan and Lee, Joo-Ho}, 83 | booktitle={2024 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, 84 | year={2024} 85 | } 86 | @article{bui2025improved, 87 | title={Improved 3D Point-Line Mapping Regression for Camera Relocalization}, 88 | author={Bui, Bach-Thuan and Bui, Huy-Hoang and Fujii, Yasuyuki and Tran, Dinh-Tuan and Lee, Joo-Ho}, 89 | journal={arXiv preprint arXiv:2502.20814}, 90 | year={2025} 91 | } 92 | ``` 93 | This code builds on previous camera relocalization pipeline, namely [D2S](https://github.com/ais-lab/d2s), please consider citing: 94 | ```bibtex 95 | @article{bui2024d2s, 96 | title={D2S: Representing sparse descriptors and 3D coordinates for camera relocalization}, 97 | author={Bui, Bach-Thuan and Bui, Huy-Hoang and Tran, Dinh-Tuan and Lee, Joo-Ho}, 98 | journal={IEEE Robotics and Automation Letters}, 99 | year={2024} 100 | } 101 | ``` 102 | 103 | ## Acknowledgement 104 | This code is built based on [Limap](https://github.com/cvg/limap), and [LineTR](https://github.com/yosungho/LineTR). We thank the authors for their useful source code. 105 | 106 | 107 | -------------------------------------------------------------------------------- /assets/New.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ais-lab/pl2map/8d1a9289bd9505647e2fbdaf4719310e51ba8e8b/assets/New.png -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ais-lab/pl2map/8d1a9289bd9505647e2fbdaf4719310e51ba8e8b/assets/demo.gif -------------------------------------------------------------------------------- /cfgs/7scenes.yaml: -------------------------------------------------------------------------------- 1 | line2d: 2 | max_num_2d_segs: 3000 3 | do_merge_lines: False # Not implemented 4 | visualize: False 5 | save_l3dpp: False 6 | detector: 7 | name: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for loading labeled 3D model 8 | name_test_model: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for testing (not depending on pre-3D model) 9 | preprocessing: 10 | grayscale: True 11 | 12 | point2d: 13 | detector: 14 | name: "superpoint" # ["superpoint", "sift"] 15 | configs: 16 | force_num_keypoints: False 17 | nms_radius: 3 18 | max_keypoints: 2048 19 | preprocessing: 20 | grayscale: True 21 | resize_max: 1024 22 | resize_force: True 23 | interpolation: 'cv2_area' # pil_linear is more accurate but slower 24 | matcher: "NN-superpoint" # ["superglue", "gluestick"] # not implemented (for unlabeled learning) 25 | 26 | regressor: 27 | name: pl2map # ["pl2map", or others] 28 | use_line: True 29 | use_point: True 30 | n_line_keypoints: 10 # number of keypoints used to represent a line 31 | 32 | train: # train configs 33 | batch_size: 1 34 | num_iters: 1500000 # number training iterations 35 | loader_shuffle: True 36 | loader_num_workers: 8 37 | log_interval: 500 # log every n batches (visdom graph) 38 | use_depth: False # use SfM corrected by depth or not 39 | loss: 40 | reprojection: 41 | apply: True 42 | type: "dyntanh" # ["l1", "l1+sqrt", "l1+log", "tanh", "dyntanh"] 43 | soft_clamp: 50 44 | soft_clamp_min: 1 45 | circle_schedule: True # 'circle'(weight increasing) or 'linear' (weight decreasing) 46 | augmentation: 47 | apply: True 48 | on_rate: 0.5 # probability of applying augmentation 49 | brightness: 0.02 50 | contrast: 0.02 51 | homography: 52 | apply: False # if apply, augmented poses will be incorrect 53 | perspective: True 54 | scaling: True 55 | rotation: True 56 | translation: True 57 | n_scales: 5 58 | n_angles: 25 59 | scaling_amplitude: 0.1 60 | perspective_amplitude_x: 0.1 61 | perspective_amplitude_y: 0.1 62 | patch_ratio: 0.8 # ratio of the patch to the image 63 | max_angle: 45 # in degrees 64 | allow_artifacts: False 65 | dsacstar: # apply DSAC*-like augmentation 66 | apply: True # homography augmentation must be disabled 67 | aug_rotation: 30 # in degrees 68 | aug_scale_min: 0.666666666 # 2/3 69 | aug_scale_max: 1.5 # 3/2 70 | 71 | 72 | optimizer: 73 | method: adam 74 | base_lr: 0.0003 # base/start learning rate 75 | weight_decay: 0.0 76 | lr_decay: 0.5 # decay rate 77 | num_lr_decay_step: 7 # decay every n epochs 78 | 79 | localization: 80 | ransac: 81 | max_reproj_error: 12.0 82 | max_epipolar_error: 10.0 -------------------------------------------------------------------------------- /cfgs/Cambridge.yaml: -------------------------------------------------------------------------------- 1 | line2d: 2 | max_num_2d_segs: 3000 3 | do_merge_lines: False # Not implemented 4 | visualize: False 5 | save_l3dpp: False 6 | detector: 7 | name: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for loading labeled 3D model 8 | name_test_model: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for testing (not depending on pre-3D model) 9 | preprocessing: 10 | grayscale: True 11 | 12 | point2d: 13 | detector: 14 | name: "superpoint" # ["superpoint", "sift"] 15 | configs: 16 | force_num_keypoints: False 17 | nms_radius: 3 18 | max_keypoints: 2048 19 | preprocessing: 20 | grayscale: True 21 | resize_max: 1024 22 | resize_force: True 23 | interpolation: 'cv2_area' # pil_linear is more accurate but slower 24 | matcher: "NN-superpoint" # ["superglue", "gluestick"] # not implemented (for unlabeled learning) 25 | 26 | regressor: 27 | name: pl2map # ["pl2map", "d2s"] 28 | use_line: True 29 | use_point: True 30 | n_line_keypoints: 10 # number of keypoints used to represent a line 31 | 32 | train: # train configs 33 | batch_size: 1 34 | num_iters: 1500000 # number training iterations 35 | loader_shuffle: True 36 | loader_num_workers: 8 37 | log_interval: 500 # log every n batches (visdom graph) 38 | loss: 39 | reprojection: 40 | apply: True 41 | start_apply: 0.8 # start applying reprojection loss 42 | type: "dyntanh" # ["l1", "l1+sqrt", "l1+log", "tanh", "dyntanh"] 43 | soft_clamp: 100 44 | soft_clamp_min: 1 45 | circle_schedule: True # 'circle'(weight increasing) or 'linear' (weight decreasing) 46 | augmentation: 47 | apply: True 48 | on_rate: 0.5 # probability of applying augmentation 49 | brightness: 0.15 50 | contrast: 0.1 51 | homography: 52 | apply: False # if apply, augmented poses will be incorrect 53 | perspective: True 54 | scaling: True 55 | rotation: True 56 | translation: True 57 | n_scales: 5 58 | n_angles: 25 59 | scaling_amplitude: 0.1 60 | perspective_amplitude_x: 0.1 61 | perspective_amplitude_y: 0.1 62 | patch_ratio: 0.8 # ratio of the patch to the image 63 | max_angle: 45 # in degrees 64 | allow_artifacts: False 65 | dsacstar: # apply DSAC*-like augmentation 66 | apply: True # homography augmentation must be disabled 67 | aug_rotation: 30 # in degrees 68 | aug_scale_min: 0.666666666 # 2/3 69 | aug_scale_max: 1.5 # 3/2 70 | 71 | 72 | optimizer: 73 | method: adam 74 | base_lr: 0.0005 # base/start learning rate 75 | weight_decay: 0.0 76 | lr_decay: 0.5 # decay rate 77 | num_lr_decay_step: 7 # decay every n epochs, 7 78 | 79 | localization: 80 | ransac: 81 | max_reproj_error: 12.0 82 | max_epipolar_error: 10.0 83 | -------------------------------------------------------------------------------- /cfgs/default.yaml: -------------------------------------------------------------------------------- 1 | line2d: 2 | max_num_2d_segs: 3000 3 | do_merge_lines: False # Not implemented 4 | visualize: False 5 | save_l3dpp: False 6 | detector: 7 | name: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for loading labeled 3D model 8 | name_test_model: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for testing (not depending on pre-3D model) 9 | preprocessing: 10 | grayscale: True 11 | 12 | point2d: 13 | detector: 14 | name: "superpoint" # ["superpoint", "sift"] 15 | configs: 16 | force_num_keypoints: False 17 | nms_radius: 3 18 | max_keypoints: 2048 19 | preprocessing: 20 | grayscale: True 21 | resize_max: 1024 22 | resize_force: True 23 | interpolation: 'cv2_area' # pil_linear is more accurate but slower 24 | matcher: "NN-superpoint" # ["superglue", "gluestick"] # not implemented (for unlabeled learning) 25 | 26 | regressor: 27 | name: pl2map # ["pl2map", "d2s"] 28 | use_line: True 29 | use_point: True 30 | n_line_keypoints: 10 # number of keypoints used to represent a line 31 | 32 | train: # train configs 33 | batch_size: 1 34 | num_iters: 2500000 # number training iterations 35 | loader_shuffle: True 36 | loader_num_workers: 8 37 | log_interval: 500 # log every n batches (visdom graph) 38 | use_depth: False # use SfM corrected by depth or not 39 | loss: 40 | reprojection: 41 | apply: False 42 | type: "dyntanh" # ["l1", "l1+sqrt", "l1+log", "tanh", "dyntanh"] 43 | soft_clamp: 50 44 | soft_clamp_min: 1 45 | circle_schedule: True # 'circle'(weight increasing) or 'linear' (weight decreasing) 46 | augmentation: 47 | apply: False 48 | on_rate: 0.5 # probability of applying augmentation 49 | brightness: 0.1 50 | contrast: 0.1 51 | homography: 52 | apply: False # if apply, augmented poses will be incorrect 53 | perspective: True 54 | scaling: True 55 | rotation: True 56 | translation: True 57 | n_scales: 5 58 | n_angles: 25 59 | scaling_amplitude: 0.1 60 | perspective_amplitude_x: 0.1 61 | perspective_amplitude_y: 0.1 62 | patch_ratio: 0.8 # ratio of the patch to the image 63 | max_angle: 45 # in degrees 64 | allow_artifacts: False 65 | dsacstar: # apply DSAC*-like augmentation 66 | apply: False # homography augmentation must be disabled 67 | aug_rotation: 30 # in degrees 68 | aug_scale_min: 0.666666666 # 2/3 69 | aug_scale_max: 1.5 # 3/2 70 | 71 | localization: 72 | 2d_matcher: "sold2" # ["epipolar", "sold2", "superglue_endpoints"] Other configs for superglue_endpoints are the same as in "line2d" section 73 | epipolar_filter: False # No use for epipolar matcher 74 | IoU_threshold: 0.2 75 | reprojection_filter: null # [null, "Perpendicular", "Midpoint", "Midpoint_Perpendicular"] 76 | ransac: 77 | method: "hybrid" # [null, "ransac", "solver", "hybrid"] 78 | thres: 10.0 # Only for normal & solver 79 | thres_point: 10.0 80 | thres_line: 10.0 81 | weight_point: 1.0 # data type weights for scoring 82 | weight_line: 1.0 # data type weights for scoring 83 | final_least_squares: True 84 | min_num_iterations: 100 85 | solver_flags: [True, True, True, True] 86 | optimize: 87 | loss_func: "TrivialLoss" 88 | loss_func_args: [] 89 | line_cost_func: "PerpendicularDist" 90 | line_weight: 1.0 # weight for optimization (cost function) 91 | hloc: 92 | skip_exists: False 93 | skip_exists: False 94 | -------------------------------------------------------------------------------- /cfgs/indoor6.yaml: -------------------------------------------------------------------------------- 1 | line2d: 2 | max_num_2d_segs: 3000 3 | do_merge_lines: False # Not implemented 4 | visualize: False 5 | save_l3dpp: False 6 | detector: 7 | name: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for loading labeled 3D model 8 | name_test_model: "deeplsd" # ["lsd", "sold2", "deeplsd"] - for testing (not depending on pre-3D model) 9 | preprocessing: 10 | grayscale: True 11 | 12 | point2d: 13 | detector: 14 | name: "superpoint" # ["superpoint", "sift"] 15 | configs: 16 | force_num_keypoints: False 17 | nms_radius: 3 18 | max_keypoints: 2048 19 | preprocessing: 20 | grayscale: True 21 | resize_max: 1024 22 | resize_force: True 23 | interpolation: 'cv2_area' # pil_linear is more accurate but slower 24 | matcher: "NN-superpoint" # ["superglue", "gluestick"] # not implemented (for unlabeled learning) 25 | 26 | regressor: 27 | name: pl2map # pl2map_only_point or pl2map_sep 28 | use_line: True 29 | use_point: True 30 | n_line_keypoints: 10 # number of keypoints used to represent a line 31 | 32 | train: # train configs 33 | batch_size: 1 34 | num_iters: 1500000 # number training iterations 35 | loader_shuffle: True 36 | loader_num_workers: 8 37 | log_interval: 500 # log every n batches (visdom graph) 38 | loss: 39 | reprojection: 40 | apply: False 41 | start_apply: 0.05 # start applying reprojection loss 42 | type: "dyntanh" # ["l1", "l1+sqrt", "l1+log", "tanh", "dyntanh"] 43 | soft_clamp: 100 44 | soft_clamp_min: 1 45 | circle_schedule: True # 'circle'(weight increasing) or 'linear' (weight decreasing) 46 | augmentation: 47 | apply: True 48 | on_rate: 0.5 # probability of applying augmentation 49 | brightness: 0.15 50 | contrast: 0.1 51 | homography: 52 | apply: True # if apply, augmented poses will be incorrect 53 | perspective: True 54 | scaling: True 55 | rotation: True 56 | translation: True 57 | n_scales: 5 58 | n_angles: 25 59 | scaling_amplitude: 0.1 60 | perspective_amplitude_x: 0.1 61 | perspective_amplitude_y: 0.1 62 | patch_ratio: 0.8 # ratio of the patch to the image 63 | max_angle: 45 # in degrees 64 | allow_artifacts: False 65 | dsacstar: # apply DSAC*-like augmentation 66 | apply: False # homography augmentation must be disabled 67 | aug_rotation: 30 # in degrees 68 | aug_scale_min: 0.666666666 # 2/3 69 | aug_scale_max: 1.5 # 3/2 70 | 71 | 72 | optimizer: 73 | method: adam 74 | base_lr: 0.0002 # base/start learning rate 75 | weight_decay: 0.0 76 | lr_decay: 0.5 # decay rate 77 | num_lr_decay_step: 7 # decay every n epochs, 7 78 | 79 | localization: 80 | ransac: 81 | max_reproj_error: 12.0 82 | max_epipolar_error: 10.0 -------------------------------------------------------------------------------- /datasets/_base.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | import numpy as np 3 | import os 4 | import sys 5 | import torch 6 | import math 7 | from scipy.spatial.transform import Rotation as R 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | from detectors.line2d.register_linedetector import get_linedetector 10 | from detectors.point2d.register_pointdetector import get_pointdetector 11 | from util.io import read_image 12 | import copy 13 | import datasets.augmentation as aug 14 | 15 | 16 | def frame2tensor(frame, device): 17 | return torch.from_numpy(frame/255.).float()[None, None].to(device) 18 | 19 | class Camera(): 20 | ''' 21 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 22 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 23 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 24 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 25 | ''' 26 | model_name2id = {"SIMPLE_PINHOLE": 0, "PINHOLE": 1, 27 | "SIMPLE_RADIAL": 2, "RADIAL": 3} 28 | 29 | def __init__(self, camera, iscolmap=True) -> None: 30 | if iscolmap: 31 | self.name = camera.model 32 | self.get_camera_vector_colmap(camera) 33 | else: # list type 34 | self.name = camera[0] 35 | self.camera_array = np.array([self.model_name2id[self.name]] + camera[1:]) 36 | def get_camera_vector_colmap(self, camera): 37 | ''' 38 | Return a camera vector from a colmap camera object 39 | return: numpy array of camera vector 40 | [modelid, width, height, focal,..., cx, cy,...] 41 | ''' 42 | id = self.model_name2id[camera.model] 43 | array = [id, camera.width, camera.height] 44 | array.extend(camera.params) 45 | self.camera_array = np.array(array) 46 | 47 | def update_scale(self, scale_factor): 48 | self.camera_array[1:] = self.camera_array[1:]*scale_factor 49 | 50 | def get_dict_camera(self): 51 | ''' 52 | Return a dictionary of camera 53 | ''' 54 | return {"model": self.name, "width": self.camera_array[1], "height": self.camera_array[2], 55 | "params": self.camera_array[3:].tolist()} 56 | 57 | class Line3D(): 58 | def __init__(self, start, end) -> None: 59 | self.start = np.asarray(start) 60 | self.end = np.asarray(end) 61 | def get_line3d_vector(self): 62 | return np.hstack([self.start, self.end]) 63 | 64 | class Pose(): 65 | def __init__(self, qvec, tvec) -> None: 66 | self.qvec = qvec # quaternion, [w,x,y,z] 67 | self.tvec = tvec # translation, [x,y,z] 68 | def get_pose_vector(self): 69 | """ 70 | Return a pose vector [tvec, qvec] 71 | """ 72 | return np.hstack([self.tvec, self.qvec]) 73 | def get_pose_Tmatrix(self): 74 | """ 75 | Return a pose matrix [R|t] 76 | """ 77 | # Convert the quaternion to a rotation matrix 78 | qvec = np.zeros(4) 79 | qvec[:3] = self.qvec[1:] # convert quaternion from [w,x,y,z] (colmap) to [x,y,z,w] (scipy) 80 | qvec[3] = self.qvec[0] 81 | rotation = R.from_quat(qvec) 82 | rotation_matrix = rotation.as_matrix() 83 | # Create a 4x4 transformation matrix 84 | T = np.eye(4) 85 | T[:3, :3] = rotation_matrix 86 | T[:3, 3] = self.tvec 87 | return T 88 | 89 | def rotate(self, angle): 90 | pose = self.get_pose_Tmatrix() 91 | angle = -angle * math.pi / 180 # convert to radian, and reverse direction != opencv 92 | pose_rot = np.eye(4) 93 | pose_rot[0, 0] = math.cos(angle) 94 | pose_rot[0, 1] = -math.sin(angle) 95 | pose_rot[1, 0] = math.sin(angle) 96 | pose_rot[1, 1] = math.cos(angle) 97 | pose = np.matmul(pose, pose_rot) 98 | self.tvec = pose[:3, 3] 99 | rotation = R.from_matrix(pose[:3, :3]) 100 | qvec = rotation.as_quat() 101 | self.qvec = np.hstack([qvec[3], qvec[:3]]) # convert quaternion from [x,y,z,w] to [w,x,y,z] colmap 102 | 103 | class Image_Class(): 104 | def __init__(self,imgname:str) -> None: 105 | ''' 106 | - image class for storing 2D & 3D points, 2D & 3D lines, camera vector, pose vector 107 | - comments with ### can be changed if augmenting data, otherwise must be fixed 108 | #@ no change, but can be reduced if augmenting data 109 | ''' 110 | self.points2Ds = None ### numpy matrix of 2D points (Nx2) 111 | self.points3Ds = None #@ numpy matrix of 3D points, including np.array[0,0,0] if not available 112 | self.validPoints = None # numpy array of valid 2D points (have 2D-3D points correspondence) 113 | self.line2Ds = None ### numpy matrix of 2D line segments (Nx4) 114 | self.line3Ds = None #@ list of 3D line segment objects, including None if not available 115 | self.line3Ds_matrix = None # numpy matrix of 3D line segments, including np.array[0,0,0,0,0,0] if not available 116 | self.validLines = None # numpy array of valid 3D lines (have 2D-3D lines correspondence) 117 | self.camera = None # camera class 118 | self.id = None 119 | self.imgname = imgname # string: image name 120 | self.pose = None ### Pose object 121 | def get_line3d_matrix(self): 122 | ''' 123 | Return a matrix of line3D vectors 124 | ''' 125 | self.line3Ds_matrix = np.stack([ii.get_line3d_vector() if ii is not None else 126 | np.array([0,0,0,0,0,0]) for ii in self.line3Ds], 0) 127 | self.validLines = np.stack([1 if ii is not None else 128 | 0 for ii in self.line3Ds], 0) 129 | 130 | 131 | class Base_Collection(): 132 | def __init__(self, args, cfg, mode) -> None: 133 | self.args = args 134 | self.cfg = cfg 135 | self.device = f'cuda:{args.cudaid}' if torch.cuda.is_available() else 'cpu' 136 | if mode == "test": 137 | self.get_detector_models() 138 | 139 | def get_point_detector_model(self): 140 | ''' 141 | Return a point detector model 142 | ''' 143 | configs = self.cfg.point2d.detector.configs 144 | method = self.cfg.point2d.detector.name 145 | return get_pointdetector(method = method, configs=configs) 146 | 147 | def get_line_detector_model(self): 148 | ''' 149 | Return a line detector model 150 | ''' 151 | max_num_2d_segs = self.cfg.line2d.max_num_2d_segs 152 | do_merge_lines = self.cfg.line2d.do_merge_lines 153 | visualize = self.cfg.line2d.visualize 154 | method = self.cfg.line2d.detector.name_test_model 155 | return get_linedetector(method= method, max_num_2d_segs=max_num_2d_segs, 156 | do_merge_lines=do_merge_lines, visualize=visualize, cudaid=self.args.cudaid) 157 | 158 | def get_detector_models(self): 159 | self.line_detector = self.get_line_detector_model() 160 | # self.point_detector = self.get_point_detector_model().eval().to(self.device) 161 | 162 | def do_augmentation(self, image, image_infor_class, debug = False): 163 | if not aug.is_apply_augment(self.cfg.train.augmentation.on_rate): 164 | # No apply augmentation 165 | return image, image_infor_class 166 | # Apply the brightness and contrast 167 | transf_image = aug.random_brightness_contrast(image, self.cfg.train.augmentation.brightness, 168 | self.cfg.train.augmentation.contrast) 169 | points2Ds = image_infor_class.points2Ds 170 | lines2Ds = image_infor_class.line2Ds 171 | camera = image_infor_class.camera 172 | pose = image_infor_class.pose 173 | if self.cfg.train.augmentation.homography.apply: 174 | # camera and pose are not correct after applying homography 175 | H,W = image.shape 176 | shape = np.array([H,W]) 177 | h_matrix = aug.sample_homography(shape, self.cfg.train.augmentation.homography) # sample homography matrix 178 | transf_image = aug.warpPerspective_forimage(transf_image, h_matrix) 179 | points2Ds = aug.perspectiveTransform_forpoints(image_infor_class.points2Ds, h_matrix) 180 | lines2Ds = aug.perspectiveTransform_forlines(image_infor_class.line2Ds, h_matrix) 181 | 182 | # dsacstar-like augmentation method. 183 | if self.cfg.train.augmentation.dsacstar.apply: 184 | # camera and pose will be corrected in this augmentation 185 | assert not self.cfg.train.augmentation.homography.apply, "dsacstar augmentation cannot be applied with homography augmentation" 186 | transf_image, points2Ds, lines2Ds, camera, pose = aug.dsacstar_augmentation( 187 | transf_image, self.cfg.train.augmentation.dsacstar, points2Ds, lines2Ds, camera, pose) 188 | 189 | if debug: 190 | from util.visualize import visualize_img_withlinesandpoints 191 | visualize_img_withlinesandpoints(image, image_infor_class.points2Ds,image_infor_class.line2Ds) 192 | image_infor_class.points2Ds = points2Ds 193 | image_infor_class.line2Ds = lines2Ds 194 | image_infor_class.camera = camera 195 | image_infor_class.pose = pose 196 | # correct points and lines inside image 197 | if self.cfg.train.augmentation.dsacstar.apply: 198 | image_infor_class = aug.correct_points_lines_inside_image(transf_image.shape, image_infor_class) 199 | if debug: 200 | visualize_img_withlinesandpoints(transf_image, image_infor_class.points2Ds,image_infor_class.line2Ds, True) 201 | return transf_image, image_infor_class 202 | 203 | def image_loader(self, image_name, augmentation=False, debug = False): 204 | ''' 205 | (use only for point2d detector model) 206 | Read an image, do augmentation if needed, preprocess it, and 207 | return a dictionary of image data and a Image_Class object 208 | ''' 209 | resize_max = self.cfg.point2d.detector.preprocessing.resize_max 210 | resize_force = self.cfg.point2d.detector.preprocessing.resize_force 211 | interpolation = self.cfg.point2d.detector.preprocessing.interpolation 212 | grayscale = self.cfg.point2d.detector.preprocessing.grayscale 213 | path_to_image = self.get_image_path(image_name) 214 | image = read_image(path_to_image, grayscale=grayscale) 215 | 216 | size = image.shape[:2][::-1] 217 | if resize_force and (max(size) > resize_max): 218 | scale = resize_max / max(size) 219 | size_new = tuple(int(round(x*scale)) for x in size) 220 | image = aug.resize_image(image, size_new, interpolation) 221 | # rescale 2D points and lines, camera focal length 222 | raise NotImplementedError 223 | 224 | image_infor_class = copy.deepcopy(self.imgname2imgclass[image_name]) 225 | if augmentation: 226 | image, image_infor_class = self.do_augmentation(image, image_infor_class, debug) 227 | if debug: 228 | print("Debugging image_loader") 229 | return None 230 | 231 | image = image.astype(np.float32) 232 | if grayscale: 233 | image = image[None] 234 | else: 235 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 236 | image = image / 255. 237 | original_size = np.array(size) 238 | data = { 239 | 'image': image, 240 | 'original_size': original_size, 241 | } 242 | return data, image_infor_class 243 | 244 | def detect_points2D(self, image_name): 245 | ''' 246 | Read an image, preprocess it, and 247 | Return a keypoints from that image using 248 | loaded point detector model. 249 | ''' 250 | point_detector = self.get_point_detector_model().eval().to(self.device) 251 | resize_force = self.cfg.point2d.detector.preprocessing.resize_force 252 | resize_max = self.cfg.point2d.detector.preprocessing.resize_max 253 | data,_ = self.image_loader(image_name, False) 254 | data['image'] = torch.from_numpy(data['image'][None]).float().to(self.device) 255 | keypointsdict = point_detector._forward_default(data) 256 | scale = resize_max / max(data['original_size']) 257 | if resize_force and (max(data['original_size']) > resize_max): 258 | keypointsdict['keypoints'][0] = (keypointsdict['keypoints'][0] + .5)/scale - .5 259 | else: 260 | keypointsdict['keypoints'][0] += .5 261 | return keypointsdict 262 | 263 | def detect_lines2D(self, image_name): 264 | ''' 265 | Return a list of lines2D in the image 266 | ''' 267 | grayscale = self.cfg.line2d.preprocessing.grayscale 268 | image_path = self.get_image_path(image_name) 269 | image = read_image(image_path, grayscale=grayscale) 270 | if self.line_detector.get_module_name() == "deeplsd": 271 | image = frame2tensor(image, self.device) 272 | segs = self.line_detector.detect(image) 273 | return segs 274 | def get_2dpoints_lines_for_testing(self, image_name): 275 | ''' 276 | Return a list of points2D and a list of lines2D in the image 277 | ''' 278 | raise NotImplementedError 279 | 280 | def get_image_path(self, image_name): 281 | ''' 282 | Return a path to image 283 | ''' 284 | img_path = os.path.join(self.args.dataset_dir, self.args.dataset, self.args.scene, image_name) 285 | return img_path 286 | 287 | 288 | -------------------------------------------------------------------------------- /datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import math 4 | import random 5 | import PIL.Image 6 | 7 | def resize_image(image, size, interp): 8 | if interp.startswith('cv2_'): 9 | interp = getattr(cv2, 'INTER_'+interp[len('cv2_'):].upper()) 10 | h, w = image.shape[:2] 11 | if interp == cv2.INTER_AREA and (w < size[0] or h < size[1]): 12 | interp = cv2.INTER_LINEAR 13 | resized = cv2.resize(image, size, interpolation=interp) 14 | elif interp.startswith('pil_'): 15 | interp = getattr(PIL.Image, interp[len('pil_'):].upper()) 16 | resized = PIL.Image.fromarray(image.astype(np.uint8)) 17 | resized = resized.resize(size, resample=interp) 18 | resized = np.asarray(resized, dtype=image.dtype) 19 | else: 20 | raise ValueError( 21 | f'Unknown interpolation {interp}.') 22 | return resized 23 | 24 | 25 | def sample_homography(shape, cfg): 26 | """Sample a random valid homography. 27 | 28 | Computes the homography transformation between a random patch in the original image 29 | and a warped projection with the same image size. 30 | As in `tf.contrib.image.transform`, it maps the output point (warped patch) to a 31 | transformed input point (original patch). 32 | The original patch, which is initialized with a simple half-size centered crop, is 33 | iteratively projected, scaled, rotated and translated. 34 | 35 | Arguments: 36 | shape: A numpy array [H,W] specifying the height and width of the original image. 37 | perspective: A boolean that enables the perspective and affine transformations. 38 | scaling: A boolean that enables the random scaling of the patch. 39 | rotation: A boolean that enables the random rotation of the patch. 40 | translation: A boolean that enables the random translation of the patch. 41 | n_scales: The number of tentative scales that are sampled when scaling. 42 | n_angles: The number of tentatives angles that are sampled when rotating. 43 | scaling_amplitude: Controls the amount of scale. 44 | perspective_amplitude_x: Controls the perspective effect in x direction. 45 | perspective_amplitude_y: Controls the perspective effect in y direction. 46 | patch_ratio: Controls the size of the patches used to create the homography. 47 | max_angle: Maximum angle used in rotations. 48 | allow_artifacts: A boolean that enables artifacts when applying the homography. 49 | translation_overflow: Amount of border artifacts caused by translation. 50 | 51 | Returns: 52 | A numpy of shape 3x3 corresponding to the homography transform. 53 | """ 54 | shift=0 55 | perspective=cfg.perspective 56 | scaling=cfg.scaling 57 | rotation=cfg.rotation 58 | translation=cfg.translation 59 | n_scales=cfg.n_scales 60 | n_angles=cfg.n_angles 61 | scaling_amplitude=cfg.scaling_amplitude 62 | perspective_amplitude_x=cfg.perspective_amplitude_x 63 | perspective_amplitude_y=cfg.perspective_amplitude_y 64 | patch_ratio=cfg.patch_ratio 65 | max_angle=math.pi*(cfg.max_angle/180) 66 | allow_artifacts=cfg.allow_artifacts 67 | translation_overflow=0. 68 | 69 | # Corners of the output image 70 | pts1 = np.stack([[0., 0.], [0., 1.], [1., 1.], [1., 0.]], axis=0) 71 | # Corners of the input patch 72 | margin = (1 - patch_ratio) / 2 73 | pts2 = margin + np.array([[0, 0], [0, patch_ratio], 74 | [patch_ratio, patch_ratio], [patch_ratio, 0]]) 75 | 76 | from numpy.random import normal 77 | from numpy.random import uniform 78 | from scipy.stats import truncnorm 79 | 80 | # Random perspective and affine perturbations 81 | # lower, upper = 0, 2 82 | std_trunc = 2 83 | 84 | if perspective: 85 | if not allow_artifacts: 86 | perspective_amplitude_x = min(perspective_amplitude_x, margin) 87 | perspective_amplitude_y = min(perspective_amplitude_y, margin) 88 | 89 | perspective_displacement = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_y/2).rvs(1) 90 | h_displacement_left = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1) 91 | h_displacement_right = truncnorm(-1*std_trunc, std_trunc, loc=0, scale=perspective_amplitude_x/2).rvs(1) 92 | pts2 += np.array([[h_displacement_left, perspective_displacement], 93 | [h_displacement_left, -perspective_displacement], 94 | [h_displacement_right, perspective_displacement], 95 | [h_displacement_right, -perspective_displacement]]).squeeze() 96 | 97 | # Random scaling 98 | # sample several scales, check collision with borders, randomly pick a valid one 99 | if scaling: 100 | scales = truncnorm(-1*std_trunc, std_trunc, loc=1, scale=scaling_amplitude/2).rvs(n_scales) 101 | scales = np.concatenate((np.array([1]), scales), axis=0) 102 | 103 | center = np.mean(pts2, axis=0, keepdims=True) 104 | scaled = (pts2 - center)[np.newaxis, :, :] * scales[:, np.newaxis, np.newaxis] + center 105 | if allow_artifacts: 106 | valid = np.arange(n_scales) # all scales are valid except scale=1 107 | else: 108 | valid = (scaled >= 0.) * (scaled < 1.) 109 | valid = valid.prod(axis=1).prod(axis=1) 110 | valid = np.where(valid)[0] 111 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int) 112 | pts2 = scaled[idx,:,:] 113 | 114 | # Random translation 115 | if translation: 116 | t_min, t_max = np.min(pts2, axis=0), np.min(1 - pts2, axis=0) 117 | if allow_artifacts: 118 | t_min += translation_overflow 119 | t_max += translation_overflow 120 | pts2 += np.array([uniform(-t_min[0], t_max[0],1), uniform(-t_min[1], t_max[1], 1)]).T 121 | 122 | # Random rotation 123 | # sample several rotations, check collision with borders, randomly pick a valid one 124 | if rotation: 125 | angles = np.linspace(-max_angle, max_angle, num=n_angles) 126 | angles = np.concatenate((angles, np.array([0.])), axis=0) # in case no rotation is valid 127 | center = np.mean(pts2, axis=0, keepdims=True) 128 | rot_mat = np.reshape(np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), 129 | np.cos(angles)], axis=1), [-1, 2, 2]) 130 | rotated = np.matmul( (pts2 - center)[np.newaxis,:,:], rot_mat) + center 131 | if allow_artifacts: 132 | valid = np.arange(n_angles) # all scales are valid except scale=1 133 | else: 134 | valid = (rotated >= 0.) * (rotated < 1.) 135 | valid = valid.prod(axis=1).prod(axis=1) 136 | valid = np.where(valid)[0] 137 | idx = valid[np.random.randint(valid.shape[0], size=1)].squeeze().astype(int) 138 | pts2 = rotated[idx,:,:] 139 | 140 | 141 | # Rescale to actual size 142 | shape = shape[::-1] # different convention [y, x] 143 | pts1 *= shape[np.newaxis,:] 144 | pts2 *= shape[np.newaxis,:] 145 | 146 | def ax(p, q): return [p[0], p[1], 1, 0, 0, 0, -p[0] * q[0], -p[1] * q[0]] 147 | def ay(p, q): return [0, 0, 0, p[0], p[1], 1, -p[0] * q[1], -p[1] * q[1]] 148 | 149 | homography = cv2.getPerspectiveTransform(np.float32(pts1+shift), np.float32(pts2+shift)) 150 | return homography 151 | 152 | def warpPerspective_forimage(ori_img, h_matrix): 153 | # Apply the homography transformation to the image 154 | transformed_image = cv2.warpPerspective(ori_img, h_matrix, (ori_img.shape[1], ori_img.shape[0])) 155 | return transformed_image 156 | 157 | def perspectiveTransform_forpoints(positions, h_matrix): 158 | # Apply the homography transformation to the list of positions 159 | transformed_positions = cv2.perspectiveTransform(np.array([positions]), h_matrix) 160 | return transformed_positions[0,:,:] 161 | 162 | def perspectiveTransform_forlines(lines, h_matrix): 163 | # Apply the homography transformation to the list of 2D lines 164 | start_points = lines[:,:2] 165 | end_points = lines[:,2:] 166 | transformed_start_points = cv2.perspectiveTransform(np.array([start_points]), h_matrix)[0,:,:] 167 | transformed_end_points = cv2.perspectiveTransform(np.array([end_points]), h_matrix)[0,:,:] 168 | transformed_lines = np.concatenate((transformed_start_points, transformed_end_points), axis=1) 169 | return transformed_lines 170 | 171 | def random_brightness_contrast(image, b_rate, c_rate): 172 | # Random the brightness and contrast values 173 | contrast = [1.0, 1.0+2.0*c_rate] 174 | brightness = [-100*b_rate, 100*b_rate] 175 | alpha = random.uniform(contrast[0], contrast[1]) 176 | beta = random.uniform(brightness[0], brightness[1]) 177 | return cv2.convertScaleAbs(image, alpha=alpha, beta=beta) 178 | 179 | 180 | def is_apply_augment(rate): 181 | ''' 182 | Return True if the augmentation is applied by select a random option 183 | ''' 184 | # Define the options and their probabilities 185 | options = [True, False] 186 | probabilities = [rate, 1-rate] 187 | # Choose an option to turn on augmentation or not 188 | return random.choices(options, weights=probabilities, k=1)[0] 189 | 190 | def dsacstar_augmentation(image, cfg, points2d, lines2d, camera, pose, interpolation='cv2_area'): 191 | ''' 192 | Apply the augmentation to the input image, points, lines, camera, and pose 193 | args: 194 | image: input image np.array WxH 195 | cfg: configuration file .yaml 196 | points2d: 2D points np.array Nx2 197 | lines2d: 2D lines np.array Nx4 198 | camera: camera parameters np.array[w, h, f, cx, cy,...] 199 | pose: camera pose 'class _base.Pose' 200 | ''' 201 | # Random the scale factor and rotation angle 202 | scale_factor = random.uniform(cfg.aug_scale_min, cfg.aug_scale_max) 203 | angle = random.uniform(-cfg.aug_rotation, cfg.aug_rotation) 204 | 205 | # Apply the scale factor and rotation angle to the image 206 | new_shape = (int(image.shape[1] * scale_factor), int(image.shape[0] * scale_factor)) # height, width 207 | image = resize_image(image, new_shape, interpolation) 208 | 209 | # ajust the points and lines coordinates 210 | points2d = points2d * scale_factor 211 | lines2d = lines2d * scale_factor 212 | 213 | # ajust the camera parameters 214 | camera.update_scale(scale_factor) 215 | 216 | # rotate input image 217 | # Get the rotation matrix 218 | M = cv2.getRotationMatrix2D((new_shape[0] / 2, new_shape[1] / 2), angle, 1) 219 | # Rotate the image 220 | image = cv2.warpAffine(image, M, new_shape) 221 | points2d = rotate_points_dsacstar(points2d, M) 222 | lines2d = rotate_lines_dsacstar(lines2d, M) 223 | # rotate ground truth camera pose 224 | pose.rotate(angle) 225 | 226 | return image, points2d, lines2d, camera, pose 227 | 228 | def rotate_points_dsacstar(points, M): 229 | # Convert the points to homogeneous coordinates 230 | points_hom = np.hstack((points, np.ones((points.shape[0], 1)))) 231 | # Rotate the points 232 | rotated_points_hom = np.dot(M, points_hom.T).T 233 | # Convert the points back to 2D 234 | rotated_points = rotated_points_hom[:, :2] 235 | return rotated_points 236 | 237 | def rotate_lines_dsacstar(lines, M): 238 | start_points = lines[:,:2] 239 | end_points = lines[:,2:] 240 | start_points = rotate_points_dsacstar(start_points, M) 241 | end_points = rotate_points_dsacstar(end_points, M) 242 | rotated_lines = np.concatenate((start_points, end_points), axis=1) 243 | return rotated_lines 244 | 245 | def is_inside_img(points, img_shape): 246 | h, w = img_shape[0], img_shape[1] 247 | return (points[:, 0] >= 0) & (points[:, 0] < w) & (points[:, 1] >= 0) & (points[:, 1] < h) 248 | 249 | def correct_points_lines_inside_image(shape, image_infor_class): 250 | ''' 251 | Correct the points and lines coordinates to be inside the image 252 | if the points/lines are outside the image, remove them 253 | if lines have half inside and half outside, shrink the line to be inside the image 254 | Then, correct the 3D ground truth points coordinates 255 | Args: 256 | shape: image shape (height, width) 257 | image_infor_class: class _base.ImageInfor 258 | ''' 259 | H, W = shape[0], shape[1] 260 | # correct 2d points 261 | points2d = image_infor_class.points2Ds 262 | valid_points = is_inside_img(points2d, shape) 263 | points2d = points2d[valid_points] 264 | image_infor_class.points2Ds = points2d 265 | # correct 3d points 266 | image_infor_class.points3Ds = image_infor_class.points3Ds[valid_points] 267 | # correct id of valids 268 | image_infor_class.validPoints = image_infor_class.validPoints[valid_points] 269 | assert len(image_infor_class.points2Ds) == len(image_infor_class.points3Ds) == len(image_infor_class.validPoints) 270 | 271 | # correct 2d lines 272 | lines2d = image_infor_class.line2Ds 273 | lines3d = image_infor_class.line3Ds_matrix 274 | valids_lines2d = image_infor_class.validLines 275 | start_points = lines2d[:,:2] 276 | end_points = lines2d[:,2:] 277 | valid_start_points = is_inside_img(start_points, shape) 278 | valid_end_points = is_inside_img(end_points, shape) 279 | # remove lines that are outside the image 280 | valid_lines = valid_start_points | valid_end_points 281 | 282 | start_points = start_points[valid_lines] 283 | end_points = end_points[valid_lines] 284 | lines3d = lines3d[valid_lines] 285 | valids_lines2d = valids_lines2d[valid_lines] 286 | 287 | valid_start_points = valid_start_points[valid_lines] 288 | valid_end_points = valid_end_points[valid_lines] 289 | # shrink lines that are half inside and half outside the image 290 | indices = np.where(~valid_start_points)[0] 291 | for idx in indices: 292 | start = start_points[idx,:] # outside points 293 | end = end_points[idx,:] 294 | m, c = line_equation(start, end) # y = mx + c 295 | if start[0] < 0: 296 | start[0] = 0+1 297 | start[1] = compute_y(m, c, 0) 298 | elif start[0] > W: 299 | start[0] = W - 1 300 | start[1] = compute_y(m, c, W) 301 | if start[1] < 0: 302 | start[0] = compute_x(m, c, 0) 303 | start[1] = 0 + 1 304 | elif start[1] > H: 305 | start[0] = compute_x(m, c, H) 306 | start[1] = H - 1 307 | start_points[idx] = start 308 | indices = np.where(~valid_end_points)[0] 309 | for idx in indices: 310 | start = start_points[idx] 311 | end = end_points[idx] # outside points 312 | m, c = line_equation(start, end) # y = mx + c 313 | if end[0] < 0: 314 | end[0] = 0+1 315 | end[1] = compute_y(m, c, 0) 316 | elif end[0] > W: 317 | end[0] = W - 1 318 | end[1] = compute_y(m, c, W) 319 | if end[1] < 0: 320 | end[0] = compute_x(m, c, 0) 321 | end[1] = 0 + 1 322 | elif end[1] > H: 323 | end[0] = compute_x(m, c, H) 324 | end[1] = H - 1 325 | end_points[idx] = end 326 | 327 | assert np.all(is_inside_img(start_points, shape)) 328 | assert np.all(is_inside_img(end_points, shape)) 329 | lines2d = np.concatenate((start_points, end_points), axis=1) 330 | assert len(lines2d) == len(lines3d) == len(valids_lines2d) 331 | image_infor_class.line2Ds = lines2d 332 | image_infor_class.line3Ds_matrix = lines3d 333 | image_infor_class.validLines = valids_lines2d 334 | return image_infor_class 335 | 336 | def line_equation(start, end): 337 | # Calculate the slope 338 | m = (end[1] - start[1]) / (end[0] - start[0]) 339 | # Calculate the y-intercept 340 | c = start[1] - m * start[0] 341 | return m, c # y = mx + c 342 | def compute_x(m, c, y): 343 | # Calculate the x value that corresponds to the given y value 344 | # and the line equation y = mx + c 345 | return (y - c) / m 346 | def compute_y(m, c, x): 347 | # Calculate the y value that corresponds to the given x value 348 | # and the line equation y = mx + c 349 | return m * x + c -------------------------------------------------------------------------------- /datasets/data_collection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from util.read_write_model import read_model 6 | import numpy as np 7 | from pathlib import Path 8 | from datasets._base import (Image_Class, Base_Collection, Line3D, Pose, 9 | Camera) 10 | 11 | strlist2floatlist = lambda strlist: [float(s) for s in strlist] 12 | strlist2intlist = lambda strlist: [int(s) for s in strlist] 13 | 14 | class DataCollection(Base_Collection): 15 | def __init__(self, args:dict, cfg:dict, mode="train")->None: 16 | super(DataCollection, self).__init__(args, cfg, mode) 17 | self.gt_3Dmodels_path = self.args.sfm_dir / f"{self.args.dataset}/{self.args.scene}" 18 | self.SfM_with_depth = self.cfg.train.use_depth # use SfM labels which has been corrected by depth or not 19 | self.train_imgs = [] # list of train image names 20 | self.test_imgs = [] # list of test image names 21 | self.imgname2limapID = {} # map from image name to limap image id 22 | self.limapID2imgname = {} # map from limap image id to image name 23 | # load all images 2D & 3D points and create Image_Class objects 24 | self.imgname2imgclass = {} # map from image name to Image_Class object 25 | self.load_all_2Dpoints_by_dataset(self.args.dataset) 26 | self.load_imgname2limapID() 27 | # load lines data from Limap output 28 | self.load_all_2Dlines_data() 29 | # load alltracks data from Limap output 30 | self.load_alltracks_limap_3Dlines() 31 | 32 | def load_all_2Dpoints_by_dataset(self, dataset): 33 | if dataset == "7scenes": 34 | self.load_all_2Dpoints_7scenes() 35 | elif dataset == "Cambridge" or dataset == "indoor6": 36 | self.load_all_2Dpoints_Cambridge() 37 | else: 38 | raise NotImplemented 39 | 40 | 41 | def load_all_2Dpoints_7scenes(self): 42 | # currently used for 7scenes. 43 | # load all 2d & 3d points from colmap output 44 | path_gt_3Dmodels_full = self.gt_3Dmodels_path/"sfm_sift_full" 45 | if self.SfM_with_depth: 46 | print("[INFOR] Using SfM labels corrected by depth.") 47 | path_gt_3Dmodels_train = self.gt_3Dmodels_path/"sfm_superpoint+superglue+depth" 48 | else: 49 | path_gt_3Dmodels_train = self.gt_3Dmodels_path/"sfm_superpoint+superglue" 50 | testlist_path = path_gt_3Dmodels_full/"list_test.txt" 51 | cameras_all, images_all, _ = read_model(path=path_gt_3Dmodels_full, ext=".bin") 52 | _, images_train, points3D_train = read_model(path=path_gt_3Dmodels_train, ext=".bin") 53 | name2id_train = {image.name: i for i, image in images_train.items()} 54 | 55 | if os.path.exists(testlist_path): 56 | with open(testlist_path, 'r') as f: 57 | testlist = f.read().rstrip().split('\n') 58 | else: 59 | raise ValueError("Error! Input file/directory {0} not found.".format(testlist_path)) 60 | for id_, image in images_all.items(): 61 | img_name = image.name 62 | self.imgname2imgclass[img_name] = Image_Class(img_name) 63 | if image.name in testlist: 64 | # fill data to TEST img classes 65 | self.test_imgs.append(img_name) 66 | self.imgname2imgclass[img_name].pose = Pose(image.qvec, image.tvec) 67 | self.imgname2imgclass[img_name].camera = Camera(cameras_all[image.camera_id], 68 | iscolmap=True) 69 | else: 70 | # fill data to TRAIN img classes 71 | self.train_imgs.append(img_name) 72 | self.imgname2imgclass[img_name].pose = Pose(image.qvec, image.tvec) 73 | image_train = images_train[name2id_train[img_name]] 74 | self.imgname2imgclass[img_name].points2Ds = image_train.xys 75 | self.imgname2imgclass[img_name].points3Ds = np.stack([points3D_train[ii].xyz if ii != -1 else 76 | np.array([0,0,0]) for ii in image_train.point3D_ids], 0) 77 | self.imgname2imgclass[img_name].validPoints = np.stack([1 if ii != -1 else 78 | 0 for ii in image_train.point3D_ids], 0) 79 | self.imgname2imgclass[img_name].camera = Camera(cameras_all[image.camera_id], 80 | iscolmap=True) 81 | 82 | def load_all_2Dpoints_Cambridge(self): 83 | # load all 2d & 3d points from colmap output 84 | path_gt_3Dmodels_full = self.gt_3Dmodels_path/"sfm_sift_full" 85 | 86 | # load query_list_with_intrinsics.txt 87 | query_list_with_intrinsics = self.gt_3Dmodels_path/"query_list_with_intrinsics.txt" 88 | if not os.path.exists(query_list_with_intrinsics): 89 | raise ValueError("Error! Input file/directory {0} not found.".format(query_list_with_intrinsics)) 90 | query_list_with_intrinsics = pd.read_csv(query_list_with_intrinsics, sep=" ", header=None) 91 | # get test dictionary with its intrinsic 92 | testimgname2intrinsic = {query_list_with_intrinsics.iloc[i,0]:list(query_list_with_intrinsics.iloc[i,1:]) 93 | for i in range(len(query_list_with_intrinsics))} 94 | 95 | # load id_to_origin_name.txt 96 | import json 97 | id_to_origin_name = self.gt_3Dmodels_path / "id_to_origin_name.txt" 98 | with open(id_to_origin_name, 'r') as f: 99 | id_to_origin_name = json.load(f) 100 | 101 | originalname2newimgname = {} 102 | for id, originalname in id_to_origin_name.items(): 103 | id = int(id) 104 | originalname2newimgname[originalname] = "image{0:08d}.png".format(id) 105 | 106 | 107 | # load the camera model from colmap output 108 | _, images_all, _ = read_model(path=path_gt_3Dmodels_full, ext=".bin") 109 | path_gt_3Dmodels_train = self.gt_3Dmodels_path/"sfm_superpoint+superglue" 110 | cameras_train, images_train, points3D_train = read_model(path=path_gt_3Dmodels_train, ext=".bin") 111 | name2id_train = {image.name: i for i, image in images_train.items()} 112 | 113 | for _, image in images_all.items(): 114 | img_name = image.name 115 | new_img_name = originalname2newimgname[img_name] 116 | self.imgname2imgclass[new_img_name] = Image_Class(new_img_name) 117 | if new_img_name in testimgname2intrinsic: 118 | # fill data to TEST img classes 119 | self.test_imgs.append(new_img_name) 120 | self.imgname2imgclass[new_img_name].pose = Pose(image.qvec, image.tvec) 121 | self.imgname2imgclass[new_img_name].camera = Camera(testimgname2intrinsic[new_img_name], 122 | iscolmap=False) 123 | else: 124 | # fill data to TRAIN img classes 125 | if new_img_name not in name2id_train: 126 | continue 127 | image_train = images_train[name2id_train[new_img_name]] 128 | if len(image_train.point3D_ids) == 0: 129 | continue 130 | self.train_imgs.append(new_img_name) 131 | self.imgname2imgclass[new_img_name].pose = Pose(image.qvec, image.tvec) 132 | self.imgname2imgclass[new_img_name].points2Ds = image_train.xys 133 | self.imgname2imgclass[new_img_name].points3Ds = np.stack([points3D_train[ii].xyz if ii != -1 else 134 | np.array([0,0,0]) for ii in image_train.point3D_ids], 0) 135 | self.imgname2imgclass[new_img_name].validPoints = np.stack([1 if ii != -1 else 136 | 0 for ii in image_train.point3D_ids], 0) 137 | self.imgname2imgclass[new_img_name].camera = Camera(cameras_train[image_train.camera_id], 138 | iscolmap=True) 139 | 140 | def load_imgname2limapID(self): 141 | # load path image list from limap output 142 | img_list_path = self.gt_3Dmodels_path/f"limap/{self.cfg.line2d.detector.name}/image_list.txt" 143 | if not os.path.exists(img_list_path): 144 | raise ValueError("Error! Input file/directory {0} not found.".format(img_list_path)) 145 | with open(img_list_path, 'r') as f: 146 | lines = f.readlines()[1:] # read all lines except the first one 147 | for line in lines: 148 | img_id, img_name = line.strip().split(',') # assuming two columns separated by comma 149 | self.imgname2limapID[img_name] = int(img_id) 150 | self.limapID2imgname[int(img_id)] = img_name 151 | 152 | 153 | 154 | def load_all_2Dlines_data(self): 155 | # load train all lines data from limap output (all exixting lines in all images) 156 | # then create line3D objects for each image 157 | segments_path = self.gt_3Dmodels_path /f"limap/{self.cfg.line2d.detector.name}/segments" 158 | def read_segments_file(img_id, segments_path): 159 | segments_file = segments_path / f"segments_{img_id}.txt" 160 | if not os.path.exists(segments_file): 161 | raise ValueError("Error! Input file/directory {0} not found.".format(segments_file)) 162 | segments_matrix = pd.read_csv(segments_file, sep=' ', skiprows=1, header=None).to_numpy() 163 | return segments_matrix 164 | for img_name in self.train_imgs: 165 | img_id = self.imgname2limapID[img_name] 166 | segments_matrix = read_segments_file(img_id, segments_path) 167 | length = segments_matrix.shape[0] 168 | self.imgname2imgclass[img_name].line2Ds = segments_matrix 169 | self.imgname2imgclass[img_name].line3Ds = [None for _ in range(length)] 170 | # if length < 80: 171 | # print(length, img_name) 172 | 173 | 174 | 175 | def load_alltracks_limap_3Dlines(self): 176 | # load all tracks data from limap output (training data only) 177 | track_file = "fitnmerge_alltracks.txt" if self.SfM_with_depth else "alltracks.txt" 178 | tracks_path = self.gt_3Dmodels_path/ f"limap/{self.cfg.line2d.detector.name}/{track_file}" 179 | if not os.path.exists(tracks_path): 180 | raise ValueError("Error! Input file/directory {0} not found.".format(tracks_path)) 181 | with open(tracks_path, 'r') as f: 182 | lines = f.readlines() 183 | number3Dlines = int(lines[0].strip()) 184 | i = 1 185 | length = len(lines) 186 | while i < length: 187 | i += 1 # skip the first line (3dline id, #2dlines, #imgs) 188 | start_3d = strlist2floatlist(lines[i].strip().split(' ')) 189 | i += 1 190 | end_3d = strlist2floatlist(lines[i].strip().split(' ')) 191 | i += 1 192 | # load img ids 193 | img_ids = strlist2intlist(lines[i].strip().split(' ')) 194 | i += 1 195 | # load 2d line ids 196 | line2d_ids = strlist2intlist(lines[i].strip().split(' ')) 197 | # fill data to Image_Class objects 198 | for img_id, line2d_id in zip(img_ids, line2d_ids): 199 | img_name = self.limapID2imgname[img_id] 200 | if img_name not in self.train_imgs: 201 | continue 202 | if self.imgname2imgclass[img_name].line3Ds[line2d_id] is not None: 203 | raise ValueError("Error! 3D line {0} in image {1} is already filled.".format(line2d_id, img_id)) 204 | self.imgname2imgclass[img_name].line3Ds[line2d_id] = Line3D(start_3d, end_3d) 205 | i += 1 206 | self.load_all_lines3D_matrix() 207 | 208 | def load_all_lines3D_matrix(self): 209 | # load all lines3D matrix from colmap output 210 | for img_name in self.train_imgs: 211 | self.imgname2imgclass[img_name].get_line3d_matrix() 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from datasets.data_collection import DataCollection 6 | import numpy as np 7 | 8 | class Collection_Loader(Dataset): 9 | def __init__(self, args, cfg, mode="train"): 10 | self.DataCol = DataCollection(args, cfg, mode=mode) 11 | self.mode = mode 12 | if "train" in mode: 13 | self.image_list = self.DataCol.train_imgs 14 | self.augmentation = cfg.train.augmentation.apply if mode == "train" else False 15 | if self.augmentation: print("[INFOR] Augmentation is applied") 16 | elif mode == "test": 17 | self.augmentation = False 18 | self.image_list = self.DataCol.test_imgs 19 | else: 20 | raise ValueError("Error! Mode {0} not supported.".format(mode)) 21 | 22 | def __len__(self): 23 | return len(self.image_list) 24 | 25 | def __getitem__(self, index): 26 | image_name = self.image_list[index] 27 | data, infor = self.DataCol.image_loader(image_name, augmentation=self.augmentation) # dict:{img, ori_img_size} 28 | target = {} 29 | if self.mode == "test": 30 | data['lines'] = self.DataCol.detect_lines2D(image_name)[:,:4] # detect lines2D 31 | data['keypoints'] = 'None' # to show there is no keypoints 32 | if "train" in self.mode: 33 | data['lines'] = infor.line2Ds 34 | data['keypoints'] = infor.points2Ds 35 | target['lines3D'] = infor.line3Ds_matrix.T 36 | target['points3D'] = infor.points3Ds.T 37 | target['validPoints'] = infor.validPoints 38 | target['validLines'] = infor.validLines 39 | assert data['lines'].shape[0] == target['lines3D'].shape[1] == target['validLines'].shape[0] 40 | assert data['keypoints'].shape[0] == target['points3D'].shape[1] == target['validPoints'].shape[0] 41 | target['pose'] = infor.pose.get_pose_vector() 42 | target['camera'] = infor.camera.camera_array 43 | data['imgname'] = image_name 44 | data = map_dict_to_torch(data) 45 | target = map_dict_to_torch(target) 46 | return data, target 47 | 48 | def map_dict_to_torch(data): 49 | for k, v in data.items(): 50 | if isinstance(v, str): 51 | continue 52 | elif isinstance(v, np.ndarray): 53 | data[k] = torch.from_numpy(v).float() 54 | else: 55 | raise ValueError("Error! Type {0} not supported.".format(type(v))) 56 | return data -------------------------------------------------------------------------------- /datasets/test.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | from data_collection import DataCollection 4 | import sys, os 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | import util.config as utilcfg 7 | import util.visualize as u_vis 8 | from omegaconf import OmegaConf 9 | 10 | def parse_config(): 11 | arg_parser = argparse.ArgumentParser(description='pre-processing for PL2Map dataset') 12 | arg_parser.add_argument('-d', '--dataset_dir', type=Path, default='datasets/imgs_datasets/', help='') 13 | arg_parser.add_argument('--dataset', type=str, default="7scenes", help='dataset name') 14 | arg_parser.add_argument('-s', '--scene', type=str, default="office", help='scene name(s)') 15 | arg_parser.add_argument('-cp','--checkpoint', type=int, default=0, choices=[0,1], help='use pre-trained model') 16 | arg_parser.add_argument('--visdom', type=int, default=1, choices=[0,1], help='visualize loss using visdom') 17 | arg_parser.add_argument('-c','--cudaid', type=int, default=0, help='specify cuda device id') 18 | arg_parser.add_argument('--use_depth', type=int, default=0, choices=[0,1], help='use SfM corrected by depth or not') 19 | arg_parser.add_argument('-o','--outputs', type=Path, default='logs/', 20 | help='Path to the output directory, default: %(default)s') 21 | arg_parser.add_argument('-expv', '--experiment_version', type=str, default="pl2map", help='experiment version folder') 22 | args, _ = arg_parser.parse_known_args() 23 | args.outputs = os.path.join(args.outputs, args.scene + "_" + args.experiment_version) 24 | print("Dataset: {} | Scene: {}".format(args.dataset, args.scene)) 25 | cfg = utilcfg.load_config(f'cfgs/{args.dataset}.yaml', default_path='cfgs/default.yaml') 26 | cfg = OmegaConf.create(cfg) 27 | utilcfg.mkdir(args.outputs) 28 | 29 | # Save the config file for evaluation purposes 30 | config_file_path = os.path.join(args.outputs, 'config.yaml') 31 | OmegaConf.save(cfg, config_file_path) 32 | 33 | return args, cfg 34 | 35 | def main(): 36 | args, cfg = parse_config() 37 | dataset = DataCollection(args, cfg, mode="test") 38 | # img_name = "seq-06/frame-000780.color.png" 39 | 40 | # print(dataset.imgname2imgclass[img_name].camera.camera_array) 41 | # print(dataset.imgname2imgclass[img_name].pose.get_pose_vector()) 42 | 43 | # u_vis.visualize_2d_points_lines_from_collection(dataset, img_name, mode="online") 44 | # u_vis.visualize_2d_lines_from_collection(dataset, img_name, mode="online") 45 | # u_vis.visualize_2d_lines_from_collection(dataset, img_name, mode="offline") 46 | # u_vis.open3d_vis_3d_points_from_datacollection(dataset) 47 | # u_vis.open3d_vis_3d_lines_from_single_imgandcollection(dataset, img_name) 48 | u_vis.open3d_vis_3d_lines_from_datacollection(dataset) 49 | # u_vis.visualize_2d_points_from_collection(dataset, img_name, mode="online") 50 | # u_vis.visualize_2d_points_from_collection(dataset, img_name, mode="offline") 51 | # dataset.image_loader(img_name, cfg.train.augmentation.apply, debug=True) 52 | # img_name = "seq-06/frame-000499.color.png" 53 | # train_img_list = dataset.train_imgs 54 | # i = 0 55 | # for img_name in train_img_list: 56 | # i+=1 57 | # if i%5 == 0: 58 | # continue 59 | # print(img_name) 60 | # # u_vis.visualize_2d_points_from_collection(dataset, img_name, mode="offline") 61 | # # u_vis.visualize_2d_points_from_collection(dataset, img_name, mode="online") 62 | # u_vis.visualize_2d_lines_from_collection(dataset, img_name, mode="offline") 63 | # # u_vis.visualize_2d_lines_from_collection(dataset, img_name, mode="online") 64 | # # visualize 3D train lines 65 | # # u_vis.open3d_vis_3d_lines_from_datacollection(dataset) 66 | # if i > 2000: 67 | # break 68 | if __name__ == "__main__": 69 | main() 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /detectors/line2d/DeepLSD/deeplsd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from third_party.DeepLSD.deeplsd.models.deeplsd_inference import DeepLSD 5 | from ..linebase_detector import LineBaseDetector, BaseDetectorOptions 6 | 7 | class DeepLSDDetector(LineBaseDetector): 8 | def __init__(self, options = BaseDetectorOptions()): 9 | super(DeepLSDDetector, self).__init__(options) 10 | 11 | conf = { 12 | 'detect_lines': True, 13 | 'line_detection_params': { 14 | 'merge': False, 15 | 'grad_nfa': True, 16 | 'filtering': 'normal', 17 | 'grad_thresh': 3, 18 | }, 19 | } 20 | self.device = f'cuda:{self.cudaid}' if torch.cuda.is_available() else 'cpu' 21 | if self.weight_path is None: 22 | ckpt = os.path.join(os.path.dirname(__file__), 'deeplsd_md.tar') 23 | else: 24 | ckpt = os.path.join(self.weight_path, "line2d", "DeepLSD", 'deeplsd_md.tar') 25 | if not os.path.isfile(ckpt): 26 | self.download_model(ckpt) 27 | ckpt = torch.load(ckpt, map_location='cpu') 28 | print('Loaded DeepLSD model') 29 | self.net = DeepLSD(conf).eval() 30 | self.net.load_state_dict(ckpt['model']) 31 | self.net = self.net.to(self.device) 32 | 33 | def download_model(self, path): 34 | import subprocess 35 | if not os.path.exists(os.path.dirname(path)): 36 | os.makedirs(os.path.dirname(path)) 37 | link = "https://www.polybox.ethz.ch/index.php/s/XVb30sUyuJttFys/download" 38 | cmd = ["wget", link, "-O", path] 39 | print("Downloading DeepLSD model...") 40 | subprocess.run(cmd, check=True) 41 | 42 | def get_module_name(self): 43 | return "deeplsd" 44 | 45 | def detect(self, image): 46 | 47 | with torch.no_grad(): 48 | lines = self.net({'image': image})['lines'][0] 49 | 50 | # Use the line length as score 51 | lines = np.concatenate([ 52 | lines.reshape(-1, 4), 53 | np.linalg.norm(lines[:, 0] - lines[:, 1], axis=1, keepdims=True)], 54 | axis=1) 55 | return lines 56 | -------------------------------------------------------------------------------- /detectors/line2d/LSD/lsd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytlsd 3 | import numpy as np 4 | from ..linebase_detector import LineBaseDetector, BaseDetectorOptions 5 | 6 | class LSDDetector(LineBaseDetector): 7 | def __init__(self, options = BaseDetectorOptions()): 8 | super(LSDDetector, self).__init__(options) 9 | 10 | def get_module_name(self): 11 | return "lsd" 12 | 13 | def detect(self, image): 14 | max_n_lines = None # 80 15 | min_length = 15 16 | lines, scores, valid_lines = [], [], [] 17 | if max_n_lines is None: 18 | b_segs = pytlsd.lsd(image) 19 | else: 20 | for s in [0.3, 0.4, 0.5, 0.7, 0.8, 1.0]: 21 | b_segs = pytlsd.lsd(image, scale=s) 22 | # print(len(b_segs)) 23 | if len(b_segs) >= max_n_lines: 24 | break 25 | # print(len(b_segs)) 26 | segs_length = np.linalg.norm(b_segs[:, 2:4] - b_segs[:, 0:2], axis=1) 27 | # Remove short lines 28 | # b_segs = b_segs[segs_length >= min_length] 29 | # segs_length = segs_length[segs_length >= min_length] 30 | b_scores = b_segs[:, -1] * np.sqrt(segs_length) 31 | # Take the most relevant segments with 32 | indices = np.argsort(-b_scores) 33 | if max_n_lines is not None: 34 | indices = indices[:max_n_lines] 35 | b_segs = b_segs[indices, :] 36 | # print(b_segs.shape) 37 | # segs = pytlsd.lsd(image) 38 | return b_segs 39 | 40 | -------------------------------------------------------------------------------- /detectors/line2d/linebase_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import cv2 5 | from torch import nn 6 | 7 | #import limap.util.io as limapio 8 | #import limap.visualize as limapvis 9 | 10 | import collections 11 | from typing import NamedTuple 12 | class BaseDetectorOptions(NamedTuple): 13 | """ 14 | Base options for the line detector 15 | 16 | :param set_gray: whether to set the image to gray scale (sometimes depending on the detector) 17 | :param max_num_2d_segs: maximum number of detected line segments (default = 3000) 18 | :param do_merge_lines: whether to merge close similar lines at post-processing (default = False) 19 | :param visualize: whether to output visualizations into output folder along with the detections (default = False) 20 | :param weight_path: specify path to load weights (at default, weights will be downloaded to ~/.local) 21 | """ 22 | max_num_2d_segs: int = 3000 23 | do_merge_lines: bool = False 24 | visualize: bool = False 25 | weight_path: str = None 26 | cudaid: int = 0 27 | 28 | class LineBaseDetector(): 29 | """ 30 | Virtual class for line detector 31 | """ 32 | def __init__(self, options = BaseDetectorOptions()): 33 | self.max_num_2d_segs = options.max_num_2d_segs 34 | self.do_merge_lines = options.do_merge_lines 35 | self.visualize = options.visualize 36 | self.weight_path = options.weight_path 37 | self.cudaid = options.cudaid 38 | 39 | # Module name needs to be set 40 | def get_module_name(self): 41 | """ 42 | Virtual method (need to be implemented) - return the name of the module 43 | """ 44 | raise NotImplementedError 45 | # The functions below are required for detectors 46 | def detect(self, data): 47 | """ 48 | Virtual method (for detector) - detect 2D line segments 49 | 50 | Args: 51 | view (:class:`limap.base.CameraView`): The `limap.base.CameraView` instance corresponding to the image 52 | Returns: 53 | :class:`np.array` of shape (N, 5): line detections. Each row corresponds to x1, y1, x2, y2 and score. 54 | """ 55 | raise NotImplementedError 56 | # The functions below are required for extractors 57 | def extract(self, camview, segs): 58 | """ 59 | Virtual method (for extractor) - extract the features for the detected segments 60 | 61 | Args: 62 | view (:class:`limap.base.CameraView`): The `limap.base.CameraView` instance corresponding to the image 63 | segs: :class:`np.array` of shape (N, 5), line detections. Each row corresponds to x1, y1, x2, y2 and score. Computed from the `detect` method. 64 | Returns: 65 | The extracted feature 66 | """ 67 | raise NotImplementedError 68 | def get_descinfo_fname(self, descinfo_folder, img_id): 69 | """ 70 | Virtual method (for extractor) - Get the target filename of the extracted feature 71 | 72 | Args: 73 | descinfo_folder (str): The output folder 74 | img_id (int): The image id 75 | Returns: 76 | str: target filename 77 | """ 78 | raise NotImplementedError 79 | def save_descinfo(self, descinfo_folder, img_id, descinfo): 80 | """ 81 | Virtual method (for extractor) - Save the extracted feature to the target folder 82 | 83 | Args: 84 | descinfo_folder (str): The output folder 85 | img_id (int): The image id 86 | descinfo: The features extracted from the function `extract` 87 | """ 88 | raise NotImplementedError 89 | def read_descinfo(self, descinfo_folder, img_id): 90 | """ 91 | Virtual method (for extractor) - Read in the extracted feature. Dual function for `save_descinfo`. 92 | 93 | Args: 94 | descinfo_folder (str): The output folder 95 | img_id (int): The image id 96 | Returns: 97 | The extracted feature 98 | """ 99 | raise NotImplementedError 100 | # The functions below are required for double-functioning objects 101 | def detect_and_extract(self, camview): 102 | """ 103 | Virtual method (for dual-functional class that can perform both detection and extraction) - Detect and extract on a single image 104 | 105 | Args: 106 | view (:class:`limap.base.CameraView`): The `limap.base.CameraView` instance corresponding to the image 107 | Returns: 108 | segs (:class:`np.array`): of shape (N, 5), line detections. Each row corresponds to x1, y1, x2, y2 and score. Computed from the `detect` method. 109 | descinfo: The features extracted from the function `extract` 110 | """ 111 | raise NotImplementedError 112 | def sample_descinfo_by_indexes(self, descinfo, indexes): 113 | """ 114 | Virtual method (for dual-functional class that can perform both detection and extraction) - sample descriptors for a subset of images 115 | 116 | Args: 117 | descinfo: The features extracted from the function `extract`. 118 | indexes (list[int]): List of image ids for the subset. 119 | """ 120 | raise NotImplementedError 121 | 122 | def get_segments_folder(self, output_folder): 123 | """ 124 | Return the folder path to the detected segments 125 | 126 | Args: 127 | output_folder (str): The output folder 128 | Returns: 129 | path_to_segments (str): The path to the saved segments 130 | """ 131 | return os.path.join(output_folder, "segments") 132 | 133 | def merge_lines(self, segs): 134 | from limap.line2d.line_utils import merge_lines 135 | segs = segs[:, :4].reshape(-1, 2, 2) 136 | segs = merge_lines(segs) 137 | segs = segs.reshape(-1, 4) 138 | return segs 139 | 140 | def take_longest_k(self, segs, max_num_2d_segs=3000): 141 | indexes = np.arange(0, segs.shape[0]) 142 | if max_num_2d_segs is None or max_num_2d_segs == -1: 143 | pass 144 | elif segs.shape[0] > max_num_2d_segs: 145 | lengths_squared = (segs[:,2] - segs[:,0]) ** 2 + (segs[:,3] - segs[:,1]) ** 2 146 | indexes = np.argsort(lengths_squared)[::-1][:max_num_2d_segs] 147 | segs = segs[indexes,:] 148 | return segs, indexes 149 | 150 | def visualize_segs(self, output_folder, imagecols, first_k=10): 151 | seg_folder = self.get_segments_folder(output_folder) 152 | n_vis_images = min(first_k, imagecols.NumImages()) 153 | vis_folder = os.path.join(output_folder, "visualize") 154 | limapio.check_makedirs(vis_folder) 155 | image_ids = imagecols.get_img_ids()[:n_vis_images] 156 | for img_id in image_ids: 157 | img = imagecols.read_image(img_id) 158 | segs = limapio.read_txt_segments(seg_folder, img_id) 159 | img = limapvis.draw_segments(img, segs, (0, 255, 0)) 160 | fname = os.path.join(vis_folder, "img_{0}_det.png".format(img_id)) 161 | cv2.imwrite(fname, img) 162 | 163 | def detect_all_images(self, output_folder, imagecols, skip_exists=False): 164 | """ 165 | Perform line detection on all images and save the line segments 166 | 167 | Args: 168 | output_folder (str): The output folder 169 | imagecols (:class:`limap.base.ImageCollection`): The input image collection 170 | skip_exists (bool): Whether to skip already processed images 171 | Returns: 172 | dict[int -> :class:`np.array`]: The line detection for each image indexed by the image id. Each segment is with shape (N, 5). Each row corresponds to x1, y1, x2, y2 and score. 173 | """ 174 | seg_folder = self.get_segments_folder(output_folder) 175 | if not skip_exists: 176 | limapio.delete_folder(seg_folder) 177 | limapio.check_makedirs(seg_folder) 178 | if self.visualize: 179 | vis_folder = os.path.join(output_folder, "visualize") 180 | limapio.check_makedirs(vis_folder) 181 | for img_id in tqdm(imagecols.get_img_ids()): 182 | if skip_exists and limapio.exists_txt_segments(seg_folder, img_id): 183 | if self.visualize: 184 | segs = limapio.read_txt_segments(seg_folder, img_id) 185 | else: 186 | segs = self.detect(imagecols.camview(img_id)) 187 | if self.do_merge_lines: 188 | segs = self.merge_lines(segs) 189 | segs, _ = self.take_longest_k(segs, max_num_2d_segs=self.max_num_2d_segs) 190 | limapio.save_txt_segments(seg_folder, img_id, segs) 191 | if self.visualize: 192 | img = imagecols.read_image(img_id) 193 | img = limapvis.draw_segments(img, segs, (0, 255, 0)) 194 | fname = os.path.join(vis_folder, "img_{0}_det.png".format(img_id)) 195 | cv2.imwrite(fname, img) 196 | all_2d_segs = limapio.read_all_segments_from_folder(seg_folder) 197 | all_2d_segs = {id: all_2d_segs[id] for id in imagecols.get_img_ids()} 198 | return all_2d_segs -------------------------------------------------------------------------------- /detectors/line2d/register_linedetector.py: -------------------------------------------------------------------------------- 1 | from .linebase_detector import BaseDetectorOptions 2 | 3 | def get_linedetector(method="lsd", max_num_2d_segs=3000, 4 | do_merge_lines=False, visualize=False, weight_path=None, 5 | cudaid=0): 6 | """ 7 | Get a line detector 8 | """ 9 | options = BaseDetectorOptions() 10 | options = options._replace(max_num_2d_segs=max_num_2d_segs, 11 | do_merge_lines=do_merge_lines, visualize=visualize, weight_path=weight_path, 12 | cudaid=cudaid) 13 | 14 | if method == "lsd": 15 | from .LSD.lsd import LSDDetector 16 | return LSDDetector(options) 17 | elif method == "deeplsd": 18 | from .DeepLSD.deeplsd import DeepLSDDetector 19 | return DeepLSDDetector(options) 20 | else: 21 | raise NotImplementedError -------------------------------------------------------------------------------- /detectors/point2d/SuperPoint/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | import os 44 | import sys 45 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))) 46 | from models.base_model import BaseModel 47 | from pathlib import Path 48 | import torch 49 | from torch import nn 50 | 51 | def simple_nms(scores, nms_radius: int): 52 | """ Fast Non-maximum suppression to remove nearby points """ 53 | assert(nms_radius >= 0) 54 | 55 | def max_pool(x): 56 | return torch.nn.functional.max_pool2d( 57 | x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) 58 | 59 | zeros = torch.zeros_like(scores) 60 | max_mask = scores == max_pool(scores) 61 | for _ in range(2): 62 | supp_mask = max_pool(max_mask.float()) > 0 63 | supp_scores = torch.where(supp_mask, zeros, scores) 64 | new_max_mask = supp_scores == max_pool(supp_scores) 65 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 66 | return torch.where(max_mask, scores, zeros) 67 | 68 | 69 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 70 | """ Removes keypoints too close to the border """ 71 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 72 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 73 | mask = mask_h & mask_w 74 | return keypoints[mask], scores[mask] 75 | 76 | 77 | def top_k_keypoints(keypoints, scores, k: int): 78 | if k >= len(keypoints): 79 | return keypoints, scores 80 | scores, indices = torch.topk(scores, k, dim=0) 81 | return keypoints[indices], scores 82 | 83 | 84 | def sample_descriptors(keypoints, descriptors, s: int = 8): 85 | """ Interpolate descriptors at keypoint locations """ 86 | b, c, h, w = descriptors.shape 87 | keypoints = keypoints - s / 2 + 0.5 88 | keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], 89 | ).to(keypoints)[None] 90 | keypoints = keypoints*2 - 1 # normalize to (-1, 1) 91 | descriptors = torch.nn.functional.grid_sample( 92 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', 93 | align_corners=True) 94 | descriptors = torch.nn.functional.normalize( 95 | descriptors.reshape(b, c, -1), p=2, dim=1) 96 | return descriptors 97 | 98 | class SuperPoint(BaseModel): 99 | """SuperPoint Convolutional Detector and Descriptor 100 | 101 | SuperPoint: Self-Supervised Interest Point Detection and 102 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 103 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 104 | 105 | """ 106 | default_conf = { 107 | 'name': 'SuperPoint', 108 | 'trainable': False, 109 | 'descriptor_dim': 256, 110 | 'nms_radius': 4, 111 | 'keypoint_threshold': 0.005, 112 | 'max_keypoints': -1, 113 | 'remove_borders': 4, 114 | 'weight_path': None, 115 | 'force_num_keypoints': False, 116 | } 117 | required_data_keys = ['image'] 118 | 119 | def _init(self, conf): 120 | if self.conf.force_num_keypoints: 121 | print('[WARNING]: \"force_num_keypoints\" is applied') 122 | self.conf.keypoint_threshold = 0.0 123 | self.relu = nn.ReLU(inplace=True) 124 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 125 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 126 | 127 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 128 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 129 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 130 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 131 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 132 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 133 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 134 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 135 | 136 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 137 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 138 | 139 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 140 | self.convDb = nn.Conv2d( 141 | c5, self.conf.descriptor_dim, 142 | kernel_size=1, stride=1, padding=0) 143 | 144 | if self.conf.weight_path is None: 145 | path = Path(__file__).parent / 'weights/superpoint_v1.pth' 146 | else: 147 | path = os.path.join(self.conf.weight_path, "point2d", "superpoint", "weights/superpoint_v1.pth") 148 | if not os.path.isfile(path): 149 | self.download_model(path) 150 | self.load_state_dict(torch.load(str(path))) 151 | 152 | mk = self.conf.max_keypoints 153 | if mk == 0 or mk < -1: 154 | raise ValueError('\"max_keypoints\" must be positive or \"-1\"') 155 | 156 | print('Loaded SuperPoint model') 157 | 158 | def download_model(self, path): 159 | import subprocess 160 | if not os.path.exists(os.path.dirname(path)): 161 | os.makedirs(os.path.dirname(path)) 162 | link = "https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/superpoint_v1.pth?raw=true" 163 | cmd = ["wget", link, "-O", path] 164 | print("Downloading SuperPoint model...") 165 | subprocess.run(cmd, check=True) 166 | 167 | def compute_dense_descriptor(self, image): 168 | """ Compute keypoints, scores, descriptors for image """ 169 | # Shared Encoder 170 | x = self.relu(self.conv1a(image)) 171 | x = self.relu(self.conv1b(x)) 172 | x = self.pool(x) 173 | x = self.relu(self.conv2a(x)) 174 | x = self.relu(self.conv2b(x)) 175 | x = self.pool(x) 176 | x = self.relu(self.conv3a(x)) 177 | x = self.relu(self.conv3b(x)) 178 | x = self.pool(x) 179 | x = self.relu(self.conv4a(x)) 180 | x = self.relu(self.conv4b(x)) 181 | 182 | # Compute the dense keypoint scores 183 | cPa = self.relu(self.convPa(x)) 184 | scores = self.convPb(cPa) 185 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 186 | b, _, h, w = scores.shape 187 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 188 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) 189 | scores = simple_nms(scores, self.conf.nms_radius) 190 | 191 | # Extract keypoints 192 | keypoints = [ 193 | torch.nonzero(s > self.conf.keypoint_threshold) 194 | for s in scores] 195 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 196 | 197 | # Discard keypoints near the image borders 198 | keypoints, scores = list(zip(*[ 199 | remove_borders(k, s, self.conf.remove_borders, h*8, w*8) 200 | for k, s in zip(keypoints, scores)])) 201 | 202 | # Keep the k keypoints with highest score 203 | if self.conf.max_keypoints >= 0: 204 | keypoints, scores = list(zip(*[ 205 | top_k_keypoints(k, s, self.conf.max_keypoints) 206 | for k, s in zip(keypoints, scores)])) 207 | 208 | # Convert (h, w) to (x, y) 209 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 210 | 211 | # Compute the dense descriptors 212 | cDa = self.relu(self.convDa(x)) 213 | descriptors = self.convDb(cDa) 214 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) 215 | return keypoints, scores, descriptors 216 | 217 | def compute_dense_descriptor_and_score(self, image): 218 | """ Compute dense scores and descriptors for an image """ 219 | # Shared Encoder 220 | x = self.relu(self.conv1a(image)) 221 | x = self.relu(self.conv1b(x)) 222 | x = self.pool(x) 223 | x = self.relu(self.conv2a(x)) 224 | x = self.relu(self.conv2b(x)) 225 | x = self.pool(x) 226 | x = self.relu(self.conv3a(x)) 227 | x = self.relu(self.conv3b(x)) 228 | x = self.pool(x) 229 | x = self.relu(self.conv4a(x)) 230 | x = self.relu(self.conv4b(x)) 231 | 232 | # Compute the dense keypoint scores 233 | cPa = self.relu(self.convPa(x)) 234 | scores = self.convPb(cPa) 235 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 236 | b, _, h, w = scores.shape 237 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 238 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) 239 | 240 | # Compute the dense descriptors 241 | cDa = self.relu(self.convDa(x)) 242 | descriptors = self.convDb(cDa) 243 | dense_descriptor = torch.nn.functional.normalize(descriptors, p=2, dim=1) 244 | return { 245 | 'dense_score': scores, 246 | 'dense_descriptor': dense_descriptor 247 | } 248 | 249 | def sample_descriptors(self, data, keypoints): 250 | _, _, descriptors = self.compute_dense_descriptor(data) 251 | 252 | # Extract descriptors 253 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 254 | for k, d in zip(keypoints, descriptors)] 255 | 256 | return { 257 | 'keypoints': keypoints, 258 | 'descriptors': descriptors 259 | } 260 | def _forward(self, datain): # sample_points_and_lines_descriptors 261 | data = datain[0] # image 262 | keypoints = datain[1] # keypoints 263 | line_keypoints = datain[2] # line_keypoints 264 | if isinstance(keypoints, list): 265 | # found test mode, then get keypoints too 266 | keypoints, _, descriptors = self.compute_dense_descriptor(data) 267 | else: 268 | _, _, descriptors = self.compute_dense_descriptor(data) 269 | # Extract keypoints descriptors 270 | points_descriptors = [sample_descriptors(k[None], d[None], 8)[0] 271 | for k, d in zip(keypoints, descriptors)] 272 | points_descriptors = torch.stack(points_descriptors, dim=0) 273 | 274 | # Extract line keypoints descriptors 275 | bs,nline,npoints,_ = line_keypoints.shape 276 | line_keypoints = line_keypoints.view(bs,nline*npoints,2) 277 | lines_descriptors = [sample_descriptors(k[None], d[None], 8)[0] 278 | for k, d in zip(line_keypoints, descriptors)] 279 | # reshape and merge lines_descriptors 280 | for i in range(bs): 281 | lines_descriptors[i] = lines_descriptors[i].view(-1, nline, npoints).permute(1,2,0) # -> nline x npoints x 256 282 | lines_descriptors = torch.stack(lines_descriptors, dim=0) 283 | return { 284 | 'points_descriptors': points_descriptors, 285 | 'lines_descriptors': lines_descriptors, 286 | 'keypoints': keypoints, 287 | } 288 | 289 | def _forward_default(self, data): 290 | image = data['image'] 291 | keypoints, scores, descriptors = self.compute_dense_descriptor(image) 292 | # Extract descriptors 293 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 294 | for k, d in zip(keypoints, descriptors)] 295 | return { 296 | 'keypoints': keypoints, 297 | 'scores': scores, 298 | 'descriptors': descriptors, 299 | } 300 | def loss(self, pred, data): 301 | raise NotImplementedError -------------------------------------------------------------------------------- /detectors/point2d/register_pointdetector.py: -------------------------------------------------------------------------------- 1 | def get_pointdetector(method="superpoint", configs=dict()): 2 | """ 3 | Get a point detector 4 | """ 5 | if method == "superpoint": 6 | from .SuperPoint.superpoint import SuperPoint 7 | return SuperPoint(configs) 8 | else: 9 | raise NotImplementedError -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from omegaconf import OmegaConf 3 | from abc import ABCMeta, abstractmethod 4 | 5 | class MetaModel(ABCMeta): 6 | def __prepare__(name, bases, **kwds): 7 | total_conf = OmegaConf.create() 8 | for base in bases: 9 | for key in ('base_default_conf', 'default_conf'): 10 | update = getattr(base, key, {}) 11 | if isinstance(update, dict): 12 | update = OmegaConf.create(update) 13 | total_conf = OmegaConf.merge(total_conf, update) 14 | return dict(base_default_conf=total_conf) 15 | 16 | class BaseModel(nn.Module, metaclass=MetaModel): 17 | default_conf = { 18 | 'name': None, 19 | 'trainable': False, 20 | } 21 | required_data = [] 22 | 23 | def __init__(self, conf): 24 | super().__init__() 25 | default_conf = OmegaConf.merge( 26 | self.base_default_conf, OmegaConf.create(self.default_conf)) 27 | self.conf = conf = OmegaConf.merge(default_conf, conf) 28 | self._init(conf) 29 | if not conf.trainable: 30 | for param in self.parameters(): 31 | param.requires_grad = False 32 | 33 | def forward(self, data): 34 | """Check the data and call the _forward method of the child model.""" 35 | def recursive_key_check(expected, given): 36 | for key in expected: 37 | assert key in given, f'Missing key {key} in data' 38 | if isinstance(expected, dict): 39 | recursive_key_check(expected[key], given[key]) 40 | recursive_key_check(self.required_data, data) 41 | return self._forward(data) 42 | 43 | @abstractmethod 44 | def _init(self, conf): 45 | """To be implemented by child class.""" 46 | raise NotImplementedError 47 | @abstractmethod 48 | def _forward(self, data): 49 | """To be implemented by child class.""" 50 | raise NotImplementedError 51 | @abstractmethod 52 | def loss(self, pred, data): 53 | """To be implemented by child class.""" 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /models/pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from models.base_model import BaseModel 5 | from models.util import get_model 6 | import os 7 | import os.path as osp 8 | 9 | class Pipeline(BaseModel): 10 | default_conf = { 11 | 'trainable': True, 12 | } 13 | required_data = ['image', 'original_size', 'keypoints', 'lines'] 14 | 15 | def _init(self, conf): 16 | # get detector model 17 | self.detector = get_model(conf.point2d.detector.name, "detector")(conf.point2d.detector.configs) 18 | assert self.detector.conf.trainable == False, "detector must be fixed, not trainable" 19 | # get regressor model 20 | self.regressor = get_model(conf.regressor.name, "regressor")(conf.regressor).train() 21 | assert self.regressor.conf.trainable == True, "regressor must be trainable" 22 | print(f'The model regresor {conf.regressor.name} has {count_parameters(self.regressor):,} trainable parameters') 23 | 24 | def _forward(self, data): 25 | # Pre process data 26 | # convert lines to line_keypoints | BxLx4 -> BxLx(4+n_line_keypoints*2) 27 | line_keypoints = get_line_keypoints(data['lines'], self.conf.regressor.n_line_keypoints) 28 | # sample descriptors using superpoint 29 | regressor_data = self.detector((data['image'], data['keypoints'], line_keypoints)) 30 | # regress descriptors to 3D points and lines 31 | pred = self.regressor(regressor_data) 32 | pred['keypoints'] = regressor_data['keypoints'] 33 | pred['lines'] = data['lines'] 34 | return pred 35 | 36 | def loss(self, pred, data): 37 | pass 38 | def save_checkpoint(self, path, name, epoch, final = False): 39 | if os.path.exists(path) == False: 40 | os.makedirs(path) 41 | filename = osp.join(path, '{}_final.pth.tar'.format(name)) if final \ 42 | else osp.join(path, '{}.pth.tar'.format(name)) 43 | checkpoint_dict =\ 44 | {'epoch': epoch, 'model_state_dict': self.regressor.state_dict()} 45 | torch.save(checkpoint_dict, filename) 46 | def load_checkpoint(self, path, exp_name): 47 | ''' Load regressor checkpoint from path''' 48 | filename = osp.join(path, '{}.pth.tar'.format(exp_name)) 49 | if not osp.exists(filename): 50 | raise FileNotFoundError(f'Cannot find checkpoint at {filename}') 51 | devide = torch.device(f'cuda:{torch.cuda.current_device()}' \ 52 | if torch.cuda.is_available() else 'cpu') 53 | checkpoint_dict = torch.load(filename, map_location=torch.device(devide)) 54 | self.regressor.load_state_dict(checkpoint_dict['model_state_dict']) 55 | print(f'[INFOR] Loaded checkpoint from {filename}') 56 | return checkpoint_dict['epoch'] 57 | 58 | def get_line_keypoints(lines, n_line_keypoints): 59 | # convert lines to line_keypoints | BxLx4 -> BxLx(n_line_keypoints+2)x2 60 | bs,n_line,_ = lines.shape 61 | total_points = n_line_keypoints + 2 # start point + end point + n_line_keypoints 62 | line_keypoints = lines.new_zeros((bs,n_line, total_points,2)) 63 | line_keypoints[:,:,0,:] = lines[:,:,:2] # start point 64 | line_keypoints[:,:,total_points-1,:] = lines[:,:,2:] # end point 65 | per_distance = (lines[:,:,2:] - lines[:,:,:2])/(n_line_keypoints+2-1) # stop - start point 66 | for i in range(n_line_keypoints): 67 | line_keypoints[:,:,i+1,:] = line_keypoints[:,:,0,:] + per_distance*(i+1) 68 | return line_keypoints 69 | 70 | def count_parameters(model): 71 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 72 | 73 | -------------------------------------------------------------------------------- /models/pl2map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from models.base_model import BaseModel 5 | from models.util import get_model 6 | from copy import deepcopy 7 | from typing import Tuple, List 8 | import torch.nn.functional as F 9 | 10 | class PL2Map(BaseModel): 11 | default_conf = { 12 | 'trainable': True, 13 | 'n_heads': 4, 14 | 'd_inner': 1024, 15 | 'n_att_layers': 1, 16 | 'feature_dim': 256, 17 | 'GNN_layers': ['self', 'cross', 'self', 'cross', 'self'], 18 | 'mapping_layers': [512, 1024, 512], 19 | } 20 | required_data = ['points_descriptors', 'lines_descriptors'] 21 | 22 | def _init(self, conf): 23 | self.line_encoder = LineEncoder(conf.feature_dim, conf.n_heads, conf.n_att_layers, conf.d_inner) 24 | self.gnn = AttentionalGNN( 25 | feature_dim=self.conf.feature_dim, layer_names=self.conf.GNN_layers) 26 | self.mapping_p = MLP([conf.feature_dim]+self.conf.mapping_layers+[4]) # mapping point descriptors to 3D points 27 | self.mapping_l = MLP([conf.feature_dim]+self.conf.mapping_layers+[7]) # mapping line descriptors to 3D lines 28 | 29 | def _forward(self, data): 30 | # get line descriptors 31 | p_desc = data['points_descriptors'] 32 | l_desc = self.line_encoder(data['lines_descriptors']) 33 | p_desc, l_desc = self.gnn(p_desc, l_desc) 34 | pred = {} 35 | pred['points3D'] = self.mapping_p(p_desc) 36 | pred['lines3D'] = self.mapping_l(l_desc) 37 | return pred 38 | def loss(self, pred, data): 39 | pass 40 | 41 | 42 | class ScaledDotProduct(nn.Module): 43 | """ Scaled Dot-Product Attention """ 44 | def __init__(self, scale, attn_dropout=0.1): 45 | super().__init__() 46 | self.scale = scale 47 | self.dropout = nn.Dropout(attn_dropout) 48 | 49 | def forward(self, q, k, v, mask=None): 50 | attn = torch.matmul(q / self.scale, k.transpose(3, 4)) 51 | if mask is not None: 52 | attn = attn.masked_fill(mask == 0, -1e9) 53 | attn = self.dropout(F.softmax(attn, dim=-1)) 54 | output = torch.matmul(attn, v) 55 | 56 | return output, attn 57 | 58 | class MultiHeadAttention_Line(nn.Module): 59 | """ Multi-Headed Attention """ 60 | def __init__(self, n_heads: int, d_feature: int, dropout=0.1): 61 | super().__init__() 62 | assert d_feature % n_heads == 0 63 | dim = d_feature // n_heads 64 | self.dim = dim 65 | self.n_heads = n_heads 66 | 67 | self.w_qs = nn.Linear(d_feature, n_heads * dim, bias=True) 68 | self.w_ks = nn.Linear(d_feature, n_heads * dim, bias=True) 69 | self.w_vs = nn.Linear(d_feature, n_heads * dim, bias=True) 70 | self.fc = nn.Linear(n_heads * dim, d_feature, bias=True) 71 | 72 | self.attention = ScaledDotProduct(scale = dim ** 0.5) 73 | 74 | self.dropout = nn.Dropout(dropout) 75 | self.layer_norm = nn.LayerNorm(d_feature, eps=1e-6) 76 | 77 | def forward(self, q, k, v, mask=None): 78 | d_k = self.dim 79 | d_v = self.dim 80 | n_heads = self.n_heads 81 | 82 | n_batches = q.size(0) 83 | n_sublines = q.size(1) 84 | n_words_q = q.size(2) 85 | n_words_k = k.size(2) 86 | n_words_v = v.size(2) 87 | 88 | residual = q 89 | 90 | q = self.w_qs(q).view(n_batches, n_sublines, n_words_q, n_heads, d_k) 91 | k = self.w_ks(k).view(n_batches, n_sublines, n_words_k, n_heads, d_k) 92 | v = self.w_vs(v).view(n_batches, n_sublines, n_words_v, n_heads, d_k) 93 | 94 | # Transpose for attention dot product: b x n x lq x dv 95 | q, k, v = q.transpose(2, 3), k.transpose(2, 3), v.transpose(2, 3) 96 | 97 | if mask is not None: 98 | mask = mask.unsqueeze(2) # For head axis broadcasting. 99 | 100 | q, attn = self.attention(q, k, v, mask=mask) 101 | 102 | # Transpose to move the head dimension back: b x lq x n x dv 103 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 104 | q = q.transpose(2,3).contiguous().view(n_batches, n_sublines, n_words_q, -1) 105 | q = self.dropout(self.fc(q)) 106 | 107 | q += residual 108 | q = self.layer_norm(q) 109 | 110 | return q, attn 111 | 112 | class FeedForward(nn.Module): 113 | """ Feed Forward layer """ 114 | def __init__(self, d_in, d_hid, dropout=0.1): 115 | super().__init__() 116 | self.w_1 = nn.Linear(d_in, d_hid) # d_in: 256, d_hid: 1024 117 | self.w_2 = nn.Linear(d_hid, d_in) # d_hid: 1024, d_in: 256 118 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 119 | self.dropout = nn.Dropout(dropout) 120 | def forward(self,x): 121 | residual = x 122 | x = self.w_2(F.gelu(self.w_1(x))) 123 | x = self.dropout(x) 124 | x += residual 125 | x = self.layer_norm(x) 126 | return x 127 | 128 | 129 | class LineDescriptiveEncoder(nn.Module): 130 | """ Line Descriptive Network using the transformer """ 131 | def __init__(self, d_feature: int, n_heads: int, d_inner: int, dropout=0.1): 132 | super().__init__() 133 | self.slf_attn = MultiHeadAttention_Line(n_heads, d_feature) 134 | self.pos_ffn = FeedForward(d_feature, d_inner, dropout=dropout) 135 | 136 | def forward(self, desc, slf_attn_mask=None): 137 | 138 | desc, enc_slf_attn = self.slf_attn(desc, desc, desc, mask=slf_attn_mask) 139 | desc = self.pos_ffn(desc) 140 | 141 | return desc, enc_slf_attn 142 | 143 | class LineEncoder(nn.Module): 144 | """ LineEncoder mimics the transformer model""" 145 | def __init__(self, feature_dim, n_heads, n_att_layers, d_inner, dropout=0.1): 146 | super().__init__() 147 | self.feature_dim = feature_dim 148 | self.desc_layers = nn.ModuleList([ 149 | LineDescriptiveEncoder(feature_dim, n_heads, d_inner, dropout=dropout) 150 | for _ in range(n_att_layers)]) 151 | 152 | def forward(self, desc, return_attns=False): 153 | enc_slf_attn_list = [] 154 | for desc_layer in self.desc_layers: 155 | enc_output, enc_slf_attn = desc_layer(desc) 156 | enc_slf_attn_list += [enc_slf_attn] if return_attns else [] 157 | # get the first token of each line 158 | sentence = enc_output[:,:,0,:].transpose(1,2) 159 | return sentence # line descriptors 160 | 161 | 162 | def MLP(channels:list): 163 | layers = [] 164 | n_chnls = len(channels) 165 | for i in range(1, n_chnls): 166 | layers.append(nn.Conv1d(channels[i-1], channels[i], 167 | kernel_size=1, bias=True)) 168 | if i < n_chnls-1: 169 | layers.append(nn.ReLU()) 170 | return nn.Sequential(*layers) 171 | 172 | 173 | def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: 174 | dim = query.shape[1] 175 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 176 | prob = torch.nn.functional.softmax(scores, dim=-1) 177 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 178 | 179 | 180 | class MultiHeadedAttention(nn.Module): 181 | """ Multi-head attention to increase model expressivitiy """ 182 | def __init__(self, num_heads: int, d_model: int): 183 | super().__init__() 184 | assert d_model % num_heads == 0 185 | self.dim = d_model // num_heads 186 | self.num_heads = num_heads 187 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 188 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 189 | 190 | def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: 191 | batch_dim = query.size(0) 192 | query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) 193 | for l, x in zip(self.proj, (query, key, value))] 194 | x, _ = attention(query, key, value) 195 | return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) 196 | 197 | 198 | class AttentionalPropagation(nn.Module): 199 | def __init__(self, feature_dim: int, num_heads: int): 200 | super().__init__() 201 | self.attn = MultiHeadedAttention(num_heads, feature_dim) 202 | self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) 203 | nn.init.constant_(self.mlp[-1].bias, 0.0) 204 | 205 | def forward(self, x: torch.Tensor, source: torch.Tensor) -> torch.Tensor: 206 | message = self.attn(x, source, source) 207 | return self.mlp(torch.cat([x, message], dim=1)) 208 | 209 | class AttentionalGNN(nn.Module): 210 | def __init__(self, feature_dim: int, layer_names: List[str]) -> None: 211 | super().__init__() 212 | self.layers = nn.ModuleList([ 213 | AttentionalPropagation(feature_dim, 4) 214 | for _ in range(len(layer_names))]) 215 | self.names = layer_names 216 | def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]: 217 | for layer, name in zip(self.layers, self.names): 218 | if name == 'cross': 219 | src0, src1 = desc1, desc0 220 | else: # if name == 'self': 221 | src0, src1 = desc0, desc1 222 | delta0, delta1 = layer(desc0, src0), layer(desc1, src1) 223 | desc0, desc1 = (desc0 + delta0), (desc1 + delta1) 224 | return desc0, desc1 -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import collections.abc as collections 2 | import torch 3 | 4 | 5 | def get_class(mod_path, BaseClass): 6 | """Get the class object which inherits from BaseClass and is defined in 7 | the module named mod_name, child of base_path. 8 | """ 9 | import inspect 10 | mod = __import__(mod_path, fromlist=['']) 11 | classes = inspect.getmembers(mod, inspect.isclass) 12 | # Filter classes defined in the module 13 | classes = [c for c in classes if c[1].__module__ == mod_path] 14 | # Filter classes inherited from BaseModel 15 | classes = [c for c in classes if issubclass(c[1], BaseClass)] 16 | assert len(classes) == 1, classes 17 | return classes[0][1] 18 | 19 | 20 | def get_model(name, _type = "detector"): 21 | from models.base_model import BaseModel 22 | if _type == "detector" and name == "superpoint": 23 | base_path = 'detectors.point2d.SuperPoint.' 24 | elif _type == "regressor": 25 | base_path = 'models.' 26 | return get_class(base_path + name, BaseModel) 27 | 28 | 29 | def numpy_image_to_torch(image): 30 | """Normalize the image tensor and reorder the dimensions.""" 31 | if image.ndim == 3: 32 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 33 | elif image.ndim == 2: 34 | image = image[None] # add channel axis 35 | else: 36 | raise ValueError(f'Not an image: {image.shape}') 37 | return torch.from_numpy(image / 255.).float() 38 | 39 | 40 | def map_tensor(input_, func): 41 | if isinstance(input_, (str, bytes)): 42 | return input_ 43 | elif isinstance(input_, collections.Mapping): 44 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 45 | elif isinstance(input_, collections.Sequence): 46 | return [map_tensor(sample, func) for sample in input_] 47 | else: 48 | return func(input_) 49 | 50 | 51 | def batch_to_np(batch): 52 | return map_tensor(batch, lambda t: t.detach().cpu().numpy()[0]) 53 | -------------------------------------------------------------------------------- /models/util_learner.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | import torch 4 | 5 | class CriterionPointLine(nn.Module): 6 | ''' 7 | Criterion for point and line''' 8 | def __init__(self, rpj_cfg, total_iterations=2000000): 9 | super(CriterionPointLine, self).__init__() 10 | self.rpj_cfg = rpj_cfg 11 | self.reprojection_loss = ReproLoss(total_iterations, self.rpj_cfg.soft_clamp, 12 | self.rpj_cfg.soft_clamp_min, self.rpj_cfg.type, 13 | self.rpj_cfg.circle_schedule) 14 | self.zero = fakezero() 15 | self.total_iterations = total_iterations 16 | 17 | def forward(self, pred, target, iteration=2000000): 18 | batch_size, _, _ = pred['points3D'].shape 19 | validPoints = target["validPoints"] 20 | validLines = target["validLines"] 21 | # get losses for points 22 | square_errors_points = torch.norm((pred['points3D'][:,:3,:] - target["points3D"]), dim = 1) 23 | loss_points = torch.sum(validPoints*square_errors_points)/batch_size 24 | uncer_loss_points = torch.sum(torch.norm(validPoints - 1/(1+100*torch.abs(pred['points3D'][:,3,:])), dim = 1))/batch_size 25 | # get losses for lines 26 | square_errors_lines = torch.norm((pred['lines3D'][:,:6,:] - target["lines3D"]), dim = 1) 27 | loss_lines = torch.sum(validLines*square_errors_lines)/batch_size 28 | uncer_loss_lines = torch.sum(torch.norm(validLines - 1/(1+100*torch.abs(pred['lines3D'][:,6,:])), dim = 1))/batch_size 29 | 30 | points_proj_loss = 0 31 | lines_proj_loss = 0 32 | 33 | if self.rpj_cfg.apply: 34 | # get projection losses for points 35 | for i in range(batch_size): # default batch_size = 1 36 | prp_error, prp= project_loss_points(pred['keypoints'][i,:,:], pred['points3D'][i,:3,:], 37 | target['pose'][i,:], target['camera'][i,:], validPoints[i,:]) 38 | points_proj_loss += self.reprojection_loss.compute_point(prp_error, prp, iteration, validPoints[i,:]) 39 | points_proj_loss = points_proj_loss / batch_size 40 | # get projection losses for lines 41 | 42 | for i in range(batch_size): 43 | prl_error, prp_s, prp_e = project_loss_lines(pred['lines'][i,:,:], pred['lines3D'][i,:6,:], 44 | target['pose'][i,:], target['camera'][i,:], validLines[i,:]) 45 | lines_proj_loss += self.reprojection_loss.compute_line(prl_error, prp_s, prp_e, iteration, validLines[i,:]) 46 | lines_proj_loss = lines_proj_loss / batch_size 47 | if iteration/self.total_iterations < self.rpj_cfg.start_apply: 48 | total_loss = loss_points + uncer_loss_points + loss_lines + uncer_loss_lines 49 | else: 50 | total_loss = loss_points + uncer_loss_points + loss_lines + uncer_loss_lines + points_proj_loss + lines_proj_loss 51 | 52 | points_proj_loss = self.zero if (isinstance(points_proj_loss, int) or isinstance(points_proj_loss, float)) else points_proj_loss 53 | lines_proj_loss = self.zero if (isinstance(lines_proj_loss, int) or isinstance(lines_proj_loss, float)) else lines_proj_loss 54 | return total_loss, loss_points, uncer_loss_points, loss_lines, uncer_loss_lines, points_proj_loss, lines_proj_loss 55 | 56 | 57 | class fakezero(object): 58 | def __init__(self): 59 | pass 60 | def item(self): 61 | return 0 62 | 63 | 64 | def qvec2rotmat(qvec): 65 | return torch.tensor([ 66 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 67 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 68 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 69 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 70 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 71 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 72 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 73 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 74 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 75 | 76 | def project_loss_points(gt_pt2Ds, pt3Ds, c_pose, camera, valids): 77 | ''' 78 | gt_pt2Ds: 2xN 79 | pt3Ds: 3xN 80 | c_pose: 1x7 81 | camera: 1x5 82 | valids: 1xN 83 | ''' 84 | device = pt3Ds.device 85 | R = qvec2rotmat(c_pose[3:]).to(device=device) 86 | t = torch.unsqueeze(c_pose[:3], dim = 1).to(device=device) 87 | if camera[0] == 0.0: # SIMPLE_PINHOLE 88 | fx = fy = camera[3] # focal length 89 | ppx = camera[4] 90 | ppy = camera[5] 91 | elif camera[0] == 1.0: # PINHOLE 92 | fx = camera[3] # focal length 93 | fy = camera[4] 94 | ppx = camera[5] 95 | ppy = camera[6] 96 | else: 97 | raise f"Camera type {camera[0]} is not implemented" 98 | prd_2Ds = R@pt3Ds + t 99 | # project 100 | px = fx*prd_2Ds[0,:]/prd_2Ds[2,:] + ppx 101 | py = fy*prd_2Ds[1,:]/prd_2Ds[2,:] + ppy 102 | errors_x = (gt_pt2Ds[:,0] - px)**2 103 | errors_y = (gt_pt2Ds[:,1] - py)**2 104 | # return torch.mean(valids * torch.sqrt(errors_x + errors_y)) 105 | return torch.sqrt(errors_x + errors_y), prd_2Ds # l2 distance error, and projected 2D points 106 | 107 | def project_loss_lines(gt_line2Ds, line3Ds, c_pose, camera, valids): 108 | ''' 109 | gt_line2Ds: 4xN 110 | line3Ds: 6xN 111 | c_pose: 1x7 # camera pose 112 | camera: 1x5 113 | valids: Nx1 114 | ''' 115 | device = line3Ds.device 116 | R = qvec2rotmat(c_pose[3:]).to(device=device) 117 | t = torch.unsqueeze(c_pose[:3], dim = 1).to(device=device) 118 | if camera[0] == 0.0: # SIMPLE_PINHOLE 119 | fx = fy = camera[3] # focal length 120 | ppx = camera[4] 121 | ppy = camera[5] 122 | elif camera[0] == 1.0: # PINHOLE 123 | fx = camera[3] # focal length 124 | fy = camera[4] 125 | ppx = camera[5] 126 | ppy = camera[6] 127 | else: 128 | raise f"Camera type {camera[0]} is not implemented" 129 | start_point = line3Ds[:3,:] 130 | end_point = line3Ds[3:,:] 131 | prd_2Ds_start = R@start_point + t 132 | prd_2Ds_end = R@end_point + t 133 | # project start point 134 | px_start = fx*prd_2Ds_start[0,:]/prd_2Ds_start[2,:] + ppx # (N,) 135 | py_start = fy*prd_2Ds_start[1,:]/prd_2Ds_start[2,:] + ppy # (N,) 136 | 137 | # project end point 138 | px_end = fx*prd_2Ds_end[0,:]/prd_2Ds_end[2,:] + ppx # (N,) 139 | py_end = fy*prd_2Ds_end[1,:]/prd_2Ds_end[2,:] + ppy # (N,) 140 | 141 | # project startpoint to line 142 | AB = gt_line2Ds[:,2:4] - gt_line2Ds[:,0:2] # ground truth line vector 143 | APstart = torch.stack([px_start - gt_line2Ds[:,0], py_start - gt_line2Ds[:,1]], dim = 1) 144 | APend = torch.stack([px_end - gt_line2Ds[:,0], py_end - gt_line2Ds[:,1]], dim = 1) 145 | # calculate the cross product 146 | cross_product_start = APstart[:,0]*AB[:,1] - APstart[:,1]*AB[:,0] 147 | AB_magnitude = torch.sqrt((AB**2).sum(dim=1)) 148 | # calculate the distance 149 | distance_start = torch.abs(cross_product_start) / AB_magnitude 150 | cross_product_end = APend[:,0]*AB[:,1] - APend[:,1]*AB[:,0] 151 | # calculate the distance 152 | distance_end = torch.abs(cross_product_end) / AB_magnitude 153 | repr_error = distance_start + distance_end 154 | # return torch.mean(valids * (repr_error)) 155 | return repr_error, prd_2Ds_start, prd_2Ds_end # l2 distance, and projected 2D points 156 | 157 | 158 | 159 | 160 | def weighted_tanh(repro_errs, weight): 161 | # return weight * torch.tanh(repro_errs / weight).sum() 162 | return torch.mean(weight * torch.tanh(repro_errs / weight)) 163 | 164 | import numpy as np 165 | class ReproLoss: 166 | """ 167 | Original from: https://github.com/nianticlabs/ace 168 | Compute per-pixel reprojection loss using different configurable approaches. 169 | 170 | - tanh: tanh loss with a constant scale factor given by the `soft_clamp` parameter (when a pixel's reprojection 171 | error is equal to `soft_clamp`, its loss is equal to `soft_clamp * tanh(1)`). 172 | - dyntanh: Used in the paper, similar to the tanh loss above, but the scaling factor decreases during the course of 173 | the training from `soft_clamp` to `soft_clamp_min`. The decrease is linear, unless `circle_schedule` 174 | is True (default), in which case it applies a circular scheduling. See paper for details. 175 | - l1: Standard L1 loss, computed only on those pixels having an error lower than `soft_clamp` 176 | - l1+sqrt: L1 loss for pixels with reprojection error smaller than `soft_clamp` and 177 | `sqrt(soft_clamp * reprojection_error)` for pixels with a higher error. 178 | - l1+logl1: Similar to the above, but using log L1 for pixels with high reprojection error. 179 | """ 180 | 181 | def __init__(self, 182 | total_iterations, 183 | soft_clamp=50, 184 | soft_clamp_min=1, 185 | type='dyntanh', 186 | circle_schedule=True): 187 | 188 | self.total_iterations = total_iterations 189 | self.soft_clamp = soft_clamp 190 | self.soft_clamp_min = soft_clamp_min 191 | self.type = type 192 | self.circle_schedule = circle_schedule 193 | 194 | def compute_point(self, reprojection_error_b1, pred_cam_coords_b31, iteration, valids): 195 | 196 | # Predicted coordinates behind or close to camera plane. 197 | invalid_min_depth_b1 = pred_cam_coords_b31[2, :] < 0.1 # 0.1 is the min depth 198 | # Very large reprojection errors. 199 | invalid_repro_b1 = reprojection_error_b1 > 1000 # repro_loss_hard_clamp 200 | # Predicted coordinates beyond max distance. 201 | invalid_max_depth_b1 = pred_cam_coords_b31[2, :] > 1000 # 1000 is the max depth 202 | valids = valids.bool() 203 | # Invalid mask is the union of all these. Valid mask is the opposite. 204 | invalid_mask_b1 = (valids | invalid_min_depth_b1 | invalid_repro_b1 | invalid_max_depth_b1) 205 | valid_mask_b1 = ~invalid_mask_b1 206 | 207 | # Reprojection error for all valid scene coordinates. 208 | repro_errs_b1N = reprojection_error_b1[valid_mask_b1] # valid_reprojection_error_b1 209 | return self.final_compute(repro_errs_b1N, iteration) 210 | 211 | def compute_line(self, reprojection_error_b1, pred_cam_coords_b31_1, 212 | pred_cam_coords_b31_2, iteration, valids): 213 | # Predicted coordinates behind or close to camera plane. 214 | invalid_min_depth_b1_1 = pred_cam_coords_b31_1[2, :] < 0.1 # 0.1 is the min depth 215 | invalid_min_depth_b1_2 = pred_cam_coords_b31_2[2, :] < 0.1 # 0.1 is the min depth 216 | # Very large reprojection errors. 217 | invalid_repro_b1 = reprojection_error_b1 > 1000 # repro_loss_hard_clamp 218 | # Predicted coordinates beyond max distance. 219 | invalid_max_depth_b1_1 = pred_cam_coords_b31_1[2, :] > 1000 # 1000 is the max depth 220 | invalid_max_depth_b1_2 = pred_cam_coords_b31_2[2, :] > 1000 # 1000 is the max depth 221 | valids = valids.bool() 222 | # Invalid mask is the union of all these. Valid mask is the opposite. 223 | invalid_mask_b1 = (valids | invalid_min_depth_b1_1 | invalid_repro_b1 | invalid_max_depth_b1_1 224 | | invalid_min_depth_b1_2 | invalid_max_depth_b1_2) 225 | valid_mask_b1 = ~invalid_mask_b1 226 | 227 | # Reprojection error for all valid scene coordinates. 228 | repro_errs_b1N = reprojection_error_b1[valid_mask_b1] # valid_reprojection_error_b1 229 | return self.final_compute(repro_errs_b1N, iteration) 230 | 231 | def final_compute(self, repro_errs_b1N, iteration): 232 | 233 | if repro_errs_b1N.nelement() == 0: 234 | return 0 235 | 236 | if self.type == "tanh": 237 | return weighted_tanh(repro_errs_b1N, self.soft_clamp) 238 | 239 | elif self.type == "dyntanh": 240 | # Compute the progress over the training process. 241 | schedule_weight = iteration / self.total_iterations 242 | 243 | if self.circle_schedule: 244 | # Optionally scale it using the circular schedule. 245 | schedule_weight = 1 - np.sqrt(1 - schedule_weight ** 2) 246 | 247 | # Compute the weight to use in the tanh loss. 248 | loss_weight = (1 - schedule_weight) * self.soft_clamp + self.soft_clamp_min 249 | 250 | # Compute actual loss. 251 | return weighted_tanh(repro_errs_b1N, loss_weight) 252 | 253 | elif self.type == "l1": 254 | # L1 loss on all pixels with small-enough error. 255 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp 256 | return repro_errs_b1N[~softclamp_mask_b1].sum() 257 | 258 | elif self.type == "l1+sqrt": 259 | # L1 loss on pixels with small errors and sqrt for the others. 260 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp 261 | loss_l1 = repro_errs_b1N[~softclamp_mask_b1].sum() 262 | loss_sqrt = torch.sqrt(self.soft_clamp * repro_errs_b1N[softclamp_mask_b1]).sum() 263 | 264 | return loss_l1 + loss_sqrt 265 | 266 | else: 267 | # l1+logl1: same as above, but use log(L1) for pixels with a larger error. 268 | softclamp_mask_b1 = repro_errs_b1N > self.soft_clamp 269 | loss_l1 = repro_errs_b1N[~softclamp_mask_b1].sum() 270 | loss_logl1 = torch.log(1 + (self.soft_clamp * repro_errs_b1N[softclamp_mask_b1])).sum() 271 | 272 | return loss_l1 + loss_logl1 273 | 274 | #### Optimizer #### 275 | 276 | class Optimizer: 277 | """ 278 | Wrapper around torch.optim + learning rate 279 | """ 280 | def __init__(self, params, nepochs, **kwargs): 281 | self.method = kwargs.pop("method") 282 | self.base_lr = kwargs.pop("base_lr") 283 | self.lr = self.base_lr 284 | self.lr_decay_step = int(nepochs/kwargs.pop("num_lr_decay_step")) 285 | self.lr_decay = kwargs.pop('lr_decay') 286 | self.nfactor = 0 287 | if self.method == 'sgd': 288 | print("OPTIMIZER: --- sgd") 289 | self.learner = optim.SGD(params, lr=self.base_lr, 290 | weight_decay=kwargs.pop("weight_decay"), **kwargs) 291 | elif self.method == 'adam': 292 | print("OPTIMIZER: --- adam") 293 | self.learner = optim.Adam(params, lr=self.base_lr, 294 | weight_decay=kwargs.pop("weight_decay"), **kwargs) 295 | elif self.method == 'rmsprop': 296 | print("OPTIMIZER: --- rmsprop") 297 | self.learner = optim.RMSprop(params, lr=self.base_lr, 298 | weight_decay=kwargs.pop("weight_decay"), **kwargs) 299 | 300 | def adjust_lr(self, epoch): 301 | ''' Adjust learning rate based on epoch. 302 | Optional: call this function if keep training the model after loading checkpoint 303 | ''' 304 | if (self.method not in ['sgd', 'adam']) or (self.lr_decay_step == 0.0): 305 | return self.base_lr 306 | nfactor = epoch // self.lr_decay_step 307 | if nfactor > self.nfactor: 308 | decay_factor = (1-self.lr_decay)**nfactor 309 | self.lr = self.base_lr * decay_factor 310 | for param_group in self.learner.param_groups: 311 | param_group['lr'] = self.lr 312 | return self.lr 313 | 314 | 315 | -------------------------------------------------------------------------------- /prepare_scripts/cambridge.sh: -------------------------------------------------------------------------------- 1 | # Description: Prepare the directory structure for the Cambridge dataset 2 | 3 | if [ ! -d "train_test_datasets" ]; then 4 | mkdir train_test_datasets 5 | fi 6 | 7 | if [ ! -d "train_test_datasets/gt_3Dmodels" ]; then 8 | mkdir train_test_datasets/gt_3Dmodels 9 | fi 10 | 11 | if [ ! -d "train_test_datasets/imgs_datasets" ]; then 12 | mkdir train_test_datasets/imgs_datasets 13 | fi 14 | 15 | TARGET_FOLDER="train_test_datasets/gt_3Dmodels" 16 | OUTPUT_FILE="Cambridge.zip" 17 | FILE_ID="19LRQ5j9I4YdrUykkoavcRTR6ekygU5iU" 18 | 19 | # Download the file from Google Drive using gdown and save it in the target folder 20 | gdown --id $FILE_ID -O $TARGET_FOLDER/$OUTPUT_FILE 21 | 22 | # Unzip the downloaded file in the target folder 23 | unzip $TARGET_FOLDER/$OUTPUT_FILE -d $TARGET_FOLDER 24 | 25 | # Remove the zip file after extraction 26 | rm $TARGET_FOLDER/$OUTPUT_FILE 27 | 28 | echo "Download, extraction, and cleanup completed in $TARGET_FOLDER." 29 | 30 | TARGET_FOLDER="train_test_datasets/imgs_datasets" 31 | FILE_ID="1MZyLPu9Z7tKCeuM4DchseoX4STIhKyi7" 32 | 33 | # Download the file from Google Drive using gdown and save it in the target folder 34 | gdown --id $FILE_ID -O $TARGET_FOLDER/$OUTPUT_FILE 35 | 36 | # Unzip the downloaded file in the target folder 37 | unzip $TARGET_FOLDER/$OUTPUT_FILE -d $TARGET_FOLDER 38 | 39 | # Remove the zip file after extraction 40 | rm $TARGET_FOLDER/$OUTPUT_FILE 41 | 42 | echo "Download, extraction, and cleanup completed in $TARGET_FOLDER." -------------------------------------------------------------------------------- /prepare_scripts/download_pre_trained_models.sh: -------------------------------------------------------------------------------- 1 | OUTPUT_FILE="logs.zip" 2 | FILE_ID="1iH8PfqgPPQod0q_I8T_ZSO_mSj5XRUuO" 3 | 4 | # Download the file from Google Drive using gdown and save it in the target folder 5 | gdown --id $FILE_ID -O $OUTPUT_FILE 6 | 7 | # Unzip the downloaded file in the target folder 8 | unzip $OUTPUT_FILE 9 | 10 | # Remove the zip file after extraction 11 | rm $OUTPUT_FILE 12 | 13 | echo "Download, extraction, and cleanup completed." -------------------------------------------------------------------------------- /prepare_scripts/indoor6.sh: -------------------------------------------------------------------------------- 1 | # Description: Prepare the directory structure for the indoor6 dataset 2 | 3 | if [ ! -d "train_test_datasets" ]; then 4 | mkdir train_test_datasets 5 | fi 6 | 7 | if [ ! -d "train_test_datasets/gt_3Dmodels" ]; then 8 | mkdir train_test_datasets/gt_3Dmodels 9 | fi 10 | 11 | if [ ! -d "train_test_datasets/imgs_datasets" ]; then 12 | mkdir train_test_datasets/imgs_datasets 13 | fi 14 | 15 | TARGET_FOLDER="train_test_datasets/gt_3Dmodels" 16 | OUTPUT_FILE="indoor6.zip" 17 | FILE_ID="1q28Tkldc--ucD4l7q15RDVsuZ7IN3CEV" 18 | 19 | # Download the file from Google Drive using gdown and save it in the target folder 20 | gdown --id $FILE_ID -O $TARGET_FOLDER/$OUTPUT_FILE 21 | 22 | # Unzip the downloaded file in the target folder 23 | unzip $TARGET_FOLDER/$OUTPUT_FILE -d $TARGET_FOLDER 24 | 25 | # Remove the zip file after extraction 26 | rm $TARGET_FOLDER/$OUTPUT_FILE 27 | 28 | echo "Download, extraction, and cleanup completed in $TARGET_FOLDER." 29 | 30 | TARGET_FOLDER="train_test_datasets/imgs_datasets" 31 | FILE_ID="1kzLPt7LuVJIqKrJMYSFicJ231KDDJxVh" 32 | 33 | # Download the file from Google Drive using gdown and save it in the target folder 34 | gdown --id $FILE_ID -O $TARGET_FOLDER/$OUTPUT_FILE 35 | 36 | # Unzip the downloaded file in the target folder 37 | unzip $TARGET_FOLDER/$OUTPUT_FILE -d $TARGET_FOLDER 38 | 39 | # Remove the zip file after extraction 40 | rm $TARGET_FOLDER/$OUTPUT_FILE 41 | 42 | echo "Download, extraction, and cleanup completed in $TARGET_FOLDER." -------------------------------------------------------------------------------- /prepare_scripts/seven_scenes.sh: -------------------------------------------------------------------------------- 1 | # Description: Prepare the directory structure for the seven scene dataset 2 | 3 | if [ ! -d "train_test_datasets" ]; then 4 | mkdir train_test_datasets 5 | fi 6 | 7 | if [ ! -d "train_test_datasets/gt_3Dmodels" ]; then 8 | mkdir train_test_datasets/gt_3Dmodels 9 | fi 10 | 11 | if [ ! -d "train_test_datasets/imgs_datasets" ]; then 12 | mkdir train_test_datasets/imgs_datasets 13 | fi 14 | 15 | TARGET_FOLDER="train_test_datasets/gt_3Dmodels" 16 | OUTPUT_FILE="7scenes.zip" 17 | FILE_ID="1X8_tV0Y4b_W-vPgeXKoqtFaDCQ5_csL3" 18 | 19 | # Download the file from Google Drive using gdown and save it in the target folder 20 | gdown --id $FILE_ID -O $TARGET_FOLDER/$OUTPUT_FILE 21 | 22 | # Unzip the downloaded file in the target folder 23 | unzip $TARGET_FOLDER/$OUTPUT_FILE -d $TARGET_FOLDER 24 | 25 | # Remove the zip file after extraction 26 | rm $TARGET_FOLDER/$OUTPUT_FILE 27 | 28 | echo "Download, extraction, and cleanup completed in $TARGET_FOLDER." 29 | 30 | 31 | cd train_test_datasets/imgs_datasets 32 | mkdir 7scenes 33 | cd 7scenes 34 | 35 | # List of datasets 36 | datasets=("chess" "fire" "heads" "office" "pumpkin" "redkitchen" "stairs") 37 | 38 | # Loop through each dataset 39 | for ds in "${datasets[@]}"; do 40 | # Check if the dataset directory exists 41 | if [ ! -d "$ds" ]; then 42 | echo "=== Downloading 7scenes Data: $ds ===============================" 43 | 44 | # Download the dataset zip file 45 | wget "http://download.microsoft.com/download/2/8/5/28564B23-0828-408F-8631-23B1EFF1DAC8/$ds.zip" 46 | 47 | # Unzip the dataset 48 | unzip "$ds.zip" 49 | 50 | # Remove the zip file 51 | rm "$ds.zip" 52 | 53 | # Loop through the dataset folder and unzip any additional zip files 54 | for file in "$ds"/*.zip; do 55 | if [ -f "$file" ]; then 56 | echo "Unpacking $file" 57 | unzip "$file" -d "$ds" 58 | rm "$file" 59 | fi 60 | done 61 | else 62 | echo "Found data of scene $ds already. Assuming its complete and skipping download." 63 | fi 64 | done 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pathlib 2 | open3d 3 | omegaconf 4 | h5py 5 | numpy 6 | scipy 7 | matplotlib 8 | tqdm 9 | pyyaml 10 | opencv-python 11 | pathlib 12 | poselib 13 | visdom 14 | scikit-image 15 | numpy==1.26.3 16 | gdown 17 | 18 | ./third_party/pytlsd 19 | ./third_party/DeepLSD -------------------------------------------------------------------------------- /runners/eval.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | import util.config as utilcfg 6 | from omegaconf import OmegaConf 7 | from evaluator import Evaluator 8 | from util.logger import DualLogger 9 | 10 | def parse_config(): 11 | arg_parser = argparse.ArgumentParser(description='pre-processing for PL2Map dataset') 12 | arg_parser.add_argument('-d', '--dataset_dir', type=Path, default='train_test_datasets/imgs_datasets/', help='') 13 | arg_parser.add_argument('--sfm_dir', type=Path, default='train_test_datasets/gt_3Dmodels/', help='sfm ground truth directory') 14 | arg_parser.add_argument('--dataset', type=str, default="7scenes", help='dataset name') 15 | arg_parser.add_argument('-s', '--scene', type=str, default="pumpkin", help='scene name(s)') 16 | arg_parser.add_argument('-c','--cudaid', type=int, default=0, help='specify cuda device id') 17 | arg_parser.add_argument('-o','--outputs', type=Path, default='logs/', 18 | help='Path to the output directory, default: %(default)s') 19 | arg_parser.add_argument('-expv', '--experiment_version', type=str, default="pl2map", help='experiment version folder') 20 | args, _ = arg_parser.parse_known_args() 21 | args.outputs = os.path.join(args.outputs, args.scene + "_" + args.experiment_version) 22 | path_to_eval_cfg = f'{args.outputs}/config.yaml' 23 | cfg = utilcfg.load_config(path_to_eval_cfg, default_path='cfgs/default.yaml') 24 | cfg = OmegaConf.create(cfg) 25 | return args, cfg 26 | 27 | def main(): 28 | eval_cfg = { 29 | "eval_train": False, # evaluate train_loader 30 | "eval_test": True, # evaluate test_loader 31 | "vis_point3d": False, # visualize predicted 3D points, if eval_train/test = True 32 | "vis_line3d": False, # visualize predicted 3D lines, if eval_train/test = True 33 | "pnp_point": True, # use point-mode-only for PnP 34 | "pnp_pointline": True, # use point+line mode for PnP 35 | "uncer_threshold_point": 0.5, # threshold to remove uncertain points 36 | "uncer_threshold_line": 0.02, # threshold to remove uncertain lines 37 | "exist_results":False, # if True, skip running model,then use the existing results in the outputs folder 38 | "save_3dmap": False, # save predicted 3D map 39 | } 40 | args, cfg = parse_config() 41 | sys.stdout = DualLogger(f'{args.outputs}/eval_log.txt') 42 | evaler = Evaluator(args, cfg, eval_cfg) 43 | evaler.eval() 44 | sys.stdout.log.close() 45 | 46 | if __name__ == "__main__": 47 | main() -------------------------------------------------------------------------------- /runners/evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | import sys, os 5 | from omegaconf import OmegaConf 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | from models.pipeline import Pipeline 8 | from util.help_evaluation import Vis_Infor, pose_evaluator 9 | from datasets.dataloader import Collection_Loader 10 | from models.util_learner import CriterionPointLine 11 | from trainer import step_fwd, ShowLosses 12 | from util.pose_estimator import Pose_Estimator # require limap library 13 | from util.io import SAVING_MAP 14 | 15 | class Evaluator(): 16 | default_cfg = { 17 | "eval_train": True, # evaluate train_loader 18 | "eval_test": True, # evaluate test_loader 19 | "vis_point3d": False, # visualize predicted 3D points, if eval_train/test = True 20 | "vis_line3d": False, # visualize predicted 3D lines, if eval_train/test = True 21 | "pnp_point": True, # use point-mode-only for PnP 22 | "pnp_pointline": True, # use point+line mode for PnP 23 | "uncer_threshold_point": 0.5, # threshold to remove uncertain points 24 | "uncer_threshold_line": 0.1, # threshold to remove uncertain lines 25 | "exist_results":False, # if True, skip running model,then use the existing results in the outputs folder 26 | "save_3dmap": False, # save predicted 3D map 27 | } 28 | def __init__(self, args, cfg, eval_cfg=dict()): 29 | self.args = args 30 | self.cfg = cfg 31 | eval_cfg = eval_cfg if cfg.regressor.name == 'pl2map' \ 32 | else force_onlypoint_cfg(eval_cfg) 33 | self.eval_cfg = OmegaConf.merge(OmegaConf.create(self.default_cfg), eval_cfg) 34 | print(f"[INFO] Model: {cfg.regressor.name}") 35 | print("[INFO] Evaluation Config: ", self.eval_cfg) 36 | 37 | if not self.eval_cfg.exist_results: 38 | self.pipeline = Pipeline(cfg) 39 | self.criterion = CriterionPointLine(self.cfg.train.loss.reprojection, cfg.train.num_iters) 40 | self.device = torch.device(f'cuda:{args.cudaid}' \ 41 | if torch.cuda.is_available() else 'cpu') 42 | self.save_path = None 43 | # to device 44 | self.pipeline.to(self.device) 45 | self.criterion.to(self.device) 46 | # dataloader 47 | if self.eval_cfg.eval_train: self.train_collection = Collection_Loader(args, cfg, mode="traintest") 48 | self.eval_collection = Collection_Loader(args, cfg, mode="test") 49 | print("[INFO] Loaded data collection") 50 | if self.eval_cfg.eval_train: self.train_loader = torch.utils.data.DataLoader(self.train_collection, batch_size=1, 51 | shuffle=True) 52 | self.eval_loader = torch.utils.data.DataLoader(self.eval_collection, batch_size=1, 53 | shuffle=True) 54 | self.train_loss = ShowLosses() 55 | self.exp_name = str(args.dataset) + "_" + str(args.scene) + "_" + str(cfg.regressor.name) 56 | self.vis_infor_train = Vis_Infor(self.eval_cfg) 57 | self.vis_infor_test = Vis_Infor(self.eval_cfg) 58 | # self.vis_infor_test = Vis_Infor(self.eval_cfg, "seq-06/frame-000780.color.png", 20) 59 | if self.eval_cfg.save_3dmap: self.saving_map = SAVING_MAP(self.args.outputs) 60 | self.pose_estimator = Pose_Estimator(self.cfg.localization, self.eval_cfg, 61 | self.args.outputs) 62 | else: 63 | print("[INFO] Skip running model, then use the existing results in the outputs folder") 64 | 65 | def eval(self): 66 | if not self.eval_cfg.exist_results: 67 | epoch = self.pipeline.load_checkpoint(self.args.outputs, self.exp_name) 68 | self.pipeline.eval() 69 | print("[INFO] Start evaluating ...") 70 | if self.eval_cfg.eval_train: 71 | print("[INFO] Evaluating train_loader ...") 72 | for _, (data, target) in enumerate(tqdm(self.train_loader)): 73 | loss, output = step_fwd(self.pipeline, self.device, data,target, 74 | iteration=self.cfg.train.num_iters, 75 | criterion=self.criterion, train=True) 76 | self.train_loss.update(loss) 77 | self.vis_infor_train.update(output, data) 78 | self.pose_estimator.run(output, data, target, mode='train') 79 | self.train_loss.show(epoch) 80 | self.vis_infor_train.vis() 81 | if self.eval_cfg.eval_test: 82 | i = 0 83 | print("[INFO] Evaluating test_loader ...") 84 | for _, (data, target) in enumerate(tqdm(self.eval_loader)): 85 | _, output = step_fwd(self.pipeline, self.device, data, 86 | target, train=False) 87 | if self.eval_cfg.save_3dmap: self.saving_map.save(output, data) 88 | # if data['imgname'][0] == self.vis_infor_test.highlight_frame: 89 | pose_vis_infor = self.pose_estimator.run(output, data, target, mode='test') 90 | self.vis_infor_test.update(output, data, pose_vis_infor) 91 | # i += 1 92 | # if i > 20: break 93 | self.vis_infor_test.vis() 94 | else: 95 | print("[INFO] Skip evaluating and use the existing results") 96 | pose_evaluator(self.eval_cfg, self.args.outputs) 97 | print("[INFO] DONE evaluation") 98 | 99 | def force_onlypoint_cfg(cfg): 100 | ''' 101 | Force the evaluation config to be only point mode 102 | ''' 103 | if cfg["pnp_pointline"] or cfg["vis_line3d"]: # turn off line mode, if it is on 104 | print("[Warning] Force the evaluation config to be only point mode") 105 | cfg["vis_line3d"] = False 106 | cfg["pnp_pointline"] = False 107 | return cfg -------------------------------------------------------------------------------- /runners/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import argparse 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | import util.config as utilcfg 6 | from omegaconf import OmegaConf 7 | from trainer import Trainer 8 | import time 9 | from util.logger import DualLogger 10 | 11 | def parse_config(): 12 | arg_parser = argparse.ArgumentParser(description='pre-processing for PL2Map dataset') 13 | arg_parser.add_argument('-d', '--dataset_dir', type=Path, default='train_test_datasets/imgs_datasets/', help='') 14 | arg_parser.add_argument('--sfm_dir', type=Path, default='train_test_datasets/gt_3Dmodels/', help='sfm ground truth directory') 15 | arg_parser.add_argument('--dataset', type=str, default="7scenes", help='dataset name') 16 | arg_parser.add_argument('-s', '--scene', type=str, default="pumpkin", help='scene name(s)') 17 | arg_parser.add_argument('-cp','--checkpoint', action= 'store_true', help='use pre-trained model') 18 | arg_parser.add_argument('--visdom', action= 'store_true', help='visualize loss using visdom') 19 | arg_parser.add_argument('-c','--cudaid', type=int, default=0, help='specify cuda device id') 20 | arg_parser.add_argument('-o','--outputs', type=Path, default='logs/', 21 | help='Path to the output directory, default: %(default)s') 22 | arg_parser.add_argument('-expv', '--experiment_version', type=str, default="pl2map", help='experiment version folder') 23 | args, _ = arg_parser.parse_known_args() 24 | args.outputs = os.path.join(args.outputs, args.scene + "_" + args.experiment_version) 25 | print("Dataset: {} | Scene: {}".format(args.dataset, args.scene)) 26 | cfg = utilcfg.load_config(f'cfgs/{args.dataset}.yaml', default_path='cfgs/default.yaml') 27 | cfg = OmegaConf.create(cfg) 28 | utilcfg.mkdir(args.outputs) 29 | 30 | # Save the config file for evaluation purposes 31 | config_file_path = os.path.join(args.outputs, 'config.yaml') 32 | OmegaConf.save(cfg, config_file_path) 33 | 34 | return args, cfg 35 | 36 | def main(): 37 | args, cfg = parse_config() 38 | sys.stdout = DualLogger(f'{args.outputs}/train_log.txt') 39 | trainer = Trainer(args, cfg) 40 | start_time = time.time() 41 | trainer.train() 42 | print("Training time: {:.2f} hours".format((time.time() - start_time) / (60*60))) 43 | sys.stdout.log.close() 44 | 45 | if __name__ == "__main__": 46 | main() 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /runners/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from models.pipeline import Pipeline 6 | from datasets.dataloader import Collection_Loader 7 | from models.util_learner import CriterionPointLine, Optimizer 8 | from tqdm import tqdm 9 | torch.manual_seed(0) 10 | 11 | class Trainer(): 12 | def __init__(self, args, cfg): 13 | self.args = args 14 | print(f"[INFO] Model: {cfg.regressor.name}") 15 | self.log_name = str(args.dataset) + "_" + str(args.scene) + "_" + str(cfg.regressor.name) 16 | self.pipeline = Pipeline(cfg) 17 | self.criterion = CriterionPointLine(cfg.train.loss.reprojection, cfg.train.num_iters) 18 | self.device = torch.device(f'cuda:{args.cudaid}' if torch.cuda.is_available() else 'cpu') 19 | 20 | # to device 21 | self.pipeline.to(self.device) 22 | self.criterion.to(self.device) 23 | 24 | # dataloader 25 | train_collection = Collection_Loader(args, cfg, mode="train") 26 | print("[INFO] Loaded data collection") 27 | self.train_loader = torch.utils.data.DataLoader(train_collection, batch_size=cfg.train.batch_size, 28 | shuffle=cfg.train.loader_shuffle, num_workers=cfg.train.loader_num_workers, 29 | pin_memory=True) 30 | 31 | self.length_train_loader = len(self.train_loader) 32 | self.epochs = int(cfg.train.num_iters / self.length_train_loader) 33 | print(f"[INFO] Total epochs: {self.epochs}") 34 | self.optimizer = Optimizer(self.pipeline.regressor.parameters(), self.epochs, **cfg.optimizer) 35 | 36 | if self.args.checkpoint: 37 | # load checkpoint and resume training 38 | self.start_epoch = self.pipeline.load_checkpoint(self.args.outputs, self.log_name) 39 | # self.start_epoch = 2024 40 | self.lr = self.optimizer.adjust_lr(self.start_epoch) 41 | else: 42 | self.start_epoch = 0 43 | self.lr = self.optimizer.lr 44 | self.train_log = Train_Log(args, cfg, self.length_train_loader, self.start_epoch, self.epochs) 45 | 46 | 47 | def train(self): 48 | print("[INFO] Start training") 49 | for epoch in range(self.start_epoch, self.epochs): 50 | if self.train_log.is_save_checkpoint(): 51 | self.pipeline.save_checkpoint(self.args.outputs, self.log_name, epoch) # overwrite(save) checkpoint per epoch 52 | for batch_idx, (data, target) in enumerate(tqdm(self.train_loader)): 53 | iters = epoch*self.length_train_loader + batch_idx 54 | loss,_ = step_fwd(self.pipeline, self.device, data, target, iters, 55 | self.criterion, self.optimizer, train=True) 56 | self.train_log.update(epoch, batch_idx, loss, self.lr) 57 | self.lr = self.optimizer.adjust_lr(epoch) # adjust learning rate 58 | self.train_log.show(epoch) # show loss per epoch 59 | # self.pipeline.save_checkpoint(self.args.outputs, self.log_name, epoch, True) 60 | 61 | 62 | def step_fwd(model, device, data, target=None, iteration=2500000, 63 | criterion=None, optim=None, train=False): 64 | """ 65 | A training/validation step.""" 66 | if train: 67 | assert criterion is not None 68 | assert target is not None 69 | for k,v in data.items(): 70 | if isinstance(v,list): 71 | continue 72 | data[k] = data[k].to(device) 73 | if target is not None: 74 | for k,_ in target.items(): 75 | target[k] = target[k].to(device) 76 | output = model(data) 77 | loss = None 78 | if train: 79 | loss = criterion(output, target, iteration) 80 | if optim is not None: 81 | optim.learner.zero_grad() 82 | loss[0].backward() 83 | optim.learner.step() 84 | return loss, output 85 | 86 | class Train_Log(): 87 | def __init__(self, args, cfg, length_loader, start_epoch, total_epoch=0) -> None: 88 | self.args = args 89 | self.cfg = cfg 90 | self.total_epoch = total_epoch 91 | self.log_interval = cfg.train.log_interval 92 | self.length_train_loader = length_loader 93 | self.vis_env = str(args.dataset) + "_" + str(args.scene) + \ 94 | "_" + str(cfg.regressor.name) +"_"+ str(args.experiment_version) 95 | self.showloss = ShowLosses(total_epoch=self.total_epoch) 96 | self.list_fignames = ['total_loss', 'point_loss', 'point_uncer_loss', 97 | 'line_loss', 'line_uncer_loss', 'points_prj_loss', 98 | 'lines_prj_loss', 'learning_rate'] 99 | if self.args.visdom: 100 | from visdom import Visdom 101 | print("[INFOR] Visdom is used for log visualization") 102 | self.vis = Visdom() 103 | for name in self.list_fignames: 104 | self.add_fig(name, start_epoch) 105 | 106 | def add_fig(self, name, start_epoch): 107 | self.vis.line(X=np.asarray([start_epoch]), Y=np.zeros(1), win=name, 108 | opts={'legend': [name], 'xlabel': 'epochs', 109 | 'ylabel': name}, env=self.vis_env) 110 | def update_fig(self, idx, epoch_count, value): 111 | name = self.list_fignames[idx] 112 | self.vis.line(X=np.asarray([epoch_count]), Y=np.asarray([value]), win=name, 113 | update='append', env=self.vis_env) 114 | 115 | def update(self, epoch, batch_idx, loss, lr): 116 | self.showloss.update(loss) 117 | self.lr = lr 118 | if self.args.visdom: 119 | if batch_idx % self.log_interval == 0: 120 | n_iter = epoch*self.length_train_loader + batch_idx 121 | epoch_count = float(n_iter)/self.length_train_loader 122 | l = len(self.list_fignames) 123 | for idx in range(l-1): 124 | self.update_fig(idx, epoch_count, loss[idx].item()) 125 | self.update_fig(l-1, epoch_count, lr) 126 | 127 | def show(self, epoch): 128 | self.showloss.show(epoch, self.lr) 129 | def is_save_checkpoint(self): 130 | his_epoch_loss = self.showloss.dict_losses[0].his_epoch_loss 131 | if len(his_epoch_loss) == 0: 132 | return False 133 | if min(his_epoch_loss) >= his_epoch_loss[-1]: 134 | return True 135 | else: 136 | return False 137 | 138 | class His_Loss(): 139 | def __init__(self)->None: 140 | self.his_epoch_loss = [] 141 | self.temp_batch_loss = [] 142 | def update_loss(self, loss): 143 | self.temp_batch_loss.append(loss) 144 | def show(self): 145 | avg_loss = np.mean(self.temp_batch_loss) 146 | self.his_epoch_loss.append(avg_loss) 147 | self.temp_batch_loss = [] # reset 148 | return avg_loss 149 | 150 | class ShowLosses(): 151 | # for debugging, showing all losses if needed 152 | def __init__(self, list_display=[True, True, False, True, False, True, True], total_epoch=0): 153 | ''' 154 | corresponding to show following losses: 155 | ['total_loss', 'point_loss', 'point_uncer_loss', 156 | 'line_loss', 'line_uncer_loss', 'points_prj_loss', 157 | 'lines_prj_loss'] 158 | ''' 159 | self.list_display = list_display 160 | self.length = len(self.list_display) 161 | self.names = ['Avg total loss', 'A.P.L', 'A.P.U.L', 'A.L.L', 'A.L.U.L', 'A.P.P.L', 'A.P.L.L'] 162 | # A.P.L means average point loss, A.P.P.L means average point projection loss, etc. 163 | self.create_dict_losses() 164 | self.total_epoch = total_epoch 165 | 166 | def create_dict_losses(self): 167 | self.dict_losses = {} 168 | for i in range(self.length): 169 | if self.list_display[i]: 170 | self.dict_losses[i] = His_Loss() 171 | 172 | def update(self, loss): 173 | for k,_ in self.dict_losses.items(): 174 | self.dict_losses[k].update_loss(loss[k].item()) 175 | 176 | 177 | def show(self, epoch, lr=0.0): 178 | content = f"Epoch {epoch}/{self.total_epoch} | " 179 | for k,_ in self.dict_losses.items(): 180 | avg_loss = self.dict_losses[k].show() 181 | content += self.names[k] + f": {avg_loss:.5f} | " 182 | content = content + f"lr: {lr:.6f}" 183 | print(content) 184 | 185 | -------------------------------------------------------------------------------- /util/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | def update_recursive(dict1, dictinfo): 5 | for k, v in dictinfo.items(): 6 | if k not in dict1: 7 | dict1[k] = dict() 8 | if isinstance(v, dict): 9 | update_recursive(dict1[k], v) 10 | else: 11 | dict1[k] = v 12 | def load_config(config_file, default_path=None): 13 | with open(config_file, 'r') as f: 14 | cfg_loaded = yaml.load(f, Loader=yaml.Loader) 15 | 16 | base_config_file = cfg_loaded.get('base_config_file') 17 | if base_config_file is not None: 18 | cfg = load_config(base_config_file) 19 | elif (default_path is not None) and (config_file != default_path): 20 | cfg = load_config(default_path) 21 | else: 22 | cfg = dict() 23 | update_recursive(cfg, cfg_loaded) 24 | return cfg 25 | 26 | def mkdir(path): 27 | if not os.path.exists(path): 28 | os.makedirs(path) 29 | return path -------------------------------------------------------------------------------- /util/help_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from util.read_write_model import qvec2rotmat 6 | 7 | class Vis_Infor(): 8 | ''' 9 | Store and Merge the 3D lines output from the model (by remove line with high uncertainty) 10 | lines3D: (N, 6) 11 | points3D: (N, 3) 12 | Visualize the 3D lines and 3D points 13 | ''' 14 | def __init__(self, eval_cfg, highlight_frame=None, limit_n_frames=None, save_list_imgs=False, 15 | output_path=None)->None: 16 | ''' 17 | highlight_frame: "seq-06/frame-000612.color.png", for example 18 | limit_n_frames: limit the number of frames to visualize the 3D lines and 3D points 19 | ''' 20 | self.eval_cfg = eval_cfg 21 | self.highlight_frame = highlight_frame 22 | self.limit_n_frames = np.inf if limit_n_frames is None else limit_n_frames 23 | self.save_list_imgs = save_list_imgs 24 | self.output_path = output_path 25 | self.lines3D = None 26 | self.points3D = None 27 | self.hightlight_lines3D = None 28 | self.hightlight_points3D = None 29 | self.threshold_point = eval_cfg.uncer_threshold_point 30 | self.threshold_line = eval_cfg.uncer_threshold_line 31 | self.current_num_frames = 0 32 | self.list_images = [] # list of images to visualize 3D lines / 3D points 33 | self.cameras = [] 34 | self.prd_poses = [] 35 | self.gt_poses = [] 36 | def update(self, output, data, vis_pose_infor=None): 37 | ''' 38 | args: 39 | output: dict of model output 40 | data: dict of data 41 | ''' 42 | 43 | if self.current_num_frames < self.limit_n_frames: 44 | if self.eval_cfg.vis_line3d: 45 | lines3D,_ = getLine3D_from_modeloutput(output['lines3D'], self.threshold_line) 46 | self.lines3D = lines3D if self.lines3D is None else np.concatenate((self.lines3D, lines3D)) 47 | self.list_images.append(data['imgname'][0]) 48 | if vis_pose_infor is not None: 49 | self.cameras.append(vis_pose_infor[0]) 50 | self.prd_poses.append(vis_pose_infor[1]) 51 | self.gt_poses.append(vis_pose_infor[2]) 52 | 53 | if self.eval_cfg.vis_point3d: 54 | points3D,_ = getPoint3D_from_modeloutput(output['points3D'], self.threshold_point) 55 | self.points3D = points3D if self.points3D is None else np.concatenate((self.points3D, points3D)) 56 | self.list_images.append(data['imgname'][0]) 57 | 58 | if self.limit_n_frames is not None and self.highlight_frame is not None: 59 | # save visualizations for the highlight 3d lines and 3d points 60 | current_frame = data['imgname'][0] 61 | if self.highlight_frame == current_frame: 62 | print("FOUND HIGHLIGHT FRAME") 63 | if self.eval_cfg.vis_line3d: 64 | self.hightlight_lines3D,_ = getLine3D_from_modeloutput(output['lines3D'], self.threshold) 65 | if self.eval_cfg.vis_point3d: 66 | self.hightlight_points3D,_ = getPoint3D_from_modeloutput(output['points3D'], self.threshold) 67 | if self.current_num_frames >= self.limit_n_frames: 68 | self.save_vis_highlights() 69 | self.current_num_frames += 1 70 | 71 | def vis(self): 72 | if self.eval_cfg.vis_line3d: 73 | print("[INFOR] Visualizing predicted 3D lines ...") 74 | from util.visualize import open3d_vis_3d_lines 75 | # open3d_vis_3d_lines(self.lines3D) 76 | open3d_vis_3d_lines(self.lines3D, self.cameras, self.prd_poses, self.gt_poses) 77 | if self.eval_cfg.vis_point3d: 78 | print("[INFOR] Visualizing predicted 3D points ...") 79 | from util.visualize import open3d_vis_3d_points 80 | open3d_vis_3d_points(self.points3D) 81 | if self.save_list_imgs: 82 | print("[INFOR] Saving list of images to visualize 3D lines / 3D points ...") 83 | with open(os.path.join(self.output_path, "list_vis_imgs.txt"), "w") as f: 84 | for img in self.list_images: 85 | f.write(img + "\n") 86 | 87 | def save_vis_highlights(self): 88 | if self.hightlight_lines3D is not None: 89 | from util.visualize import open3d_vis_3d_lines_with_hightlightFrame 90 | open3d_vis_3d_lines_with_hightlightFrame(self.lines3D, self.hightlight_lines3D) 91 | if self.hightlight_points3D is not None: 92 | from util.visualize import open3d_vis_3d_points_with_hightlightFrame 93 | open3d_vis_3d_points_with_hightlightFrame(self.points3D, self.hightlight_points3D) 94 | 95 | def getLine3D_from_modeloutput(lines3D, threshold=0.5): 96 | ''' 97 | get uncertainty and remove line with high uncertainty 98 | args: 99 | lines3D: numpy array (1, 7, N) 100 | return: lines3D (N, 6) 101 | ''' 102 | lines3D = np.squeeze(lines3D.detach().cpu().numpy()) 103 | uncertainty = 1/(1+100*np.abs(lines3D[6,:])) 104 | lines3D = lines3D[:6,:] 105 | uncertainty = [True if tmpc >= threshold else False for tmpc in uncertainty] 106 | lines3D = lines3D.T[uncertainty,:] 107 | return lines3D, uncertainty 108 | 109 | def getPoint3D_from_modeloutput(points3D, threshold=0.5): 110 | ''' 111 | get uncertainty and remove point with high uncertainty 112 | args: 113 | points3D: numpy array (1, 4, N) 114 | return: points3D (N, 3) 115 | ''' 116 | points3D = np.squeeze(points3D.detach().cpu().numpy()) 117 | uncertainty = 1/(1+100*np.abs(points3D[3,:])) 118 | points3D = points3D[:3,:] 119 | uncertainty = [True if tmpc >= threshold else False for tmpc in uncertainty] 120 | points3D = points3D.T[uncertainty,:] 121 | return points3D, uncertainty 122 | 123 | def pose_evaluator(eval_cfg, spath): 124 | ''' 125 | Evaluate the estimated poses with ground truth poses 126 | args: 127 | eval_cfg: evaluation config 128 | spath: path to save the estimated poses and ground truth poses 129 | ''' 130 | def eval(eval_cfg, spath, mode): 131 | if eval_cfg.pnp_point: 132 | evaluate_pose_results(spath, mode=mode, pnp='point') 133 | if eval_cfg.pnp_pointline: 134 | evaluate_pose_results(spath, mode=mode, pnp='pointline') 135 | 136 | if eval_cfg.eval_train: 137 | mode = 'train' 138 | eval(eval_cfg, spath, mode) 139 | if eval_cfg.eval_test: 140 | mode = 'test' 141 | eval(eval_cfg, spath, mode) 142 | 143 | def evaluate_pose_results(spath, mode='train', pnp='pointline'): 144 | ''' 145 | Evaluate the estimated poses with ground truth poses 146 | args: 147 | spath: path to save the estimated poses and ground truth poses 148 | mode: 'train' or 'test' 149 | pnp: 'point' or 'pointline' 150 | ''' 151 | gt_path = os.path.join(spath, f"gt_poses_{mode}.txt") 152 | prd_path = os.path.join(spath, f"est_poses_{mode}_{pnp}.txt") 153 | gt = pd.read_csv(gt_path, header=None, sep=" ") 154 | prd = pd.read_csv(prd_path, header=None, sep =" ") 155 | # assert len(gt) == len(prd) 156 | errors_t = [] 157 | errors_R = [] 158 | for i in range(len(prd)): 159 | R_gt = qvec2rotmat(gt.iloc[i,3:7].to_numpy()) 160 | t_gt = gt.iloc[i,:3].to_numpy() 161 | t = prd.iloc[i,:3].to_numpy() 162 | R = qvec2rotmat(prd.iloc[i,3:].to_numpy()) 163 | e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0) 164 | cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1., 1.) 165 | e_R = np.rad2deg(np.abs(np.arccos(cos))) 166 | errors_t.append(e_t) 167 | errors_R.append(e_R) 168 | errors_t = np.array(errors_t) 169 | errors_R = np.array(errors_R) 170 | med_t = np.median(errors_t) 171 | med_R = np.median(errors_R) 172 | print(f'Evaluation results on {mode} set ({len(gt)}imgs) & PnP {pnp}:') 173 | print('Median errors: {:.4f}m, {:.4f}deg'.format(med_t, med_R)) 174 | print('Average PnP time: {:.4f}s'.format(np.mean(prd.iloc[:,7].to_numpy()))) 175 | print('Percentage of test images localized within:') 176 | threshs_t = [0.01, 0.02, 0.03, 0.05, 0.10] 177 | threshs_R = [1.0, 2.0, 3.0, 5.0, 10.0] 178 | for th_t, th_R in zip(threshs_t, threshs_R): 179 | ratio = np.mean((errors_t < th_t) & (errors_R < th_R)) 180 | print('\t{:.0f}cm, {:.0f}deg : {:.2f}%'.format(th_t*100, th_R, ratio*100)) 181 | return med_t, med_R 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /util/io.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | 5 | def read_image(path, grayscale=False): 6 | if grayscale: 7 | mode = cv2.IMREAD_GRAYSCALE 8 | else: 9 | mode = cv2.IMREAD_COLOR 10 | image = cv2.imread(str(path), mode) 11 | if image is None: 12 | raise ValueError(f'Cannot read image {path}.') 13 | if not grayscale and len(image.shape) == 3: 14 | image = image[:, :, ::-1] # BGR to RGB 15 | return image 16 | 17 | 18 | class SAVING_MAP(): 19 | def __init__(self, save_path) -> None: 20 | print("[INFOR] Saving prediction 3D map") 21 | self.save_path = os.path.join(save_path, "Map_Prediction") 22 | if not os.path.exists(self.save_path): 23 | os.makedirs(self.save_path) 24 | self.image_list = [] 25 | self.idx = 0 26 | def save(self, output, data): 27 | image_name = data['imgname'][0] 28 | self.image_list.append(image_name) 29 | 30 | p2ds = output['keypoints'][0].detach().cpu().numpy() 31 | # save 2D points 32 | np.savetxt(os.path.join(self.save_path, str(self.idx) + "_p2d.txt"), p2ds) 33 | 34 | points3D = np.squeeze(output['points3D'].detach().cpu().numpy()) 35 | np.savetxt(os.path.join(self.save_path, str(self.idx) + "_p3d.txt"), points3D) 36 | 37 | l2ds = data['lines'][0].detach().cpu().numpy() 38 | np.savetxt(os.path.join(self.save_path, str(self.idx) + "_l2d.txt"), l2ds) 39 | 40 | lines3D = np.squeeze(output['lines3D'].detach().cpu().numpy()) 41 | np.savetxt(os.path.join(self.save_path, str(self.idx) + "_l3d.txt"), lines3D) 42 | 43 | with open(os.path.join(self.save_path, "images.txt"), "a") as f: 44 | f.write(str(self.idx) + " " + image_name + "\n") 45 | 46 | self.idx += 1 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | class DualLogger: 5 | def __init__(self, filename): 6 | self.terminal = sys.stdout 7 | if os.path.exists(filename): 8 | os.remove(filename) 9 | self.log = open(filename, 'a') 10 | 11 | def write(self, message): 12 | self.terminal.write(message) 13 | self.log.write(message) 14 | 15 | def flush(self): 16 | # This method is needed for Python 3 compatibility. 17 | # This handles the flush command by doing nothing. 18 | # You might want to specify some behavior here. 19 | pass -------------------------------------------------------------------------------- /util/pose_estimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from omegaconf import OmegaConf 3 | import sys, os 4 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 5 | from util.help_evaluation import getLine3D_from_modeloutput, getPoint3D_from_modeloutput 6 | import time 7 | import poselib 8 | 9 | class Pose_Estimator(): 10 | def __init__(self, localize_cfg, eval_cfg, spath): 11 | self.localize_cfg = localize_cfg # config file for localization 12 | self.eval_cfg = eval_cfg # local config for evaluation 13 | self.spath = spath 14 | self.uncertainty_point = eval_cfg.uncer_threshold_point 15 | self.uncertainty_line = eval_cfg.uncer_threshold_line 16 | self.pnppoint = eval_cfg.pnp_point 17 | self.pnppointline = eval_cfg.pnp_pointline 18 | if not self.eval_cfg.exist_results: 19 | self.checkexist() 20 | def checkexist(self): 21 | ''' 22 | Check if the files exist, if yes, remove them 23 | ''' 24 | trainfiles_list = ['est_poses_train_pointline.txt', 'est_poses_train_point.txt', 25 | 'gt_poses_train.txt'] 26 | testfiles_list = ['est_poses_test_pointline.txt', 'est_poses_test_point.txt', 27 | 'gt_poses_test.txt'] 28 | if self.eval_cfg.eval_train: 29 | self.rmfiles(trainfiles_list) 30 | if self.eval_cfg.eval_test: 31 | self.rmfiles(testfiles_list) 32 | 33 | def rmfiles(self, rm_list): 34 | for file in rm_list: 35 | if os.path.exists(os.path.join(self.spath, file)): 36 | os.remove(os.path.join(self.spath, file)) 37 | 38 | def run(self, output, data, target, mode='train'): 39 | return camera_pose_estimation(self.localize_cfg, output, data, target, self.spath, mode=mode, 40 | uncertainty_point=self.uncertainty_point, uncertainty_line=self.uncertainty_line, 41 | pnppoint=self.pnppoint, pnppointline=self.pnppointline) 42 | 43 | def camera_pose_estimation(localize_cfg, output, data, target, spath, mode='train', 44 | uncertainty_point=0.5, uncertainty_line=0.5, pnppoint=False, pnppointline=True): 45 | ''' 46 | Creating same inputs for limap library and estimate camera pose 47 | ''' 48 | p3ds_, point_uncer = getPoint3D_from_modeloutput(output['points3D'], uncertainty_point) 49 | p3ds = [i for i in p3ds_] 50 | p2ds = output['keypoints'][0].detach().cpu().numpy() + 0.5 # COLMAP 51 | p2ds = p2ds[point_uncer,:] 52 | p2ds = [i for i in p2ds] 53 | camera = target['camera'][0].detach().cpu().numpy() 54 | camera_model = "PINHOLE" if camera[0] == 1.0 else "SIMPLE_PINHOLE" 55 | poselibcamera = {'model': camera_model, 'width': camera[2], 'height': camera[1], 'params': camera[3:]} 56 | image_name = data['imgname'][0] 57 | 58 | if pnppoint: 59 | start = time.time() 60 | pose_point, _ = poselib.estimate_absolute_pose(p2ds, p3ds, poselibcamera, {'max_reproj_error': 12.0}, {}) 61 | est_time = time.time() - start 62 | with open(os.path.join(spath, f"est_poses_{mode}_point.txt"), 'a') as f: 63 | f.write(f"{pose_point.t[0]} {pose_point.t[1]} {pose_point.t[2]} {pose_point.q[0]} {pose_point.q[1]} {pose_point.q[2]} {pose_point.q[3]} {est_time} {image_name}\n") 64 | target_pose = target['pose'][0].detach().cpu().numpy() 65 | with open(os.path.join(spath, f"gt_poses_{mode}.txt"), 'a') as f: 66 | f.write(f"{target_pose[0]} {target_pose[1]} {target_pose[2]} {target_pose[3]} {target_pose[4]} {target_pose[5]} {target_pose[6]}\n") 67 | if not pnppointline: 68 | return None 69 | # modify the limap pnp to poselib pnp 70 | 71 | 72 | l3ds, line_uncer = getLine3D_from_modeloutput(output['lines3D'], uncertainty_line) 73 | l3d_ids = [i for i in range(len(l3ds))] 74 | l2ds = data['lines'][0].detach().cpu().numpy() 75 | l2ds = l2ds[line_uncer,:] 76 | 77 | localize_cfg = OmegaConf.to_container(localize_cfg, resolve=True) 78 | 79 | if pnppointline: 80 | start = time.time() 81 | ransac_opt = {"max_reproj_error": 12.0, "max_epipolar_error": 10.0} 82 | l2d_1 = [i for i in l2ds[:,:2]] 83 | l2d_2 = [i for i in l2ds[:,2:]] 84 | l3d_1 = [i for i in l3ds[:,:3]] 85 | l3d_2 = [i for i in l3ds[:,3:]] 86 | pose, _ = poselib.estimate_absolute_pose_pnpl(p2ds, p3ds, l2d_1, l2d_2, l3d_1, l3d_2, poselibcamera, ransac_opt) 87 | est_time = time.time() - start 88 | with open(os.path.join(spath, f"est_poses_{mode}_pointline.txt"), 'a') as f: 89 | f.write(f"{pose.t[0]} {pose.t[1]} {pose.t[2]} {pose.q[0]} {pose.q[1]} {pose.q[2]} {pose.q[3]} {est_time} {image_name}\n") 90 | return [poselibcamera, np.array([pose.t[0], pose.t[1], pose.t[2], pose.q[0], pose.q[1], pose.q[2], pose.q[3]]), target_pose] -------------------------------------------------------------------------------- /util/read_write_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import os 33 | import collections 34 | import numpy as np 35 | import struct 36 | import argparse 37 | 38 | 39 | CameraModel = collections.namedtuple( 40 | "CameraModel", ["model_id", "model_name", "num_params"]) 41 | Camera = collections.namedtuple( 42 | "Camera", ["id", "model", "width", "height", "params"]) 43 | BaseImage = collections.namedtuple( 44 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 45 | Point3D = collections.namedtuple( 46 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 47 | 48 | 49 | class Image(BaseImage): 50 | def qvec2rotmat(self): 51 | return qvec2rotmat(self.qvec) 52 | 53 | 54 | CAMERA_MODELS = { 55 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 56 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 57 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 58 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 59 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 60 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 61 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 62 | CameraModel(model_id=7, model_name="FOV", num_params=5), 63 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 64 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 65 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 66 | } 67 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 68 | for camera_model in CAMERA_MODELS]) 69 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 70 | for camera_model in CAMERA_MODELS]) 71 | 72 | 73 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 74 | """Read and unpack the next bytes from a binary file. 75 | :param fid: 76 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 77 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 78 | :param endian_character: Any of {@, =, <, >, !} 79 | :return: Tuple of read and unpacked values. 80 | """ 81 | data = fid.read(num_bytes) 82 | return struct.unpack(endian_character + format_char_sequence, data) 83 | 84 | 85 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 86 | """pack and write to a binary file. 87 | :param fid: 88 | :param data: data to send, if multiple elements are sent at the same time, 89 | they should be encapsuled either in a list or a tuple 90 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 91 | should be the same length as the data list or tuple 92 | :param endian_character: Any of {@, =, <, >, !} 93 | """ 94 | if isinstance(data, (list, tuple)): 95 | bytes = struct.pack(endian_character + format_char_sequence, *data) 96 | else: 97 | bytes = struct.pack(endian_character + format_char_sequence, data) 98 | fid.write(bytes) 99 | 100 | 101 | def read_cameras_text(path): 102 | """ 103 | see: src/base/reconstruction.cc 104 | void Reconstruction::WriteCamerasText(const std::string& path) 105 | void Reconstruction::ReadCamerasText(const std::string& path) 106 | """ 107 | cameras = {} 108 | with open(path, "r") as fid: 109 | while True: 110 | line = fid.readline() 111 | if not line: 112 | break 113 | line = line.strip() 114 | if len(line) > 0 and line[0] != "#": 115 | elems = line.split() 116 | camera_id = int(elems[0]) 117 | model = elems[1] 118 | width = int(elems[2]) 119 | height = int(elems[3]) 120 | params = np.array(tuple(map(float, elems[4:]))) 121 | cameras[camera_id] = Camera(id=camera_id, model=model, 122 | width=width, height=height, 123 | params=params) 124 | return cameras 125 | 126 | 127 | def read_cameras_binary(path_to_model_file): 128 | """ 129 | see: src/base/reconstruction.cc 130 | void Reconstruction::WriteCamerasBinary(const std::string& path) 131 | void Reconstruction::ReadCamerasBinary(const std::string& path) 132 | """ 133 | cameras = {} 134 | with open(path_to_model_file, "rb") as fid: 135 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 136 | for _ in range(num_cameras): 137 | camera_properties = read_next_bytes( 138 | fid, num_bytes=24, format_char_sequence="iiQQ") 139 | camera_id = camera_properties[0] 140 | model_id = camera_properties[1] 141 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 142 | width = camera_properties[2] 143 | height = camera_properties[3] 144 | num_params = CAMERA_MODEL_IDS[model_id].num_params 145 | params = read_next_bytes(fid, num_bytes=8*num_params, 146 | format_char_sequence="d"*num_params) 147 | cameras[camera_id] = Camera(id=camera_id, 148 | model=model_name, 149 | width=width, 150 | height=height, 151 | params=np.array(params)) 152 | assert len(cameras) == num_cameras 153 | return cameras 154 | 155 | 156 | def write_cameras_text(cameras, path): 157 | """ 158 | see: src/base/reconstruction.cc 159 | void Reconstruction::WriteCamerasText(const std::string& path) 160 | void Reconstruction::ReadCamerasText(const std::string& path) 161 | """ 162 | HEADER = "# Camera list with one line of data per camera:\n" + \ 163 | "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \ 164 | "# Number of cameras: {}\n".format(len(cameras)) 165 | with open(path, "w") as fid: 166 | fid.write(HEADER) 167 | for _, cam in cameras.items(): 168 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 169 | line = " ".join([str(elem) for elem in to_write]) 170 | fid.write(line + "\n") 171 | 172 | 173 | def write_cameras_binary(cameras, path_to_model_file): 174 | """ 175 | see: src/base/reconstruction.cc 176 | void Reconstruction::WriteCamerasBinary(const std::string& path) 177 | void Reconstruction::ReadCamerasBinary(const std::string& path) 178 | """ 179 | with open(path_to_model_file, "wb") as fid: 180 | write_next_bytes(fid, len(cameras), "Q") 181 | for _, cam in cameras.items(): 182 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 183 | camera_properties = [cam.id, 184 | model_id, 185 | cam.width, 186 | cam.height] 187 | write_next_bytes(fid, camera_properties, "iiQQ") 188 | for p in cam.params: 189 | write_next_bytes(fid, float(p), "d") 190 | return cameras 191 | 192 | 193 | def read_images_text(path): 194 | """ 195 | see: src/base/reconstruction.cc 196 | void Reconstruction::ReadImagesText(const std::string& path) 197 | void Reconstruction::WriteImagesText(const std::string& path) 198 | """ 199 | images = {} 200 | with open(path, "r") as fid: 201 | while True: 202 | line = fid.readline() 203 | if not line: 204 | break 205 | line = line.strip() 206 | if len(line) > 0 and line[0] != "#": 207 | elems = line.split() 208 | image_id = int(elems[0]) 209 | qvec = np.array(tuple(map(float, elems[1:5]))) 210 | tvec = np.array(tuple(map(float, elems[5:8]))) 211 | camera_id = int(elems[8]) 212 | image_name = elems[9] 213 | elems = fid.readline().split() 214 | xys = np.column_stack([tuple(map(float, elems[0::3])), 215 | tuple(map(float, elems[1::3]))]) 216 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 217 | images[image_id] = Image( 218 | id=image_id, qvec=qvec, tvec=tvec, 219 | camera_id=camera_id, name=image_name, 220 | xys=xys, point3D_ids=point3D_ids) 221 | return images 222 | 223 | 224 | def read_images_binary(path_to_model_file): 225 | """ 226 | see: src/base/reconstruction.cc 227 | void Reconstruction::ReadImagesBinary(const std::string& path) 228 | void Reconstruction::WriteImagesBinary(const std::string& path) 229 | """ 230 | images = {} 231 | with open(path_to_model_file, "rb") as fid: 232 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 233 | for _ in range(num_reg_images): 234 | binary_image_properties = read_next_bytes( 235 | fid, num_bytes=64, format_char_sequence="idddddddi") 236 | image_id = binary_image_properties[0] 237 | qvec = np.array(binary_image_properties[1:5]) 238 | tvec = np.array(binary_image_properties[5:8]) 239 | camera_id = binary_image_properties[8] 240 | image_name = "" 241 | current_char = read_next_bytes(fid, 1, "c")[0] 242 | while current_char != b"\x00": # look for the ASCII 0 entry 243 | image_name += current_char.decode("utf-8") 244 | current_char = read_next_bytes(fid, 1, "c")[0] 245 | num_points2D = read_next_bytes(fid, num_bytes=8, 246 | format_char_sequence="Q")[0] 247 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 248 | format_char_sequence="ddq"*num_points2D) 249 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 250 | tuple(map(float, x_y_id_s[1::3]))]) 251 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 252 | images[image_id] = Image( 253 | id=image_id, qvec=qvec, tvec=tvec, 254 | camera_id=camera_id, name=image_name, 255 | xys=xys, point3D_ids=point3D_ids) 256 | return images 257 | 258 | 259 | def write_images_text(images, path): 260 | """ 261 | see: src/base/reconstruction.cc 262 | void Reconstruction::ReadImagesText(const std::string& path) 263 | void Reconstruction::WriteImagesText(const std::string& path) 264 | """ 265 | if len(images) == 0: 266 | mean_observations = 0 267 | else: 268 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) 269 | HEADER = "# Image list with two lines of data per image:\n" + \ 270 | "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \ 271 | "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \ 272 | "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) 273 | 274 | with open(path, "w") as fid: 275 | fid.write(HEADER) 276 | for _, img in images.items(): 277 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] 278 | first_line = " ".join(map(str, image_header)) 279 | fid.write(first_line + "\n") 280 | 281 | points_strings = [] 282 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 283 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 284 | fid.write(" ".join(points_strings) + "\n") 285 | 286 | 287 | def write_images_binary(images, path_to_model_file): 288 | """ 289 | see: src/base/reconstruction.cc 290 | void Reconstruction::ReadImagesBinary(const std::string& path) 291 | void Reconstruction::WriteImagesBinary(const std::string& path) 292 | """ 293 | with open(path_to_model_file, "wb") as fid: 294 | write_next_bytes(fid, len(images), "Q") 295 | for _, img in images.items(): 296 | write_next_bytes(fid, img.id, "i") 297 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 298 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 299 | write_next_bytes(fid, img.camera_id, "i") 300 | for char in img.name: 301 | write_next_bytes(fid, char.encode("utf-8"), "c") 302 | write_next_bytes(fid, b"\x00", "c") 303 | write_next_bytes(fid, len(img.point3D_ids), "Q") 304 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 305 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 306 | 307 | 308 | def read_points3D_text(path): 309 | """ 310 | see: src/base/reconstruction.cc 311 | void Reconstruction::ReadPoints3DText(const std::string& path) 312 | void Reconstruction::WritePoints3DText(const std::string& path) 313 | """ 314 | points3D = {} 315 | with open(path, "r") as fid: 316 | while True: 317 | line = fid.readline() 318 | if not line: 319 | break 320 | line = line.strip() 321 | if len(line) > 0 and line[0] != "#": 322 | elems = line.split() 323 | point3D_id = int(elems[0]) 324 | xyz = np.array(tuple(map(float, elems[1:4]))) 325 | rgb = np.array(tuple(map(int, elems[4:7]))) 326 | error = float(elems[7]) 327 | image_ids = np.array(tuple(map(int, elems[8::2]))) 328 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 329 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 330 | error=error, image_ids=image_ids, 331 | point2D_idxs=point2D_idxs) 332 | return points3D 333 | 334 | 335 | def read_points3D_binary(path_to_model_file): 336 | """ 337 | see: src/base/reconstruction.cc 338 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 339 | void Reconstruction::WritePoints3DBinary(const std::string& path) 340 | """ 341 | points3D = {} 342 | with open(path_to_model_file, "rb") as fid: 343 | num_points = read_next_bytes(fid, 8, "Q")[0] 344 | for _ in range(num_points): 345 | binary_point_line_properties = read_next_bytes( 346 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 347 | point3D_id = binary_point_line_properties[0] 348 | xyz = np.array(binary_point_line_properties[1:4]) 349 | rgb = np.array(binary_point_line_properties[4:7]) 350 | error = np.array(binary_point_line_properties[7]) 351 | track_length = read_next_bytes( 352 | fid, num_bytes=8, format_char_sequence="Q")[0] 353 | track_elems = read_next_bytes( 354 | fid, num_bytes=8*track_length, 355 | format_char_sequence="ii"*track_length) 356 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 357 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 358 | points3D[point3D_id] = Point3D( 359 | id=point3D_id, xyz=xyz, rgb=rgb, 360 | error=error, image_ids=image_ids, 361 | point2D_idxs=point2D_idxs) 362 | return points3D 363 | 364 | 365 | def write_points3D_text(points3D, path): 366 | """ 367 | see: src/base/reconstruction.cc 368 | void Reconstruction::ReadPoints3DText(const std::string& path) 369 | void Reconstruction::WritePoints3DText(const std::string& path) 370 | """ 371 | if len(points3D) == 0: 372 | mean_track_length = 0 373 | else: 374 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) 375 | HEADER = "# 3D point list with one line of data per point:\n" + \ 376 | "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \ 377 | "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length) 378 | 379 | with open(path, "w") as fid: 380 | fid.write(HEADER) 381 | for _, pt in points3D.items(): 382 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] 383 | fid.write(" ".join(map(str, point_header)) + " ") 384 | track_strings = [] 385 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): 386 | track_strings.append(" ".join(map(str, [image_id, point2D]))) 387 | fid.write(" ".join(track_strings) + "\n") 388 | 389 | 390 | def write_points3D_binary(points3D, path_to_model_file): 391 | """ 392 | see: src/base/reconstruction.cc 393 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 394 | void Reconstruction::WritePoints3DBinary(const std::string& path) 395 | """ 396 | with open(path_to_model_file, "wb") as fid: 397 | write_next_bytes(fid, len(points3D), "Q") 398 | for _, pt in points3D.items(): 399 | write_next_bytes(fid, pt.id, "Q") 400 | write_next_bytes(fid, pt.xyz.tolist(), "ddd") 401 | write_next_bytes(fid, pt.rgb.tolist(), "BBB") 402 | write_next_bytes(fid, pt.error, "d") 403 | track_length = pt.image_ids.shape[0] 404 | write_next_bytes(fid, track_length, "Q") 405 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): 406 | write_next_bytes(fid, [image_id, point2D_id], "ii") 407 | 408 | 409 | def detect_model_format(path, ext): 410 | if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ 411 | os.path.isfile(os.path.join(path, "images" + ext)) and \ 412 | os.path.isfile(os.path.join(path, "points3D" + ext)): 413 | print("Detected model format: '" + ext + "'") 414 | return True 415 | 416 | return False 417 | 418 | 419 | def read_model(path, ext=""): 420 | # try to detect the extension automatically 421 | if ext == "": 422 | if detect_model_format(path, ".bin"): 423 | ext = ".bin" 424 | elif detect_model_format(path, ".txt"): 425 | ext = ".txt" 426 | else: 427 | print("Provide model format: '.bin' or '.txt'") 428 | return 429 | 430 | if ext == ".txt": 431 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 432 | images = read_images_text(os.path.join(path, "images" + ext)) 433 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 434 | else: 435 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 436 | images = read_images_binary(os.path.join(path, "images" + ext)) 437 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) 438 | return cameras, images, points3D 439 | 440 | 441 | def write_model(cameras, images, points3D, path, ext=".bin"): 442 | if ext == ".txt": 443 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 444 | write_images_text(images, os.path.join(path, "images" + ext)) 445 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext) 446 | else: 447 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 448 | write_images_binary(images, os.path.join(path, "images" + ext)) 449 | write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) 450 | return cameras, images, points3D 451 | 452 | 453 | def qvec2rotmat(qvec): 454 | return np.array([ 455 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 456 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 457 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 458 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 459 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 460 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 461 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 462 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 463 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 464 | 465 | 466 | def rotmat2qvec(R): 467 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 468 | K = np.array([ 469 | [Rxx - Ryy - Rzz, 0, 0, 0], 470 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 471 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 472 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 473 | eigvals, eigvecs = np.linalg.eigh(K) 474 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 475 | if qvec[0] < 0: 476 | qvec *= -1 477 | return qvec 478 | 479 | 480 | def main(): 481 | parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models") 482 | parser.add_argument("--input_model", help="path to input model folder") 483 | parser.add_argument("--input_format", choices=[".bin", ".txt"], 484 | help="input model format", default="") 485 | parser.add_argument("--output_model", 486 | help="path to output model folder") 487 | parser.add_argument("--output_format", choices=[".bin", ".txt"], 488 | help="outut model format", default=".txt") 489 | args = parser.parse_args() 490 | 491 | cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) 492 | 493 | print("num_cameras:", len(cameras)) 494 | print("num_images:", len(images)) 495 | print("num_points3D:", len(points3D)) 496 | 497 | if args.output_model is not None: 498 | write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) 499 | 500 | 501 | if __name__ == "__main__": 502 | main() 503 | -------------------------------------------------------------------------------- /util/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import os 4 | import sys 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | from datasets._base import Line3D 7 | from util.io import read_image 8 | import matplotlib.pyplot as plt 9 | 10 | def test_point_inside_ranges(point, ranges): 11 | point = np.array(point) 12 | if ~np.all(point > ranges[0]) or ~np.all(point < ranges[1]): 13 | return False 14 | return True 15 | 16 | def test_line_inside_ranges(line, ranges): 17 | if not test_point_inside_ranges(line.start, ranges): 18 | return False 19 | if not test_point_inside_ranges(line.end, ranges): 20 | return False 21 | return True 22 | 23 | def open3d_get_line_set(lines, color=[0.0, 0.0, 0.0], width=2, ranges=None, scale=1.0): 24 | """ 25 | convert a list of line3D objects to an Open3D lines set 26 | Args: 27 | lines (list[:class:`datasets._base.Line3D`] or numpy array of Nx6): The 3D line map 28 | color (list[float]): The color of the lines 29 | width (float, optional): width of the line 30 | """ 31 | o3d_points, o3d_lines, o3d_colors = [], [], [] 32 | counter = 0 33 | for line in lines: 34 | if isinstance(line, np.ndarray): 35 | line = Line3D(line[:3], line[3:]) 36 | if ranges is not None: 37 | if not test_line_inside_ranges(line, ranges): 38 | continue 39 | o3d_points.append(line.start * scale) 40 | o3d_points.append(line.end * scale) 41 | o3d_lines.append([2*counter, 2*counter+1]) 42 | counter += 1 43 | o3d_colors.append(color) 44 | line_set = o3d.geometry.LineSet() 45 | line_set.points = o3d.utility.Vector3dVector(o3d_points) 46 | line_set.lines = o3d.utility.Vector2iVector(o3d_lines) 47 | line_set.colors = o3d.utility.Vector3dVector(o3d_colors) 48 | return line_set 49 | 50 | 51 | def open3d_vis_3d_lines_with_hightlightFrame(lines3D, hightlight_lines3D, width=2, ranges=None, scale=1.0): 52 | """ 53 | Save 3D line map with `Open3D `_ 54 | 55 | Args: 56 | lines3D: numpy array of Nx6 57 | hightlight_lines3D: numpy array of Nx6 58 | width (float, optional): width of the line 59 | """ 60 | 61 | line_set = open3d_get_line_set(lines3D, width=width, ranges=ranges, scale=scale) 62 | line_set_highlight = open3d_get_line_set(hightlight_lines3D, color=[0.0, 1.0, 0.0], width=width*2, ranges=ranges, scale=scale) 63 | 64 | # Save the line_set 65 | o3d.io.write_line_set("visualization/line_set.ply", line_set) 66 | o3d.io.write_line_set("visualization/line_set_highlight.ply", line_set_highlight) 67 | 68 | ''' 69 | vis = o3d.visualization.Visualizer() 70 | vis.create_window(height=1080, width=1920) 71 | vis.add_geometry(line_set) 72 | vis.add_geometry(line_set_highlight) 73 | vis.run() 74 | vis.destroy_window() 75 | ''' 76 | 77 | 78 | 79 | 80 | def open3d_vis_3d_lines(lines3D, cameras=None, poses=None, gt_pose=None, width=2, ranges=None, scale=1.0): 81 | """ 82 | Visualize a 3D line map with `Open3D `_ 83 | 84 | Args: 85 | lines (list[:class:`datasets._base.Line3D` and/or None]): The 3D line map 86 | width (float, optional): width of the line 87 | """ 88 | if isinstance(lines3D, list): 89 | lines = [] 90 | for line in lines3D: 91 | if line is not None: 92 | lines.append(line) 93 | elif isinstance(lines3D, np.ndarray): 94 | lines = lines3D 95 | else: 96 | raise ValueError("lines3D must be either a list or a numpy array") 97 | 98 | vis = o3d.visualization.Visualizer() 99 | vis.create_window(height=1080, width=1920) 100 | 101 | prune = len(lines) 102 | # prune = int(0.8*len(lines)) 103 | line_set = open3d_get_line_set(lines[:prune,:], width=width, ranges=ranges, scale=scale) 104 | 105 | vis.add_geometry(line_set) 106 | if poses is not None: 107 | assert cameras is not None 108 | assert gt_pose is not None 109 | def get_t(pose): 110 | R = qvec2rotmat(pose[3:]) 111 | # translation 112 | t = pose[:3] 113 | # invert 114 | t = -R.T @ t 115 | return t 116 | connect_poses_lines = [] 117 | is_draws = [] 118 | for i in range(len(poses)): 119 | est_t = get_t(poses[i]) 120 | gt_t = get_t(gt_pose[i]) 121 | # calculate distance between two points 122 | is_draw = True if np.linalg.norm(est_t - gt_t) < 100 else False 123 | is_draws.append(is_draw) 124 | if is_draw: 125 | tmp_line = np.array([est_t[0], est_t[1], est_t[2], gt_t[0], gt_t[1], gt_t[2]]) 126 | connect_poses_lines.append(tmp_line) 127 | connect_line_set = open3d_get_line_set(connect_poses_lines, width=width, ranges=ranges, scale=scale, color=[0,1,0]) 128 | vis.add_geometry(connect_line_set) 129 | i = 0 130 | for pose, camera in zip(poses, cameras): 131 | if is_draws[i]: add_camera(vis, pose, camera, scale=0.2, gt = False) 132 | i+=1 133 | i = 0 134 | for pose, camera in zip(gt_pose, cameras): 135 | if is_draws[i]: add_camera(vis, pose, camera, scale=0.2, gt = True) 136 | i+=1 137 | vis.run() 138 | vis.destroy_window() 139 | 140 | def open3d_vis_3d_lines_from_datacollection(datacollection, train_or_test="train"): 141 | ''' 142 | Visualize 3D lines from datasetcollection 143 | Args: 144 | datacollection: DataCollection object 145 | train_or_test: string, "train" or "test"''' 146 | if train_or_test != "train": 147 | raise ValueError("Currently only support 'train' mode.") 148 | vis_lines = [] 149 | imgs_list = datacollection.train_imgs if train_or_test=="train" else datacollection.test_imgs 150 | import random 151 | random.shuffle(imgs_list) 152 | cameras = [] 153 | poses = [] 154 | i = 0 155 | for img in imgs_list: 156 | vis_lines += datacollection.imgname2imgclass[img].line3Ds 157 | poses.append(datacollection.imgname2imgclass[img].pose.get_pose_vector()) 158 | cameras.append(datacollection.imgname2imgclass[img].camera.get_dict_camera()) 159 | # i += 1 160 | # if i > 20: 161 | # break 162 | open3d_vis_3d_lines(vis_lines, cameras=cameras, poses=poses) 163 | 164 | def open3d_vis_3d_lines_from_single_imgandcollection(datacollection, img_name): 165 | ''' 166 | Visualize 3D lines from datasetcollection 167 | Args: 168 | datacollection: DataCollection object 169 | img_name: string, image name 170 | ''' 171 | if img_name in datacollection.test_imgs: 172 | raise ValueError("Only train images have 3D labeled lines.") 173 | vis_lines = datacollection.imgname2imgclass[img_name].line3Ds 174 | open3d_vis_3d_lines(vis_lines) 175 | 176 | def visualize_2d_lines(img_path, savename,lines2D, lines3D, save_path="visualization/"): 177 | """ Plot lines for existing images. 178 | Args: 179 | img_path: string, path to the image. 180 | lines2D: list of ndarrays of size (N, 4). 181 | lines3D: list of objects with size of (N, 1). 182 | save_path: string, path to save the image. 183 | """ 184 | save_path = os.path.join(save_path,savename) 185 | img = read_image(img_path) 186 | plt.figure() 187 | plt.imshow(img) 188 | length = lines2D.shape[0] 189 | for i in range(length): 190 | k = lines2D[i,:] 191 | x = [k[0], k[2]] 192 | y = [k[1], k[3]] 193 | if lines3D is not None: 194 | c = 'lime' if lines3D[i] is None else 'red' 195 | else: 196 | c = 'lime' 197 | plt.plot(x, y, color=c) 198 | plt.savefig(save_path) 199 | # Close the figure to free up memory 200 | plt.close() 201 | 202 | def visualize_2d_lines_from_collection(datacollection, img_name, mode="offline"): 203 | """ 204 | Visualize 2D lines from datasetcollection 205 | Args: 206 | datacollection: DataCollection object 207 | img_name: string, image name 208 | mode: string, "offline" (take from exiting labels) or "online" (use detector model to get 2D points) 209 | """ 210 | if mode == "offline": 211 | line2Ds = datacollection.imgname2imgclass[img_name].line2Ds 212 | line3Ds = datacollection.imgname2imgclass[img_name].line3Ds 213 | elif mode == "online": 214 | line2Ds = datacollection.detect_lines2D(img_name) 215 | line3Ds = None 216 | else: 217 | raise ValueError("mode must be either 'offline' or 'online'") 218 | img_path = datacollection.get_image_path(img_name) 219 | save_name = img_name.replace("/","_") + "_lines_" + mode +".png" 220 | visualize_2d_lines(img_path, save_name, line2Ds, line3Ds) 221 | 222 | # -------------------------------- end line visualization -------------------------------- 223 | ########################################################################################## 224 | # -------------------------------- start point visualization ----------------------------- 225 | 226 | 227 | def visualize_2d_points(img_path, points2D, savename, colors='lime', ps=4, save_path="visualization/"): 228 | """Plot keypoints for existing images. 229 | Args: 230 | img_path: string, path to the image. 231 | points2D: list of ndarrays of size (N, 2). 232 | colors: string, or list of list of tuples (one for each keypoints). 233 | ps: size of the keypoints as float. 234 | save_path: string, path to save the image. 235 | """ 236 | save_path = os.path.join(save_path,savename) 237 | img = read_image(img_path) 238 | plt.figure() 239 | plt.imshow(img) 240 | if not isinstance(colors, list): 241 | colors = [colors] * len(points2D) 242 | for k, c in zip(points2D, colors): 243 | plt.scatter(k[0], k[1], c=c, s=ps, linewidths=0) 244 | plt.savefig(save_path) 245 | # Close the figure to free up memory 246 | plt.close() 247 | 248 | def visualize_2d_points_from_collection(datacollection, img_name, mode="offline"): 249 | """ 250 | Visualize 2D points from datasetcollection 251 | Args: 252 | datacollection: DataCollection object 253 | img_name: string, image name 254 | mode: string, "offline" (take from exiting labels) or "online" (use detector model to get 2D points) 255 | """ 256 | if mode == "offline": 257 | if img_name in datacollection.test_imgs: 258 | raise ValueError("Only train images have 2D labeled points.") 259 | points2D = datacollection.imgname2imgclass[img_name].points2Ds 260 | elif mode == "online": 261 | data = datacollection.detect_points2D(img_name) 262 | points2D = data["keypoints"][0].detach().cpu().numpy() 263 | else: 264 | raise ValueError("mode must be either 'offline' or 'online'") 265 | img_path = datacollection.get_image_path(img_name) 266 | save_name = img_name.replace("/","_") + "_points_" + mode +".png" 267 | visualize_2d_points(img_path, points2D, save_name) 268 | 269 | def open3d_get_point_set(points, color=[0.0, 0.0, 0.0], width=2, scale=1.0): 270 | """ 271 | convert a numpy array of points3D to an Open3D lines set 272 | Args: 273 | points (numpy array of Nx3): The 3D point map 274 | color (list[float]): The color of the lines 275 | width (float, optional): width of the line 276 | """ 277 | o3d_points, o3d_colors = [], [] 278 | for point in points: 279 | if np.sum(point) == 0: 280 | continue 281 | o3d_points.append(point) 282 | o3d_colors.append(color) 283 | point_set = o3d.geometry.PointCloud() 284 | point_set.points = o3d.utility.Vector3dVector(o3d_points) 285 | point_set.colors = o3d.utility.Vector3dVector(o3d_colors) 286 | return point_set 287 | 288 | def open3d_vis_3d_points(points3D:np.asanyarray, width=2, ranges=None, scale=1.0): 289 | """ 290 | Visualize a 3D point map with `Open3D `_ 291 | 292 | Args: 293 | points3D (list[:class:`datasets._base.Line3D` and/or None]): The 3D line map 294 | width (float, optional): width of the line 295 | """ 296 | 297 | vis = o3d.visualization.Visualizer() 298 | vis.create_window(height=1080, width=1920) 299 | point_set = open3d_get_point_set(points3D, width=width, scale=scale) 300 | vis.add_geometry(point_set) 301 | vis.run() 302 | vis.destroy_window() 303 | 304 | def open3d_vis_3d_points_from_datacollection(datacollection, mode="train"): 305 | ''' 306 | Visualize 3D points from datasetcollection 307 | Args: 308 | datacollection: DataCollection object 309 | mode: string, "train" or "test" 310 | ''' 311 | if mode != "train": 312 | raise ValueError("Currently only support 'train' mode.") 313 | vis_points = np.array([[0,0,0]]) 314 | imgs_list = datacollection.train_imgs if mode=="train" else datacollection.test_imgs 315 | for img in imgs_list: 316 | vis_points = np.concatenate((vis_points, datacollection.imgname2imgclass[img].points3Ds)) 317 | open3d_vis_3d_points(vis_points) 318 | 319 | 320 | def open3d_vis_3d_points_with_hightlightFrame(points3D, hightlight_points3D, width=2, ranges=None, scale=1.0): 321 | """ 322 | Save 3D point map with `Open3D `_ 323 | 324 | Args: 325 | points3D (list[:class:`datasets._base.Line3D` and/or None]): The 3D line map 326 | width (float, optional): width of the line 327 | """ 328 | 329 | point_set = open3d_get_point_set(points3D, width=width, scale=scale) 330 | highlight_point_set = open3d_get_point_set(hightlight_points3D, color=[0.0, 1.0, 0.0], width=width*2, scale=scale) 331 | 332 | # save the point_set 333 | o3d.io.write_point_cloud("visualization/point_set.ply", point_set) 334 | o3d.io.write_point_cloud("visualization/highlight_point_set.ply", highlight_point_set) 335 | ''' 336 | vis = o3d.visualization.Visualizer() 337 | vis.create_window(height=1080, width=1920) 338 | vis.add_geometry(point_set) 339 | vis.add_geometry(highlight_point_set) 340 | vis.run() 341 | vis.destroy_window() 342 | ''' 343 | 344 | 345 | ########################################################################################## 346 | # -------------------- merging points and lines for visualization together -------------- 347 | def visualize_2d_points_lines(img_path, points2D, lines2D, lines3D, savename, 348 | colors='lime', ps=4, save_path="visualization/"): 349 | """Plot keypoints for existing images. 350 | Args: 351 | img_path: string, path to the image. 352 | points2D: list of ndarrays of size (N, 2). 353 | lines2D: list of ndarrays of size (N, 4). 354 | lines3D: list of objects with size of (N, 1). 355 | colors: string, or list of list of tuples (one for each keypoints). 356 | ps: size of the keypoints as float. 357 | save_path: string, path to save the image. 358 | """ 359 | save_path = os.path.join(save_path,savename) 360 | img = read_image(img_path) 361 | plt.figure() 362 | plt.imshow(img) 363 | if not isinstance(colors, list): 364 | colors = [colors] * len(points2D) 365 | # visualize points 366 | for k, c in zip(points2D, colors): 367 | plt.scatter(k[0], k[1], c=c, s=ps, linewidths=0) 368 | 369 | # visualize lines 370 | length = lines2D.shape[0] 371 | for i in range(length): 372 | k = lines2D[i,:] 373 | x = [k[0], k[2]] 374 | y = [k[1], k[3]] 375 | if lines3D is not None: 376 | c = 'lime' if lines3D[i] is None else 'red' 377 | else: 378 | c = 'lime' 379 | plt.plot(x, y, color=c) 380 | plt.savefig(save_path) 381 | # Close the figure to free up memory 382 | plt.close() 383 | 384 | 385 | def visualize_2d_points_lines_from_collection(datacollection, img_name, mode="offline"): 386 | """ 387 | Visualize 2D points and lines from datasetcollection 388 | Args: 389 | datacollection: DataCollection object 390 | img_name: string, image name 391 | mode: string, "offline" (take from exiting labels) or "online" (use detector model to get 2D points) 392 | """ 393 | if mode == "offline": 394 | if img_name in datacollection.test_imgs: 395 | raise ValueError("Only train images have 2D labeled points.") 396 | points2D = datacollection.imgname2imgclass[img_name].points2Ds 397 | 398 | line2Ds = datacollection.imgname2imgclass[img_name].line2Ds 399 | line3Ds = datacollection.imgname2imgclass[img_name].line3Ds 400 | 401 | elif mode == "online": 402 | data = datacollection.detect_points2D(img_name) 403 | points2D = data["keypoints"][0].detach().cpu().numpy() 404 | 405 | line2Ds = datacollection.detect_lines2D(img_name) 406 | line3Ds = None 407 | else: 408 | raise ValueError("mode must be either 'offline' or 'online'") 409 | img_path = datacollection.get_image_path(img_name) 410 | save_name = img_name.replace("/","_") + "_points_lines_" + mode +".svg" 411 | 412 | visualize_2d_points_lines(img_path, points2D, line2Ds, line3Ds, save_name) 413 | 414 | 415 | ########################################################################################## 416 | # -------------------------------- Augmentation Visualization Debug ---------------------- 417 | 418 | import cv2 419 | 420 | def visualize_img_withlinesandpoints(image, points, lines, augmented=False): 421 | 422 | save_path = "visualization/" 423 | point_size = 1 424 | 425 | # Draw the original positions on the original image 426 | for position in points: 427 | cv2.circle(image, (int(position[0]), int(position[1])), point_size, (0, 0, 255), -1) 428 | # Draw the original lines on the original image 429 | for line in lines: 430 | cv2.line(image, (int(line[0]), int(line[1])), (int(line[2]), int(line[3])), (255, 0, 0), 1) 431 | cv2.circle(image, (int(line[0]), int(line[1])), point_size*3, (0, 0, 255), -1) 432 | cv2.circle(image, (int(line[2]), int(line[3])), point_size*3, (0, 0, 255), -1) 433 | 434 | if augmented: 435 | cv2.imwrite(save_path+'Transformed_Image.jpg', image) 436 | else: 437 | cv2.imwrite(save_path+'Original_Image.jpg', image) 438 | 439 | ########################################################################################## 440 | # -------------------------------- draw camera poses ------------------------------------- 441 | from util.read_write_model import qvec2rotmat 442 | def draw_camera(K, R, t, w, h, 443 | scale=1, color=[1, 0, 0]): 444 | """Create axis, plane and pyramed geometries in Open3D format. 445 | :param K: calibration matrix (camera intrinsics) 446 | :param R: rotation matrix 447 | :param t: translation 448 | :param w: image width 449 | :param h: image height 450 | :param scale: camera model scale 451 | :param color: color of the image plane and pyramid lines 452 | :return: camera model geometries (axis, plane and pyramid) 453 | """ 454 | 455 | # intrinsics 456 | K = K.copy() / scale 457 | Kinv = np.linalg.inv(K) 458 | 459 | # 4x4 transformation 460 | T = np.column_stack((R, t)) 461 | T = np.vstack((T, (0, 0, 0, 1))) 462 | 463 | # axis 464 | axis = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.5 * scale) 465 | axis.transform(T) 466 | 467 | # points in pixel 468 | points_pixel = [ 469 | [0, 0, 0], 470 | [0, 0, 1], 471 | [w, 0, 1], 472 | [0, h, 1], 473 | [w, h, 1], 474 | ] 475 | 476 | # pixel to camera coordinate system 477 | points = [Kinv @ p for p in points_pixel] 478 | 479 | # image plane 480 | width = abs(points[1][0]) + abs(points[3][0]) 481 | height = abs(points[1][1]) + abs(points[3][1]) 482 | plane = o3d.geometry.TriangleMesh.create_box(width, height, depth=1e-6) 483 | # plane.paint_uniform_color([0.5,0,0]) 484 | # plane.paint_uniform_color(color) 485 | plane.translate([points[1][0], points[1][1], scale]) 486 | plane.transform(T) 487 | 488 | # pyramid 489 | points_in_world = [(R @ p + t) for p in points] 490 | lines = [ 491 | [0, 1], 492 | [0, 2], 493 | [0, 3], 494 | [0, 4], 495 | [1, 2], 496 | [2, 4], 497 | [4, 3], 498 | [3, 1], 499 | ] 500 | colors = [color for i in range(len(lines))] 501 | line_set = o3d.geometry.LineSet( 502 | points=o3d.utility.Vector3dVector(points_in_world), 503 | lines=o3d.utility.Vector2iVector(lines)) 504 | line_set.colors = o3d.utility.Vector3dVector(colors) 505 | 506 | # return as list in Open3D format 507 | # return [axis, plane, line_set] 508 | # return [plane, line_set] 509 | return [line_set] 510 | 511 | 512 | 513 | def add_camera(vis, pose, camera, scale=0.1, gt = False, othermethod = False): 514 | plane_scale = 1 515 | # rotation 516 | R = qvec2rotmat(pose[3:]) 517 | # translation 518 | t = pose[:3] 519 | # invert 520 | t = -R.T @ t 521 | R = R.T 522 | # intrinsics 523 | 524 | if camera['model'] in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"): 525 | fx = fy = camera['params'][0] 526 | cx = camera['params'][1] 527 | cy = camera['params'][2] 528 | elif camera['model'] in ("PINHOLE", "OPENCV", "OPENCV_FISHEYE", "FULL_OPENCV"): 529 | fx = camera['params'][0] 530 | fy = camera['params'][1] 531 | cx = camera['params'][2] 532 | cy = camera['params'][3] 533 | else: 534 | raise Exception("Camera model not supported") 535 | 536 | # intrinsics 537 | K = np.identity(3) 538 | K[0, 0] = fx 539 | K[1, 1] = fy 540 | K[0, 2] = cx 541 | K[1, 2] = cy 542 | if othermethod: 543 | color = [0,1,0] 544 | else: 545 | color = [1, 0, 0] if gt else [0, 0, 1] 546 | # create axis, plane and pyramed geometries that will be drawn 547 | cam_model = draw_camera(K, R, t, camera['width']*plane_scale, camera['height']*plane_scale, scale, color) 548 | for i in cam_model: 549 | vis.add_geometry(i) 550 | 551 | --------------------------------------------------------------------------------