├── DEMO.md ├── LICENSE ├── README.md ├── assets └── onepose-github-teaser.gif ├── configs ├── config.yaml ├── experiment │ ├── object_detector.yaml │ ├── test_GATsSPG.yaml │ ├── test_demo.yaml │ ├── test_sample.yaml │ └── train_GATsSPG.yaml └── preprocess │ ├── merge_anno.yaml │ ├── sfm_spp_spg_demo.yaml │ ├── sfm_spp_spg_sample.yaml │ ├── sfm_spp_spg_test.yaml │ ├── sfm_spp_spg_train.yaml │ └── sfm_spp_spg_val.yaml ├── environment.yaml ├── feature_matching_object_detector.py ├── inference.py ├── inference_demo.py ├── parse_scanned_data.py ├── requirements.txt ├── run.py ├── scripts ├── demo_pipeline.sh ├── parse_full_img.sh └── prepare_2D_matching_resources.sh ├── src ├── callbacks │ ├── custom_callbacks.py │ └── wandb_callbacks.py ├── datamodules │ └── GATs_spg_datamodule.py ├── datasets │ ├── GATs_spg_dataset.py │ └── normalized_dataset.py ├── evaluators │ └── cmd_evaluator.py ├── local_feature_2D_detector │ ├── __init__.py │ └── local_feature_2D_detector.py ├── losses │ └── focal_loss.py ├── models │ ├── GATsSPG_architectures │ │ ├── GATs.py │ │ └── GATs_SuperGlue.py │ ├── GATsSPG_lightning_model.py │ ├── extractors │ │ └── SuperPoint │ │ │ └── superpoint.py │ └── matchers │ │ ├── SuperGlue │ │ └── superglue.py │ │ └── nn │ │ └── nearest_neighbour.py ├── sfm │ ├── extract_features.py │ ├── generate_empty.py │ ├── global_ba.py │ ├── match_features.py │ ├── pairs_from_poses.py │ ├── postprocess │ │ ├── feature_process.py │ │ ├── filter_points.py │ │ └── filter_tkl.py │ └── triangulation.py ├── tracker │ ├── __init__.py │ ├── ba_tracker.py │ └── tracking_utils.py └── utils │ ├── colmap │ ├── database.py │ └── read_write_model.py │ ├── comm.py │ ├── data_utils.py │ ├── eval_utils.py │ ├── model_io.py │ ├── path_utils.py │ ├── template_utils.py │ └── vis_utils.py ├── train.py └── video2img.py /DEMO.md: -------------------------------------------------------------------------------- 1 | # OnePose Demo on Custom Data (WIP) 2 | In this tutorial we introduce the demo of OnePose running with data captured 3 | with our **OnePose Cap** application available for iOS device. 4 | The app is still under preparing for release. 5 | However, you can try it with the [sample data]() and skip the first step. 6 | 7 | ### Step 1: Capture the mapping sequence and the test sequence with OnePose Cap. 8 | #### The app is under brewing🍺 coming soon. 9 | 10 | ### Step 2: Organize the file structure of collected sequences 11 | 1. Export the collected mapping sequence and the test sequence to the PC. 12 | 2. Rename the **annotate** and **test** sequences directories to ``your_obj_name-annotate`` and `your_obj_name-test` respectively and organize the data as the follow structure: 13 | ``` 14 | |--- /your/path/to/scanned_data 15 | | |--- your_obj_name 16 | | | |---your_obj_name-annotate 17 | | | |---your_obj_name-test 18 | ``` 19 | Refer to the [sample data]() as an example. 20 | 3. Link the collected data to the project directory 21 | ```shell 22 | REPO_ROOT=/path/to/OnePose 23 | ln -s /path/to/scanned_data $REPO_ROOT/data/demo 24 | ``` 25 | 26 | Now the data is prepared! 27 | 28 | ### Step 3: Run OnePose with collected data 29 | Download the [pretrained OnePose model](https://drive.google.com/drive/folders/1VjLLjJ9oxjKV5Xy3Aty0uQUVwyEhgtIE?usp=sharing) and move it to `${REPO_ROOT}/data/model/checkpoints/onepose/GATsSPG.ckpt`. 30 | 31 | [Optional] To run OnePose with tracking modeule, pelase install [DeepLM](https://github.com/hjwdzh/DeepLM.git). 32 | Please make sure the sample program in `DeepLM` can be correctly executed to ensure successful installation. 33 | 34 | 35 | Execute the following commands, and a demo video naming `demo_video.mp4` will be saved in the folder of the test sequence. 36 | ```shell 37 | REPO_ROOT=/path/to/OnePose 38 | OBJ_NAME=your_obj_name 39 | 40 | cd $REPO_ROOT 41 | conda activate OnePose 42 | 43 | bash scripts/demo_pipeline.sh $OBJ_NAME 44 | 45 | # [Optional] running OnePose with tracking 46 | export PYTHONPATH=$PYTHONPATH:/path/to/DeepLM/build 47 | export TORCH_USE_RTLD_GLOBAL=YES 48 | 49 | bash scripts/demo_pipeline.sh $OBJ_NAME --WITH_TRACKING 50 | 51 | ``` 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OnePose: One-Shot Object Pose Estimation without CAD Models 2 | ### [Project Page](https://zju3dv.github.io/onepose) | [Paper](https://arxiv.org/pdf/2205.12257.pdf) 3 |
4 | 5 | > OnePose: One-Shot Object Pose Estimation without CAD Models 6 | > [Jiaming Sun](https://jiamingsun.ml)\*, [Zihao Wang](http://zihaowang.xyz/)\*, [Siyu Zhang](https://derizsy.github.io/)\*, [Xingyi He](https://github.com/hxy-123/), [Hongcheng Zhao](https://github.com/HongchengZhao), [Guofeng Zhang](http://www.cad.zju.edu.cn/home/gfzhang/), [Xiaowei Zhou](https://xzhou.me) 7 | > CVPR 2022 8 | 9 | ![demo_vid](assets/onepose-github-teaser.gif) 10 | 11 | ## TODO List 12 | - [x] Training and inference code. 13 | - [x] Pipeline to reproduce the evaluation results on the proposed OnePose dataset. 14 | - [x] `OnePose Cap` app is available at the [App Store](https://apps.apple.com/cn/app/onepose-capture/id6447052065?l=en-GB) (iOS only) to capture your own training and test data. 15 | - [x] Demo pipeline for running OnePose with custom-captured data. 16 | 17 | ## Installation 18 | 19 | ```shell 20 | conda env create -f environment.yaml 21 | conda activate onepose 22 | ``` 23 | We use [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork) and [SuperGlue](https://github.com/magicleap/SuperPointPretrainedNetwork) 24 | for 2D feature detection and matching in this project. 25 | We can't provide the code directly due its LICENSE requirements, please download the inference code and pretrained models using the following script: 26 | ```shell 27 | REPO_ROOT=/path/to/OnePose 28 | cd $REPO_ROOT 29 | sh ./scripts/prepare_2D_matching_resources.sh 30 | ``` 31 | 32 | [COLMAP](https://colmap.github.io/) is used in this project for Structure-from-Motion. 33 | Please refer to the official [instructions](https://colmap.github.io/install.html) for the installation. 34 | 35 | [Optional, WIP] You may optionally try out our web-based 3D visualization tool [Wis3D](https://github.com/zju3dv/Wis3D) for convenient and interactive visualizations of feature matches. We also provide many other cool visualization features in Wis3D, welcome to try it out. 36 | 37 | ```bash 38 | # Working in progress, should be ready very soon, only available on test-pypi now. 39 | pip install -i https://test.pypi.org/simple/ wis3d 40 | ``` 41 | 42 | ## Training and Evaluation on OnePose dataset 43 | ### Dataset setup 44 | 1. Download OnePose dataset from [onedrive storage](https://zjueducn-my.sharepoint.com/:f:/g/personal/zihaowang_zju_edu_cn/ElfzHE0sTXxNndx6uDLWlbYB-2zWuLfjNr56WxF11_DwSg?e=GKI0Df) and extract them into `$/your/path/to/onepose_datasets`. 45 | The directory should be organized in the following structure: 46 | ``` 47 | |--- /your/path/to/onepose_datasets 48 | | |--- train_data 49 | | |--- val_data 50 | | |--- test_data 51 | | |--- sample_data 52 | ``` 53 | 54 | 2. Build the dataset symlinks 55 | ```shell 56 | REPO_ROOT=/path/to/OnePose 57 | ln -s /your/path/to/onepose_datasets $REPO_ROOT/data/onepose_datasets 58 | ``` 59 | 60 | 3. Run Structure-from-Motion for the data sequences 61 | 62 | Reconstructed the object point cloud and 2D-3D correspondences are needed for both training and test objects: 63 | ```python 64 | python run.py +preprocess=sfm_spp_spg_train.yaml # for training data 65 | python run.py +preprocess=sfm_spp_spg_test.yaml # for testing data 66 | python run.py +preprocess=sfm_spp_spg_val.yaml # for val data 67 | python run.py +preprocess=sfm_spp_spg_sample.yaml # an example, if you don't want to test the full dataset 68 | ``` 69 | 70 | ### Inference on OnePose dataset 71 | 1. Download the pretrain weights [pretrained model](https://drive.google.com/drive/folders/1VjLLjJ9oxjKV5Xy3Aty0uQUVwyEhgtIE?usp=sharing) and move it to `${REPO_ROOT}/data/model/checkpoints/onepose/GATsSPG.ckpt`. 72 | 73 | 2. Inference with category-agnostic 2D object detection. 74 | 75 | When deploying OnePose to a real world system, 76 | an off-the-shelf category-level 2D object detector like [YOLOv5](https://github.com/ultralytics/yolov5) can be used. 77 | However, this could defeat the category-agnostic nature of OnePose. 78 | We can instead use a feature-matching-based pipeline for 2D object detection, which locates the scanned object on the query image through 2D feature matching. 79 | Note that the 2D object detection is only necessary during the initialization. 80 | After the initialization, the 2D bounding box can be obtained from projecting the previously detected 3D bounding box to the current camera frame. 81 | Please refer to the [supplementary material](https://zju3dv.github.io/onepose/files/onepose_supp.pdf) for more details. 82 | 83 | ```python 84 | # Obtaining category-agnostic 2D object detection results first. 85 | # Increasing the `n_ref_view` will improve the detection robustness but with the cost of slowing down the initialization speed. 86 | python feature_matching_object_detector.py +experiment=object_detector.yaml n_ref_view=15 87 | 88 | # Running pose estimation with `object_detect_mode` set to `feature_matching`. 89 | # Note that enabling visualization will slow down the inference. 90 | python inference.py +experiment=test_GATsSPG.yaml object_detect_mode=feature_matching save_wis3d=False 91 | ``` 92 | 93 | 3. Running inference with ground-truth 2D bounding boxes 94 | 95 | The following command should reproduce results in the paper, which use 2D boxes projected from 3D boxes as object detection results. 96 | 97 | ```python 98 | # Note that enabling visualization will slow down the inference. 99 | python inference.py +experiment=test_GATsSPG.yaml object_detect_mode=GT_box save_wis3d=False # for testing data 100 | ``` 101 | 102 | 103 | 104 | 4. [Optional] Visualize matching and estimated poses with Wis3D. Make sure the flag `save_wis3d` is set as True in testing 105 | and the full images are extracted from `Frames.m4v` by script `scripts/parse_full_img.sh`. 106 | The visualization file will be saved under `cfg.output.vis_dir` directory which is set as `GATsSPG` by default. 107 | Run the following commands for visualization: 108 | ```shell 109 | sh ./scripts/parse_full_img.sh path_to_Frames_m4v # parse full image from m4v file 110 | 111 | cd runs/vis/GATsSPG 112 | wis3d --vis_dir ./ --host localhost --port 11020 113 | ``` 114 | This would launch a web service for visualization at port 11020. 115 | 116 | 117 | ### Training the GATs Network 118 | 1. Prepare ground-truth annotations. Merge annotations of training/val data: 119 | ```python 120 | python run.py +preprocess=merge_anno task_name=onepose split=train 121 | python run.py +preprocess=merge_anno task_name=onepose split=val 122 | ``` 123 | 124 | 2. Begin training 125 | ```python 126 | python train.py +experiment=train_GATsSPG task_name=onepose exp_name=training_onepose 127 | ``` 128 | 129 | All model weights will be saved under `${REPO_ROOT}/data/models/checkpoints/${exp_name}` and logs will be saved under `${REPO_ROOT}/data/logs/${exp_name}`. 130 | 134 | 135 | ## Citation 136 | If you find this code useful for your research, please use the following BibTeX entry. 137 | 138 | ```bibtex 139 | @article{sun2022onepose, 140 | title={{OnePose}: One-Shot Object Pose Estimation without {CAD} Models}, 141 | author = {Sun, Jiaming and Wang, Zihao and Zhang, Siyu and He, Xingyi and Zhao, Hongcheng and Zhang, Guofeng and Zhou, Xiaowei}, 142 | journal={CVPR}, 143 | year={2022}, 144 | } 145 | ``` 146 | 147 | ## Copyright 148 | 149 | This work is affiliated with ZJU-SenseTime Joint Lab of 3D Vision, and its intellectual property belongs to SenseTime Group Ltd. 150 | 151 | ``` 152 | Copyright SenseTime. All Rights Reserved. 153 | 154 | Licensed under the Apache License, Version 2.0 (the "License"); 155 | you may not use this file except in compliance with the License. 156 | You may obtain a copy of the License at 157 | 158 | http://www.apache.org/licenses/LICENSE-2.0 159 | 160 | Unless required by applicable law or agreed to in writing, software 161 | distributed under the License is distributed on an "AS IS" BASIS, 162 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 163 | See the License for the specific language governing permissions and 164 | limitations under the License. 165 | ``` 166 | 167 | ## Acknowledgement 168 | Part of our code is borrowed from [hloc](https://github.com/cvg/Hierarchical-Localization) and [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), thanks to their authors for the great works. 169 | -------------------------------------------------------------------------------- /assets/onepose-github-teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/OnePose/d0c313a5c36994658e63eac7ca3a63d7e3573d7b/assets/onepose-github-teaser.gif -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - trainer: null 6 | - model: null 7 | - datamodule: null 8 | - callbacks: null # set this to null if you don't want to use callbacks 9 | - logger: null # set logger here or use command line (e.g. `python train.py logger=wandb`) 10 | 11 | # enable color logging 12 | # - override hydra/hydra_logging: colorlog 13 | # - override hydra/job_logging: colorlog 14 | 15 | 16 | # path to original working directory (that `train.py` was executed from in command line) 17 | # hydra hijacks working directory by changing it to the current log directory, 18 | # so it's useful to have path to original working directory as a special variable 19 | # read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory 20 | work_dir: ${hydra:runtime.cwd} 21 | 22 | 23 | # path to folder with data 24 | data_dir: ${work_dir}/data 25 | 26 | 27 | # pretty print config at the start of the run using Rich library 28 | print_config: True 29 | 30 | 31 | # output paths for hydra logs 32 | # hydra: 33 | # run: 34 | # dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} 35 | # sweep: 36 | # dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} 37 | # subdir: ${hydra.job.num} 38 | 39 | hydra: 40 | run: 41 | dir: ${work_dir} -------------------------------------------------------------------------------- /configs/experiment/object_detector.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: inference 4 | task_name: local_feature_object_detector 5 | suffix: '' 6 | 7 | model: 8 | extractor_model_path: ${data_dir}/models/extractors/SuperPoint/superpoint_v1.pth 9 | match_model_path: ${data_dir}/models/matchers/SuperGlue/superglue_outdoor.pth 10 | 11 | network: 12 | detection: superpoint 13 | matching: superglue 14 | 15 | n_ref_view: 15 16 | scan_data_dir: ${data_dir}/onepose_datasets/test_data 17 | sfm_model_dir: ${data_dir}/sfm_model 18 | 19 | input: 20 | data_dirs: 21 | - ${scan_data_dir}/0408-colorbox-box colorbox-4 22 | - ${scan_data_dir}/0409-aptamil-box aptamil-3 23 | - ${scan_data_dir}/0419-cookies2-others cookies2-4 24 | - ${scan_data_dir}/0422-qvduoduo-box qvduoduo-4 25 | - ${scan_data_dir}/0423-oreo-box oreo-4 26 | - ${scan_data_dir}/0424-chocbox-box chocbox-4 27 | - ${scan_data_dir}/0447-nabati-box nabati-5 28 | - ${scan_data_dir}/0450-hlychocpie-box hlychocpie-4 29 | - ${scan_data_dir}/0452-hlymatchapie-box hlymatchapie-4 30 | - ${scan_data_dir}/0455-strawberryoreo-box strawberryoreo-4 31 | - ${scan_data_dir}/0456-chocoreo-box chocoreo-4 32 | - ${scan_data_dir}/0458-hetaocakes-box hetaocakes-4 33 | - ${scan_data_dir}/0459-jzhg-box jzhg-4 34 | - ${scan_data_dir}/0466-mfmilkcake-box mfmilkcake-4 35 | - ${scan_data_dir}/0468-minipuff-box minipuff-4 36 | - ${scan_data_dir}/0469-diycookies-box diycookies-4 37 | - ${scan_data_dir}/0470-eggrolls-box eggrolls-4 38 | - ${scan_data_dir}/0471-hlyormosiapie-box hlyormosiapie-4 39 | - ${scan_data_dir}/0472-chocoreo-bottle chocoreo-4 40 | - ${scan_data_dir}/0473-twgrassjelly1-box twgrassjelly1-4 41 | - ${scan_data_dir}/0474-twgrassjelly2-box twgrassjelly2-4 42 | - ${scan_data_dir}/0476-giraffecup-bottle giraffecup-4 43 | - ${scan_data_dir}/0480-ljcleaner-others ljcleaner-4 44 | - ${scan_data_dir}/0483-ambrosial-box ambrosial-4 45 | - ${scan_data_dir}/0486-sanqitoothpaste-box sanqitoothpaste-4 46 | - ${scan_data_dir}/0487-jindiantoothpaste-box jindiantoothpaste-4 47 | - ${scan_data_dir}/0488-jijiantoothpaste-box jijiantoothpaste-4 48 | - ${scan_data_dir}/0489-taipingcookies-others taipingcookies-4 49 | - ${scan_data_dir}/0490-haochidiancookies-others haochidiancookies-4 50 | - ${scan_data_dir}/0492-tuccookies-box tuccookies-4 51 | - ${scan_data_dir}/0493-haochidianeggroll-box haochidianeggroll-4 52 | - ${scan_data_dir}/0494-qvduoduocookies-box qvduoduocookies-4 53 | - ${scan_data_dir}/0495-fulingstapler-box fulingstapler-4 54 | - ${scan_data_dir}/0496-delistapler-box delistapler-4 55 | - ${scan_data_dir}/0497-delistaplerlarger-box delistaplerlarger-4 56 | - ${scan_data_dir}/0498-yousuanru-box yousuanru-4 57 | - ${scan_data_dir}/0500-chocfranzzi-box chocfranzzi-4 58 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-4 59 | - ${scan_data_dir}/0502-shufujia-box shufujia-5 60 | - ${scan_data_dir}/0503-shufujiawhite-box shufujiawhite-3 61 | - ${scan_data_dir}/0504-lux-box lux-4 62 | - ${scan_data_dir}/0508-yqsl-others yqsl-4 63 | - ${scan_data_dir}/0510-yqslmilk-others yqslmilk-4 64 | - ${scan_data_dir}/0511-policecar-others policecar-4 65 | - ${scan_data_dir}/0517-nationalgeo-box nationalgeo-4 66 | - ${scan_data_dir}/0518-jasmine-box jasmine-4 67 | - ${scan_data_dir}/0519-backpack1-box backpack1-4 68 | - ${scan_data_dir}/0520-lipault-box lipault-4 69 | - ${scan_data_dir}/0521-ranova-box ranova-4 70 | - ${scan_data_dir}/0522-milkbox-box milkbox-4 71 | - ${scan_data_dir}/0523-edibleoil-others edibleoil-4 72 | - ${scan_data_dir}/0525-toygrab-others toygrab-2 73 | - ${scan_data_dir}/0526-toytable-others toytable-3 74 | - ${scan_data_dir}/0527-spalding-others spalding-2 75 | - ${scan_data_dir}/0534-tonkotsuramen-box tonkotsuramen-4 76 | - ${scan_data_dir}/0535-odbmilk-box odbmilk-4 77 | - ${scan_data_dir}/0537-petsnack-box petsnack-4 78 | - ${scan_data_dir}/0539-spamwrapper-others spamwrapper-5 79 | - ${scan_data_dir}/0543-brownhouse-others brownhouse-4 80 | - ${scan_data_dir}/0547-cubebox-box cubebox-4 81 | - ${scan_data_dir}/0548-duck-others duck-4 82 | - ${scan_data_dir}/0550-greenbox-box greenbox-4 83 | - ${scan_data_dir}/0551-milk-others milk-4 84 | - ${scan_data_dir}/0552-mushroom-others mushroom-4 85 | - ${scan_data_dir}/0557-santachoc-others santachoc-4 86 | - ${scan_data_dir}/0558-teddychoc-others teddychoc-4 87 | - ${scan_data_dir}/0559-tissuebox-box tissuebox-4 88 | - ${scan_data_dir}/0560-tofubox-box tofubox-4 89 | - ${scan_data_dir}/0564-biatee-others biatee-4 90 | - ${scan_data_dir}/0565-biscuits-box biscuits-4 91 | - ${scan_data_dir}/0568-cornflakes-box cornflakes-5 92 | - ${scan_data_dir}/0570-kasekuchen-box kasekuchen-4 93 | - ${scan_data_dir}/0577-schoko-box schoko-4 94 | - ${scan_data_dir}/0578-tee-others tee-4 95 | - ${scan_data_dir}/0579-tomatocan-bottle tomatocan-4 96 | - ${scan_data_dir}/0580-xmaxbox-others xmaxbox-4 97 | - ${scan_data_dir}/0582-yogurtlarge-others yogurtlarge-4 98 | - ${scan_data_dir}/0583-yogurtmedium-others yogurtmedium-4 99 | - ${scan_data_dir}/0594-martinBootsLeft-others martinBootsLeft-2 100 | - ${scan_data_dir}/0595-martinBootsRight-others martinBootsRight-4 101 | 102 | sfm_model_dirs: 103 | - ${sfm_model_dir}/0408-colorbox-box 104 | - ${sfm_model_dir}/0409-aptamil-box 105 | - ${sfm_model_dir}/0419-cookies2-others 106 | - ${sfm_model_dir}/0422-qvduoduo-box 107 | - ${sfm_model_dir}/0423-oreo-box 108 | - ${sfm_model_dir}/0424-chocbox-box 109 | - ${sfm_model_dir}/0447-nabati-box 110 | - ${sfm_model_dir}/0450-hlychocpie-box 111 | - ${sfm_model_dir}/0452-hlymatchapie-box 112 | - ${sfm_model_dir}/0455-strawberryoreo-box 113 | - ${sfm_model_dir}/0456-chocoreo-box 114 | - ${sfm_model_dir}/0458-hetaocakes-box 115 | - ${sfm_model_dir}/0459-jzhg-box 116 | - ${sfm_model_dir}/0466-mfmilkcake-box 117 | - ${sfm_model_dir}/0468-minipuff-box 118 | - ${sfm_model_dir}/0469-diycookies-box 119 | - ${sfm_model_dir}/0470-eggrolls-box 120 | - ${sfm_model_dir}/0471-hlyormosiapie-box 121 | - ${sfm_model_dir}/0472-chocoreo-bottle 122 | - ${sfm_model_dir}/0473-twgrassjelly1-box 123 | - ${sfm_model_dir}/0474-twgrassjelly2-box 124 | - ${sfm_model_dir}/0476-giraffecup-bottle 125 | - ${sfm_model_dir}/0480-ljcleaner-others 126 | - ${sfm_model_dir}/0483-ambrosial-box 127 | - ${sfm_model_dir}/0486-sanqitoothpaste-box 128 | - ${sfm_model_dir}/0487-jindiantoothpaste-box 129 | - ${sfm_model_dir}/0488-jijiantoothpaste-box 130 | - ${sfm_model_dir}/0489-taipingcookies-others 131 | - ${sfm_model_dir}/0490-haochidiancookies-others 132 | - ${sfm_model_dir}/0492-tuccookies-box 133 | - ${sfm_model_dir}/0493-haochidianeggroll-box 134 | - ${sfm_model_dir}/0494-qvduoduocookies-box 135 | - ${sfm_model_dir}/0495-fulingstapler-box 136 | - ${sfm_model_dir}/0496-delistapler-box 137 | - ${sfm_model_dir}/0497-delistaplerlarger-box 138 | - ${sfm_model_dir}/0498-yousuanru-box 139 | - ${sfm_model_dir}/0500-chocfranzzi-box 140 | - ${sfm_model_dir}/0501-matchafranzzi-box 141 | - ${sfm_model_dir}/0502-shufujia-box 142 | - ${sfm_model_dir}/0503-shufujiawhite-box 143 | - ${sfm_model_dir}/0504-lux-box 144 | - ${sfm_model_dir}/0508-yqsl-others 145 | - ${sfm_model_dir}/0510-yqslmilk-others 146 | - ${sfm_model_dir}/0511-policecar-others 147 | - ${sfm_model_dir}/0517-nationalgeo-box 148 | - ${sfm_model_dir}/0518-jasmine-box 149 | - ${sfm_model_dir}/0519-backpack1-box 150 | - ${sfm_model_dir}/0520-lipault-box 151 | - ${sfm_model_dir}/0521-ranova-box 152 | - ${sfm_model_dir}/0522-milkbox-box 153 | - ${sfm_model_dir}/0523-edibleoil-others 154 | - ${sfm_model_dir}/0525-toygrab-others 155 | - ${sfm_model_dir}/0526-toytable-others 156 | - ${sfm_model_dir}/0527-spalding-others 157 | - ${sfm_model_dir}/0534-tonkotsuramen-box 158 | - ${sfm_model_dir}/0535-odbmilk-box 159 | - ${sfm_model_dir}/0537-petsnack-box 160 | - ${sfm_model_dir}/0539-spamwrapper-others 161 | - ${sfm_model_dir}/0543-brownhouse-others 162 | - ${sfm_model_dir}/0547-cubebox-box 163 | - ${sfm_model_dir}/0548-duck-others 164 | - ${sfm_model_dir}/0550-greenbox-box 165 | - ${sfm_model_dir}/0551-milk-others 166 | - ${sfm_model_dir}/0552-mushroom-others 167 | - ${sfm_model_dir}/0557-santachoc-others 168 | - ${sfm_model_dir}/0558-teddychoc-others 169 | - ${sfm_model_dir}/0559-tissuebox-box 170 | - ${sfm_model_dir}/0560-tofubox-box 171 | - ${sfm_model_dir}/0564-biatee-others 172 | - ${sfm_model_dir}/0565-biscuits-box 173 | - ${sfm_model_dir}/0568-cornflakes-box 174 | - ${sfm_model_dir}/0570-kasekuchen-box 175 | - ${sfm_model_dir}/0577-schoko-box 176 | - ${sfm_model_dir}/0578-tee-others 177 | - ${sfm_model_dir}/0579-tomatocan-bottle 178 | - ${sfm_model_dir}/0580-xmaxbox-others 179 | - ${sfm_model_dir}/0582-yogurtlarge-others 180 | - ${sfm_model_dir}/0583-yogurtmedium-others 181 | - ${sfm_model_dir}/0594-martinBootsLeft-others 182 | - ${sfm_model_dir}/0595-martinBootsRight-others -------------------------------------------------------------------------------- /configs/experiment/test_GATsSPG.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: inference 4 | task_name: test_onepose 5 | num_leaf: 8 6 | suffix: '' 7 | save_demo: False 8 | save_wis3d: False 9 | demo_root: ${data_dir}/runs/demo 10 | 11 | model: 12 | onepose_model_path: ${data_dir}/models/checkpoints/onepose/GATsSPG.ckpt 13 | extractor_model_path: ${data_dir}/models/extractors/SuperPoint/superpoint_v1.pth 14 | 15 | network: 16 | detection: superpoint 17 | matching: superglue 18 | 19 | # object_detect_mode: 'GT_box' # ["GT_box", "feature_matching"] 20 | object_detect_mode: 'GT_box' # ["GT_box", "feature_matching"] 21 | max_num_kp3d: 2500 22 | scan_data_dir: ${data_dir}/onepose_datasets/test_data 23 | sfm_model_dir: ${data_dir}/sfm_model 24 | 25 | input: 26 | data_dirs: 27 | - ${scan_data_dir}/0408-colorbox-box colorbox-4 28 | - ${scan_data_dir}/0409-aptamil-box aptamil-3 29 | - ${scan_data_dir}/0419-cookies2-others cookies2-4 30 | - ${scan_data_dir}/0422-qvduoduo-box qvduoduo-4 31 | - ${scan_data_dir}/0423-oreo-box oreo-4 32 | - ${scan_data_dir}/0424-chocbox-box chocbox-4 33 | - ${scan_data_dir}/0447-nabati-box nabati-5 34 | - ${scan_data_dir}/0450-hlychocpie-box hlychocpie-4 35 | - ${scan_data_dir}/0452-hlymatchapie-box hlymatchapie-4 36 | - ${scan_data_dir}/0455-strawberryoreo-box strawberryoreo-4 37 | - ${scan_data_dir}/0456-chocoreo-box chocoreo-4 38 | - ${scan_data_dir}/0458-hetaocakes-box hetaocakes-4 39 | - ${scan_data_dir}/0459-jzhg-box jzhg-4 40 | - ${scan_data_dir}/0466-mfmilkcake-box mfmilkcake-4 41 | - ${scan_data_dir}/0468-minipuff-box minipuff-4 42 | - ${scan_data_dir}/0469-diycookies-box diycookies-4 43 | - ${scan_data_dir}/0470-eggrolls-box eggrolls-4 44 | - ${scan_data_dir}/0471-hlyormosiapie-box hlyormosiapie-4 45 | - ${scan_data_dir}/0472-chocoreo-bottle chocoreo-4 46 | - ${scan_data_dir}/0473-twgrassjelly1-box twgrassjelly1-4 47 | - ${scan_data_dir}/0474-twgrassjelly2-box twgrassjelly2-4 48 | - ${scan_data_dir}/0476-giraffecup-bottle giraffecup-4 49 | - ${scan_data_dir}/0480-ljcleaner-others ljcleaner-4 50 | - ${scan_data_dir}/0483-ambrosial-box ambrosial-4 51 | - ${scan_data_dir}/0486-sanqitoothpaste-box sanqitoothpaste-4 52 | - ${scan_data_dir}/0487-jindiantoothpaste-box jindiantoothpaste-4 53 | - ${scan_data_dir}/0488-jijiantoothpaste-box jijiantoothpaste-4 54 | - ${scan_data_dir}/0489-taipingcookies-others taipingcookies-4 55 | - ${scan_data_dir}/0490-haochidiancookies-others haochidiancookies-4 56 | - ${scan_data_dir}/0492-tuccookies-box tuccookies-4 57 | - ${scan_data_dir}/0493-haochidianeggroll-box haochidianeggroll-4 58 | - ${scan_data_dir}/0494-qvduoduocookies-box qvduoduocookies-4 59 | - ${scan_data_dir}/0495-fulingstapler-box fulingstapler-4 60 | - ${scan_data_dir}/0496-delistapler-box delistapler-4 61 | - ${scan_data_dir}/0497-delistaplerlarger-box delistaplerlarger-4 62 | - ${scan_data_dir}/0498-yousuanru-box yousuanru-4 63 | - ${scan_data_dir}/0500-chocfranzzi-box chocfranzzi-4 64 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-4 65 | - ${scan_data_dir}/0502-shufujia-box shufujia-5 66 | - ${scan_data_dir}/0503-shufujiawhite-box shufujiawhite-3 67 | - ${scan_data_dir}/0504-lux-box lux-4 68 | - ${scan_data_dir}/0508-yqsl-others yqsl-4 69 | - ${scan_data_dir}/0510-yqslmilk-others yqslmilk-4 70 | - ${scan_data_dir}/0511-policecar-others policecar-4 71 | - ${scan_data_dir}/0517-nationalgeo-box nationalgeo-4 72 | - ${scan_data_dir}/0518-jasmine-box jasmine-4 73 | - ${scan_data_dir}/0519-backpack1-box backpack1-4 74 | - ${scan_data_dir}/0520-lipault-box lipault-4 75 | - ${scan_data_dir}/0521-ranova-box ranova-4 76 | - ${scan_data_dir}/0522-milkbox-box milkbox-4 77 | - ${scan_data_dir}/0523-edibleoil-others edibleoil-4 78 | - ${scan_data_dir}/0525-toygrab-others toygrab-2 79 | - ${scan_data_dir}/0526-toytable-others toytable-3 80 | - ${scan_data_dir}/0527-spalding-others spalding-2 81 | - ${scan_data_dir}/0534-tonkotsuramen-box tonkotsuramen-4 82 | - ${scan_data_dir}/0535-odbmilk-box odbmilk-4 83 | - ${scan_data_dir}/0537-petsnack-box petsnack-4 84 | - ${scan_data_dir}/0539-spamwrapper-others spamwrapper-5 85 | - ${scan_data_dir}/0543-brownhouse-others brownhouse-4 86 | - ${scan_data_dir}/0547-cubebox-box cubebox-4 87 | - ${scan_data_dir}/0548-duck-others duck-4 88 | - ${scan_data_dir}/0550-greenbox-box greenbox-4 89 | - ${scan_data_dir}/0551-milk-others milk-4 90 | - ${scan_data_dir}/0552-mushroom-others mushroom-4 91 | - ${scan_data_dir}/0557-santachoc-others santachoc-4 92 | - ${scan_data_dir}/0558-teddychoc-others teddychoc-4 93 | - ${scan_data_dir}/0559-tissuebox-box tissuebox-4 94 | - ${scan_data_dir}/0560-tofubox-box tofubox-4 95 | - ${scan_data_dir}/0564-biatee-others biatee-4 96 | - ${scan_data_dir}/0565-biscuits-box biscuits-4 97 | - ${scan_data_dir}/0568-cornflakes-box cornflakes-5 98 | - ${scan_data_dir}/0570-kasekuchen-box kasekuchen-4 99 | - ${scan_data_dir}/0577-schoko-box schoko-4 100 | - ${scan_data_dir}/0578-tee-others tee-4 101 | - ${scan_data_dir}/0579-tomatocan-bottle tomatocan-4 102 | - ${scan_data_dir}/0580-xmaxbox-others xmaxbox-4 103 | - ${scan_data_dir}/0582-yogurtlarge-others yogurtlarge-4 104 | - ${scan_data_dir}/0583-yogurtmedium-others yogurtmedium-4 105 | - ${scan_data_dir}/0594-martinBootsLeft-others martinBootsLeft-2 106 | - ${scan_data_dir}/0595-martinBootsRight-others martinBootsRight-4 107 | 108 | sfm_model_dirs: 109 | - ${sfm_model_dir}/0408-colorbox-box 110 | - ${sfm_model_dir}/0409-aptamil-box 111 | - ${sfm_model_dir}/0419-cookies2-others 112 | - ${sfm_model_dir}/0422-qvduoduo-box 113 | - ${sfm_model_dir}/0423-oreo-box 114 | - ${sfm_model_dir}/0424-chocbox-box 115 | - ${sfm_model_dir}/0447-nabati-box 116 | - ${sfm_model_dir}/0450-hlychocpie-box 117 | - ${sfm_model_dir}/0452-hlymatchapie-box 118 | - ${sfm_model_dir}/0455-strawberryoreo-box 119 | - ${sfm_model_dir}/0456-chocoreo-box 120 | - ${sfm_model_dir}/0458-hetaocakes-box 121 | - ${sfm_model_dir}/0459-jzhg-box 122 | - ${sfm_model_dir}/0466-mfmilkcake-box 123 | - ${sfm_model_dir}/0468-minipuff-box 124 | - ${sfm_model_dir}/0469-diycookies-box 125 | - ${sfm_model_dir}/0470-eggrolls-box 126 | - ${sfm_model_dir}/0471-hlyormosiapie-box 127 | - ${sfm_model_dir}/0472-chocoreo-bottle 128 | - ${sfm_model_dir}/0473-twgrassjelly1-box 129 | - ${sfm_model_dir}/0474-twgrassjelly2-box 130 | - ${sfm_model_dir}/0476-giraffecup-bottle 131 | - ${sfm_model_dir}/0480-ljcleaner-others 132 | - ${sfm_model_dir}/0483-ambrosial-box 133 | - ${sfm_model_dir}/0486-sanqitoothpaste-box 134 | - ${sfm_model_dir}/0487-jindiantoothpaste-box 135 | - ${sfm_model_dir}/0488-jijiantoothpaste-box 136 | - ${sfm_model_dir}/0489-taipingcookies-others 137 | - ${sfm_model_dir}/0490-haochidiancookies-others 138 | - ${sfm_model_dir}/0492-tuccookies-box 139 | - ${sfm_model_dir}/0493-haochidianeggroll-box 140 | - ${sfm_model_dir}/0494-qvduoduocookies-box 141 | - ${sfm_model_dir}/0495-fulingstapler-box 142 | - ${sfm_model_dir}/0496-delistapler-box 143 | - ${sfm_model_dir}/0497-delistaplerlarger-box 144 | - ${sfm_model_dir}/0498-yousuanru-box 145 | - ${sfm_model_dir}/0500-chocfranzzi-box 146 | - ${sfm_model_dir}/0501-matchafranzzi-box 147 | - ${sfm_model_dir}/0502-shufujia-box 148 | - ${sfm_model_dir}/0503-shufujiawhite-box 149 | - ${sfm_model_dir}/0504-lux-box 150 | - ${sfm_model_dir}/0508-yqsl-others 151 | - ${sfm_model_dir}/0510-yqslmilk-others 152 | - ${sfm_model_dir}/0511-policecar-others 153 | - ${sfm_model_dir}/0517-nationalgeo-box 154 | - ${sfm_model_dir}/0518-jasmine-box 155 | - ${sfm_model_dir}/0519-backpack1-box 156 | - ${sfm_model_dir}/0520-lipault-box 157 | - ${sfm_model_dir}/0521-ranova-box 158 | - ${sfm_model_dir}/0522-milkbox-box 159 | - ${sfm_model_dir}/0523-edibleoil-others 160 | - ${sfm_model_dir}/0525-toygrab-others 161 | - ${sfm_model_dir}/0526-toytable-others 162 | - ${sfm_model_dir}/0527-spalding-others 163 | - ${sfm_model_dir}/0534-tonkotsuramen-box 164 | - ${sfm_model_dir}/0535-odbmilk-box 165 | - ${sfm_model_dir}/0537-petsnack-box 166 | - ${sfm_model_dir}/0539-spamwrapper-others 167 | - ${sfm_model_dir}/0543-brownhouse-others 168 | - ${sfm_model_dir}/0547-cubebox-box 169 | - ${sfm_model_dir}/0548-duck-others 170 | - ${sfm_model_dir}/0550-greenbox-box 171 | - ${sfm_model_dir}/0551-milk-others 172 | - ${sfm_model_dir}/0552-mushroom-others 173 | - ${sfm_model_dir}/0557-santachoc-others 174 | - ${sfm_model_dir}/0558-teddychoc-others 175 | - ${sfm_model_dir}/0559-tissuebox-box 176 | - ${sfm_model_dir}/0560-tofubox-box 177 | - ${sfm_model_dir}/0564-biatee-others 178 | - ${sfm_model_dir}/0565-biscuits-box 179 | - ${sfm_model_dir}/0568-cornflakes-box 180 | - ${sfm_model_dir}/0570-kasekuchen-box 181 | - ${sfm_model_dir}/0577-schoko-box 182 | - ${sfm_model_dir}/0578-tee-others 183 | - ${sfm_model_dir}/0579-tomatocan-bottle 184 | - ${sfm_model_dir}/0580-xmaxbox-others 185 | - ${sfm_model_dir}/0582-yogurtlarge-others 186 | - ${sfm_model_dir}/0583-yogurtmedium-others 187 | - ${sfm_model_dir}/0594-martinBootsLeft-others 188 | - ${sfm_model_dir}/0595-martinBootsRight-others 189 | 190 | 191 | output: 192 | vis_dir: ${work_dir}/runs/vis/GATsSPG 193 | eval_dir: ${work_dir}/runs/eval/GATsSPG -------------------------------------------------------------------------------- /configs/experiment/test_demo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: inference 4 | task_name: demo 5 | num_leaf: 8 6 | suffix: '' 7 | save_demo: False 8 | save_wis3d: False 9 | use_tracking: False 10 | 11 | model: 12 | onepose_model_path: ${work_dir}/data/models/checkpoints/onepose/GATsSPG.ckpt 13 | extractor_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 14 | match_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth 15 | 16 | network: 17 | detection: superpoint 18 | matching: superglue 19 | 20 | max_num_kp3d: 2500 21 | 22 | input: 23 | data_dirs: null 24 | sfm_model_dirs: null 25 | 26 | output: 27 | vis_dir: ${work_dir}/runs/vis/demo 28 | eval_dir: ${work_dir}/runs/eval/demo -------------------------------------------------------------------------------- /configs/experiment/test_sample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: inference 4 | task_name: test_sample 5 | num_leaf: 8 6 | suffix: '' 7 | save_demo: False 8 | save_wis3d: False 9 | demo_root: ${data_dir}/runs/demo 10 | 11 | model: 12 | onepose_model_path: ${data_dir}/models/checkpoints/onepose/GATsSPG.ckpt 13 | extractor_model_path: ${data_dir}/models/extractors/SuperPoint/superpoint_v1.pth 14 | 15 | network: 16 | detection: superpoint 17 | matching: superglue 18 | 19 | object_detect_mode: 'GT_box' # ["GT_box", "feature_matching"] 20 | max_num_kp3d: 2500 21 | scan_data_dir: ${data_dir}/onepose_datasets/sample_data 22 | sfm_model_dir: ${data_dir}/sfm_model 23 | 24 | input: 25 | data_dirs: 26 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-4 27 | 28 | sfm_model_dirs: 29 | - ${sfm_model_dir}/0501-matchafranzzi-box 30 | 31 | output: 32 | vis_dir: ${work_dir}/runs/vis/test_sample 33 | eval_dir: ${work_dir}/runs/eval/test_sample -------------------------------------------------------------------------------- /configs/experiment/train_GATsSPG.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py +experiment=train_GATsSPG 5 | 6 | defaults: 7 | - override /trainer: null # override trainer to null so it's not loaded from main config defaults... 8 | - override /model: null 9 | - override /datamodule: null 10 | - override /callbacks: null 11 | - override /logger: null 12 | 13 | # we override default configurations with nulls to prevent them from loading at all 14 | # instead we define all modules and their paths directly in this config, 15 | # so everything is stored in one place for more readibility 16 | 17 | seed: 12345 18 | 19 | task_name: null 20 | exp_name: train_onepose 21 | trainer: 22 | _target_: pytorch_lightning.Trainer 23 | gpus: 24 | - 6 25 | min_epochs: 1 26 | max_epochs: 10 27 | gradient_clip_val: 0.5 28 | accumulate_grad_batches: 2 29 | weights_summary: null 30 | num_sanity_val_steps: 2 31 | 32 | 33 | model: 34 | # _target_: src.models.spg_model.LitModelSPG 35 | _target_: src.models.GATsSPG_lightning_model.LitModelGATsSPG 36 | optimizer: adam 37 | lr: 1e-3 38 | weight_decay: 0. 39 | architecture: SuperGlue 40 | 41 | milestones: [5, 10, 15, 20] 42 | gamma: 0.5 43 | 44 | descriptor_dim: 256 45 | keypoints_encoder: [32, 64, 128] 46 | sinkhorn_iterations: 100 47 | match_threshold: 0.2 48 | match_type: 'softmax' 49 | scale_factor: 0.07 50 | 51 | # focal loss 52 | focal_loss_alpha: 0.5 53 | focal_loss_gamma: 2 54 | pos_weights: 0.5 55 | neg_weights: 0.5 56 | 57 | # GATs 58 | include_self: True 59 | with_linear_transform: False 60 | additional: False 61 | 62 | # SuperPoint 63 | spp_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 64 | 65 | # trainer: 66 | # n_val_pairs_to_plot: 5 67 | 68 | datamodule: 69 | _target_: src.datamodules.GATs_spg_datamodule.GATsSPGDataModule 70 | data_dirs: ${data_dir}/sfm_model 71 | anno_dirs: outputs_${model.match_type}/anno 72 | train_anno_file: ${work_dir}/data/cache/${task_name}/train.json 73 | val_anno_file: ${work_dir}/data/cache/${task_name}/val.json 74 | batch_size: 8 75 | num_workers: 16 76 | num_leaf: 8 77 | pin_memory: True 78 | shape2d: 1000 79 | shape3d: 2000 80 | assign_pad_val: 0 81 | 82 | callbacks: 83 | model_checkpoint: 84 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 85 | monitor: "val/loss" 86 | save_top_k: -1 87 | save_last: True 88 | mode: "min" 89 | dirpath: '${data_dir}/models/checkpoints/${exp_name}' 90 | filename: '{epoch}' 91 | lr_monitor: 92 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 93 | logging_interval: 'step' 94 | 95 | logger: 96 | tensorboard: 97 | _target_: pytorch_lightning.loggers.TensorBoardLogger 98 | save_dir: '${data_dir}/logs' 99 | name: ${exp_name} 100 | default_hp_metric: False 101 | 102 | neptune: 103 | tags: ["best_model"] 104 | csv_logger: 105 | save_dir: "." 106 | 107 | hydra: 108 | run: 109 | dir: ${work_dir} -------------------------------------------------------------------------------- /configs/preprocess/merge_anno.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: merge_anno 4 | task_name: null 5 | split: 'train' 6 | 7 | train: 8 | names: 9 | - 0410-huiyuan-box 10 | - 0413-juliecookies-box 11 | - 0414-babydiapers-others 12 | - 0415-captaincup-others 13 | - 0416-tongyinoodles-others 14 | - 0418-cookies1-others 15 | - 0420-liuliumei-others 16 | - 0421-cannedfish-others 17 | - 0443-wheatbear-others 18 | - 0445-pistonhead-others 19 | - 0448-soyabeancookies-bottle 20 | - 0460-bhdsoyabeancookies-bottle 21 | - 0461-cranberrycookies-bottle 22 | - 0462-ylmilkpowder-bottle 23 | - 0463-camelmilk-bottle 24 | - 0464-mfchoccake-box 25 | - 0465-mfcreamcake-box 26 | - 0477-cutlet-bottle 27 | - 0479-ggbondcutlet-others 28 | - 0484-bigroll-box 29 | - 0499-tiramisufranzzi-box 30 | - 0506-sauerkrautnoodles-others 31 | - 0507-hotsournoodles-others 32 | - 0509-bscola-others 33 | - 0512-ugreenhub-box 34 | - 0513-busbox-box 35 | - 0516-wewarm-box 36 | - 0529-onionnoodles-box 37 | - 0530-trufflenoodles-box 38 | - 0531-whiskware-box 39 | - 0532-delonghi-box 40 | - 0533-shiramyun-box 41 | - 0536-ranovarect-box 42 | - 0542-bueno-box 43 | - 0545-book-others 44 | - 0546-can-bottle 45 | - 0549-footballcan-bottle 46 | - 0556-pinkbox-box 47 | - 0561-yellowbottle-bottle 48 | - 0562-yellowbox-box 49 | - 0563-applejuice-box 50 | - 0566-chillisauce-box 51 | - 0567-coffeebox-box 52 | - 0569-greentea-bottle 53 | - 0571-cakebox-box 54 | - 0572-milkbox-others 55 | - 0573-redchicken-others 56 | - 0574-rubberduck-others 57 | - 0575-saltbottle-bottle 58 | 59 | val: 60 | names: 61 | - 0601-loquat-box 62 | - 0606-tiger-others 63 | - 0611-pikachubowl-others 64 | - 0616-hmbb-others 65 | 66 | network: 67 | detection: superpoint 68 | matching: superglue 69 | 70 | datamodule: 71 | scan_data_dir: ${work_dir}/data/onepose_datasets 72 | data_dir: ${work_dir}/data/sfm_model 73 | out_path: ${work_dir}/data/cache/${task_name}/${split}.json 74 | 75 | 76 | hydra: 77 | run: 78 | dir: ${work_dir} -------------------------------------------------------------------------------- /configs/preprocess/sfm_spp_spg_demo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: sfm 4 | work_dir: ${hydra:runtime.cwd} 5 | redo: True 6 | 7 | dataset: 8 | max_num_kp3d: 2500 9 | max_num_kp2d: 1000 10 | 11 | data_dir: null 12 | outputs_dir: ${work_dir}/data/sfm_model/{} 13 | 14 | network: 15 | detection: superpoint 16 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 17 | 18 | matching: superglue 19 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth 20 | 21 | sfm: 22 | down_ratio: 5 23 | covis_num: 10 24 | rotation_thresh: 50 25 | 26 | disable_lightning_logs: True 27 | 28 | -------------------------------------------------------------------------------- /configs/preprocess/sfm_spp_spg_sample.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: sfm 4 | work_dir: ${hydra:runtime.cwd} 5 | redo: False 6 | 7 | scan_data_dir: ${work_dir}/data/onepose_datasets/sample_data 8 | 9 | dataset: 10 | max_num_kp3d: 2500 11 | max_num_kp2d: 1000 12 | 13 | data_dir: 14 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-1 15 | 16 | outputs_dir: ${work_dir}/data/sfm_model/{} 17 | 18 | network: 19 | detection: superpoint 20 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 21 | 22 | matching: superglue 23 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth 24 | 25 | sfm: 26 | down_ratio: 5 27 | covis_num: 10 28 | rotation_thresh: 50 29 | 30 | disable_lightning_logs: True -------------------------------------------------------------------------------- /configs/preprocess/sfm_spp_spg_test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: sfm 4 | work_dir: ${hydra:runtime.cwd} 5 | redo: False 6 | 7 | scan_data_dir: ${work_dir}/data/onepose_datasets/test_data 8 | 9 | dataset: 10 | max_num_kp3d: 2500 11 | max_num_kp2d: 1000 12 | 13 | data_dir: 14 | - ${scan_data_dir}/0408-colorbox-box colorbox-1 15 | - ${scan_data_dir}/0409-aptamil-box aptamil-1 16 | - ${scan_data_dir}/0419-cookies2-others cookies2-1 17 | - ${scan_data_dir}/0422-qvduoduo-box qvduoduo-1 18 | - ${scan_data_dir}/0423-oreo-box oreo-1 19 | - ${scan_data_dir}/0424-chocbox-box chocbox-1 20 | - ${scan_data_dir}/0447-nabati-box nabati-1 21 | - ${scan_data_dir}/0450-hlychocpie-box hlychocpie-1 22 | - ${scan_data_dir}/0452-hlymatchapie-box hlymatchapie-1 23 | - ${scan_data_dir}/0455-strawberryoreo-box strawberryoreo-1 24 | - ${scan_data_dir}/0456-chocoreo-box chocoreo-1 25 | - ${scan_data_dir}/0458-hetaocakes-box hetaocakes-1 26 | - ${scan_data_dir}/0459-jzhg-box jzhg-1 27 | - ${scan_data_dir}/0466-mfmilkcake-box mfmilkcake-1 28 | - ${scan_data_dir}/0468-minipuff-box minipuff-1 29 | - ${scan_data_dir}/0469-diycookies-box diycookies-1 30 | - ${scan_data_dir}/0470-eggrolls-box eggrolls-1 31 | - ${scan_data_dir}/0471-hlyormosiapie-box hlyormosiapie-1 32 | - ${scan_data_dir}/0472-chocoreo-bottle chocoreo-1 33 | - ${scan_data_dir}/0473-twgrassjelly1-box twgrassjelly1-1 34 | - ${scan_data_dir}/0474-twgrassjelly2-box twgrassjelly2-1 35 | - ${scan_data_dir}/0476-giraffecup-bottle giraffecup-1 36 | - ${scan_data_dir}/0480-ljcleaner-others ljcleaner-1 37 | - ${scan_data_dir}/0483-ambrosial-box ambrosial-1 38 | - ${scan_data_dir}/0486-sanqitoothpaste-box sanqitoothpaste-1 39 | - ${scan_data_dir}/0487-jindiantoothpaste-box jindiantoothpaste-1 40 | - ${scan_data_dir}/0488-jijiantoothpaste-box jijiantoothpaste-1 41 | - ${scan_data_dir}/0489-taipingcookies-others taipingcookies-1 42 | - ${scan_data_dir}/0490-haochidiancookies-others haochidiancookies-1 43 | - ${scan_data_dir}/0492-tuccookies-box tuccookies-1 44 | - ${scan_data_dir}/0493-haochidianeggroll-box haochidianeggroll-1 45 | - ${scan_data_dir}/0494-qvduoduocookies-box qvduoduocookies-1 46 | - ${scan_data_dir}/0495-fulingstapler-box fulingstapler-1 47 | - ${scan_data_dir}/0496-delistapler-box delistapler-1 48 | - ${scan_data_dir}/0497-delistaplerlarger-box delistaplerlarger-1 49 | - ${scan_data_dir}/0498-yousuanru-box yousuanru-1 50 | - ${scan_data_dir}/0500-chocfranzzi-box chocfranzzi-1 51 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-1 52 | - ${scan_data_dir}/0502-shufujia-box shufujia-1 53 | - ${scan_data_dir}/0503-shufujiawhite-box shufujiawhite-1 54 | - ${scan_data_dir}/0504-lux-box lux-1 55 | - ${scan_data_dir}/0508-yqsl-others yqsl-1 56 | - ${scan_data_dir}/0510-yqslmilk-others yqslmilk-1 57 | - ${scan_data_dir}/0511-policecar-others policecar-1 58 | - ${scan_data_dir}/0517-nationalgeo-box nationalgeo-1 59 | - ${scan_data_dir}/0518-jasmine-box jasmine-1 60 | - ${scan_data_dir}/0519-backpack1-box backpack1-1 61 | - ${scan_data_dir}/0520-lipault-box lipault-1 62 | - ${scan_data_dir}/0521-ranova-box ranova-1 63 | - ${scan_data_dir}/0522-milkbox-box milkbox-1 64 | - ${scan_data_dir}/0523-edibleoil-others edibleoil-1 65 | - ${scan_data_dir}/0525-toygrab-others toygrab-1 66 | - ${scan_data_dir}/0526-toytable-others toytable-1 67 | - ${scan_data_dir}/0527-spalding-others spalding-1 68 | - ${scan_data_dir}/0534-tonkotsuramen-box tonkotsuramen-1 69 | - ${scan_data_dir}/0535-odbmilk-box odbmilk-1 70 | - ${scan_data_dir}/0537-petsnack-box petsnack-1 71 | - ${scan_data_dir}/0539-spamwrapper-others spamwrapper-1 72 | - ${scan_data_dir}/0543-brownhouse-others brownhouse-1 73 | - ${scan_data_dir}/0547-cubebox-box cubebox-1 74 | - ${scan_data_dir}/0548-duck-others duck-1 75 | - ${scan_data_dir}/0550-greenbox-box greenbox-1 76 | - ${scan_data_dir}/0551-milk-others milk-1 77 | - ${scan_data_dir}/0552-mushroom-others mushroom-1 78 | - ${scan_data_dir}/0557-santachoc-others santachoc-1 79 | - ${scan_data_dir}/0558-teddychoc-others teddychoc-1 80 | - ${scan_data_dir}/0559-tissuebox-box tissuebox-1 81 | - ${scan_data_dir}/0560-tofubox-box tofubox-1 82 | - ${scan_data_dir}/0564-biatee-others biatee-1 83 | - ${scan_data_dir}/0565-biscuits-box biscuits-1 84 | - ${scan_data_dir}/0568-cornflakes-box cornflakes-1 85 | - ${scan_data_dir}/0570-kasekuchen-box kasekuchen-1 86 | - ${scan_data_dir}/0577-schoko-box schoko-1 87 | - ${scan_data_dir}/0578-tee-others tee-1 88 | - ${scan_data_dir}/0579-tomatocan-bottle tomatocan-1 89 | - ${scan_data_dir}/0580-xmaxbox-others xmaxbox-1 90 | - ${scan_data_dir}/0582-yogurtlarge-others yogurtlarge-1 91 | - ${scan_data_dir}/0583-yogurtmedium-others yogurtmedium-1 92 | - ${scan_data_dir}/0594-martinBootsLeft-others martinBootsLeft-1 93 | - ${scan_data_dir}/0595-martinBootsRight-others martinBootsRight-1 94 | 95 | outputs_dir: ${work_dir}/data/sfm_model/{} 96 | 97 | network: 98 | detection: superpoint 99 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 100 | 101 | matching: superglue 102 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth 103 | 104 | sfm: 105 | down_ratio: 5 106 | covis_num: 10 107 | rotation_thresh: 50 108 | 109 | disable_lightning_logs: True 110 | 111 | -------------------------------------------------------------------------------- /configs/preprocess/sfm_spp_spg_train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: sfm 4 | work_dir: ${hydra:runtime.cwd} 5 | redo: False 6 | 7 | scan_data_dir: ${work_dir}/data/onepose_datasets/train_data 8 | 9 | dataset: 10 | max_num_kp3d: 2500 11 | max_num_kp2d: 1000 12 | 13 | data_dir: 14 | - ${scan_data_dir}/0410-huiyuan-box huiyuan-1 huiyuan-2 huiyuan-3 15 | - ${scan_data_dir}/0413-juliecookies-box juliecookies-1 juliecookies-2 juliecookies-3 16 | - ${scan_data_dir}/0414-babydiapers-others babydiapers-1 babydiapers-2 babydiapers-3 17 | - ${scan_data_dir}/0415-captaincup-others captaincup-1 captaincup-2 captaincup-3 18 | - ${scan_data_dir}/0416-tongyinoodles-others tongyinoodles-1 tongyinoodles-2 tongyinoodles-3 19 | - ${scan_data_dir}/0418-cookies1-others cookies1-1 cookies1-2 20 | - ${scan_data_dir}/0420-liuliumei-others liuliumei-1 liuliumei-2 liuliumei-3 21 | - ${scan_data_dir}/0421-cannedfish-others cannedfish-1 cannedfish-2 cannedfish-3 22 | - ${scan_data_dir}/0443-wheatbear-others wheatbear-1 wheatbear-2 23 | - ${scan_data_dir}/0445-pistonhead-others pistonhead-1 pistonhead-2 pistonhead-3 24 | - ${scan_data_dir}/0448-soyabeancookies-bottle soyabeancookies-1 soyabeancookies-2 25 | - ${scan_data_dir}/0460-hbdsoyabeancookies-bottle hbdsoyabeancookies-1 bhdsoyabeancookies-2 bhdsoyabeancookies-3 26 | - ${scan_data_dir}/0461-cranberrycookies-bottle cranberrycookies-1 cranberrycookies-2 cranberrycookies-3 27 | - ${scan_data_dir}/0462-ylmilkpowder-bottle ylmilkpowder-1 ylmilkpowder-2 ylmilkpowder-3 28 | - ${scan_data_dir}/0463-camelmilk-bottle camelmilk-1 camelmilk-2 camelmilk-3 29 | - ${scan_data_dir}/0464-mfchoccake-box mfchoccake-1 mfchoccake-2 mfchoccake-3 30 | - ${scan_data_dir}/0465-mfcreamcake-box mfcreamcake-1 mfcreamcake-2 mfcreamcake-3 31 | - ${scan_data_dir}/0477-cutlet-bottle cutlet-1 cutlet-2 cutlet-3 32 | - ${scan_data_dir}/0479-ggbondcutlet-others ggbondcutlet-1 ggbondcutlet-2 ggbondcutlet-3 33 | - ${scan_data_dir}/0484-bigroll-box bigroll-1 bigroll-2 bigroll-3 34 | - ${scan_data_dir}/0499-tiramisufranzzi-box tiramisufranzzi-1 tiramisufranzzi-2 tiramisufranzzi-3 35 | - ${scan_data_dir}/0506-sauerkrautnoodles-others sauerkrautnoodles-1 sauerkrautnoodles-2 sauerkrautnoodles-3 36 | - ${scan_data_dir}/0507-hotsournoodles-others hotsournoodles-1 hotsournoodles-2 hotsournoodles-3 37 | - ${scan_data_dir}/0509-bscola-others bscola-1 bscola-2 bscola-3 38 | - ${scan_data_dir}/0512-ugreenhub-box ugreenhub-1 ugreenhub-2 ugreenhub-3 39 | - ${scan_data_dir}/0513-busbox-box busbox-1 busbox-2 busbox-3 40 | - ${scan_data_dir}/0516-wewarm-box wewarm-1 wewarm-2 wewarm-3 41 | - ${scan_data_dir}/0529-onionnoodles-box onionnoodles-1 onionnoodles-2 42 | - ${scan_data_dir}/0530-trufflenoodles-box trufflenoodles-1 trufflenoodles-2 trufflenoodles-3 43 | - ${scan_data_dir}/0531-whiskware-box whiskware-1 whiskware-2 whiskware-3 44 | - ${scan_data_dir}/0532-delonghi-box delonghi-1 delonghi-2 45 | - ${scan_data_dir}/0533-shiramyun-box shiramyun-1 shiramyun-2 shiramyun-3 46 | - ${scan_data_dir}/0536-ranovarect-box ranovarect-1 ranovarect-2 ranvorect-3 ranvorect-4 47 | - ${scan_data_dir}/0542-bueno-box bueno-1 bueno-2 bueno-3 48 | - ${scan_data_dir}/0545-book-others book-1 book-2 book-3 49 | - ${scan_data_dir}/0546-can-bottle can-1 can-2 can-3 50 | - ${scan_data_dir}/0549-footballcan-bottle footballcan-1 footballcan-2 footballcan-3 51 | - ${scan_data_dir}/0556-pinkbox-box pinkbox-1 pinkbox-2 pinkbox-3 52 | - ${scan_data_dir}/0561-yellowbottle-bottle yellowbottle-1 yellowbottle-2 yellowbottle-3 53 | - ${scan_data_dir}/0562-yellowbox-box yellowbox-1 yellowbox-2 yellowbox-3 54 | - ${scan_data_dir}/0563-applejuice-box applejuice-1 applejuice-2 applejuice-3 55 | - ${scan_data_dir}/0566-chillisauce-box chillisauce-1 chillisauce-2 chillisauce-3 56 | - ${scan_data_dir}/0567-coffeebox-box coffeebox-1 coffeebox-2 coffeebox-3 57 | - ${scan_data_dir}/0569-greentea-bottle greentea-1 greentea-2 greentea-3 58 | - ${scan_data_dir}/0571-cakebox-box cakebox-1 cakebox-2 cakebox-3 59 | - ${scan_data_dir}/0572-milkbox-others milkbox-1 milkbox-2 milkbox-3 60 | - ${scan_data_dir}/0573-redchicken-others redchicken-1 redchicken-2 redchicken-3 61 | - ${scan_data_dir}/0574-rubberduck-others rubberduck-1 rubberduck-2 rubberduck-3 62 | - ${scan_data_dir}/0575-saltbottle-bottle saltbottle-1 saltbottle-2 satlbottle-3 63 | 64 | outputs_dir: ${work_dir}/data/sfm_model/{} 65 | 66 | network: 67 | detection: superpoint 68 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 69 | 70 | matching: superglue 71 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth 72 | 73 | sfm: 74 | down_ratio: 5 75 | covis_num: 10 76 | rotation_thresh: 50 77 | 78 | 79 | disable_lightning_logs: True 80 | 81 | -------------------------------------------------------------------------------- /configs/preprocess/sfm_spp_spg_val.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | type: sfm 4 | work_dir: ${hydra:runtime.cwd} 5 | redo: False 6 | 7 | scan_data_dir: ${work_dir}/data/onepose_datasets/val_data 8 | 9 | dataset: 10 | max_num_kp3d: 2500 11 | max_num_kp2d: 1000 12 | 13 | data_dir: 14 | - ${scan_data_dir}/0601-loquat-box loquat-1 15 | - ${scan_data_dir}/0602-aficion-box aficion-1 16 | - ${scan_data_dir}/0603-redbook-box redbook-1 17 | - ${scan_data_dir}/0604-pillbox1-box pillbox1-1 18 | - ${scan_data_dir}/0605-pillbox2-box pillbox2-1 19 | - ${scan_data_dir}/0606-tiger-others tiger-1 20 | - ${scan_data_dir}/0607-admilk-others admilk-1 21 | - ${scan_data_dir}/0608-teacan-others teacan-1 22 | - ${scan_data_dir}/0609-doll-others doll-1 23 | - ${scan_data_dir}/0610-calendar-box calendar-1 24 | - ${scan_data_dir}/0611-pikachubowl-others pikachubowl-1 25 | - ${scan_data_dir}/0612-originaloreo-box originaloreo-1 26 | - ${scan_data_dir}/0613-adidasshoeright-others adidasshoeright-1 27 | - ${scan_data_dir}/0614-darlietoothpaste-box darlietoothpaste-1 28 | - ${scan_data_dir}/0615-nabati-bottle nabati-1 29 | - ${scan_data_dir}/0616-hmbb-others hmbb-1 30 | - ${scan_data_dir}/0617-porcelain-others porcelain-1 31 | - ${scan_data_dir}/0618-yogurt-bottle yogurt-1 32 | - ${scan_data_dir}/0619-newtolmeat-others newtolmeat-1 33 | - ${scan_data_dir}/0620-dinosaurcup-bottle dinosaurcup-1 34 | - ${scan_data_dir}/0621-saltbox-box saltbox-1 35 | 36 | outputs_dir: ${work_dir}/data/sfm_model/{} 37 | 38 | network: 39 | detection: superpoint 40 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth 41 | 42 | matching: superglue 43 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth 44 | 45 | sfm: 46 | down_ratio: 5 47 | covis_num: 10 48 | rotation_thresh: 50 49 | 50 | disable_lightning_logs: True -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: onepose 2 | channels: 3 | # - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.7 9 | - pytorch=1.8.0 10 | - torchvision=0.9.1 11 | - cudatoolkit=10.2 12 | - ipython 13 | - tqdm 14 | - matplotlib 15 | - pylint 16 | - conda-forge::jupyterlab 17 | - conda-forge::h5py=3.1.0 18 | - conda-forge::loguru=0.5.3 19 | - conda-forge::scipy 20 | - conda-forge::numba 21 | - conda-forge::ipdb 22 | - conda-forge::albumentations=0.5.1 23 | - pip 24 | - pip: 25 | - -r requirements.txt -------------------------------------------------------------------------------- /feature_matching_object_detector.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import hydra 4 | from tqdm import tqdm 5 | import os 6 | import os.path as osp 7 | import natsort 8 | 9 | from loguru import logger 10 | from torch.utils.data import DataLoader 11 | from src.utils import data_utils 12 | from src.utils.model_io import load_network 13 | from src.local_feature_2D_detector import LocalFeatureObjectDetector 14 | 15 | from pytorch_lightning import seed_everything 16 | 17 | seed_everything(12345) 18 | 19 | 20 | def get_default_paths(cfg, data_root, data_dir, sfm_model_dir): 21 | anno_dir = osp.join( 22 | sfm_model_dir, f"outputs_{cfg.network.detection}_{cfg.network.matching}", "anno" 23 | ) 24 | avg_anno_3d_path = osp.join(anno_dir, "anno_3d_average.npz") 25 | clt_anno_3d_path = osp.join(anno_dir, "anno_3d_collect.npz") 26 | idxs_path = osp.join(anno_dir, "idxs.npy") 27 | sfm_ws_dir = osp.join( 28 | sfm_model_dir, 29 | f"outputs_{cfg.network.detection}_{cfg.network.matching}", 30 | "sfm_ws", 31 | "model", 32 | ) 33 | 34 | img_lists = [] 35 | color_dir = osp.join(data_dir, "color_full") 36 | if not osp.exists(color_dir): 37 | logger.info('color_full directory not exists! Try to parse from Frames.m4v') 38 | scan_video_dir = osp.join(data_dir, 'Frames.m4v') 39 | assert osp.exists(scan_video_dir), 'Frames.m4v not found! Run detector fail!' 40 | data_utils.video2img(scan_video_dir, color_dir) 41 | img_lists += glob.glob(color_dir + "/*.png", recursive=True) 42 | 43 | img_lists = natsort.natsorted(img_lists) 44 | 45 | # Save detect results: 46 | detect_img_dir = osp.join(data_dir, "color_det") 47 | if osp.exists(detect_img_dir): 48 | os.system(f"rm -rf {detect_img_dir}") 49 | os.makedirs(detect_img_dir, exist_ok=True) 50 | 51 | detect_K_dir = osp.join(data_dir, "intrin_det") 52 | if osp.exists(detect_K_dir): 53 | os.system(f"rm -rf {detect_K_dir}") 54 | os.makedirs(detect_K_dir, exist_ok=True) 55 | 56 | intrin_full_path = osp.join(data_dir, "intrinsics.txt") 57 | paths = { 58 | "data_root": data_root, 59 | "data_dir": data_dir, 60 | "sfm_model_dir": sfm_model_dir, 61 | "sfm_ws_dir": sfm_ws_dir, 62 | "avg_anno_3d_path": avg_anno_3d_path, 63 | "clt_anno_3d_path": clt_anno_3d_path, 64 | "idxs_path": idxs_path, 65 | "intrin_full_path": intrin_full_path, 66 | "output_detect_img_dir": detect_img_dir, 67 | "output_K_crop_dir": detect_K_dir 68 | } 69 | return img_lists, paths 70 | 71 | def load_2D_matching_model(cfg): 72 | 73 | def load_extractor_model(cfg, model_path): 74 | """Load extractor model(SuperPoint)""" 75 | from src.models.extractors.SuperPoint.superpoint import SuperPoint 76 | from src.sfm.extract_features import confs 77 | 78 | extractor_model = SuperPoint(confs[cfg.network.detection]["conf"]) 79 | extractor_model.cuda() 80 | extractor_model.eval() 81 | load_network(extractor_model, model_path, force=True) 82 | 83 | return extractor_model 84 | 85 | def load_2D_matcher(cfg): 86 | """Load matching model(SuperGlue)""" 87 | from src.models.matchers.SuperGlue.superglue import SuperGlue 88 | from src.sfm.match_features import confs 89 | 90 | match_model = SuperGlue(confs[cfg.network.matching]["conf"]) 91 | match_model.eval() 92 | load_network(match_model, cfg.model.match_model_path) 93 | return match_model 94 | 95 | extractor_model = load_extractor_model(cfg, cfg.model.extractor_model_path) 96 | matcher = load_2D_matcher(cfg) 97 | return extractor_model, matcher 98 | 99 | 100 | def pack_data(avg_descriptors3d, clt_descriptors, keypoints3d, detection, image_size): 101 | """Prepare data for OnePose inference""" 102 | keypoints2d = torch.Tensor(detection["keypoints"]) 103 | descriptors2d = torch.Tensor(detection["descriptors"]) 104 | 105 | inp_data = { 106 | "keypoints2d": keypoints2d[None].cuda(), # [1, n1, 2] 107 | "keypoints3d": keypoints3d[None].cuda(), # [1, n2, 3] 108 | "descriptors2d_query": descriptors2d[None].cuda(), # [1, dim, n1] 109 | "descriptors3d_db": avg_descriptors3d[None].cuda(), # [1, dim, n2] 110 | "descriptors2d_db": clt_descriptors[None].cuda(), # [1, dim, n2*num_leaf] 111 | "image_size": image_size, 112 | } 113 | 114 | return inp_data 115 | 116 | 117 | @torch.no_grad() 118 | def inference_core(cfg, data_root, seq_dir, sfm_model_dir): 119 | """Inference & visualize""" 120 | from src.datasets.normalized_dataset import NormalizedDataset 121 | from src.sfm.extract_features import confs 122 | 123 | # Load models and prepare data: 124 | extractor_model, matching_2D_model = load_2D_matching_model(cfg) 125 | img_lists, paths = get_default_paths(cfg, data_root, seq_dir, sfm_model_dir) 126 | K, _ = data_utils.get_K(paths["intrin_full_path"]) 127 | 128 | local_feature_obj_detector = LocalFeatureObjectDetector( 129 | extractor_model, 130 | matching_2D_model, 131 | sfm_ws_dir=paths["sfm_ws_dir"], 132 | n_ref_view=cfg.n_ref_view, 133 | output_results=True, 134 | detect_save_dir=paths['output_detect_img_dir'], 135 | K_crop_save_dir=paths['output_K_crop_dir'] 136 | ) 137 | dataset = NormalizedDataset( 138 | img_lists, confs[cfg.network.detection]["preprocessing"] 139 | ) 140 | loader = DataLoader(dataset, num_workers=1) 141 | 142 | # Begin Object detection: 143 | for id, data in enumerate(tqdm(loader)): 144 | img_path = data["path"][0] 145 | inp = data["image"].cuda() 146 | 147 | # Detect object by 2D local feature matching for the first frame: 148 | local_feature_obj_detector.detect(inp, img_path, K) 149 | 150 | def inference(cfg): 151 | data_dirs = cfg.input.data_dirs 152 | sfm_model_dirs = cfg.input.sfm_model_dirs 153 | if isinstance(data_dirs, str) and isinstance(sfm_model_dirs, str): 154 | data_dirs = [data_dirs] 155 | sfm_model_dirs = [sfm_model_dirs] 156 | 157 | for data_dir, sfm_model_dir in tqdm( 158 | zip(data_dirs, sfm_model_dirs), total=len(data_dirs) 159 | ): 160 | splits = data_dir.split(" ") 161 | data_root = splits[0] 162 | for seq_name in splits[1:]: 163 | seq_dir = osp.join(data_root, seq_name) 164 | logger.info(f"Run feature matching object detector for: {seq_dir}") 165 | inference_core(cfg, data_root, seq_dir, sfm_model_dir) 166 | 167 | 168 | @hydra.main(config_path="configs/", config_name="config.yaml") 169 | def main(cfg): 170 | globals()[cfg.type](cfg) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() 175 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import hydra 4 | from tqdm import tqdm 5 | import os.path as osp 6 | import numpy as np 7 | 8 | from PIL import Image 9 | from loguru import logger 10 | from torch.utils.data import DataLoader 11 | from src.utils import data_utils, path_utils, eval_utils, vis_utils 12 | 13 | from pytorch_lightning import seed_everything 14 | seed_everything(12345) 15 | 16 | 17 | def get_default_paths(cfg, data_root, data_dir, sfm_model_dir): 18 | anno_dir = osp.join(sfm_model_dir, f'outputs_{cfg.network.detection}_{cfg.network.matching}', 'anno') 19 | avg_anno_3d_path = osp.join(anno_dir, 'anno_3d_average.npz') 20 | clt_anno_3d_path = osp.join(anno_dir, 'anno_3d_collect.npz') 21 | idxs_path = osp.join(anno_dir, 'idxs.npy') 22 | 23 | object_detect_mode = cfg.object_detect_mode 24 | logger.info(f"Use {object_detect_mode} as object detector") 25 | if object_detect_mode == 'GT_box': 26 | color_dir = osp.join(data_dir, 'color') 27 | elif object_detect_mode == 'feature_matching': 28 | color_dir = osp.join(data_dir, 'color_det') 29 | assert osp.exists(color_dir), "color_det directory not exists! You need to run local_feature_2D_detector.py for object detection. Please refer to README.md for the instructions" 30 | else: 31 | raise NotImplementedError 32 | 33 | img_lists = [] 34 | img_lists += glob.glob(color_dir + '/*.png', recursive=True) 35 | 36 | intrin_full_path = osp.join(data_dir, 'intrinsics.txt') 37 | paths = { 38 | "data_root": data_root, 39 | 'data_dir': data_dir, 40 | 'sfm_model_dir': sfm_model_dir, 41 | 'avg_anno_3d_path': avg_anno_3d_path, 42 | 'clt_anno_3d_path': clt_anno_3d_path, 43 | 'idxs_path': idxs_path, 44 | 'intrin_full_path': intrin_full_path 45 | } 46 | return img_lists, paths 47 | 48 | 49 | def load_model(cfg): 50 | """ Load model """ 51 | def load_matching_model(model_path): 52 | """ Load onepose model """ 53 | from src.models.GATsSPG_lightning_model import LitModelGATsSPG 54 | 55 | trained_model = LitModelGATsSPG.load_from_checkpoint(checkpoint_path=model_path) 56 | trained_model.cuda() 57 | trained_model.eval() 58 | trained_model.freeze() 59 | 60 | return trained_model 61 | 62 | def load_extractor_model(cfg, model_path): 63 | """ Load extractor model(SuperPoint) """ 64 | from src.models.extractors.SuperPoint.superpoint import SuperPoint 65 | from src.sfm.extract_features import confs 66 | from src.utils.model_io import load_network 67 | 68 | extractor_model = SuperPoint(confs[cfg.network.detection]['conf']) 69 | extractor_model.cuda() 70 | extractor_model.eval() 71 | load_network(extractor_model, model_path) 72 | 73 | return extractor_model 74 | 75 | matching_model = load_matching_model(cfg.model.onepose_model_path) 76 | extractor_model = load_extractor_model(cfg, cfg.model.extractor_model_path) 77 | return matching_model, extractor_model 78 | 79 | 80 | def pack_data(avg_descriptors3d, clt_descriptors, keypoints3d, detection, image_size): 81 | """ Prepare data for OnePose inference """ 82 | keypoints2d = torch.Tensor(detection['keypoints']) 83 | descriptors2d = torch.Tensor(detection['descriptors']) 84 | 85 | inp_data = { 86 | 'keypoints2d': keypoints2d[None].cuda(), # [1, n1, 2] 87 | 'keypoints3d': keypoints3d[None].cuda(), # [1, n2, 3] 88 | 'descriptors2d_query': descriptors2d[None].cuda(), # [1, dim, n1] 89 | 'descriptors3d_db': avg_descriptors3d[None].cuda(), # [1, dim, n2] 90 | 'descriptors2d_db': clt_descriptors[None].cuda(), # [1, dim, n2*num_leaf] 91 | 'image_size': image_size 92 | } 93 | 94 | return inp_data 95 | 96 | 97 | @torch.no_grad() 98 | def inference_core(cfg, data_root, seq_dir, sfm_model_dir): 99 | """ Inference & visualize""" 100 | from src.datasets.normalized_dataset import NormalizedDataset 101 | from src.sfm.extract_features import confs 102 | from src.evaluators.cmd_evaluator import Evaluator 103 | 104 | matching_model, extractor_model = load_model(cfg) 105 | img_lists, paths = get_default_paths(cfg, data_root, seq_dir, sfm_model_dir) 106 | 107 | dataset = NormalizedDataset(img_lists, confs[cfg.network.detection]['preprocessing']) 108 | loader = DataLoader(dataset, num_workers=1) 109 | evaluator = Evaluator() 110 | 111 | idx = 0 112 | num_leaf = cfg.num_leaf 113 | avg_data = np.load(paths['avg_anno_3d_path']) 114 | clt_data = np.load(paths['clt_anno_3d_path']) 115 | idxs = np.load(paths['idxs_path']) 116 | 117 | keypoints3d = torch.Tensor(clt_data['keypoints3d']).cuda() 118 | num_3d = keypoints3d.shape[0] 119 | # Load average 3D features: 120 | avg_descriptors3d, _ = data_utils.pad_features3d_random( 121 | avg_data['descriptors3d'], 122 | avg_data['scores3d'], 123 | num_3d 124 | ) 125 | # Load corresponding 2D features of each 3D point: 126 | clt_descriptors, _ = data_utils.build_features3d_leaves( 127 | clt_data['descriptors3d'], 128 | clt_data['scores3d'], 129 | idxs, num_3d, num_leaf 130 | ) 131 | 132 | for data in tqdm(loader): 133 | img_path = data['path'][0] 134 | inp = data['image'].cuda() 135 | 136 | intrin_path = path_utils.get_intrin_path_by_color(img_path, det_type=cfg.object_detect_mode) 137 | K_crop = np.loadtxt(intrin_path) 138 | 139 | # Detect query image keypoints and extract descriptors: 140 | pred_detection = extractor_model(inp) 141 | pred_detection = {k: v[0].cpu().numpy() for k, v in pred_detection.items()} 142 | 143 | # 2D-3D matching by GATsSPG: 144 | inp_data = pack_data(avg_descriptors3d, clt_descriptors, 145 | keypoints3d, pred_detection, data['size']) 146 | pred, _ = matching_model(inp_data) 147 | matches = pred['matches0'].detach().cpu().numpy() 148 | valid = matches > -1 149 | kpts2d = pred_detection['keypoints'] 150 | kpts3d = inp_data['keypoints3d'][0].detach().cpu().numpy() 151 | confidence = pred['matching_scores0'].detach().cpu().numpy() 152 | mkpts2d, mkpts3d, mconf = kpts2d[valid], kpts3d[matches[valid]], confidence[valid] 153 | 154 | # Estimate object pose by 2D-3D correspondences: 155 | pose_pred, pose_pred_homo, inliers = eval_utils.ransac_PnP(K_crop, mkpts2d, mkpts3d, scale=1000) 156 | 157 | # Evaluate: 158 | gt_pose_path = path_utils.get_gt_pose_path_by_color(img_path, det_type=cfg.object_detect_mode) 159 | pose_gt = np.loadtxt(gt_pose_path) 160 | evaluator.evaluate(pose_pred, pose_gt) 161 | 162 | # Visualize: 163 | if cfg.save_wis3d: 164 | poses = [pose_gt, pose_pred_homo] 165 | box3d_path = path_utils.get_3d_box_path(data_root) 166 | intrin_full_path = path_utils.get_intrin_full_path(seq_dir) 167 | image_full_path = path_utils.get_img_full_path_by_color(img_path, det_type=cfg.object_detect_mode) 168 | 169 | image_full = vis_utils.vis_reproj(image_full_path, poses, box3d_path, intrin_full_path, 170 | save_demo=cfg.save_demo, demo_root=cfg.demo_root) 171 | mkpts3d_2d = vis_utils.reproj(K_crop, pose_gt, mkpts3d) 172 | image0 = Image.open(img_path).convert('LA') 173 | image1 = image0.copy() 174 | vis_utils.dump_wis3d(idx, cfg, seq_dir, image0, image1, image_full, 175 | mkpts2d, mkpts3d_2d, mconf, inliers) 176 | 177 | idx += 1 178 | 179 | eval_result = evaluator.summarize() 180 | obj_name = sfm_model_dir.split('/')[-1] 181 | seq_name = seq_dir.split('/')[-1] 182 | eval_utils.record_eval_result(cfg.output.eval_dir, obj_name, seq_name, eval_result) 183 | 184 | 185 | def inference(cfg): 186 | data_dirs = cfg.input.data_dirs 187 | sfm_model_dirs = cfg.input.sfm_model_dirs 188 | if isinstance(data_dirs, str) and isinstance(sfm_model_dirs, str): 189 | data_dirs = [data_dirs] 190 | sfm_model_dirs = [sfm_model_dirs] 191 | 192 | for data_dir, sfm_model_dir in tqdm(zip(data_dirs, sfm_model_dirs), total=len(data_dirs)): 193 | splits = data_dir.split(" ") 194 | data_root = splits[0] 195 | for seq_name in splits[1:]: 196 | seq_dir = osp.join(data_root, seq_name) 197 | logger.info(f'Eval {seq_dir}') 198 | inference_core(cfg, data_root, seq_dir, sfm_model_dir) 199 | 200 | 201 | @hydra.main(config_path='configs/', config_name='config.yaml') 202 | def main(cfg): 203 | globals()[cfg.type](cfg) 204 | 205 | if __name__ == "__main__": 206 | main() 207 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.5.10 2 | aiohttp==3.7 3 | aioredis==1.3.1 4 | pydegensac==0.1.2 5 | opencv_python==4.4.0.46 6 | yacs>=0.1.8 7 | einops==0.3.0 8 | kornia==0.4.1 9 | pickle5==0.0.11 10 | timm>=0.3.2 11 | hydra-core==1.1.1 12 | omegaconf==2.1.2 13 | pycocotools==2.0.4 14 | wandb==0.12.17 15 | rich==12.4.4 16 | transforms3d==0.3.1 17 | natsort==8.1.0 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import glob 4 | import hydra 5 | 6 | import os.path as osp 7 | from loguru import logger 8 | from pathlib import Path 9 | from omegaconf import DictConfig 10 | 11 | 12 | def merge_(anno_2d_file, avg_anno_3d_file, collect_anno_3d_file, 13 | idxs_file, img_id, ann_id, images, annotations): 14 | """ To prepare training and test objects, we merge annotations about difference objs""" 15 | with open(anno_2d_file, 'r') as f: 16 | annos_2d = json.load(f) 17 | 18 | for anno_2d in annos_2d: 19 | img_id += 1 20 | info = { 21 | 'id': img_id, 22 | 'img_file': anno_2d['img_file'], 23 | } 24 | images.append(info) 25 | 26 | ann_id += 1 27 | anno = { 28 | 'image_id': img_id, 29 | 'id': ann_id, 30 | 'pose_file': anno_2d['pose_file'], 31 | 'anno2d_file': anno_2d['anno_file'], 32 | 'avg_anno3d_file': avg_anno_3d_file, 33 | 'collect_anno3d_file': collect_anno_3d_file, 34 | 'idxs_file': idxs_file 35 | } 36 | annotations.append(anno) 37 | return img_id, ann_id 38 | 39 | 40 | def merge_anno(cfg): 41 | """ Merge different objects' anno file into one anno file """ 42 | anno_dirs = [] 43 | 44 | if cfg.split == 'train': 45 | names = cfg.train.names 46 | elif cfg.split == 'val': 47 | names = cfg.val.names 48 | 49 | for name in names: 50 | anno_dir = osp.join(cfg.datamodule.data_dir, name, f'outputs_{cfg.network.detection}_{cfg.network.matching}', 'anno') 51 | anno_dirs.append(anno_dir) 52 | 53 | img_id = 0 54 | ann_id = 0 55 | images = [] 56 | annotations = [] 57 | for anno_dir in anno_dirs: 58 | logger.info(f'Merging anno dir: {anno_dir}') 59 | anno_2d_file = osp.join(anno_dir, 'anno_2d.json') 60 | avg_anno_3d_file = osp.join(anno_dir, 'anno_3d_average.npz') 61 | collect_anno_3d_file = osp.join(anno_dir, 'anno_3d_collect.npz') 62 | idxs_file = osp.join(anno_dir, 'idxs.npy') 63 | 64 | if not osp.isfile(anno_2d_file) or not osp.isfile(avg_anno_3d_file) or not osp.isfile(collect_anno_3d_file): 65 | logger.info(f'No annotation in: {anno_dir}') 66 | continue 67 | 68 | img_id, ann_id = merge_(anno_2d_file, avg_anno_3d_file, collect_anno_3d_file, 69 | idxs_file, img_id, ann_id, images, annotations) 70 | 71 | logger.info(f'Total num: {len(images)}') 72 | instance = {'images': images, 'annotations': annotations} 73 | 74 | out_dir = osp.dirname(cfg.datamodule.out_path) 75 | Path(out_dir).mkdir(exist_ok=True, parents=True) 76 | with open(cfg.datamodule.out_path, 'w') as f: 77 | json.dump(instance, f) 78 | 79 | 80 | def sfm(cfg): 81 | """ Reconstruct and postprocess sparse object point cloud, and store point cloud features""" 82 | data_dirs = cfg.dataset.data_dir 83 | down_ratio = cfg.sfm.down_ratio 84 | data_dirs = [data_dirs] if isinstance(data_dirs, str) else data_dirs 85 | 86 | for data_dir in data_dirs: 87 | logger.info(f"Processing {data_dir}.") 88 | root_dir, sub_dirs = data_dir.split(' ')[0], data_dir.split(' ')[1:] 89 | 90 | # Parse image directory and downsample images: 91 | img_lists = [] 92 | for sub_dir in sub_dirs: 93 | seq_dir = osp.join(root_dir, sub_dir) 94 | img_lists += glob.glob(str(Path(seq_dir)) + '/color/*.png', recursive=True) 95 | 96 | down_img_lists = [] 97 | for img_file in img_lists: 98 | index = int(img_file.split('/')[-1].split('.')[0]) 99 | if index % down_ratio == 0: 100 | down_img_lists.append(img_file) 101 | img_lists = down_img_lists 102 | 103 | if len(img_lists) == 0: 104 | logger.info(f"No png image in {root_dir}") 105 | continue 106 | 107 | obj_name = root_dir.split('/')[-1] 108 | outputs_dir_root = cfg.dataset.outputs_dir.format(obj_name) 109 | 110 | # Begin SfM and postprocess: 111 | sfm_core(cfg, img_lists, outputs_dir_root) 112 | postprocess(cfg, img_lists, root_dir, outputs_dir_root) 113 | 114 | 115 | def sfm_core(cfg, img_lists, outputs_dir_root): 116 | """ Sparse reconstruction: extract features, match features, triangulation""" 117 | from src.sfm import extract_features, match_features, \ 118 | generate_empty, triangulation, pairs_from_poses 119 | 120 | # Construct output directory structure: 121 | outputs_dir = osp.join(outputs_dir_root, 'outputs' + '_' + cfg.network.detection + '_' + cfg.network.matching) 122 | feature_out = osp.join(outputs_dir, f'feats-{cfg.network.detection}.h5') 123 | covis_pairs_out = osp.join(outputs_dir, f'pairs-covis{cfg.sfm.covis_num}.txt') 124 | matches_out = osp.join(outputs_dir, f'matches-{cfg.network.matching}.h5') 125 | empty_dir = osp.join(outputs_dir, 'sfm_empty') 126 | deep_sfm_dir = osp.join(outputs_dir, 'sfm_ws') 127 | 128 | if cfg.redo: 129 | os.system(f'rm -rf {outputs_dir}') 130 | Path(outputs_dir).mkdir(exist_ok=True, parents=True) 131 | 132 | # Extract image features, construct image pairs and then match: 133 | extract_features.main(img_lists, feature_out, cfg) 134 | pairs_from_poses.covis_from_pose(img_lists, covis_pairs_out, cfg.sfm.covis_num, max_rotation=cfg.sfm.rotation_thresh) 135 | match_features.main(cfg, feature_out, covis_pairs_out, matches_out, vis_match=False) 136 | 137 | # Reconstruct 3D point cloud with known image poses: 138 | generate_empty.generate_model(img_lists, empty_dir) 139 | triangulation.main(deep_sfm_dir, empty_dir, outputs_dir, covis_pairs_out, feature_out, matches_out, image_dir=None) 140 | 141 | 142 | def postprocess(cfg, img_lists, root_dir, outputs_dir_root): 143 | """ Filter points and average feature""" 144 | from src.sfm.postprocess import filter_points, feature_process, filter_tkl 145 | 146 | bbox_path = osp.join(root_dir, "box3d_corners.txt") 147 | # Construct output directory structure: 148 | outputs_dir = osp.join(outputs_dir_root, 'outputs' + '_' + cfg.network.detection + '_' + cfg.network.matching) 149 | feature_out = osp.join(outputs_dir, f'feats-{cfg.network.detection}.h5') 150 | deep_sfm_dir = osp.join(outputs_dir, 'sfm_ws') 151 | model_path = osp.join(deep_sfm_dir, 'model') 152 | 153 | # Select feature track length to limit the number of 3D points below the 'max_num_kp3d' threshold: 154 | track_length, points_count_list = filter_tkl.get_tkl(model_path, thres=cfg.dataset.max_num_kp3d, show=False) 155 | filter_tkl.vis_tkl_filtered_pcds(model_path, points_count_list, track_length, outputs_dir) # For visualization only 156 | 157 | # Leverage the selected feature track length threshold and 3D BBox to filter 3D points: 158 | xyzs, points_idxs = filter_points.filter_3d(model_path, track_length, bbox_path) 159 | # Merge 3d points by distance between points 160 | merge_xyzs, merge_idxs = filter_points.merge(xyzs, points_idxs, dist_threshold=1e-3) 161 | 162 | # Save features of the filtered point cloud: 163 | feature_process.get_kpt_ann(cfg, img_lists, feature_out, outputs_dir, merge_idxs, merge_xyzs) 164 | 165 | 166 | @hydra.main(config_path='configs/', config_name='config.yaml') 167 | def main(cfg: DictConfig): 168 | globals()[cfg.type](cfg) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /scripts/demo_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROJECT_DIR="$(pwd)" 3 | OBJ_NAME=$1 4 | echo "Current work dir: $PROJECT_DIR" 5 | 6 | echo '-------------------' 7 | echo 'Parse scanned data:' 8 | echo '-------------------' 9 | # Parse scanned annotated & test sequence: 10 | python $PROJECT_DIR/parse_scanned_data.py \ 11 | --scanned_object_path \ 12 | "$PROJECT_DIR/data/demo/$OBJ_NAME" 13 | 14 | echo '--------------------------------------------------------------' 15 | echo 'Run SfM to reconstruct object point cloud for pose estimation:' 16 | echo '--------------------------------------------------------------' 17 | # Run SfM to reconstruct object sparse point cloud from $OBJ_NAME-annotate sequence: 18 | python $PROJECT_DIR/run.py \ 19 | +preprocess="sfm_spp_spg_demo" \ 20 | dataset.data_dir="$PROJECT_DIR/data/demo/$OBJ_NAME $OBJ_NAME-annotate" \ 21 | dataset.outputs_dir="$PROJECT_DIR/data/demo/$OBJ_NAME/sfm_model" \ 22 | 23 | echo "-----------------------------------" 24 | echo "Run inference and output demo video:" 25 | echo "-----------------------------------" 26 | 27 | WITH_TRACKING=False 28 | while [[ "$#" -gt 0 ]]; do 29 | case $1 in 30 | -u|--WITH_TRACKING) WITH_TRACKING=True ;; 31 | esac 32 | shift 33 | done 34 | 35 | # Run inference on $OBJ_NAME-test and output demo video: 36 | python $PROJECT_DIR/inference_demo.py \ 37 | +experiment="test_demo" \ 38 | input.data_dirs="$PROJECT_DIR/data/demo/$OBJ_NAME $OBJ_NAME-test" \ 39 | input.sfm_model_dirs="$PROJECT_DIR/data/demo/$OBJ_NAME/sfm_model" \ 40 | use_tracking=${WITH_TRACKING} 41 | -------------------------------------------------------------------------------- /scripts/parse_full_img.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROJECT_DIR="$(pwd)" 3 | VIDEO_PATH=$1 4 | 5 | echo '-------------------' 6 | echo 'Parse full image: ' 7 | echo '-------------------' 8 | 9 | # Parse full image from Frames.m4v 10 | python $PROJECT_DIR/video2img.py \ 11 | --video_file ${VIDEO_PATH} 12 | -------------------------------------------------------------------------------- /scripts/prepare_2D_matching_resources.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | REPO_ROOT=$(pwd) 3 | echo "work space:$REPO_ROOT" 4 | 5 | # Download code and pretrained model of SuperPoint: 6 | mkdir -p $REPO_ROOT/data/models/extractors/SuperPoint 7 | # cd $REPO_ROOT/src/models/extractors/SuperPoint 8 | # wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/superpoint.py 9 | cd $REPO_ROOT/data/models/extractors/SuperPoint 10 | wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth 11 | 12 | # Download code and pretrained model of SuperGlue: 13 | mkdir -p $REPO_ROOT/data/models/matchers/SuperGlue 14 | # cd $REPO_ROOT/src/models/matchers/SuperGlue 15 | # wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/superglue.py 16 | cd $REPO_ROOT/data/models/matchers/SuperGlue 17 | wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superglue_outdoor.pth -------------------------------------------------------------------------------- /src/callbacks/custom_callbacks.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Callback 2 | 3 | 4 | class ExampleCallback(Callback): 5 | def __init__(self): 6 | pass 7 | 8 | def on_init_start(self, trainer): 9 | print("Starting to initialize trainer!") 10 | 11 | def on_init_end(self, trainer): 12 | print("Trainer is initialized now.") 13 | 14 | def on_train_end(self, trainer, pl_module): 15 | print("Do something when training ends.") 16 | 17 | 18 | class UnfreezeModelCallback(Callback): 19 | """ 20 | Unfreeze all model parameters after a few epochs. 21 | """ 22 | 23 | def __init__(self, wait_epochs=5): 24 | self.wait_epochs = wait_epochs 25 | 26 | def on_epoch_end(self, trainer, pl_module): 27 | if trainer.current_epoch == self.wait_epochs: 28 | for param in pl_module.model.model.parameters(): 29 | param.requires_grad = True 30 | -------------------------------------------------------------------------------- /src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | # wandb 2 | from pytorch_lightning.loggers import WandbLogger 3 | import wandb 4 | 5 | # pytorch 6 | from pytorch_lightning import Callback 7 | import pytorch_lightning as pl 8 | import torch 9 | 10 | # others 11 | from sklearn.metrics import precision_score, recall_score, f1_score 12 | from typing import List 13 | import glob 14 | import os 15 | 16 | 17 | def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger: 18 | logger = None 19 | for lg in trainer.logger: 20 | if isinstance(lg, WandbLogger): 21 | logger = lg 22 | 23 | if not logger: 24 | raise Exception( 25 | "You're using wandb related callback, " 26 | "but WandbLogger was not found for some reason..." 27 | ) 28 | 29 | return logger 30 | 31 | 32 | class UploadCodeToWandbAsArtifact(Callback): 33 | """Upload all *.py files to wandb as an artifact at the beginning of the run.""" 34 | 35 | def __init__(self, code_dir: str): 36 | self.code_dir = code_dir 37 | 38 | def on_train_start(self, trainer, pl_module): 39 | logger = get_wandb_logger(trainer=trainer) 40 | experiment = logger.experiment 41 | 42 | code = wandb.Artifact("project-source", type="code") 43 | for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): 44 | code.add_file(path) 45 | 46 | experiment.use_artifact(code) 47 | 48 | 49 | class UploadCheckpointsToWandbAsArtifact(Callback): 50 | """Upload experiment checkpoints to wandb as an artifact at the end of training.""" 51 | 52 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): 53 | self.ckpt_dir = ckpt_dir 54 | self.upload_best_only = upload_best_only 55 | 56 | def on_train_end(self, trainer, pl_module): 57 | logger = get_wandb_logger(trainer=trainer) 58 | experiment = logger.experiment 59 | 60 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") 61 | 62 | if self.upload_best_only: 63 | ckpts.add_file(trainer.checkpoint_callback.best_model_path) 64 | else: 65 | for path in glob.glob( 66 | os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True 67 | ): 68 | ckpts.add_file(path) 69 | 70 | experiment.use_artifact(ckpts) 71 | 72 | 73 | class WatchModelWithWandb(Callback): 74 | """Make WandbLogger watch model at the beginning of the run.""" 75 | 76 | def __init__(self, log: str = "gradients", log_freq: int = 100): 77 | self.log = log 78 | self.log_freq = log_freq 79 | 80 | def on_train_start(self, trainer, pl_module): 81 | logger = get_wandb_logger(trainer=trainer) 82 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) 83 | 84 | 85 | class LogF1PrecisionRecallHeatmapToWandb(Callback): 86 | """ 87 | Generate f1, precision and recall heatmap from validation step outputs. 88 | Expects validation step to return predictions and targets. 89 | Works only for single label classification! 90 | """ 91 | 92 | def __init__(self, class_names: List[str] = None): 93 | self.class_names = class_names 94 | self.preds = [] 95 | self.targets = [] 96 | self.ready = False 97 | 98 | def on_sanity_check_end(self, trainer, pl_module): 99 | """Start executing this callback only after all validation sanity checks end.""" 100 | self.ready = True 101 | 102 | def on_validation_batch_end( 103 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 104 | ): 105 | """Gather data from single batch.""" 106 | if self.ready: 107 | preds, targets = outputs["preds"], outputs["targets"] 108 | self.preds.append(preds) 109 | self.targets.append(targets) 110 | 111 | def on_validation_epoch_end(self, trainer, pl_module): 112 | """Generate f1, precision and recall heatmap.""" 113 | if self.ready: 114 | logger = get_wandb_logger(trainer=trainer) 115 | experiment = logger.experiment 116 | 117 | self.preds = torch.cat(self.preds).cpu() 118 | self.targets = torch.cat(self.targets).cpu() 119 | f1 = f1_score(self.preds, self.targets, average=None) 120 | r = recall_score(self.preds, self.targets, average=None) 121 | p = precision_score(self.preds, self.targets, average=None) 122 | 123 | experiment.log( 124 | { 125 | f"f1_p_r_heatmap/{trainer.current_epoch}_{experiment.id}": wandb.plots.HeatMap( 126 | x_labels=self.class_names, 127 | y_labels=["f1", "precision", "recall"], 128 | matrix_values=[f1, p, r], 129 | show_text=True, 130 | ) 131 | }, 132 | commit=False, 133 | ) 134 | 135 | self.preds = [] 136 | self.targets = [] 137 | 138 | 139 | class LogConfusionMatrixToWandb(Callback): 140 | """ 141 | Generate Confusion Matrix. 142 | Expects validation step to return predictions and targets. 143 | Works only for single label classification! 144 | """ 145 | 146 | def __init__(self, class_names: List[str] = None): 147 | self.class_names = class_names 148 | self.preds = [] 149 | self.targets = [] 150 | self.ready = False 151 | 152 | def on_sanity_check_end(self, trainer, pl_module): 153 | """Start executing this callback only after all validation sanity checks end.""" 154 | self.ready = True 155 | 156 | def on_validation_batch_end( 157 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 158 | ): 159 | """Gather data from single batch.""" 160 | if self.ready: 161 | preds, targets = outputs["preds"], outputs["targets"] 162 | self.preds.append(preds) 163 | self.targets.append(targets) 164 | 165 | def on_validation_epoch_end(self, trainer, pl_module): 166 | """Generate confusion matrix.""" 167 | if self.ready: 168 | logger = get_wandb_logger(trainer=trainer) 169 | experiment = logger.experiment 170 | 171 | self.preds = torch.cat(self.preds).tolist() 172 | self.targets = torch.cat(self.targets).tolist() 173 | 174 | experiment.log( 175 | { 176 | f"confusion_matrix/{trainer.current_epoch}_{experiment.id}": wandb.plot.confusion_matrix( 177 | preds=self.preds, 178 | y_true=self.targets, 179 | class_names=self.class_names, 180 | ) 181 | }, 182 | commit=False, 183 | ) 184 | 185 | self.preds = [] 186 | self.targets = [] 187 | 188 | 189 | ''' BUGGED :( 190 | class LogBestMetricScoresToWandb(Callback): 191 | """ 192 | Store in wandb: 193 | - max train acc 194 | - min train loss 195 | - max val acc 196 | - min val loss 197 | Useful for comparing runs in table views, as wandb doesn't currently support column aggregation. 198 | """ 199 | 200 | def __init__(self): 201 | self.train_loss_best = None 202 | self.train_acc_best = None 203 | self.val_loss_best = None 204 | self.val_acc_best = None 205 | self.ready = False 206 | 207 | def on_sanity_check_end(self, trainer, pl_module): 208 | """Start executing this callback only after all validation sanity checks end.""" 209 | self.ready = True 210 | 211 | def on_epoch_end(self, trainer, pl_module): 212 | if self.ready: 213 | logger = get_wandb_logger(trainer=trainer) 214 | experiment = logger.experiment 215 | 216 | metrics = trainer.callback_metrics 217 | 218 | if not self.train_loss_best or metrics["train/loss"] < self.train_loss_best: 219 | self.train_loss_best = metrics["train_loss"] 220 | 221 | if not self.train_acc_best or metrics["train/acc"] > self.train_acc_best: 222 | self.train_acc_best = metrics["train/acc"] 223 | 224 | if not self.val_loss_best or metrics["val/loss"] < self.val_loss_best: 225 | self.val_loss_best = metrics["val/loss"] 226 | 227 | if not self.val_acc_best or metrics["val/acc"] > self.val_acc_best: 228 | self.val_acc_best = metrics["val/acc"] 229 | 230 | experiment.log({"train/loss_best": self.train_loss_best}, commit=False) 231 | experiment.log({"train/acc_best": self.train_acc_best}, commit=False) 232 | experiment.log({"val/loss_best": self.val_loss_best}, commit=False) 233 | experiment.log({"val/acc_best": self.val_acc_best}, commit=False) 234 | ''' 235 | -------------------------------------------------------------------------------- /src/datamodules/GATs_spg_datamodule.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningDataModule 2 | from torch.utils.data.dataloader import DataLoader 3 | from src.datasets.GATs_spg_dataset import GATsSPGDataset 4 | 5 | 6 | class GATsSPGDataModule(LightningDataModule): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__() 9 | 10 | self.train_anno_file = kwargs['train_anno_file'] 11 | self.val_anno_file = kwargs['val_anno_file'] 12 | self.batch_size = kwargs['batch_size'] 13 | self.num_workers = kwargs['num_workers'] 14 | self.pin_memory = kwargs['pin_memory'] 15 | self.num_leaf = kwargs['num_leaf'] 16 | self.shape2d = kwargs['shape2d'] 17 | self.shape3d = kwargs['shape3d'] 18 | self.assign_pad_val = kwargs['assign_pad_val'] 19 | 20 | self.data_train = None 21 | self.data_val = None 22 | self.data_test = None 23 | 24 | # Loader parameters 25 | self.train_loader_params = { 26 | 'batch_size': self.batch_size, 27 | 'shuffle': True, 28 | 'num_workers': self.num_workers, 29 | 'pin_memory': self.pin_memory, 30 | } 31 | self.val_loader_params = { 32 | 'batch_size': 1, 33 | 'shuffle': False, 34 | 'num_workers': self.num_workers, 35 | 'pin_memory': self.pin_memory, 36 | } 37 | self.test_loader_params = { 38 | 'batch_size': 1, 39 | 'shuffle': False, 40 | 'num_workers': self.num_workers, 41 | 'pin_memory': self.pin_memory, 42 | } 43 | 44 | def prepare_data(self): 45 | pass 46 | 47 | def setup(self, stage=None): 48 | """ Load data. Set variable: self.data_train, self.data_val, self.data_test""" 49 | trainset = GATsSPGDataset( 50 | anno_file=self.train_anno_file, 51 | num_leaf=self.num_leaf, 52 | split='train', 53 | shape2d=self.shape2d, 54 | shape3d=self.shape3d, 55 | pad_val=self.assign_pad_val 56 | ) 57 | print("=> Read train anno file: ", self.train_anno_file) 58 | 59 | valset = GATsSPGDataset( 60 | anno_file=self.val_anno_file, 61 | num_leaf=self.num_leaf, 62 | split='val', 63 | shape2d=self.shape2d, 64 | shape3d=self.shape3d, 65 | pad_val=self.assign_pad_val, 66 | load_pose_gt=True 67 | ) 68 | print("=> Read validation anno file: ", self.val_anno_file) 69 | 70 | self.data_train = trainset 71 | self.data_val = valset 72 | self.data_test = valset 73 | 74 | def train_dataloader(self): 75 | return DataLoader(dataset=self.data_train, **self.train_loader_params) 76 | 77 | def val_dataloader(self): 78 | return DataLoader(dataset=self.data_val, **self.val_loader_params) 79 | 80 | def test_dataloader(self): 81 | return DataLoader(dataset=self.data_test, **self.test_loader_params) 82 | -------------------------------------------------------------------------------- /src/datasets/GATs_spg_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | try: 3 | import ujson as json 4 | except ImportError: 5 | import json 6 | import torch 7 | import numpy as np 8 | 9 | from pycocotools.coco import COCO 10 | from torch.utils.data import Dataset 11 | from src.utils import data_utils 12 | 13 | 14 | class GATsSPGDataset(Dataset): 15 | def __init__( 16 | self, 17 | anno_file, 18 | num_leaf, 19 | split, 20 | pad=True, 21 | shape2d=1000, 22 | shape3d=2000, 23 | pad_val=0, 24 | load_pose_gt=False, 25 | ): 26 | super(Dataset, self).__init__() 27 | 28 | self.coco = COCO(anno_file) 29 | self.anns = np.array(self.coco.getImgIds()) 30 | self.num_leaf = num_leaf 31 | 32 | self.split = split 33 | self.pad = pad 34 | self.shape2d = shape2d 35 | self.shape3d = shape3d 36 | self.pad_val = pad_val 37 | self.load_pose_gt = load_pose_gt 38 | 39 | def get_intrin_by_color_path(self, color_path): 40 | intrin_path = color_path.replace('/color/', '/intrin_ba/').replace( 41 | '.png', '.txt' 42 | ) 43 | K_crop = torch.from_numpy(np.loadtxt(intrin_path)) # [3, 3] 44 | return K_crop 45 | 46 | def get_gt_pose_by_color_path(self, color_path): 47 | gt_pose_path = color_path.replace('/color/', '/poses_ba/').replace( 48 | '.png', '.txt' 49 | ) 50 | pose_gt = torch.from_numpy(np.loadtxt(gt_pose_path)) # [4, 4] 51 | return pose_gt 52 | 53 | def read_anno2d(self, anno2d_file, height, width, pad=True): 54 | """ Read (and pad) 2d info""" 55 | with open(anno2d_file, 'r') as f: 56 | data = json.load(f) 57 | 58 | keypoints2d = torch.Tensor(data['keypoints2d']) # [n, 2] 59 | descriptors2d = torch.Tensor(data['descriptors2d']) # [dim, n] 60 | scores2d = torch.Tensor(data['scores2d']) # [n, 1] 61 | assign_matrix = torch.Tensor(data['assign_matrix']) # [2, k] 62 | 63 | num_2d_orig = keypoints2d.shape[0] 64 | 65 | if pad: 66 | keypoints2d, descriptors2d, scores2d = data_utils.pad_keypoints2d_random(keypoints2d, descriptors2d, scores2d, 67 | height, width, self.shape2d) 68 | return keypoints2d, descriptors2d, scores2d, assign_matrix, num_2d_orig 69 | 70 | def read_anno3d(self, avg_anno3d_file, clt_anno3d_file, idxs_file, pad=True): 71 | """ Read(and pad) 3d info""" 72 | avg_data = np.load(avg_anno3d_file) 73 | clt_data = np.load(clt_anno3d_file) 74 | idxs = np.load(idxs_file) 75 | 76 | keypoints3d = torch.Tensor(clt_data['keypoints3d']) # [m, 3] 77 | avg_descriptors3d = torch.Tensor(avg_data['descriptors3d']) # [dim, m] 78 | clt_descriptors = torch.Tensor(clt_data['descriptors3d']) # [dim, k] 79 | avg_scores = torch.Tensor(avg_data['scores3d']) # [m, 1] 80 | clt_scores = torch.Tensor(clt_data['scores3d']) # [k, 1] 81 | 82 | num_3d_orig = keypoints3d.shape[0] 83 | if pad: 84 | keypoints3d = data_utils.pad_keypoints3d_random(keypoints3d, self.shape3d) 85 | avg_descriptors3d, avg_scores = data_utils.pad_features3d_random(avg_descriptors3d, avg_scores, self.shape3d) 86 | clt_descriptors, clt_scores = data_utils.build_features3d_leaves(clt_descriptors, clt_scores, idxs, 87 | self.shape3d, num_leaf=self.num_leaf) 88 | return keypoints3d, avg_descriptors3d, avg_scores, clt_descriptors, clt_scores, num_3d_orig 89 | 90 | def read_anno(self, img_id): 91 | """ 92 | Read image, 2d info and 3d info. 93 | Pad 2d info and 3d info to a constant size. 94 | """ 95 | ann_ids = self.coco.getAnnIds(imgIds=img_id) 96 | anno = self.coco.loadAnns(ann_ids)[0] 97 | 98 | color_path = self.coco.loadImgs(int(img_id))[0]['img_file'] 99 | image = cv2.imread(color_path) 100 | height, width, _ = image.shape 101 | 102 | idxs_file = anno['idxs_file'] 103 | avg_anno3d_file = anno['avg_anno3d_file'] 104 | collect_anno3d_file = anno['collect_anno3d_file'] 105 | 106 | # Load 3D points and features: 107 | ( 108 | keypoints3d, 109 | avg_descriptors3d, 110 | avg_scores3d, 111 | clt_descriptors2d, 112 | clt_scores2d, 113 | num_3d_orig 114 | ) = self.read_anno3d(avg_anno3d_file, collect_anno3d_file, idxs_file, pad=self.pad) 115 | 116 | if self.split == 'train': 117 | anno2d_file = anno['anno2d_file'] 118 | # Load 2D keypoints, features and GT 2D-3D correspondences: 119 | ( 120 | keypoints2d, 121 | descriptors2d, 122 | scores2d, 123 | assign_matrix, 124 | num_2d_orig 125 | ) = self.read_anno2d(anno2d_file, height, width, pad=self.pad) 126 | 127 | # Construct GT conf_matrix: 128 | conf_matrix = data_utils.reshape_assign_matrix( 129 | assign_matrix, 130 | num_2d_orig, 131 | num_3d_orig, 132 | self.shape2d, 133 | self.shape3d, 134 | pad=True, 135 | pad_val=self.pad_val 136 | ) 137 | 138 | data = { 139 | 'keypoints2d': keypoints2d, # [n1, 2] 140 | 'descriptors2d_query': descriptors2d, # [dim, n1] 141 | } 142 | 143 | elif self.split == 'val': 144 | image_gray = data_utils.read_gray_scale(color_path) / 255. 145 | data = { 146 | 'image': image_gray 147 | } 148 | conf_matrix = torch.Tensor([]) 149 | 150 | data.update({ 151 | 'keypoints3d': keypoints3d, # [n2, 3] 152 | 'descriptors3d_db': avg_descriptors3d, # [dim, n2] 153 | 'descriptors2d_db': clt_descriptors2d, # [dim, n2 * num_leaf] 154 | 'image_size': torch.Tensor([height, width]) 155 | }) 156 | 157 | if self.load_pose_gt: 158 | K_crop = self.get_intrin_by_color_path(color_path) 159 | pose_gt = self.get_gt_pose_by_color_path(color_path) 160 | data.update({'query_intrinsic': K_crop, 'query_pose_gt': pose_gt, 'query_image': image}) 161 | 162 | return data, conf_matrix 163 | 164 | def __getitem__(self, index): 165 | img_id = self.anns[index] 166 | 167 | data, conf_matrix = self.read_anno(img_id) 168 | return data, conf_matrix 169 | 170 | def __len__(self): 171 | return len(self.anns) -------------------------------------------------------------------------------- /src/datasets/normalized_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from types import SimpleNamespace 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class NormalizedDataset(Dataset): 9 | """read images(suppose images have been cropped)""" 10 | default_conf = { 11 | 'globs': ['*.jpg', '*.png'], 12 | 'grayscale': True, 13 | } 14 | 15 | def __init__(self, img_lists, conf): 16 | self.img_lists = img_lists 17 | self.conf = SimpleNamespace(**{**self.default_conf, **conf}) 18 | 19 | if len(img_lists) == 0: 20 | raise ValueError('Could not find any image.') 21 | 22 | def __getitem__(self, index): 23 | img_path = self.img_lists[index] 24 | 25 | mode = cv2.IMREAD_GRAYSCALE if self.conf.grayscale else cv2.IMREAD_COLOR 26 | image = cv2.imread(img_path, mode) 27 | size = image.shape[:2] 28 | 29 | image = image.astype(np.float32) 30 | if self.conf.grayscale: 31 | image = image[None] 32 | else: 33 | image = image.transpose((2, 0, 1)) 34 | image /= 255. 35 | 36 | data = { 37 | 'path': str(img_path), 38 | 'image': image, 39 | 'size': np.array(size), 40 | } 41 | return data 42 | 43 | def __len__(self): 44 | return len(self.img_lists) -------------------------------------------------------------------------------- /src/evaluators/cmd_evaluator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Evaluator(): 4 | def __init__(self): 5 | self.cmd1 = [] 6 | self.cmd3 = [] 7 | self.cmd5 = [] 8 | self.cmd7 = [] 9 | self.add = [] 10 | 11 | def cm_degree_1_metric(self, pose_pred, pose_target): 12 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_target[:, 3]) * 100 13 | rotation_diff = np.dot(pose_pred[:, :3], pose_target[:, :3].T) 14 | trace = np.trace(rotation_diff) 15 | trace = trace if trace <= 3 else 3 16 | angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.)) 17 | self.cmd1.append(translation_distance < 1 and angular_distance < 1) 18 | 19 | def cm_degree_5_metric(self, pose_pred, pose_target): 20 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_target[:, 3]) * 100 21 | rotation_diff = np.dot(pose_pred[:, :3], pose_target[:, :3].T) 22 | trace = np.trace(rotation_diff) 23 | trace = trace if trace <= 3 else 3 24 | angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.)) 25 | self.cmd5.append(translation_distance < 5 and angular_distance < 5) 26 | 27 | def cm_degree_3_metric(self, pose_pred, pose_target): 28 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_target[:, 3]) * 100 29 | rotation_diff = np.dot(pose_pred[:, :3], pose_target[:, :3].T) 30 | trace = np.trace(rotation_diff) 31 | trace = trace if trace <= 3 else 3 32 | angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.)) 33 | self.cmd3.append(translation_distance < 3 and angular_distance < 3) 34 | 35 | def evaluate(self, pose_pred, pose_gt): 36 | if pose_pred is None: 37 | self.cmd5.append(False) 38 | self.cmd1.append(False) 39 | self.cmd3.append(False) 40 | self.cmd7.append(False) 41 | else: 42 | if pose_pred.shape == (4, 4): 43 | pose_pred = pose_pred[:3, :4] 44 | if pose_gt.shape == (4, 4): 45 | pose_gt = pose_gt[:3, :4] 46 | self.cm_degree_1_metric(pose_pred, pose_gt) 47 | self.cm_degree_3_metric(pose_pred, pose_gt) 48 | self.cm_degree_5_metric(pose_pred, pose_gt) 49 | 50 | def summarize(self): 51 | cmd1 = np.mean(self.cmd1) 52 | cmd3 = np.mean(self.cmd3) 53 | cmd5 = np.mean(self.cmd5) 54 | print('1 cm 1 degree metric: {}'.format(cmd1)) 55 | print('3 cm 3 degree metric: {}'.format(cmd3)) 56 | print('5 cm 5 degree metric: {}'.format(cmd5)) 57 | 58 | self.cmd1 = [] 59 | self.cmd3 = [] 60 | self.cmd5 = [] 61 | self.cmd7 = [] 62 | return {'cmd1': cmd1, 'cmd3': cmd3, 'cmd5': cmd5} -------------------------------------------------------------------------------- /src/local_feature_2D_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from .local_feature_2D_detector import LocalFeatureObjectDetector -------------------------------------------------------------------------------- /src/local_feature_2D_detector/local_feature_2D_detector.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | import cv2 4 | import torch 5 | import numpy as np 6 | from src.utils.colmap.read_write_model import read_model 7 | from src.utils.data_utils import get_K_crop_resize, get_image_crop_resize 8 | from src.utils.vis_utils import reproj 9 | 10 | 11 | def pack_extract_data(img_path): 12 | image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) 13 | 14 | image = image[None] / 255.0 15 | return torch.Tensor(image) 16 | 17 | 18 | def pack_match_data(db_detection, query_detection, db_size, query_size): 19 | data = {} 20 | for k in db_detection.keys(): 21 | data[k + "0"] = db_detection[k].__array__() 22 | for k in query_detection.keys(): 23 | data[k + "1"] = query_detection[k].__array__() 24 | data = {k: torch.from_numpy(v)[None].float().cuda() for k, v in data.items()} 25 | 26 | data["image0"] = torch.empty( 27 | ( 28 | 1, 29 | 1, 30 | ) 31 | + tuple(db_size)[::-1] 32 | ) 33 | data["image1"] = torch.empty( 34 | ( 35 | 1, 36 | 1, 37 | ) 38 | + tuple(query_size)[::-1] 39 | ) 40 | return data 41 | 42 | 43 | class LocalFeatureObjectDetector(): 44 | def __init__(self, extractor, matcher, sfm_ws_dir, n_ref_view=15, output_results=False, detect_save_dir=None, K_crop_save_dir=None): 45 | self.extractor = extractor.cuda() 46 | self.matcher = matcher.cuda() 47 | self.db_dict = self.extract_ref_view_features(sfm_ws_dir, n_ref_view) 48 | self.output_results = output_results 49 | self.detect_save_dir = detect_save_dir 50 | self.K_crop_save_dir = K_crop_save_dir 51 | 52 | def extract_ref_view_features(self, sfm_ws_dir, n_ref_views): 53 | assert osp.exists(sfm_ws_dir), f"SfM work space:{sfm_ws_dir} not exists!" 54 | cameras, images, points3D = read_model(sfm_ws_dir) 55 | idx = 0 56 | sample_gap = len(images) // n_ref_views 57 | 58 | # Prepare reference input to matcher: 59 | db_dict = {} # id: image 60 | for idx in range(1, len(images), sample_gap): 61 | db_img_path = images[idx].name 62 | 63 | db_img = pack_extract_data(db_img_path) 64 | 65 | # Detect DB image keypoints: 66 | db_inp = db_img[None].cuda() 67 | db_detection = self.extractor(db_inp) 68 | db_detection = { 69 | k: v[0].detach().cpu().numpy() for k, v in db_detection.items() 70 | } 71 | db_detection["size"] = np.array(db_img.shape[-2:]) 72 | db_dict[idx] = db_detection 73 | 74 | return db_dict 75 | 76 | @torch.no_grad() 77 | def match_worker(self, query): 78 | detect_results_dict = {} 79 | for idx, db in self.db_dict.items(): 80 | db_shape = db["size"] 81 | query_shape = query["size"] 82 | 83 | match_data = pack_match_data(db, query, db["size"], query["size"]) 84 | match_pred = self.matcher(match_data) 85 | matches = match_pred["matches0"][0].detach().cpu().numpy() 86 | confs = match_pred["matching_scores0"][0].detach().cpu().numpy() 87 | valid = matches > -1 88 | 89 | mkpts0 = db["keypoints"][valid] 90 | mkpts1 = query["keypoints"][matches[valid]] 91 | confs = confs[valid] 92 | 93 | if mkpts0.shape[0] < 6: 94 | affine = None 95 | inliers = np.empty((0)) 96 | detect_results_dict[idx] = { 97 | "inliers": inliers, 98 | "bbox": np.array([0, 0, query["size"][0], query["size"][1]]), 99 | } 100 | continue 101 | 102 | # Estimate affine and warp source image: 103 | affine, inliers = cv2.estimateAffinePartial2D( 104 | mkpts0, mkpts1, ransacReprojThreshold=6 105 | ) 106 | 107 | # Estimate box: 108 | four_corner = np.array( 109 | [ 110 | [0, 0, 1], 111 | [db_shape[1], 0, 1], 112 | [0, db_shape[0], 1], 113 | [db_shape[1], db_shape[0], 1], 114 | ] 115 | ).T # 3*4 116 | 117 | bbox = (affine @ four_corner).T.astype(np.int32) # 4*2 118 | 119 | left_top = np.min(bbox, axis=0) 120 | right_bottom = np.max(bbox, axis=0) 121 | 122 | w, h = right_bottom - left_top 123 | offset_percent = 0.0 124 | x0 = left_top[0] - int(w * offset_percent) 125 | y0 = left_top[1] - int(h * offset_percent) 126 | x1 = right_bottom[0] + int(w * offset_percent) 127 | y1 = right_bottom[1] + int(h * offset_percent) 128 | 129 | detect_results_dict[idx] = { 130 | "inliers": inliers, 131 | "bbox": np.array([x0, y0, x1, y1]), 132 | } 133 | return detect_results_dict 134 | 135 | def detect_by_matching(self, query): 136 | detect_results_dict = self.match_worker(query) 137 | 138 | # Sort multiple bbox candidate and use bbox with maxium inliers: 139 | idx_sorted = [ 140 | k 141 | for k, _ in sorted( 142 | detect_results_dict.items(), 143 | reverse=True, 144 | key=lambda item: item[1]["inliers"].shape[0], 145 | ) 146 | ] 147 | return detect_results_dict[idx_sorted[0]]["bbox"] 148 | 149 | def robust_crop(self, query_img_path, bbox, K, crop_size=512): 150 | x0, y0 = bbox[0], bbox[1] 151 | x1, y1 = bbox[2], bbox[3] 152 | x_c = (x0 + x1) / 2 153 | y_c = (y0 + y1) / 2 154 | 155 | origin_img = cv2.imread(query_img_path, cv2.IMREAD_GRAYSCALE) 156 | image_crop = origin_img 157 | K_crop, K_crop_homo = get_K_crop_resize(bbox, K, [crop_size, crop_size]) 158 | return image_crop, K_crop 159 | 160 | def crop_img_by_bbox(self, query_img_path, bbox, K=None, crop_size=512): 161 | """ 162 | Crop image by detect bbox 163 | Input: 164 | query_img_path: str, 165 | bbox: np.ndarray[x0, y0, x1, y1], 166 | K[optional]: 3*3 167 | Output: 168 | image_crop: np.ndarray[crop_size * crop_size], 169 | K_crop[optional]: 3*3 170 | """ 171 | x0, y0 = bbox[0], bbox[1] 172 | x1, y1 = bbox[2], bbox[3] 173 | origin_img = cv2.imread(query_img_path, cv2.IMREAD_GRAYSCALE) 174 | 175 | resize_shape = np.array([y1 - y0, x1 - x0]) 176 | if K is not None: 177 | K_crop, K_crop_homo = get_K_crop_resize(bbox, K, resize_shape) 178 | image_crop, trans1 = get_image_crop_resize(origin_img, bbox, resize_shape) 179 | 180 | bbox_new = np.array([0, 0, x1 - x0, y1 - y0]) 181 | resize_shape = np.array([crop_size, crop_size]) 182 | if K is not None: 183 | K_crop, K_crop_homo = get_K_crop_resize(bbox_new, K_crop, resize_shape) 184 | image_crop, trans2 = get_image_crop_resize(image_crop, bbox_new, resize_shape) 185 | 186 | return image_crop, K_crop if K is not None else None 187 | 188 | def save_detection(self, crop_img, query_img_path): 189 | if self.output_results and self.detect_save_dir is not None: 190 | cv2.imwrite(osp.join(self.detect_save_dir, osp.basename(query_img_path)), crop_img) 191 | 192 | def save_K_crop(self, K_crop, query_img_path): 193 | if self.output_results and self.K_crop_save_dir is not None: 194 | np.savetxt(osp.join(self.K_crop_save_dir, osp.splitext(osp.basename(query_img_path))[0] + '.txt'), K_crop) # K_crop: 3*3 195 | 196 | def detect(self, query_img, query_img_path, K, crop_size=512): 197 | """ 198 | Detect object by local feature matching and crop image. 199 | Input: 200 | query_image: np.ndarray[1*1*H*W], 201 | query_img_path: str, 202 | K: np.ndarray[3*3], intrinsic matrix of original image 203 | Output: 204 | bounding_box: np.ndarray[x0, y0, x1, y1] 205 | cropped_image: torch.tensor[1 * 1 * crop_size * crop_size] (normalized), 206 | cropped_K: np.ndarray[3*3]; 207 | """ 208 | if len(query_img.shape) != 4: 209 | query_inp = query_img[None].cuda() 210 | else: 211 | query_inp = query_img.cuda() 212 | 213 | # Extract query image features: 214 | query_inp = self.extractor(query_inp) 215 | query_inp = {k: v[0].detach().cpu().numpy() for k, v in query_inp.items()} 216 | query_inp["size"] = np.array(query_img.shape[-2:]) 217 | 218 | # Detect bbox and crop image: 219 | bbox = self.detect_by_matching( 220 | query=query_inp, 221 | ) 222 | image_crop, K_crop = self.crop_img_by_bbox(query_img_path, bbox, K, crop_size=crop_size) 223 | self.save_detection(image_crop, query_img_path) 224 | self.save_K_crop(K_crop, query_img_path) 225 | 226 | # To Tensor: 227 | image_crop = image_crop.astype(np.float32) / 255 228 | image_crop_tensor = torch.from_numpy(image_crop)[None][None].cuda() 229 | 230 | return bbox, image_crop_tensor, K_crop 231 | 232 | def previous_pose_detect(self, query_img_path, K, pre_pose, bbox3D_corner, crop_size=512): 233 | """ 234 | Detect object by projecting 3D bbox with estimated last frame pose. 235 | Input: 236 | query_image_path: str, 237 | K: np.ndarray[3*3], intrinsic matrix of original image 238 | pre_pose: np.ndarray[3*4] or [4*4], pose of last frame 239 | bbox3D_corner: np.ndarray[8*3], corner coordinate of annotated 3D bbox 240 | Output: 241 | bounding_box: np.ndarray[x0, y0, x1, y1] 242 | cropped_image: torch.tensor[1 * 1 * crop_size * crop_size] (normalized), 243 | cropped_K: np.ndarray[3*3]; 244 | """ 245 | # Project 3D bbox: 246 | proj_2D_coor = reproj(K, pre_pose, bbox3D_corner) 247 | x0, y0 = np.min(proj_2D_coor, axis=0) 248 | x1, y1 = np.max(proj_2D_coor, axis=0) 249 | bbox = np.array([x0, y0, x1, y1]).astype(np.int32) 250 | 251 | image_crop, K_crop = self.crop_img_by_bbox(query_img_path, bbox, K, crop_size=crop_size) 252 | self.save_detection(image_crop, query_img_path) 253 | self.save_K_crop(K_crop, query_img_path) 254 | 255 | # To Tensor: 256 | image_crop = image_crop.astype(np.float32) / 255 257 | image_crop_tensor = torch.from_numpy(image_crop)[None][None].cuda() 258 | 259 | return bbox, image_crop_tensor, K_crop -------------------------------------------------------------------------------- /src/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class FocalLoss(nn.Module): 5 | 6 | def __init__(self, alpha=1, gamma=2, neg_weights=0.5, pos_weights=0.5): 7 | super(FocalLoss, self).__init__() 8 | self.alpha = alpha 9 | self.gamma = gamma 10 | self.neg_weights = neg_weights 11 | self.pos_weights = pos_weights 12 | 13 | def forward(self, pred, target): 14 | loss_pos = - self.alpha * torch.pow(1 - pred[target==1], self.gamma) * (pred[target==1]).log() 15 | loss_neg = - (1 - self.alpha) * torch.pow(pred[target==0], self.gamma) * (1 - pred[target==0]).log() 16 | 17 | assert len(loss_pos) != 0 or len(loss_neg) != 0, 'Invalid loss.' 18 | # operate mean operation on an empty list will lead to NaN 19 | if len(loss_pos) == 0: 20 | loss_mean = self.neg_weights * loss_neg.mean() 21 | elif len(loss_neg) == 0: 22 | loss_mean = self.pos_weights * loss_pos.mean() 23 | else: 24 | loss_mean = self.pos_weights * loss_pos.mean() + self.neg_weights * loss_neg.mean() 25 | 26 | return loss_mean -------------------------------------------------------------------------------- /src/models/GATsSPG_architectures/GATs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GraphAttentionLayer(nn.Module): 7 | def __init__( 8 | self, 9 | in_features, 10 | out_features, 11 | dropout, 12 | alpha, 13 | concat=True, 14 | include_self=True, 15 | additional=False, 16 | with_linear_transform=True 17 | ): 18 | super(GraphAttentionLayer, self).__init__() 19 | self.dropout = dropout 20 | self.in_features = in_features 21 | self.out_features = out_features 22 | self.alpha = alpha 23 | self.concat = concat 24 | 25 | self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) 26 | nn.init.xavier_normal_(self.W.data, gain=1.414) 27 | self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1))) 28 | nn.init.xavier_normal_(self.a.data, gain=1.414) 29 | 30 | self.leakyrelu = nn.LeakyReLU(self.alpha) 31 | self.include_self = include_self 32 | self.with_linear_transform = with_linear_transform 33 | self.additional = additional 34 | 35 | def forward(self, h_2d, h_3d): 36 | b, n1, dim = h_3d.shape 37 | b, n2, dim = h_2d.shape 38 | num_leaf = int(n2 / n1) 39 | 40 | wh_2d = torch.matmul(h_2d, self.W) 41 | wh_3d = torch.matmul(h_3d, self.W) 42 | 43 | e = self._prepare_attentional_mechanism_input(wh_2d, wh_3d, num_leaf, self.include_self) 44 | attention = F.softmax(e, dim=2) 45 | 46 | h_2d = torch.reshape(h_2d, (b, n1, num_leaf, dim)) 47 | wh_2d = torch.reshape(wh_2d, (b, n1, num_leaf, dim)) 48 | if self.include_self: 49 | wh_2d = torch.cat( 50 | [wh_3d.unsqueeze(-2), wh_2d], dim=-2 51 | ) # [b, N, 1+num_leaf, d_out] 52 | h_2d = torch.cat( 53 | [h_3d.unsqueeze(-2), h_2d], dim=-2 54 | ) 55 | 56 | if self.with_linear_transform: 57 | h_prime = torch.einsum('bncd,bncq->bnq', attention, wh_2d) 58 | else: 59 | h_prime = torch.einsum('bncd,bncq->bnq', attention, h_2d) 60 | 61 | if self.additional: 62 | h_prime = h_prime + h_3d 63 | else: 64 | if self.with_linear_transform: 65 | h_prime = torch.einsum('bncd,bncq->bnq', attention, wh_2d) / 2. + wh_3d 66 | else: 67 | h_prime = torch.einsum('bncd,bncq->bnq', attention, h_2d) / 2. + h_3d 68 | 69 | if self.concat: 70 | return F.elu(h_prime) 71 | else: 72 | return h_prime 73 | 74 | def _prepare_attentional_mechanism_input(self, wh_2d, wh_3d, num_leaf, include_self=False): 75 | b, n1, dim = wh_3d.shape 76 | b, n2, dim = wh_2d.shape 77 | 78 | wh_2d_ = torch.matmul(wh_2d, self.a[:self.out_features, :]) # [b, N2, 1] 79 | wh_2d_ = torch.reshape(wh_2d_, (b, n1, num_leaf, -1)) # [b, n1, 6, 1] 80 | wh_3d_ = torch.matmul(wh_3d, self.a[self.out_features:, :]) # [b, N1, 1] 81 | 82 | if include_self: 83 | wh_2d_ = torch.cat( 84 | [wh_3d_.unsqueeze(2), wh_2d_], dim=-2 85 | ) # [b, N1, 1 + num_leaf, 1] 86 | 87 | e = wh_3d_.unsqueeze(2) + wh_2d_ 88 | return self.leakyrelu(e) 89 | -------------------------------------------------------------------------------- /src/models/GATsSPG_architectures/GATs_SuperGlue.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .GATs import GraphAttentionLayer 6 | 7 | 8 | def arange_like(x, dim: int): 9 | return x.new_ones(x.shape[dim]).cumsum(0) - 1 10 | 11 | 12 | def buildAdjMatrix(num_2d, num_3d): 13 | num_leaf = int(num_2d / num_3d) 14 | 15 | adj_matrix = torch.zeros(num_3d, num_2d) 16 | for i in range(num_3d): 17 | adj_matrix[i, num_leaf*i: num_leaf*(i+1)] = 1 / num_leaf 18 | return adj_matrix.cuda() 19 | 20 | 21 | class AttentionalGNN(nn.Module): 22 | 23 | def __init__( 24 | self, 25 | feature_dim: int, 26 | layer_names: list, 27 | include_self: bool, 28 | additional: bool, 29 | with_linear_transform: bool 30 | ): 31 | super().__init__() 32 | 33 | self.layers = nn.ModuleList([ 34 | GraphAttentionLayer( 35 | in_features=256, 36 | out_features=256, 37 | dropout=0.6, 38 | alpha=0.2, 39 | concat=True, 40 | include_self=include_self, 41 | additional=additional, 42 | with_linear_transform=with_linear_transform, 43 | ) if i % 3 == 0 else AttentionPropagation(feature_dim, 4) 44 | for i in range(len(layer_names)) 45 | ]) 46 | self.names = layer_names 47 | 48 | def forward(self, desc2d_query, desc3d_db, desc2d_db): 49 | for layer, name in zip(self.layers, self.names): 50 | if name == 'GATs': 51 | desc2d_db_ = torch.einsum('bdn->bnd', desc2d_db) 52 | desc3d_db_ = torch.einsum('bdn->bnd', desc3d_db) 53 | desc3d_db_ = layer(desc2d_db_, desc3d_db_) 54 | desc3d_db = torch.einsum('bnd->bdn', desc3d_db_) 55 | elif name == 'cross': 56 | layer.attn.prob = [] 57 | src0, src1 = desc3d_db, desc2d_query # [b, c, l1], [b, c, l2] 58 | delta0, delta1 = layer(desc2d_query, src0), layer(desc3d_db, src1) 59 | desc2d_query, desc3d_db = (desc2d_query + delta0), (desc3d_db + delta1) 60 | elif name == 'self': 61 | layer.attn.prob = [] 62 | src0, src1 = desc2d_query, desc3d_db 63 | delta0, delta1 = layer(desc2d_query, src0), layer(desc3d_db, src1) 64 | desc2d_query, desc3d_db = (desc2d_query + delta0), (desc3d_db + delta1) 65 | 66 | return desc2d_query, desc3d_db 67 | 68 | 69 | def linear_attention(query, key, value): 70 | eps = 1e-6 71 | query = F.elu(query) + 1 72 | key = F.elu(key) + 1 73 | 74 | v_length = value.size(3) 75 | value = value / v_length 76 | 77 | KV = torch.einsum('bdhm,bqhm->bqdh', key, value) 78 | Z = 1 / (torch.einsum('bdhm,bdh->bhm', query, key.sum(3)) + eps) 79 | queried_values = torch.einsum('bdhm,bqdh,bhm->bqhm', query, KV, Z) * v_length 80 | return queried_values.contiguous() 81 | 82 | 83 | class MultiHeadedAttention(nn.Module): 84 | """ Multi-head attention to increase model expressivity""" 85 | def __init__(self, num_heads: int, d_model: int): 86 | super().__init__() 87 | assert d_model % num_heads == 0 88 | self.dim = d_model // num_heads 89 | self.num_heads = num_heads 90 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 91 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 92 | 93 | def forward(self, query, key, value): 94 | batch_dim = query.size(0) 95 | 96 | query, key, value = [ 97 | l(x).view(batch_dim, self.dim, self.num_heads, -1) 98 | for l, x in zip(self.proj, (query, key, value)) 99 | ] 100 | x = linear_attention(query, key, value) 101 | return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) 102 | 103 | 104 | class AttentionPropagation(nn.Module): 105 | def __init__(self, feature_dim: int, num_heads: int): 106 | super().__init__() 107 | self.attn = MultiHeadedAttention(num_heads, feature_dim) 108 | self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) 109 | nn.init.constant_(self.mlp[-1].bias, 0.) 110 | 111 | def forward(self, x, source): 112 | message = self.attn(x, source, source) 113 | return self.mlp(torch.cat([x, message], dim=1)) # [b, 2c, 1000] / [b, 2c, 2000] 114 | 115 | 116 | def MLP(channels: list, do_bn=True): 117 | """ Multi-layer perceptron""" 118 | n = len(channels) 119 | layers = [] 120 | for i in range(1, n): 121 | layers.append( 122 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True) 123 | ) 124 | if i < n -1: 125 | if do_bn: 126 | layers.append(nn.InstanceNorm1d(channels[i])) 127 | layers.append(nn.ReLU()) 128 | return nn.Sequential(*layers) 129 | 130 | 131 | class KeypointEncoder(nn.Module): 132 | """ Joint encoding of visual appearance and location using MLPs """ 133 | def __init__(self, inp_dim, feature_dim, layers): 134 | super().__init__() 135 | self.encoder = MLP([inp_dim] + list(layers) + [feature_dim]) 136 | nn.init.constant_(self.encoder[-1].bias, 0.0) 137 | 138 | def forward(self, kpts, scores): 139 | inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)] 140 | return self.encoder(torch.cat(inputs, dim=1)) 141 | 142 | 143 | class GATsSuperGlue(nn.Module): 144 | 145 | def __init__(self, hparams): 146 | super().__init__() 147 | self.hparams = hparams 148 | self.match_type = hparams['match_type'] 149 | 150 | self.kenc_2d = KeypointEncoder( 151 | inp_dim=3, 152 | feature_dim=hparams['descriptor_dim'], 153 | layers=hparams['keypoints_encoder'] 154 | ) 155 | 156 | self.kenc_3d = KeypointEncoder( 157 | inp_dim=4, 158 | feature_dim=hparams['descriptor_dim'], 159 | layers=hparams['keypoints_encoder'] 160 | ) 161 | 162 | GNN_layers = ['GATs', 'self', 'cross'] * 4 163 | self.gnn = AttentionalGNN( 164 | feature_dim=hparams['descriptor_dim'], 165 | layer_names=GNN_layers, 166 | include_self=hparams['include_self'], 167 | additional=hparams['additional'], 168 | with_linear_transform=hparams['with_linear_transform'] 169 | ) 170 | self.final_proj = nn.Conv1d( 171 | in_channels=hparams['descriptor_dim'], 172 | out_channels=hparams['descriptor_dim'], 173 | kernel_size=1, 174 | bias=True 175 | ) 176 | bin_score = nn.Parameter(torch.tensor(1.)) 177 | self.register_parameter('bin_score', bin_score) 178 | 179 | def forward(self, data): 180 | """ 181 | Keys of data: 182 | keypoints2d: [b, n1, 2] 183 | keypoints3d: [b, n2, 3] 184 | descriptors2d_query: [b, dim, n1] 185 | descriptors3d_db: [b, dim, n2] 186 | descriptors2d_db: [b, dim, n2 * num_leaf] 187 | scores2d_query: [b, n1, 1] 188 | scores3d_db: [b, n2, 1] 189 | scores2d_db: [b, n2 * num_leaf, 1] 190 | """ 191 | kpts2d, kpts3d = data['keypoints2d'].float(), data['keypoints3d'].float() 192 | desc2d_query = data['descriptors2d_query'].float() 193 | desc3d_db, desc2d_db = data['descriptors3d_db'].float(), data['descriptors2d_db'].float() 194 | 195 | if kpts2d.shape[1] == 0 or kpts3d.shape[1] == 0: 196 | shape0, shape1 = kpts2d.shape[:-1], kpts3d.shape[:-1] 197 | return { 198 | 'matches0': kpts2d.new_full(shape0, -1, dtype=torch.int)[0], 199 | 'matches1': kpts3d.new_full(shape1, -1, dtype=torch.int)[0], 200 | 'matching_scores0': kpts2d.new_zeros(shape0)[0], 201 | 'matching_scores1': kpts3d.new_zeros(shape1)[0], 202 | 'skip_train': True 203 | } 204 | 205 | # Multi-layer Transformer network 206 | desc2d_query, desc3d_db = self.gnn(desc2d_query, desc3d_db, desc2d_db) 207 | 208 | # Final MLP projection 209 | mdesc2d_query, mdesc3d_db = self.final_proj(desc2d_query), self.final_proj(desc3d_db) 210 | 211 | # Normalize mdesc to avoid NaN 212 | mdesc2d_query = F.normalize(mdesc2d_query, p=2, dim=1) 213 | mdesc3d_db = F.normalize(mdesc3d_db, p=2, dim=1) 214 | 215 | # Get the matches with score above "match_threshold" 216 | if self.match_type == "softmax": 217 | scores = torch.einsum('bdn,bdm->bnm', mdesc2d_query, mdesc3d_db) / self.hparams['scale_factor'] 218 | conf_matrix = F.softmax(scores, 1) * F.softmax(scores, 2) 219 | 220 | max0, max1 = conf_matrix[:, :, :].max(2), conf_matrix[:, :, :].max(1) 221 | indices0, indices1 = max0.indices, max1.indices 222 | mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) 223 | mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1) 224 | zero = conf_matrix.new_tensor(0) 225 | mscores0 = torch.where(mutual0, max0.values, zero) 226 | mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) 227 | valid0 = mutual0 & (mscores0 > self.hparams['match_threshold']) 228 | valid1 = mutual1 & valid0.gather(1, indices1) 229 | indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) 230 | indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) 231 | 232 | pred = { 233 | 'matches0': indices0[0], # use -1 for invalid match 234 | 'matches1': indices1[0], # use -1 for invalid match 235 | 'matching_scores0': mscores0[0], 236 | 'matching_scores1': mscores1[0], 237 | } 238 | else: 239 | raise NotImplementedError 240 | 241 | return pred, conf_matrix 242 | -------------------------------------------------------------------------------- /src/models/GATsSPG_lightning_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from itertools import chain 5 | from src.models.GATsSPG_architectures.GATs_SuperGlue import GATsSuperGlue 6 | from src.losses.focal_loss import FocalLoss 7 | from src.utils.eval_utils import compute_query_pose_errors, aggregate_metrics 8 | from src.utils.vis_utils import draw_reprojection_pair 9 | from src.utils.comm import gather 10 | from src.models.extractors.SuperPoint.superpoint import SuperPoint 11 | from src.sfm.extract_features import confs 12 | from src.utils.model_io import load_network 13 | 14 | 15 | class LitModelGATsSPG(pl.LightningModule): 16 | 17 | def __init__(self, *args, **kwargs): 18 | super().__init__() 19 | 20 | self.save_hyperparameters() 21 | self.extractor = SuperPoint(confs['superpoint']) 22 | load_network(self.extractor, self.hparams.spp_model_path, force=False) 23 | self.matcher = GATsSuperGlue(hparams=self.hparams) 24 | self.crit = FocalLoss( 25 | alpha=self.hparams.focal_loss_alpha, 26 | gamma=self.hparams.focal_loss_gamma, 27 | neg_weights=self.hparams.neg_weights, 28 | pos_weights=self.hparams.pos_weights 29 | ) 30 | self.n_vals_plot = 10 31 | 32 | self.train_loss_hist = [] 33 | self.val_loss_hist = [] 34 | self.save_flag = True 35 | 36 | def forward(self, x): 37 | return self.matcher(x) 38 | 39 | def training_step(self, batch, batch_idx): 40 | self.save_flag = False 41 | data, conf_matrix_gt = batch 42 | preds, conf_matrix_pred = self.matcher(data) 43 | 44 | loss_mean = self.crit(conf_matrix_pred, conf_matrix_gt) 45 | if ( 46 | self.trainer.global_rank == 0 47 | and self.global_step % self.trainer.log_every_n_steps == 0 48 | ): 49 | self.logger.experiment.add_scalar('train/loss', loss_mean, self.global_step) 50 | 51 | return {'loss': loss_mean, 'preds': preds} 52 | 53 | def validation_step(self, batch, batch_idx): 54 | data, _ = batch 55 | extraction = self.extractor(data['image']) 56 | data.update({ 57 | 'keypoints2d': extraction['keypoints'][0][None], 58 | 'descriptors2d_query': extraction['descriptors'][0][None], 59 | }) 60 | preds, conf_matrix_pred = self.matcher(data) 61 | 62 | pose_pred, val_results = compute_query_pose_errors(data, preds) 63 | 64 | # Visualize match: 65 | val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1) 66 | figures = {'evaluation': []} 67 | if batch_idx % val_plot_interval == 0: 68 | figures = draw_reprojection_pair(data, val_results, visual_color_type='conf') 69 | 70 | loss_mean = 0 71 | self.log('val/loss', loss_mean, on_step=False, on_epoch=True, prog_bar=False) 72 | del data 73 | return {'figures': figures, 'metrics': val_results} 74 | 75 | def test_step(self, batch, batch_idx): 76 | pass 77 | 78 | def training_epoch_end(self, outputs): 79 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean() 80 | if self.trainer.global_rank == 0: 81 | self.logger.experiment.add_scalar( 82 | 'train/avg_loss_on_epoch', avg_loss, global_step=self.current_epoch 83 | ) 84 | 85 | def validation_epoch_end(self, outputs): 86 | self.val_loss_hist.append(self.trainer.callback_metrics['val/loss']) 87 | self.log('val/loss_best', min(self.val_loss_hist), prog_bar=False) 88 | 89 | multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs 90 | for valset_idx, outputs in enumerate(multi_outputs): 91 | cur_epoch = self.trainer.current_epoch 92 | if not self.trainer.resume_from_checkpoint and self.trainer.sanity_checking: 93 | cur_epoch = -1 94 | 95 | def flattenList(x): return list(chain(*x)) 96 | 97 | # Log val metrics: dict of list, numpy 98 | _metrics = [o['metrics'] for o in outputs] 99 | metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]} 100 | 101 | # Log figures: 102 | _figures = [o['figures'] for o in outputs] 103 | figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]} 104 | 105 | # NOTE: tensorboard records only on rank 0 106 | if self.trainer.global_rank == 0: 107 | val_metrics_4tb = aggregate_metrics(metrics) 108 | for k, v in val_metrics_4tb.items(): 109 | self.logger.experiment.add_scalar(f'metrics_{valset_idx}/{k}', v, global_step=cur_epoch) 110 | 111 | for k, v in figures.items(): 112 | for plot_idx, fig in enumerate(v): 113 | self.logger.experiment.add_figure( 114 | f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True 115 | ) 116 | 117 | def configure_optimizers(self): 118 | if self.hparams.optimizer == 'adam': 119 | optimizer = torch.optim.Adam( 120 | self.parameters(), 121 | lr=self.hparams.lr, 122 | weight_decay=self.hparams.weight_decay 123 | ) 124 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 125 | milestones=self.hparams.milestones, 126 | gamma=self.hparams.gamma) 127 | return [optimizer], [lr_scheduler] 128 | else: 129 | raise Exception("Invalid optimizer name.") -------------------------------------------------------------------------------- /src/models/extractors/SuperPoint/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | from pathlib import Path 44 | import torch 45 | from torch import nn 46 | 47 | def simple_nms(scores, nms_radius: int): 48 | """ Fast Non-maximum suppression to remove nearby points """ 49 | assert(nms_radius >= 0) 50 | 51 | def max_pool(x): 52 | return torch.nn.functional.max_pool2d( 53 | x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) 54 | 55 | zeros = torch.zeros_like(scores) 56 | max_mask = scores == max_pool(scores) 57 | for _ in range(2): 58 | supp_mask = max_pool(max_mask.float()) > 0 59 | supp_scores = torch.where(supp_mask, zeros, scores) 60 | new_max_mask = supp_scores == max_pool(supp_scores) 61 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 62 | return torch.where(max_mask, scores, zeros) 63 | 64 | 65 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 66 | """ Removes keypoints too close to the border """ 67 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 68 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 69 | mask = mask_h & mask_w 70 | return keypoints[mask], scores[mask] 71 | 72 | 73 | def top_k_keypoints(keypoints, scores, k: int): 74 | if k >= len(keypoints): 75 | return keypoints, scores 76 | scores, indices = torch.topk(scores, k, dim=0) 77 | return keypoints[indices], scores 78 | 79 | 80 | def sample_descriptors(keypoints, descriptors, s: int = 8): 81 | """ Interpolate descriptors at keypoint locations """ 82 | b, c, h, w = descriptors.shape 83 | keypoints = keypoints - s / 2 + 0.5 84 | keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], 85 | ).to(keypoints)[None] 86 | keypoints = keypoints*2 - 1 # normalize to (-1, 1) 87 | args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {} 88 | descriptors = torch.nn.functional.grid_sample( 89 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 90 | descriptors = torch.nn.functional.normalize( 91 | descriptors.reshape(b, c, -1), p=2, dim=1) 92 | return descriptors 93 | 94 | 95 | class SuperPoint(nn.Module): 96 | """SuperPoint Convolutional Detector and Descriptor 97 | 98 | SuperPoint: Self-Supervised Interest Point Detection and 99 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 100 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 101 | 102 | """ 103 | default_config = { 104 | 'descriptor_dim': 256, 105 | 'nms_radius': 4, 106 | 'keypoint_threshold': 0.005, 107 | 'max_keypoints': -1, 108 | 'remove_borders': 4, 109 | } 110 | 111 | def __init__(self, config): 112 | super().__init__() 113 | self.config = {**self.default_config, **config} 114 | 115 | self.relu = nn.ReLU(inplace=True) 116 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 117 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 118 | 119 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 120 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 121 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 122 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 123 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 124 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 125 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 126 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 127 | 128 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 129 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 130 | 131 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 132 | self.convDb = nn.Conv2d( 133 | c5, self.config['descriptor_dim'], 134 | kernel_size=1, stride=1, padding=0) 135 | 136 | mk = self.config['max_keypoints'] 137 | if mk == 0 or mk < -1: 138 | raise ValueError('\"max_keypoints\" must be positive or \"-1\"') 139 | 140 | def forward(self, inp): 141 | """ Compute keypoints, scores, descriptors for image """ 142 | # Shared Encoder 143 | x = self.relu(self.conv1a(inp)) 144 | x = self.relu(self.conv1b(x)) 145 | x = self.pool(x) 146 | x = self.relu(self.conv2a(x)) 147 | x = self.relu(self.conv2b(x)) 148 | x = self.pool(x) 149 | x = self.relu(self.conv3a(x)) 150 | x = self.relu(self.conv3b(x)) 151 | x = self.pool(x) 152 | x = self.relu(self.conv4a(x)) 153 | x = self.relu(self.conv4b(x)) 154 | 155 | # Compute the dense keypoint scores 156 | cPa = self.relu(self.convPa(x)) 157 | scores = self.convPb(cPa) 158 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 159 | b, _, h, w = scores.shape 160 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 161 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) 162 | scores = simple_nms(scores, self.config['nms_radius']) 163 | 164 | # Extract keypoints 165 | keypoints = [ 166 | torch.nonzero(s > self.config['keypoint_threshold']) 167 | for s in scores] 168 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 169 | 170 | # Discard keypoints near the image borders 171 | keypoints, scores = list(zip(*[ 172 | remove_borders(k, s, self.config['remove_borders'], h*8, w*8) 173 | for k, s in zip(keypoints, scores)])) 174 | 175 | # Keep the k keypoints with highest score 176 | if self.config['max_keypoints'] >= 0: 177 | keypoints, scores = list(zip(*[ 178 | top_k_keypoints(k, s, self.config['max_keypoints']) 179 | for k, s in zip(keypoints, scores)])) 180 | 181 | # Convert (h, w) to (x, y) 182 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 183 | 184 | # Compute the dense descriptors 185 | cDa = self.relu(self.convDa(x)) 186 | descriptors = self.convDb(cDa) 187 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) 188 | 189 | # Extract descriptors 190 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 191 | for k, d in zip(keypoints, descriptors)] 192 | 193 | return { 194 | 'keypoints': keypoints, 195 | 'scores': scores, 196 | 'descriptors': descriptors, 197 | } 198 | -------------------------------------------------------------------------------- /src/models/matchers/nn/nearest_neighbour.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def find_nn(sim, ratio_thresh, distance_thresh): 6 | sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True) 7 | dist_nn = 2 * (1 - sim_nn) 8 | mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device) 9 | if ratio_thresh: 10 | mask = mask & (dist_nn[..., 0] <= (ratio_thresh ** 2) * dist_nn[..., 1]) 11 | if distance_thresh: 12 | mask = mask & (dist_nn[..., 0] <= distance_thresh ** 2) 13 | matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1)) 14 | scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0)) 15 | return matches, scores 16 | 17 | 18 | def mutual_check(m0, m1): 19 | inds0 = torch.arange(m0.shape[-1], device=m0.device) 20 | loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0))) 21 | ok = (m0 > -1) & (inds0 == loop) 22 | m0_new = torch.where(ok, m0, m0.new_tensor(-1)) 23 | return m0_new 24 | 25 | 26 | class NearestNeighbour(nn.Module): 27 | default_conf = { 28 | 'match_threshold': None, 29 | 'ratio_threshold': None, 30 | 'distance_threshold': None, 31 | 'do_mutual_check': True 32 | } 33 | 34 | def __init__(self, conf=None): 35 | super().__init__() 36 | if conf: 37 | self.conf = conf = {**self.default_conf, **conf} 38 | else: 39 | self.conf = self.default_conf 40 | 41 | def forward(self, data): 42 | sim = torch.einsum('bdn, bdm->bnm', data['descriptors0'], data['descriptors1']) 43 | matches0, scores0 = find_nn( 44 | sim, self.conf['ratio_threshold'], self.conf['distance_threshold'] 45 | ) 46 | 47 | matches1, scores1 = find_nn( 48 | sim.transpose(1, 2), self.conf['ratio_threshold'], self.conf['distance_threshold'] 49 | ) 50 | 51 | if self.conf['do_mutual_check']: 52 | matches1, scores1 = find_nn( 53 | sim.transpose(1, 2), self.conf['ratio_threshold'], 54 | self.conf['distance_threshold'] 55 | ) 56 | matches0 = mutual_check(matches0, matches1) 57 | 58 | return { 59 | 'matches0': matches0, 60 | 'matches1': matches1, 61 | 'matching_scores0': scores0, 62 | 'matching_scores1': scores1 63 | } -------------------------------------------------------------------------------- /src/sfm/extract_features.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import tqdm 3 | import torch 4 | import logging 5 | 6 | from torch.utils.data import DataLoader 7 | 8 | confs = { 9 | 'superpoint': { 10 | 'output': 'feats-spp', 11 | 'model': { 12 | 'name': 'spp_det', 13 | }, 14 | 'preprocessing': { 15 | 'grayscale': True, 16 | 'resize_h': 512, 17 | 'resize_w': 512 18 | }, 19 | 'conf': { 20 | 'descriptor_dim': 256, 21 | 'nms_radius': 3, 22 | 'max_keypoints': 4096, 23 | 'keypoints_threshold': 0.6 24 | } 25 | } 26 | } 27 | 28 | 29 | @torch.no_grad() 30 | def spp(img_lists, feature_out, cfg): 31 | """extract keypoints info by superpoint""" 32 | from src.utils.model_io import load_network 33 | from src.models.extractors.SuperPoint.superpoint import SuperPoint as spp_det 34 | from src.datasets.normalized_dataset import NormalizedDataset 35 | 36 | conf = confs[cfg.network.detection] 37 | model = spp_det(conf['conf']).cuda() 38 | model.eval() 39 | load_network(model, cfg.network.detection_model_path, force=True) 40 | 41 | dataset = NormalizedDataset(img_lists, conf['preprocessing']) 42 | loader = DataLoader(dataset, num_workers=1) 43 | 44 | feature_file = h5py.File(feature_out, 'w') 45 | logging.info(f'Exporting features to {feature_out}') 46 | for data in tqdm.tqdm(loader): 47 | inp = data['image'].cuda() 48 | pred = model(inp) 49 | 50 | pred = {k: v[0].cpu().numpy() for k, v in pred.items()} 51 | pred['image_size'] = data['size'][0].numpy() 52 | 53 | grp = feature_file.create_group(data['path'][0]) 54 | for k, v in pred.items(): 55 | grp.create_dataset(k, data=v) 56 | 57 | del pred 58 | 59 | feature_file.close() 60 | logging.info('Finishing exporting features.') 61 | 62 | 63 | def main(img_lists, feature_out, cfg): 64 | if cfg.network.detection == 'superpoint': 65 | spp(img_lists, feature_out, cfg) 66 | else: 67 | raise NotImplementedError -------------------------------------------------------------------------------- /src/sfm/generate_empty.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import logging 3 | import os.path as osp 4 | import numpy as np 5 | 6 | from pathlib import Path 7 | from src.utils import path_utils 8 | from src.utils.colmap.read_write_model import Camera, Image, Point3D 9 | from src.utils.colmap.read_write_model import rotmat2qvec 10 | from src.utils.colmap.read_write_model import write_model 11 | 12 | 13 | def get_pose_from_txt(img_index, pose_dir): 14 | """ Read 4x4 transformation matrix from txt """ 15 | pose_file = osp.join(pose_dir, '{}.txt'.format(img_index)) 16 | pose = np.loadtxt(pose_file) 17 | 18 | tvec = pose[:3, 3].reshape(3, ) 19 | qvec = rotmat2qvec(pose[:3, :3]).reshape(4, ) 20 | return pose, tvec, qvec 21 | 22 | 23 | def get_intrin_from_txt(img_index, intrin_dir): 24 | """ Read 3x3 intrinsic matrix from txt """ 25 | intrin_file = osp.join(intrin_dir, '{}.txt'.format(img_index)) 26 | intrin = np.loadtxt(intrin_file) 27 | 28 | return intrin 29 | 30 | 31 | def import_data(img_lists, do_ba=False): 32 | """ Import intrinsics and camera pose info """ 33 | points3D_out = {} 34 | images_out = {} 35 | cameras_out = {} 36 | 37 | def compare(img_name): 38 | key = img_name.split('/')[-1] 39 | return int(key.split('.')[0]) 40 | img_lists.sort(key=compare) 41 | 42 | key, img_id, camera_id = 0, 0, 0 43 | xys_ = np.zeros((0, 2), float) 44 | point3D_ids_ = np.full(0, -1, int) # will be filled after triangulation 45 | 46 | # import data 47 | for img_path in img_lists: 48 | key += 1 49 | img_id += 1 50 | camera_id += 1 51 | 52 | img_name = img_path.split('/')[-1] 53 | base_dir = osp.dirname(osp.dirname(img_path)) 54 | img_index = int(img_name.split('.')[0]) 55 | 56 | # read pose 57 | pose_dir = path_utils.get_gt_pose_dir(base_dir) 58 | _, tvec, qvec = get_pose_from_txt(img_index, pose_dir) 59 | 60 | # read intrinsic 61 | intrin_dir = path_utils.get_intrin_dir(base_dir) 62 | K = get_intrin_from_txt(img_index, intrin_dir) 63 | fx, fy, cx, cy = K[0][0], K[1][1], K[0, 2], K[1, 2] 64 | 65 | image = cv2.imread(img_path) 66 | h, w, _ = image.shape 67 | 68 | image = Image( 69 | id=img_id, 70 | qvec=qvec, 71 | tvec=tvec, 72 | camera_id=camera_id, 73 | name=img_path, 74 | xys=xys_, 75 | point3D_ids=point3D_ids_ 76 | ) 77 | 78 | camera = Camera( 79 | id=camera_id, 80 | model='PINHOLE', 81 | width=w, 82 | height=h, 83 | params=np.array([fx, fy, cx, cy]) 84 | ) 85 | 86 | images_out[key] = image 87 | cameras_out[key] = camera 88 | 89 | return cameras_out, images_out, points3D_out 90 | 91 | 92 | def generate_model(img_lists, empty_dir, do_ba=False): 93 | """ Write intrinsics and camera poses into COLMAP format model""" 94 | logging.info('Generate empty model...') 95 | model = import_data(img_lists, do_ba) 96 | 97 | logging.info(f'Writing the COLMAP model to {empty_dir}') 98 | Path(empty_dir).mkdir(exist_ok=True, parents=True) 99 | write_model(*model, path=str(empty_dir), ext='.bin') 100 | logging.info('Finishing writing model.') 101 | -------------------------------------------------------------------------------- /src/sfm/global_ba.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import subprocess 4 | import os.path as osp 5 | 6 | from pathlib import Path 7 | 8 | 9 | def run_bundle_adjuster(deep_sfm_dir, ba_dir, colmap_path): 10 | logging.info("Running the bundle adjuster.") 11 | 12 | deep_sfm_model_dir = osp.join(deep_sfm_dir, 'model') 13 | cmd = [ 14 | str(colmap_path), 'bundle_adjuster', 15 | '--input_path', str(deep_sfm_model_dir), 16 | '--output_path', str(ba_dir), 17 | '--BundleAdjustment.max_num_iterations', '150', 18 | '--BundleAdjustment.max_linear_solver_iterations', '500', 19 | '--BundleAdjustment.function_tolerance', '0', 20 | '--BundleAdjustment.gradient_tolerance', '0', 21 | '--BundleAdjustment.parameter_tolerance', '0', 22 | '--BundleAdjustment.refine_focal_length', '0', 23 | '--BundleAdjustment.refine_principal_point', '0', 24 | '--BundleAdjustment.refine_extra_params', '0', 25 | '--BundleAdjustment.refine_extrinsics', '1' 26 | ] 27 | logging.info(' '.join(cmd)) 28 | 29 | ret = subprocess.call(cmd) 30 | if ret != 0: 31 | logging.warning('Problem with point_triangulator, existing.') 32 | exit(ret) 33 | 34 | 35 | def main(deep_sfm_dir, ba_dir, colmap_path='colmap'): 36 | assert Path(deep_sfm_dir).exists(), deep_sfm_dir 37 | 38 | Path(ba_dir).mkdir(parents=True, exist_ok=True) 39 | run_bundle_adjuster(deep_sfm_dir, ba_dir, colmap_path) -------------------------------------------------------------------------------- /src/sfm/match_features.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import torch 3 | import logging 4 | import tqdm 5 | 6 | import os.path as osp 7 | 8 | confs = { 9 | 'superglue': { 10 | 'output': 'matches-spg', 11 | 'conf': { 12 | 'descriptor_dim': 256, 13 | 'weights': 'outdoor', 14 | 'match_threshold': 0.7 15 | } 16 | } 17 | } 18 | 19 | 20 | def names_to_pair(name0, name1): 21 | return '_'.join((name0.replace('/', '-'), name1.replace('/', '-'))) 22 | 23 | 24 | @torch.no_grad() 25 | def spg(cfg, feature_path, covis_pairs, matches_out, vis_match=False): 26 | """Match features by SuperGlue""" 27 | from src.models.matchers.SuperGlue.superglue import SuperGlue as spg_matcher 28 | from src.utils.model_io import load_network 29 | from src.utils.vis_utils import vis_match_pairs 30 | 31 | assert osp.exists(feature_path), feature_path 32 | feature_file = h5py.File(feature_path, 'r') 33 | logging.info(f'Exporting matches to {matches_out}') 34 | 35 | with open(covis_pairs, 'r') as f: 36 | pair_list = f.read().rstrip('\n').split('\n') 37 | 38 | # load superglue model 39 | conf = confs[cfg.network.matching]['conf'] 40 | model = spg_matcher(conf).cuda() 41 | model.eval() 42 | load_network(model, cfg.network.matching_model_path, force=True) 43 | 44 | # match features by superglue 45 | match_file = h5py.File(matches_out, 'w') 46 | matched = set() 47 | for pair in tqdm.tqdm(pair_list): 48 | name0, name1 = pair.split(' ') 49 | pair = names_to_pair(name0, name1) 50 | 51 | if len({(name0, name1), (name1, name0)} & matched) \ 52 | or pair in match_file: 53 | continue 54 | 55 | data = {} 56 | feats0, feats1 = feature_file[name0], feature_file[name1] 57 | for k in feats0.keys(): 58 | data[k+'0'] = feats0[k].__array__() 59 | for k in feats1.keys(): 60 | data[k+'1'] = feats1[k].__array__() 61 | data = {k: torch.from_numpy(v)[None].float().cuda() for k, v in data.items()} 62 | 63 | data['image0'] = torch.empty((1, 1, ) + tuple(feats0['image_size'])[::-1]) 64 | data['image1'] = torch.empty((1, 1, ) + tuple(feats1['image_size'])[::-1]) 65 | pred = model(data) 66 | 67 | grp = match_file.create_group(pair) 68 | matches0 = pred['matches0'][0].cpu().short().numpy() 69 | grp.create_dataset('matches0', data=matches0) 70 | 71 | matches1 = pred['matches1'][0].cpu().short().numpy() 72 | grp.create_dataset('matches1', data=matches1) 73 | 74 | if 'matching_scores0' in pred: 75 | scores = pred['matching_scores0'][0].cpu().half().numpy() 76 | grp.create_dataset('matching_scores0', data=scores) 77 | 78 | if 'matching_scores1' in pred: 79 | scores = pred['matching_scores1'][0].cpu().half().numpy() 80 | grp.create_dataset('matching_scores1', data=scores) 81 | 82 | matched |= {(name0, name1), (name1, name0)} 83 | 84 | if vis_match: 85 | vis_match_pairs(pred, feats0, feats1, name0, name1) 86 | 87 | match_file.close() 88 | logging.info('Finishing exporting matches.') 89 | 90 | 91 | def main(cfg, feature_out, covis_pairs_out, matches_out, vis_match=False): 92 | if cfg.network.matching == 'superglue': 93 | spg(cfg, feature_out, covis_pairs_out, matches_out, vis_match) 94 | else: 95 | raise NotImplementedError -------------------------------------------------------------------------------- /src/sfm/pairs_from_poses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.distance as distance 3 | from src.utils import path_utils 4 | 5 | 6 | def get_pairswise_distances(pose_files): 7 | Rs = [] 8 | ts = [] 9 | 10 | seqs_ids = {} 11 | for i in range(len(pose_files)): 12 | pose_file = pose_files[i] 13 | seq_name = pose_file.split('/')[-3] 14 | if seq_name not in seqs_ids.keys(): 15 | seqs_ids[seq_name] = [i] 16 | else: 17 | seqs_ids[seq_name].append(i) 18 | 19 | for pose_file in pose_files: 20 | pose = np.loadtxt(pose_file) 21 | R = pose[:3, :3] 22 | t = pose[:3, 3:] 23 | Rs.append(R) 24 | ts.append(t) 25 | 26 | Rs = np.stack(Rs, axis=0) 27 | ts = np.stack(ts, axis=0) 28 | 29 | Rs = Rs.transpose(0, 2, 1) # [n, 3, 3] 30 | ts = -(Rs @ ts)[:, :, 0] # [n, 3, 3] @ [n, 3, 1] 31 | 32 | dist = distance.squareform(distance.pdist(ts)) 33 | trace = np.einsum('nji,mji->mn', Rs, Rs, optimize=True) 34 | dR = np.clip((trace - 1) / 2, -1., 1.) 35 | dR = np.rad2deg(np.abs(np.arccos(dR))) 36 | 37 | return dist, dR, seqs_ids 38 | 39 | 40 | def covis_from_pose(img_lists, covis_pairs_out, num_matched, max_rotation, do_ba=False): 41 | pose_lists = [path_utils.get_gt_pose_path_by_color(color_path) for color_path in img_lists] 42 | dist, dR, seqs_ids = get_pairswise_distances(pose_lists) 43 | 44 | min_rotation = 10 45 | valid = dR > min_rotation 46 | np.fill_diagonal(valid, False) 47 | dist = np.where(valid, dist, np.inf) 48 | 49 | pairs = [] 50 | num_matched_per_seq = num_matched // len(seqs_ids.keys()) 51 | for i in range(len(img_lists)): 52 | dist_i = dist[i] 53 | for seq_id in seqs_ids: 54 | ids = np.array(seqs_ids[seq_id]) 55 | try: 56 | idx = np.argpartition(dist_i[ids], num_matched_per_seq * 2)[: num_matched_per_seq:2] 57 | except: 58 | idx = np.argpartition(dist_i[ids], dist_i.shape[0]-1) 59 | idx = ids[idx] 60 | idx = idx[np.argsort(dist_i[idx])] 61 | idx = idx[valid[i][idx]] 62 | 63 | for j in idx: 64 | name0 = img_lists[i] 65 | name1 = img_lists[j] 66 | 67 | pairs.append((name0, name1)) 68 | 69 | with open(covis_pairs_out, 'w') as f: 70 | f.write('\n'.join(' '.join([i, j]) for i, j in pairs)) 71 | -------------------------------------------------------------------------------- /src/sfm/postprocess/filter_points.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import os.path as osp 5 | from src.utils.colmap import read_write_model 6 | 7 | 8 | def filter_by_track_length(points3D, track_length): 9 | """ 10 | Filter 3d points by track length. 11 | Return new pcds and corresponding point ids in origin pcds. 12 | """ 13 | idxs_3d = list(points3D.keys()) 14 | idxs_3d.sort() 15 | xyzs = np.empty(shape=[0, 3]) 16 | points_idxs = np.empty(shape=[0], dtype=int) 17 | for i in range(len(idxs_3d)): 18 | idx_3d = idxs_3d[i] 19 | if len(points3D[idx_3d].point2D_idxs) < track_length: 20 | continue 21 | xyz = points3D[idx_3d].xyz.reshape(1, -1) 22 | xyzs = np.append(xyzs, xyz, axis=0) 23 | points_idxs = np.append(points_idxs, idx_3d) 24 | 25 | return xyzs, points_idxs 26 | 27 | 28 | def filter_by_3d_box(points, points_idxs, box_path): 29 | """ Filter 3d points by 3d box.""" 30 | corner_in_cano = np.loadtxt(box_path) 31 | 32 | assert points.shape[1] == 3, "Input pcds must have shape (n, 3)" 33 | if not isinstance(points, torch.Tensor): 34 | points = torch.as_tensor(points, dtype=torch.float32) 35 | if not isinstance(corner_in_cano, torch.Tensor): 36 | corner_in_cano = torch.as_tensor(corner_in_cano, dtype=torch.float32) 37 | 38 | def filter_(bbox_3d, points): 39 | """ 40 | @param bbox_3d: corners (8, 3) 41 | @param points: (n, 3) 42 | """ 43 | v45 = bbox_3d[5] - bbox_3d[4] 44 | v40 = bbox_3d[0] - bbox_3d[4] 45 | v47 = bbox_3d[7] - bbox_3d[4] 46 | 47 | points = points - bbox_3d[4] 48 | m0 = torch.matmul(points, v45) 49 | m1 = torch.matmul(points, v40) 50 | m2 = torch.matmul(points, v47) 51 | 52 | cs = [] 53 | for m, v in zip([m0, m1, m2], [v45, v40, v47]): 54 | c0 = 0 < m 55 | c1 = m < torch.matmul(v, v) 56 | c = c0 & c1 57 | cs.append(c) 58 | cs = cs[0] & cs[1] & cs[2] 59 | passed_inds = torch.nonzero(cs).squeeze(1) 60 | num_passed = torch.sum(cs) 61 | return num_passed, passed_inds, cs 62 | 63 | num_passed, passed_inds, keeps = filter_(corner_in_cano, points) 64 | 65 | xyzs_filtered = np.empty(shape=(0, 3), dtype=np.float32) 66 | for i in range(int(num_passed)): 67 | ind = passed_inds[i] 68 | xyzs_filtered = np.append(xyzs_filtered, points[ind, None], axis=0) 69 | 70 | filtered_xyzs = points[passed_inds] 71 | passed_inds = points_idxs[passed_inds] 72 | return filtered_xyzs, passed_inds 73 | 74 | 75 | def filter_3d(model_path, track_length, box_path): 76 | """ Filter 3d points by tracke length and 3d box """ 77 | points_model_path = osp.join(model_path, 'points3D.bin') 78 | points3D = read_write_model.read_points3d_binary(points_model_path) 79 | 80 | xyzs, points_idxs = filter_by_track_length(points3D, track_length) 81 | xyzs, points_idxs = filter_by_3d_box(xyzs, points_idxs, box_path) 82 | 83 | return xyzs, points_idxs 84 | 85 | 86 | def merge(xyzs, points_idxs, dist_threshold=1e-3): 87 | """ 88 | Merge points which are close to others. ({[x1, y1], [x2, y2], ...} => [mean(x_i), mean(y_i)]) 89 | """ 90 | from scipy.spatial.distance import pdist, squareform 91 | 92 | if not isinstance(xyzs, np.ndarray): 93 | xyzs = np.array(xyzs) 94 | 95 | dist = pdist(xyzs, 'euclidean') 96 | distance_matrix = squareform(dist) 97 | close_than_thresh = distance_matrix < dist_threshold 98 | 99 | ret_points_count = 0 # num of return points 100 | ret_points = np.empty(shape=[0, 3]) # pcds after merge 101 | ret_idxs = {} # {new_point_idx: points idxs in Points3D} 102 | 103 | points3D_idx_record = [] # points that have been merged 104 | for j in range(distance_matrix.shape[0]): 105 | idxs = close_than_thresh[j] 106 | 107 | if np.isin(points_idxs[idxs], points3D_idx_record).any(): 108 | continue 109 | 110 | points = np.mean(xyzs[idxs], axis=0) # new point 111 | ret_points = np.append(ret_points, points.reshape(1, 3), axis=0) 112 | ret_idxs[ret_points_count] = points_idxs[idxs] 113 | ret_points_count += 1 114 | 115 | points3D_idx_record = points3D_idx_record + points_idxs[idxs].tolist() 116 | 117 | return ret_points, ret_idxs 118 | 119 | 120 | -------------------------------------------------------------------------------- /src/sfm/postprocess/filter_tkl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import os.path as osp 4 | import matplotlib.pyplot as plt 5 | 6 | from src.utils.colmap.read_write_model import write_model 7 | 8 | 9 | def get_points_count(points3D, show=False): 10 | """ Count track length for each point """ 11 | points_count_list = [] # track length for each point 12 | for points_id, points_item in points3D.items(): 13 | points_count = len(points_item.point2D_idxs) 14 | points_count_list.append(points_count) 15 | 16 | count_dict = dict() # {track length: num of 3d points} 17 | for count in points_count_list: 18 | if count not in count_dict.keys(): 19 | count_dict[count] = 0 20 | count_dict[count] += 1 21 | counts = list(count_dict.keys()) 22 | counts.sort() 23 | 24 | count_list_ordered = [] 25 | for count in counts: 26 | count_list_ordered.append(count_dict[count]) 27 | 28 | if show: 29 | plt.plot(counts, count_list_ordered) 30 | plt.show() 31 | 32 | return count_dict, points_count_list 33 | 34 | 35 | def get_tkl(model_path, thres, show=False): 36 | """ Get the track length value which can limit the number of 3d points below thres""" 37 | from src.utils.colmap.read_write_model import read_model 38 | 39 | cameras, images, points3D = read_model(model_path, ext='.bin') 40 | count_dict, points_count_list = get_points_count(points3D, show) 41 | 42 | ret_points = len(points3D.keys()) 43 | count_keys = np.array(list(count_dict.keys())) 44 | count_keys.sort() 45 | 46 | for key in count_keys: 47 | ret_points -= count_dict[key] 48 | if ret_points <= thres: 49 | track_length = key 50 | break 51 | 52 | return track_length, points_count_list 53 | 54 | 55 | def vis_tkl_filtered_pcds(model_path, points_count_list, track_length, output_path): 56 | """ 57 | Given a track length value, filter 3d points. 58 | Output filtered pcds for visualization. 59 | """ 60 | from src.utils.colmap.read_write_model import read_model 61 | 62 | cameras, images, points3D = read_model(model_path, ext='.bin') 63 | 64 | invalid_points = np.where(np.array(points_count_list) < track_length)[0] 65 | point_ids = [] 66 | for k, v in points3D.items(): 67 | point_ids.append(k) 68 | 69 | invalid_points_ids = [] 70 | for invalid_count in invalid_points: 71 | points3D.pop(point_ids[invalid_count]) 72 | invalid_points_ids.append(point_ids[invalid_count]) 73 | 74 | output_path = osp.join(output_path, 'tkl_model') 75 | output_file_path = osp.join(output_path, 'tl-{}.ply'.format(track_length)) 76 | if not osp.exists(output_path): 77 | os.makedirs(output_path) 78 | 79 | write_model(cameras, images, points3D, output_path, '.bin') 80 | os.system(f'colmap model_converter --input_path {output_path} --output_path {output_file_path} --output_type PLY') 81 | return output_file_path -------------------------------------------------------------------------------- /src/sfm/triangulation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import logging 4 | import tqdm 5 | import subprocess 6 | import os.path as osp 7 | import numpy as np 8 | 9 | from pathlib import Path 10 | from src.utils.colmap.read_write_model import CAMERA_MODEL_NAMES, Image, read_cameras_binary, read_images_binary 11 | from src.utils.colmap.database import COLMAPDatabase 12 | 13 | 14 | def names_to_pair(name0, name1): 15 | return '_'.join((name0.replace('/', '-'), name1.replace('/', '-'))) 16 | 17 | 18 | def geometric_verification(colmap_path, database_path, pairs_path): 19 | """ Geometric verfication """ 20 | logging.info('Performing geometric verification of the matches...') 21 | cmd = [ 22 | str(colmap_path), 'matches_importer', 23 | '--database_path', str(database_path), 24 | '--match_list_path', str(pairs_path), 25 | '--match_type', 'pairs' 26 | ] 27 | ret = subprocess.call(cmd) 28 | if ret != 0: 29 | logging.warning('Problem with matches_importer, existing.') 30 | exit(ret) 31 | 32 | 33 | def create_db_from_model(empty_model, database_path): 34 | """ Create COLMAP database file from empty COLMAP binary file. """ 35 | if database_path.exists(): 36 | logging.warning('Database already exists.') 37 | 38 | cameras = read_cameras_binary(str(empty_model / 'cameras.bin')) 39 | images = read_images_binary(str(empty_model / 'images.bin')) 40 | 41 | db = COLMAPDatabase.connect(database_path) 42 | db.create_tables() 43 | 44 | for i, camera in cameras.items(): 45 | model_id = CAMERA_MODEL_NAMES[camera.model].model_id 46 | db.add_camera(model_id, camera.width, camera.height, camera.params, 47 | camera_id=i, prior_focal_length=True) 48 | 49 | for i, image in images.items(): 50 | db.add_image(image.name, image.camera_id, image_id=i) 51 | 52 | db.commit() 53 | db.close() 54 | return {image.name: i for i, image in images.items()} 55 | 56 | 57 | def import_features(image_ids, database_path, feature_path): 58 | """ Import keypoints info into COLMAP database. """ 59 | logging.info("Importing features into the database...") 60 | feature_file = h5py.File(str(feature_path), 'r') 61 | db = COLMAPDatabase.connect(database_path) 62 | 63 | for image_name, image_id in tqdm.tqdm(image_ids.items()): 64 | keypoints = feature_file[image_name]['keypoints'].__array__() 65 | keypoints += 0.5 66 | db.add_keypoints(image_id, keypoints) 67 | 68 | feature_file.close() 69 | db.commit() 70 | db.close() 71 | 72 | 73 | def import_matches(image_ids, database_path, pairs_path, matches_path, feature_path, 74 | min_match_score=None, skip_geometric_verification=False): 75 | """ Import matches info into COLMAP database. """ 76 | logging.info("Importing matches into the database...") 77 | 78 | with open(str(pairs_path), 'r') as f: 79 | pairs = [p.split(' ') for p in f.read().split('\n')] 80 | 81 | match_file = h5py.File(str(matches_path), 'r') 82 | db = COLMAPDatabase.connect(database_path) 83 | 84 | matched = set() 85 | for name0, name1 in tqdm.tqdm(pairs): 86 | id0, id1 = image_ids[name0], image_ids[name1] 87 | if len({(id0, id1), (id1, id0)} & matched) > 0: 88 | continue 89 | 90 | pair = names_to_pair(name0, name1) 91 | if pair not in match_file: 92 | raise ValueError( 93 | f'Could not find pair {(name0, name1)}... ' 94 | 'Maybe you matched with a different list of pairs? ' 95 | f'Reverse in file: {names_to_pair(name0, name1) in match_file}.' 96 | ) 97 | 98 | matches = match_file[pair]['matches0'].__array__() 99 | valid = matches > -1 100 | if min_match_score: 101 | scores = match_file[pair]['matching_scores0'].__array__() 102 | valid = valid & (scores > min_match_score) 103 | 104 | matches = np.stack([np.where(valid)[0], matches[valid]], -1) 105 | 106 | db.add_matches(id0, id1, matches) 107 | matched |= {(id0, id1), (id1, id0)} 108 | 109 | if skip_geometric_verification: 110 | db.add_two_view_geometry(id0, id1, matches) 111 | 112 | match_file.close() 113 | db.commit() 114 | db.close() 115 | 116 | 117 | def run_triangulation(colmap_path, model_path, database_path, image_dir, empty_model): 118 | """ run triangulation on given database """ 119 | logging.info('Running the triangulation...') 120 | 121 | cmd = [ 122 | str(colmap_path), 'point_triangulator', 123 | '--database_path', str(database_path), 124 | '--image_path', str(image_dir), 125 | '--input_path', str(empty_model), 126 | '--output_path', str(model_path), 127 | '--Mapper.ba_refine_focal_length', '0', 128 | '--Mapper.ba_refine_principal_point', '0', 129 | '--Mapper.ba_refine_extra_params', '0' 130 | ] 131 | logging.info(' '.join(cmd)) 132 | ret = subprocess.call(cmd) 133 | if ret != 0: 134 | logging.warning('Problem with point_triangulator, existing.') 135 | exit(ret) 136 | 137 | stats_raw = subprocess.check_output( 138 | [str(colmap_path), 'model_analyzer', '--path', model_path] 139 | ) 140 | stats_raw = stats_raw.decode().split('\n') 141 | stats = dict() 142 | for stat in stats_raw: 143 | if stat.startswith('Register images'): 144 | stats['num_reg_images'] = int(stat.split()[-1]) 145 | elif stat.startswith('Points'): 146 | stats['num_sparse_points'] = int(stat.split()[-1]) 147 | elif stat.startswith('Observation'): 148 | stats['num_observations'] = int(stat.split()[-1]) 149 | elif stat.startswith('Mean track length'): 150 | stats['mean_track_length'] = float(stat.split()[-1]) 151 | elif stat.startswith('Mean observation per image'): 152 | stats['num_observations_per_image'] = float(stat.split()[-1]) 153 | elif stat.startswith('Mean reprojection error'): 154 | stats['mean_reproj_error'] = float(stat.split()[-1][:-2]) 155 | return stats 156 | 157 | 158 | def main(sfm_dir, empty_sfm_model, outputs_dir, pairs, features, matches, \ 159 | colmap_path='colmap', skip_geometric_verification=False, min_match_score=None, image_dir=None): 160 | """ 161 | Import keypoints, matches. 162 | Given keypoints and matches, reconstruct sparse model from given camera poses. 163 | """ 164 | assert Path(empty_sfm_model).exists(), empty_sfm_model 165 | assert Path(features).exists(), features 166 | assert Path(pairs).exists(), pairs 167 | assert Path(matches).exists(), matches 168 | 169 | Path(sfm_dir).mkdir(parents=True, exist_ok=True) 170 | database = osp.join(sfm_dir, 'database.db') 171 | model = osp.join(sfm_dir, 'model') 172 | Path(model).mkdir(exist_ok=True) 173 | 174 | image_ids = create_db_from_model(Path(empty_sfm_model), Path(database)) 175 | import_features(image_ids, database, features) 176 | import_matches(image_ids, database, pairs, matches, features, 177 | min_match_score, skip_geometric_verification) 178 | 179 | if not skip_geometric_verification: 180 | geometric_verification(colmap_path, database, pairs) 181 | 182 | if not image_dir: 183 | image_dir = '/' 184 | stats = run_triangulation(colmap_path, model, database, image_dir, empty_sfm_model) 185 | os.system(f'colmap model_converter --input_path {model} --output_path {outputs_dir}/model.ply --output_type PLY') -------------------------------------------------------------------------------- /src/tracker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zju3dv/OnePose/d0c313a5c36994658e63eac7ca3a63d7e3573d7b/src/tracker/__init__.py -------------------------------------------------------------------------------- /src/tracker/tracking_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Timer(object): 6 | def __init__(self): 7 | self.time_dict = dict() 8 | self.res_dict = dict() 9 | self.stash_dict = dict() 10 | 11 | def set(self, label, value): 12 | if label not in self.stash_dict.keys(): 13 | self.stash_dict[label] = [] 14 | self.stash_dict[label].append(value) 15 | 16 | def tick(self, tick_name): 17 | import time 18 | self.time_dict[tick_name] = time.time() * 1000.0 19 | return self.time_dict[tick_name] 20 | 21 | def tock(self, tock_name, pop=False): 22 | if tock_name not in self.time_dict: 23 | return 0.0 24 | else: 25 | import time 26 | t2 = time.time() * 1000.0 27 | t1 = self.time_dict[tock_name] 28 | self.res_dict[tock_name] = t2 - t1 29 | if pop: 30 | self.time_dict.pop(tock_name) 31 | return self.res_dict[tock_name] 32 | 33 | def stash(self): 34 | for k, v in self.res_dict.items(): 35 | if k not in self.stash_dict.keys(): 36 | self.stash_dict[k] = [] 37 | self.stash_dict[k].append(v) 38 | 39 | def report_stash(self): 40 | res_dict = dict() 41 | for k, v in self.stash_dict.items(): 42 | res_dict[k] = np.mean(v) 43 | return res_dict 44 | 45 | def report(self): 46 | return self.res_dict 47 | 48 | 49 | def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1): 50 | def to_homogeneous(points): 51 | return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1) 52 | 53 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None] 54 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None] 55 | kpts0 = to_homogeneous(kpts0) 56 | kpts1 = to_homogeneous(kpts1) 57 | 58 | t0, t1, t2 = T_0to1[:3, 3] 59 | t_skew = np.array([ 60 | [0, -t2, t1], 61 | [t2, 0, -t0], 62 | [-t1, t0, 0] 63 | ]) 64 | E = t_skew @ T_0to1[:3, :3] 65 | 66 | Ep0 = kpts0 @ E.T # N x 3 67 | p1Ep0 = np.sum(kpts1 * Ep0, -1) # N 68 | Etp1 = kpts1 @ E # N x 3 69 | d = p1Ep0 ** 2 * (1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2) 70 | + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2)) 71 | return d 72 | 73 | 74 | def project(xyz, K, RT, need_depth=False): 75 | """ 76 | xyz: [N, 3] 77 | K: [3, 3] 78 | RT: [3, 4] 79 | """ 80 | xyz = np.dot(xyz, RT[:, :3].T) 81 | xyz += RT[:, 3:].T 82 | depth = xyz[:, 2:].flatten() 83 | xyz = np.dot(xyz, K.T) 84 | xy = xyz[:, :2] / xyz[:, 2:] 85 | if need_depth: 86 | return xy, depth 87 | else: 88 | return xy 89 | 90 | 91 | def AngleAxisRotatePoint(angleAxis, pt): 92 | theta2 = (angleAxis * angleAxis).sum(dim=1) 93 | 94 | mask = (theta2 > 0).float() 95 | 96 | theta = torch.sqrt(theta2 + (1 - mask)) 97 | 98 | mask = mask.reshape((mask.shape[0], 1)) 99 | mask = torch.cat([mask, mask, mask], dim=1) 100 | 101 | costheta = torch.cos(theta) 102 | sintheta = torch.sin(theta) 103 | thetaInverse = 1.0 / theta 104 | 105 | w0 = angleAxis[:, 0] * thetaInverse 106 | w1 = angleAxis[:, 1] * thetaInverse 107 | w2 = angleAxis[:, 2] * thetaInverse 108 | 109 | wCrossPt0 = w1 * pt[:, 2] - w2 * pt[:, 1] 110 | wCrossPt1 = w2 * pt[:, 0] - w0 * pt[:, 2] 111 | wCrossPt2 = w0 * pt[:, 1] - w1 * pt[:, 0] 112 | 113 | tmp = (w0 * pt[:, 0] + w1 * pt[:, 1] + w2 * pt[:, 2]) * (1.0 - costheta) 114 | 115 | r0 = pt[:, 0] * costheta + wCrossPt0 * sintheta + w0 * tmp 116 | r1 = pt[:, 1] * costheta + wCrossPt1 * sintheta + w1 * tmp 117 | r2 = pt[:, 2] * costheta + wCrossPt2 * sintheta + w2 * tmp 118 | 119 | r0 = r0.reshape((r0.shape[0], 1)) 120 | r1 = r1.reshape((r1.shape[0], 1)) 121 | r2 = r2.reshape((r2.shape[0], 1)) 122 | 123 | res1 = torch.cat([r0, r1, r2], dim=1) 124 | 125 | wCrossPt0 = angleAxis[:, 1] * pt[:, 2] - angleAxis[:, 2] * pt[:, 1] 126 | wCrossPt1 = angleAxis[:, 2] * pt[:, 0] - angleAxis[:, 0] * pt[:, 2] 127 | wCrossPt2 = angleAxis[:, 0] * pt[:, 1] - angleAxis[:, 1] * pt[:, 0] 128 | 129 | r00 = pt[:, 0] + wCrossPt0 130 | r01 = pt[:, 1] + wCrossPt1 131 | r02 = pt[:, 2] + wCrossPt2 132 | 133 | r00 = r00.reshape((r00.shape[0], 1)) 134 | r01 = r01.reshape((r01.shape[0], 1)) 135 | r02 = r02.reshape((r02.shape[0], 1)) 136 | 137 | res2 = torch.cat([r00, r01, r02], dim=1) 138 | 139 | return res1 * mask + res2 * (1 - mask) 140 | 141 | 142 | def SnavelyReprojectionErrorV2(points_ob, cameras_ob, features): 143 | if (len(points_ob.shape) == 3): 144 | points_ob = points_ob[:,0,:] 145 | cameras_ob = cameras_ob[:,0,:] 146 | focals = features[:, 2] 147 | l1 = features[:, 3] 148 | l2 = features[:, 4] 149 | 150 | # camera[0,1,2] are the angle-axis rotation. 151 | p = AngleAxisRotatePoint(cameras_ob[:, :3], points_ob) 152 | p = p + cameras_ob[:, 3:6] 153 | 154 | xp = p[:,0] / p[:,2] 155 | yp = p[:,1] / p[:,2] 156 | 157 | # predicted_x, predicted_y = DistortV2(xp, yp, cameras_ob, cam_K) 158 | 159 | predicted_x = focals * xp + l1 160 | predicted_y = focals * yp + l2 161 | 162 | residual_0 = predicted_x - features[:, 0] 163 | residual_1 = predicted_y - features[:, 1] 164 | 165 | residual_0 = residual_0.reshape((residual_0.shape[0], 1)) 166 | residual_1 = residual_1.reshape((residual_1.shape[0], 1)) 167 | 168 | #return torch.sqrt(residual_0**2 + residual_1 ** 2) 169 | return torch.cat([residual_0, residual_1], dim=1) 170 | 171 | 172 | def put_text(img, inform_text, color=None): 173 | import cv2 174 | fontScale = 1 175 | if color is None: 176 | color = (255, 0, 0) 177 | org = (50, 50) 178 | font = cv2.FONT_HERSHEY_SIMPLEX 179 | thickness = 2 180 | img = cv2.putText(img, inform_text, org, font, 181 | fontScale, color, thickness, cv2.LINE_AA) 182 | return img 183 | 184 | 185 | def draw_kpt2d(image, kpt2d, color=(0, 0, 255), radius=2, thikness=1): 186 | import cv2 187 | for coord in kpt2d: 188 | cv2.circle(image, (int(coord[0]), int(coord[1])), radius, color, thikness, 1) 189 | # cv2.circle(image, (int(coord[0]), int(coord[1])), 7, color, 1, 1) 190 | return image 191 | 192 | 193 | class MovieWriter: 194 | def __init__(self): 195 | self.video_out_path = '' 196 | self.movie_cap = None 197 | self.id = 0 198 | 199 | def start(self): 200 | if self.movie_cap is not None: 201 | self.movie_cap.release() 202 | 203 | def write(self, im_bgr, video_out_path, text_info=[], fps=20): 204 | import cv2 205 | if self.movie_cap is None: 206 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 207 | length = fps 208 | self.video_out_path = video_out_path 209 | self.movie_cap = cv2.VideoWriter(video_out_path, fourcc, 210 | length, (im_bgr.shape[1], im_bgr.shape[0])) 211 | 212 | if len(text_info) > 0: 213 | self.put_text(im_bgr, text_info[0], color=text_info[1]) 214 | self.movie_cap.write(im_bgr) 215 | self.id += 1 216 | 217 | def put_text(self, img, inform_text, color=None): 218 | import cv2 219 | fontScale = 1 220 | if color is None: 221 | color = (255, 0, 0) 222 | org = (200, 200) 223 | font = cv2.FONT_HERSHEY_SIMPLEX 224 | thickness = 2 225 | img = cv2.putText(img, inform_text, org, font, 226 | fontScale, color, thickness, cv2.LINE_AA) 227 | return img 228 | 229 | def end(self): 230 | if self.movie_cap is not None: 231 | self.movie_cap.release() 232 | self.movie_cap = None 233 | print(f"Output frames:{self.id} to {self.video_out_path}") 234 | self.id = 0 235 | -------------------------------------------------------------------------------- /src/utils/comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | [Copied from detectron2] 4 | This file contains primitives for multi-gpu communication. 5 | This is useful when doing distributed training. 6 | """ 7 | 8 | import functools 9 | import logging 10 | import numpy as np 11 | import pickle 12 | import torch 13 | import torch.distributed as dist 14 | 15 | _LOCAL_PROCESS_GROUP = None 16 | """ 17 | A torch process group which only includes processes that on the same machine as the current process. 18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py". 19 | """ 20 | 21 | 22 | def get_world_size() -> int: 23 | if not dist.is_available(): 24 | return 1 25 | if not dist.is_initialized(): 26 | return 1 27 | return dist.get_world_size() 28 | 29 | 30 | def get_rank() -> int: 31 | if not dist.is_available(): 32 | return 0 33 | if not dist.is_initialized(): 34 | return 0 35 | return dist.get_rank() 36 | 37 | 38 | def get_local_rank() -> int: 39 | """ 40 | Returns: 41 | The rank of the current process within the local (per-machine) process group. 42 | """ 43 | if not dist.is_available(): 44 | return 0 45 | if not dist.is_initialized(): 46 | return 0 47 | assert _LOCAL_PROCESS_GROUP is not None 48 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 49 | 50 | 51 | def get_local_size() -> int: 52 | """ 53 | Returns: 54 | The size of the per-machine process group, 55 | i.e. the number of processes per machine. 56 | """ 57 | if not dist.is_available(): 58 | return 1 59 | if not dist.is_initialized(): 60 | return 1 61 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 62 | 63 | 64 | def is_main_process() -> bool: 65 | return get_rank() == 0 66 | 67 | 68 | def synchronize(): 69 | """ 70 | Helper function to synchronize (barrier) among all processes when 71 | using distributed training 72 | """ 73 | if not dist.is_available(): 74 | return 75 | if not dist.is_initialized(): 76 | return 77 | world_size = dist.get_world_size() 78 | if world_size == 1: 79 | return 80 | dist.barrier() 81 | 82 | 83 | @functools.lru_cache() 84 | def _get_global_gloo_group(): 85 | """ 86 | Return a process group based on gloo backend, containing all the ranks 87 | The result is cached. 88 | """ 89 | if dist.get_backend() == "nccl": 90 | return dist.new_group(backend="gloo") 91 | else: 92 | return dist.group.WORLD 93 | 94 | 95 | def _serialize_to_tensor(data, group): 96 | backend = dist.get_backend(group) 97 | assert backend in ["gloo", "nccl"] 98 | device = torch.device("cpu" if backend == "gloo" else "cuda") 99 | 100 | buffer = pickle.dumps(data) 101 | if len(buffer) > 1024 ** 3: 102 | logger = logging.getLogger(__name__) 103 | logger.warning( 104 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 105 | get_rank(), len(buffer) / (1024 ** 3), device 106 | ) 107 | ) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to(device=device) 110 | return tensor 111 | 112 | 113 | def _pad_to_largest_tensor(tensor, group): 114 | """ 115 | Returns: 116 | list[int]: size of the tensor, on each rank 117 | Tensor: padded tensor that has the max size 118 | """ 119 | world_size = dist.get_world_size(group=group) 120 | assert ( 121 | world_size >= 1 122 | ), "comm.gather/all_gather must be called from ranks within the given group!" 123 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 124 | size_list = [ 125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 126 | ] 127 | dist.all_gather(size_list, local_size, group=group) 128 | 129 | size_list = [int(size.item()) for size in size_list] 130 | 131 | max_size = max(size_list) 132 | 133 | # we pad the tensor because torch all_gather does not support 134 | # gathering tensors of different shapes 135 | if local_size != max_size: 136 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 137 | tensor = torch.cat((tensor, padding), dim=0) 138 | return size_list, tensor 139 | 140 | 141 | def all_gather(data, group=None): 142 | """ 143 | Run all_gather on arbitrary picklable data (not necessarily tensors). 144 | Args: 145 | data: any picklable object 146 | group: a torch process group. By default, will use a group which 147 | contains all ranks on gloo backend. 148 | Returns: 149 | list[data]: list of data gathered from each rank 150 | """ 151 | if get_world_size() == 1: 152 | return [data] 153 | if group is None: 154 | group = _get_global_gloo_group() 155 | if dist.get_world_size(group) == 1: 156 | return [data] 157 | 158 | tensor = _serialize_to_tensor(data, group) 159 | 160 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 161 | max_size = max(size_list) 162 | 163 | # receiving Tensor from all ranks 164 | tensor_list = [ 165 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 166 | ] 167 | dist.all_gather(tensor_list, tensor, group=group) 168 | 169 | data_list = [] 170 | for size, tensor in zip(size_list, tensor_list): 171 | buffer = tensor.cpu().numpy().tobytes()[:size] 172 | data_list.append(pickle.loads(buffer)) 173 | 174 | return data_list 175 | 176 | 177 | def gather(data, dst=0, group=None): 178 | """ 179 | Run gather on arbitrary picklable data (not necessarily tensors). 180 | Args: 181 | data: any picklable object 182 | dst (int): destination rank 183 | group: a torch process group. By default, will use a group which 184 | contains all ranks on gloo backend. 185 | Returns: 186 | list[data]: on dst, a list of data gathered from each rank. Otherwise, 187 | an empty list. 188 | """ 189 | if get_world_size() == 1: 190 | return [data] 191 | if group is None: 192 | group = _get_global_gloo_group() 193 | if dist.get_world_size(group=group) == 1: 194 | return [data] 195 | rank = dist.get_rank(group=group) 196 | 197 | tensor = _serialize_to_tensor(data, group) 198 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 199 | 200 | # receiving Tensor from all ranks 201 | if rank == dst: 202 | max_size = max(size_list) 203 | tensor_list = [ 204 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 205 | ] 206 | dist.gather(tensor, tensor_list, dst=dst, group=group) 207 | 208 | data_list = [] 209 | for size, tensor in zip(size_list, tensor_list): 210 | buffer = tensor.cpu().numpy().tobytes()[:size] 211 | data_list.append(pickle.loads(buffer)) 212 | return data_list 213 | else: 214 | dist.gather(tensor, [], dst=dst, group=group) 215 | return [] 216 | 217 | 218 | def shared_random_seed(): 219 | """ 220 | Returns: 221 | int: a random number that is the same across all workers. 222 | If workers need a shared RNG, they can use this shared seed to 223 | create one. 224 | All workers must call this function, otherwise it will deadlock. 225 | """ 226 | ints = np.random.randint(2 ** 31) 227 | all_ints = all_gather(ints) 228 | return all_ints[0] 229 | 230 | 231 | def reduce_dict(input_dict, average=True): 232 | """ 233 | Reduce the values in the dictionary from all processes so that process with rank 234 | 0 has the reduced results. 235 | Args: 236 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. 237 | average (bool): whether to do average or sum 238 | Returns: 239 | a dict with the same keys as input_dict, after reduction. 240 | """ 241 | world_size = get_world_size() 242 | if world_size < 2: 243 | return input_dict 244 | with torch.no_grad(): 245 | names = [] 246 | values = [] 247 | # sort the keys so that they are consistent across processes 248 | for k in sorted(input_dict.keys()): 249 | names.append(k) 250 | values.append(input_dict[k]) 251 | values = torch.stack(values, dim=0) 252 | dist.reduce(values, dst=0) 253 | if dist.get_rank() == 0 and average: 254 | # only main process gets accumulated, so only divide by 255 | # world_size in this case 256 | values /= world_size 257 | reduced_dict = {k: v for k, v in zip(names, values)} 258 | return reduced_dict 259 | -------------------------------------------------------------------------------- /src/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os.path as osp 4 | 5 | from pathlib import Path 6 | 7 | def record_eval_result(out_dir, obj_name, seq_name, eval_result): 8 | Path(out_dir).mkdir(exist_ok=True, parents=True) 9 | 10 | out_file = osp.join(out_dir, obj_name + seq_name + '.txt') 11 | f = open(out_file, 'w') 12 | for k, v in eval_result.items(): 13 | f.write(f'{k}: {v}\n') 14 | 15 | f.close() 16 | 17 | 18 | def ransac_PnP(K, pts_2d, pts_3d, scale=1): 19 | """ solve pnp """ 20 | dist_coeffs = np.zeros(shape=[8, 1], dtype='float64') 21 | 22 | pts_2d = np.ascontiguousarray(pts_2d.astype(np.float64)) 23 | pts_3d = np.ascontiguousarray(pts_3d.astype(np.float64)) 24 | K = K.astype(np.float64) 25 | 26 | pts_3d *= scale 27 | try: 28 | _, rvec, tvec, inliers = cv2.solvePnPRansac(pts_3d, pts_2d, K, dist_coeffs, reprojectionError=5, 29 | iterationsCount=10000, flags=cv2.SOLVEPNP_EPNP) 30 | 31 | rotation = cv2.Rodrigues(rvec)[0] 32 | 33 | tvec /= scale 34 | pose = np.concatenate([rotation, tvec], axis=-1) 35 | pose_homo = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) 36 | 37 | inliers = [] if inliers is None else inliers 38 | 39 | return pose, pose_homo, inliers 40 | except cv2.error: 41 | print("CV ERROR") 42 | return np.eye(4)[:3], np.eye(4), [] 43 | 44 | 45 | def query_pose_error(pose_pred, pose_gt): 46 | """ 47 | Input: 48 | --------- 49 | pose_pred: np.array 3*4 or 4*4 50 | pose_gt: np.array 3*4 or 4*4 51 | """ 52 | # Dim check: 53 | if pose_pred.shape[0] == 4: 54 | pose_pred = pose_pred[:3] 55 | if pose_gt.shape[0] == 4: 56 | pose_gt = pose_gt[:3] 57 | 58 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_gt[:, 3]) * 100 59 | rotation_diff = np.dot(pose_pred[:, :3], pose_gt[:, :3].T) 60 | trace = np.trace(rotation_diff) 61 | trace = trace if trace <= 3 else 3 62 | angular_distance = np.rad2deg(np.arccos((trace - 1.0) / 2.0)) 63 | return angular_distance, translation_distance 64 | 65 | 66 | def compute_query_pose_errors(data, preds): 67 | query_pose_gt = data['query_pose_gt'][0].cpu().numpy() 68 | query_K = data['query_intrinsic'][0].cpu().numpy() 69 | query_kpts2d = data['keypoints2d'][0].cpu().numpy() 70 | query_kpts3d = data['keypoints3d'][0].cpu().numpy() 71 | 72 | matches0 = preds['matches0'].cpu().numpy() 73 | confidence = preds['matching_scores0'].cpu().numpy() 74 | valid = matches0 > -1 75 | mkpts2d = query_kpts2d[valid] 76 | mkpts3d = query_kpts3d[matches0[valid]] 77 | mconf = confidence[valid] 78 | 79 | pose_pred = [] 80 | val_results = {'R_errs': [], 't_errs': [], 'inliers': []} 81 | 82 | query_pose_pred, query_pose_pred_homo, inliers = ransac_PnP( 83 | query_K, 84 | mkpts2d, 85 | mkpts3d 86 | ) 87 | pose_pred.append(query_pose_pred_homo) 88 | 89 | if query_pose_pred is None: 90 | val_results['R_errs'].append(np.inf) 91 | val_results['t_errs'].append(np.inf) 92 | val_results['inliers'].append(np.array([])).astype(np.bool) 93 | else: 94 | R_err, t_err = query_pose_error(query_pose_pred, query_pose_gt) 95 | val_results['R_errs'].append(R_err) 96 | val_results['t_errs'].append(t_err) 97 | val_results['inliers'].append(inliers) 98 | 99 | pose_pred = np.stack(pose_pred) 100 | 101 | val_results.update({'mkpts2d': mkpts2d, 'mkpts3d': mkpts3d, 'mconf': mconf}) 102 | return pose_pred, val_results 103 | 104 | 105 | def aggregate_metrics(metrics, thres=[1, 3, 5]): 106 | """ Aggregate metrics for the whole dataset: 107 | (This method should be called once per dataset) 108 | 1. AUC of the pose error (angular) at the threshold [5, 10, 20] 109 | 2. Mean matching precision at the threshold 5e-4 110 | """ 111 | R_errs = metrics['R_errs'] 112 | t_errs = metrics['t_errs'] 113 | 114 | degree_distance_metric = {} 115 | for threshold in thres: 116 | degree_distance_metric[f'{threshold}cm@{threshold}degree'] = np.mean( 117 | (np.array(R_errs) < threshold) & (np.array(t_errs) < threshold) 118 | ) 119 | 120 | return degree_distance_metric -------------------------------------------------------------------------------- /src/utils/model_io.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | 6 | def load_model(net, optim, scheduler, recorder, model_dir, resume=True, epoch=-1): 7 | if not resume: 8 | os.system('rm -rf {}'.format(model_dir)) 9 | 10 | if not os.path.exists(model_dir): 11 | return 0 12 | 13 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] 14 | if len(pths) == 0: 15 | return 0 16 | if epoch == -1: 17 | pth = max(pths) 18 | else: 19 | pth = epoch 20 | print('Load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth)))) 21 | pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth))) 22 | net.load_state_dict(pretrained_model['net']) 23 | optim.load_state_dict(pretrained_model['optim']) 24 | scheduler.load_state_dict(pretrained_model['scheduler']) 25 | recorder.load_state_dict(pretrained_model['recorder']) 26 | return pretrained_model['epoch'] + 1 27 | 28 | 29 | def save_model(net, optim, scheduler, recorder, epoch, model_dir): 30 | os.system('mkdir -p {}'.format(model_dir)) 31 | torch.save({ 32 | 'net': net.state_dict(), 33 | 'optim': optim.state_dict(), 34 | 'scheduler': scheduler.state_dict(), 35 | 'recorder': recorder.state_dict(), 36 | 'epoch': epoch 37 | }, os.path.join(model_dir, '{}.pth'.format(epoch))) 38 | 39 | # remove previous pretrained model if the number of models is too big 40 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)] 41 | if len(pths) <= 200: 42 | return 43 | os.system('rm {}'.format(os.path.join(model_dir, '{}.pth'.format(min(pths))))) 44 | 45 | 46 | def load_network_ckpt(net, ckpt_path): 47 | pretrained_model = torch.load(ckpt_path, torch.device('cpu')) 48 | pretrained_model = pretrained_model['state_dict'] 49 | 50 | pretrained_model = remove_net_layer(pretrained_model, 'detector') 51 | pretrained_model = remove_net_prefix(pretrained_model, 'superglue.') 52 | 53 | print('=> load weights: ', ckpt_path) 54 | net.load_state_dict(pretrained_model) 55 | return None 56 | 57 | 58 | def load_network(net, model_dir, resume=True, epoch=-1, strict=True, force=False): 59 | """ 60 | Load latest network-weights from dir or path 61 | """ 62 | if not resume: 63 | return 0 64 | 65 | if not os.path.exists(model_dir): 66 | if force: 67 | raise NotImplementedError 68 | else: 69 | print('pretrained model does not exist') 70 | return 0 71 | 72 | if os.path.isdir(model_dir): 73 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir) if 'pth' in pth] 74 | if len(pths) == 0: 75 | return 0 76 | if epoch == -1: 77 | pth = max(pths) 78 | else: 79 | pth = epoch 80 | model_path = os.path.join(model_dir, '{}.pth'.format(pth)) 81 | else: 82 | model_path = model_dir 83 | 84 | print('=> load weights: ', model_path) 85 | pretrained_model = torch.load(model_path, torch.device("cpu")) 86 | if 'net' in pretrained_model.keys(): 87 | net.load_state_dict(pretrained_model['net'], strict=strict) 88 | else: 89 | net.load_state_dict(pretrained_model, strict=strict) 90 | return pretrained_model.get('epoch', 0) + 1 91 | 92 | 93 | def remove_net_prefix(net, prefix): 94 | net_ = OrderedDict() 95 | for k in net.keys(): 96 | if k.startswith(prefix): 97 | net_[k[len(prefix):]] = net[k] 98 | else: 99 | net_[k] = net[k] 100 | return net_ 101 | 102 | 103 | def add_net_prefix(net, prefix): 104 | net_ = OrderedDict() 105 | for k in net.keys(): 106 | net_[prefix + k] = net[k] 107 | return net_ 108 | 109 | 110 | def replace_net_prefix(net, orig_prefix, prefix): 111 | net_ = OrderedDict() 112 | for k in net.keys(): 113 | if k.startswith(orig_prefix): 114 | net_[prefix + k[len(orig_prefix):]] = net[k] 115 | else: 116 | net_[k] = net[k] 117 | return net_ 118 | 119 | 120 | def remove_net_layer(net, layers): 121 | keys = list(net.keys()) 122 | for k in keys: 123 | for layer in layers: 124 | if k.startswith(layer): 125 | del net[k] 126 | return net 127 | 128 | 129 | def to_cuda(data): 130 | if type(data).__name__ == "Tensor": 131 | data = data.cuda() 132 | elif type(data).__name__ == 'list': 133 | data = [d.cuda() for d in data] 134 | elif type(data).__name__ == 'dict': 135 | data = {k: v.cuda() for k, v in data.items()} 136 | else: 137 | raise NotImplementedError 138 | return data 139 | -------------------------------------------------------------------------------- /src/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | 5 | """ 6 | For each object, we store in the following directory format: 7 | 8 | data_root: 9 | - box3d_corners.txt 10 | - seq1_root 11 | - intrinsics.txt 12 | - color/ 13 | - color_det[optional]/ 14 | - poses_ba/ 15 | - intrin_ba/ 16 | - intrin_det[optional]/ 17 | - ...... 18 | - seq2_root 19 | - ...... 20 | """ 21 | 22 | def get_gt_pose_path_by_color(color_path, det_type='GT_box'): 23 | if det_type == "GT_box": 24 | return color_path.replace("/color/", "/poses_ba/").replace( 25 | ".png", ".txt" 26 | ) 27 | elif det_type == 'feature_matching': 28 | return color_path.replace("/color_det/", "/poses_ba/").replace( 29 | ".png", ".txt" 30 | ) 31 | else: 32 | raise NotImplementedError 33 | 34 | def get_img_full_path_by_color(color_path, det_type='GT_box'): 35 | if det_type == "GT_box": 36 | return color_path.replace("/color/", "/color_full/") 37 | elif det_type == 'feature_matching': 38 | return color_path.replace("/color_det/", "/color_full/") 39 | else: 40 | raise NotImplementedError 41 | 42 | def get_intrin_path_by_color(color_path, det_type='GT_box'): 43 | if det_type == "GT_box": 44 | return color_path.replace("/color/", "/intrin_ba/").replace( 45 | ".png", ".txt" 46 | ) 47 | elif det_type == 'feature_matching': 48 | return color_path.replace("/color_det/", "/intrin_det/").replace( 49 | ".png", ".txt" 50 | ) 51 | else: 52 | raise NotImplementedError 53 | 54 | def get_intrin_dir(seq_root): 55 | return osp.join(seq_root, "intrin_ba") 56 | 57 | def get_gt_pose_dir(seq_root): 58 | return osp.join(seq_root, "poses_ba") 59 | 60 | def get_intrin_full_path(seq_root): 61 | return osp.join(seq_root, "intrinsics.txt") 62 | 63 | def get_3d_box_path(data_root): 64 | return osp.join(data_root, "box3d_corners.txt") 65 | 66 | -------------------------------------------------------------------------------- /src/utils/template_utils.py: -------------------------------------------------------------------------------- 1 | # pytorch lightning imports 2 | import pytorch_lightning as pl 3 | 4 | # hydra imports 5 | from omegaconf import DictConfig, OmegaConf 6 | from hydra.utils import get_original_cwd, to_absolute_path 7 | 8 | # loggers 9 | import wandb 10 | from pytorch_lightning.loggers.wandb import WandbLogger 11 | 12 | # from pytorch_lightning.loggers.neptune import NeptuneLogger 13 | # from pytorch_lightning.loggers.comet import CometLogger 14 | # from pytorch_lightning.loggers.mlflow import MLFlowLogger 15 | # from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 16 | 17 | # rich imports 18 | from rich import print 19 | from rich.syntax import Syntax 20 | from rich.tree import Tree 21 | 22 | # normal imports 23 | from typing import List 24 | 25 | 26 | def print_config(config: DictConfig): 27 | """Prints content of Hydra config using Rich library. 28 | 29 | Args: 30 | config (DictConfig): [description] 31 | """ 32 | 33 | # TODO print main config path and experiment config path 34 | # directory = to_absolute_path("configs/config.yaml") 35 | # print(f"Main config path: [link file://{directory}]{directory}") 36 | 37 | style = "dim" 38 | 39 | tree = Tree(f":gear: FULL HYDRA CONFIG", style=style, guide_style=style) 40 | 41 | trainer = OmegaConf.to_yaml(config["trainer"], resolve=True) 42 | trainer_branch = tree.add("Trainer", style=style, guide_style=style) 43 | trainer_branch.add(Syntax(trainer, "yaml")) 44 | 45 | model = OmegaConf.to_yaml(config["model"], resolve=True) 46 | model_branch = tree.add("Model", style=style, guide_style=style) 47 | model_branch.add(Syntax(model, "yaml")) 48 | 49 | datamodule = OmegaConf.to_yaml(config["datamodule"], resolve=True) 50 | datamodule_branch = tree.add("Datamodule", style=style, guide_style=style) 51 | datamodule_branch.add(Syntax(datamodule, "yaml")) 52 | 53 | callbacks_branch = tree.add("Callbacks", style=style, guide_style=style) 54 | if "callbacks" in config: 55 | for cb_name, cb_conf in config["callbacks"].items(): 56 | cb = callbacks_branch.add(cb_name, style=style, guide_style=style) 57 | cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=True), "yaml")) 58 | else: 59 | callbacks_branch.add("None") 60 | 61 | logger_branch = tree.add("Logger", style=style, guide_style=style) 62 | if "logger" in config: 63 | for lg_name, lg_conf in config["logger"].items(): 64 | lg = logger_branch.add(lg_name, style=style, guide_style=style) 65 | lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=True), "yaml")) 66 | else: 67 | logger_branch.add("None") 68 | 69 | seed = config.get("seed", "None") 70 | seed_branch = tree.add(f"Seed", style=style, guide_style=style) 71 | seed_branch.add(str(seed) + "\n") 72 | 73 | print(tree) 74 | 75 | 76 | def log_hparams_to_all_loggers( 77 | config: DictConfig, 78 | model: pl.LightningModule, 79 | datamodule: pl.LightningDataModule, 80 | trainer: pl.Trainer, 81 | callbacks: List[pl.Callback], 82 | logger: List[pl.loggers.LightningLoggerBase], 83 | ): 84 | """This method controls which parameters from Hydra config are saved by Lightning loggers. 85 | 86 | Args: 87 | config (DictConfig): [description] 88 | model (pl.LightningModule): [description] 89 | datamodule (pl.LightningDataModule): [description] 90 | trainer (pl.Trainer): [description] 91 | callbacks (List[pl.Callback]): [description] 92 | logger (List[pl.loggers.LightningLoggerBase]): [description] 93 | """ 94 | 95 | hparams = {} 96 | 97 | # save all params of model, datamodule and trainer 98 | hparams.update(config["model"]) 99 | hparams.update(config["datamodule"]) 100 | hparams.update(config["trainer"]) 101 | hparams.pop("_target_") 102 | 103 | # save seed 104 | hparams["seed"] = config.get("seed", "None") 105 | 106 | # save targets 107 | hparams["_class_model"] = config["model"]["_target_"] 108 | hparams["_class_datamodule"] = config["datamodule"]["_target_"] 109 | 110 | # save sizes of each dataset 111 | if hasattr(datamodule, "data_train") and datamodule.data_train: 112 | hparams["train_size"] = len(datamodule.data_train) 113 | if hasattr(datamodule, "data_val") and datamodule.data_val: 114 | hparams["val_size"] = len(datamodule.data_val) 115 | if hasattr(datamodule, "data_test") and datamodule.data_test: 116 | hparams["test_size"] = len(datamodule.data_test) 117 | 118 | # save number of model parameters 119 | hparams["#params_total"] = sum(p.numel() for p in model.parameters()) 120 | hparams["#params_trainable"] = sum( 121 | p.numel() for p in model.parameters() if p.requires_grad 122 | ) 123 | hparams["#params_not_trainable"] = sum( 124 | p.numel() for p in model.parameters() if not p.requires_grad 125 | ) 126 | 127 | # send hparams to all loggers 128 | for lg in logger: 129 | lg.log_hyperparams(hparams) 130 | 131 | 132 | def finish( 133 | config: DictConfig, 134 | model: pl.LightningModule, 135 | datamodule: pl.LightningDataModule, 136 | trainer: pl.Trainer, 137 | callbacks: List[pl.Callback], 138 | logger: List[pl.loggers.LightningLoggerBase], 139 | ): 140 | """Makes sure everything closed properly. 141 | 142 | Args: 143 | config (DictConfig): [description] 144 | model (pl.LightningModule): [description] 145 | datamodule (pl.LightningDataModule): [description] 146 | trainer (pl.Trainer): [description] 147 | callbacks (List[pl.Callback]): [description] 148 | logger (List[pl.loggers.LightningLoggerBase]): [description] 149 | """ 150 | 151 | # without this sweeps with wandb logger might crash! 152 | for lg in logger: 153 | if isinstance(lg, WandbLogger): 154 | wandb.finish() 155 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningModule, Callback, Trainer 2 | from pytorch_lightning import seed_everything 3 | from pytorch_lightning.loggers import LightningLoggerBase 4 | 5 | import hydra 6 | from omegaconf import DictConfig 7 | from typing import List 8 | from src.utils import template_utils as utils 9 | 10 | import warnings 11 | warnings.filterwarnings('ignore') 12 | 13 | 14 | def train(config: DictConfig): 15 | if config['print_config']: 16 | utils.print_config(config) 17 | 18 | if "seed" in config: 19 | seed_everything(config['seed']) 20 | 21 | # Init PyTorch Lightning model ⚡ 22 | model: LightningModule = hydra.utils.instantiate(config['model']) 23 | 24 | # Init PyTorch Lightning datamodule ⚡ 25 | datamodule: LightningModule = hydra.utils.instantiate(config['datamodule']) 26 | datamodule.setup() 27 | 28 | # Init PyTorch Lightning callbacks ⚡ 29 | callbacks: List[Callback] = [] 30 | if "callbacks" in config: 31 | for _, cb_conf in config['callbacks'].items(): 32 | if "_target_" in cb_conf: 33 | callbacks.append(hydra.utils.instantiate(cb_conf)) 34 | 35 | # Init PyTorch Lightning loggers ⚡ 36 | logger: List[LightningLoggerBase] = [] 37 | if "logger" in config: 38 | for _, lg_conf in config['logger'].items(): 39 | if "_target_" in lg_conf: 40 | logger.append(hydra.utils.instantiate(lg_conf)) 41 | 42 | # Init PyTorch Lightning trainer ⚡ 43 | trainer: Trainer = hydra.utils.instantiate( 44 | config['trainer'], callbacks=callbacks, logger=logger 45 | ) 46 | 47 | # Send some parameters from config to all lightning loggers 48 | utils.log_hparams_to_all_loggers( 49 | config=config, 50 | model=model, 51 | datamodule=datamodule, 52 | trainer=trainer, 53 | callbacks=callbacks, 54 | logger=logger 55 | ) 56 | 57 | # Train the model 58 | trainer.fit(model=model, datamodule=datamodule) 59 | 60 | # Evaluate model on test set after training 61 | # trainer.test() 62 | 63 | # Make sure everything closed properly 64 | utils.finish( 65 | config=config, 66 | model=model, 67 | datamodule=datamodule, 68 | trainer=trainer, 69 | callbacks=callbacks, 70 | logger=logger 71 | ) 72 | 73 | # Return best achieved metric score for optuna 74 | optimized_metric = config.get("optimized_metric", None) 75 | if optimized_metric: 76 | return trainer.callback_metrics[optimized_metric] 77 | 78 | 79 | @hydra.main(config_path="configs/", config_name="config.yaml") 80 | def main(config: DictConfig): 81 | return train(config) 82 | 83 | 84 | if __name__ == "__main__": 85 | main() -------------------------------------------------------------------------------- /video2img.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | from src.utils.data_utils import video2img 5 | 6 | 7 | def main(): 8 | parser = ArgumentParser() 9 | parser.add_argument("--input", required=True, help="The video file or directory to be parsed") 10 | parser.add_argument("--downsample", default=1, type=int) 11 | args = parser.parse_args() 12 | 13 | input = args.input 14 | 15 | if osp.isdir(input): # in case of directory which contains video file 16 | video_file = osp.join(input, 'Frames.m4v') 17 | else: # in case of video file 18 | video_file = input 19 | assert osp.exists(video_file), "Please input an valid video file!" 20 | 21 | data_dir = osp.dirname(video_file) 22 | out_dir = osp.join(data_dir, 'color_full') 23 | Path(out_dir).mkdir(exist_ok=True, parents=True) 24 | 25 | video2img(video_file, out_dir, args.downsample) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() --------------------------------------------------------------------------------