├── DEMO.md
├── LICENSE
├── README.md
├── assets
└── onepose-github-teaser.gif
├── configs
├── config.yaml
├── experiment
│ ├── object_detector.yaml
│ ├── test_GATsSPG.yaml
│ ├── test_demo.yaml
│ ├── test_sample.yaml
│ └── train_GATsSPG.yaml
└── preprocess
│ ├── merge_anno.yaml
│ ├── sfm_spp_spg_demo.yaml
│ ├── sfm_spp_spg_sample.yaml
│ ├── sfm_spp_spg_test.yaml
│ ├── sfm_spp_spg_train.yaml
│ └── sfm_spp_spg_val.yaml
├── environment.yaml
├── feature_matching_object_detector.py
├── inference.py
├── inference_demo.py
├── parse_scanned_data.py
├── requirements.txt
├── run.py
├── scripts
├── demo_pipeline.sh
├── parse_full_img.sh
└── prepare_2D_matching_resources.sh
├── src
├── callbacks
│ ├── custom_callbacks.py
│ └── wandb_callbacks.py
├── datamodules
│ └── GATs_spg_datamodule.py
├── datasets
│ ├── GATs_spg_dataset.py
│ └── normalized_dataset.py
├── evaluators
│ └── cmd_evaluator.py
├── local_feature_2D_detector
│ ├── __init__.py
│ └── local_feature_2D_detector.py
├── losses
│ └── focal_loss.py
├── models
│ ├── GATsSPG_architectures
│ │ ├── GATs.py
│ │ └── GATs_SuperGlue.py
│ ├── GATsSPG_lightning_model.py
│ ├── extractors
│ │ └── SuperPoint
│ │ │ └── superpoint.py
│ └── matchers
│ │ ├── SuperGlue
│ │ └── superglue.py
│ │ └── nn
│ │ └── nearest_neighbour.py
├── sfm
│ ├── extract_features.py
│ ├── generate_empty.py
│ ├── global_ba.py
│ ├── match_features.py
│ ├── pairs_from_poses.py
│ ├── postprocess
│ │ ├── feature_process.py
│ │ ├── filter_points.py
│ │ └── filter_tkl.py
│ └── triangulation.py
├── tracker
│ ├── __init__.py
│ ├── ba_tracker.py
│ └── tracking_utils.py
└── utils
│ ├── colmap
│ ├── database.py
│ └── read_write_model.py
│ ├── comm.py
│ ├── data_utils.py
│ ├── eval_utils.py
│ ├── model_io.py
│ ├── path_utils.py
│ ├── template_utils.py
│ └── vis_utils.py
├── train.py
└── video2img.py
/DEMO.md:
--------------------------------------------------------------------------------
1 | # OnePose Demo on Custom Data (WIP)
2 | In this tutorial we introduce the demo of OnePose running with data captured
3 | with our **OnePose Cap** application available for iOS device.
4 | The app is still under preparing for release.
5 | However, you can try it with the [sample data]() and skip the first step.
6 |
7 | ### Step 1: Capture the mapping sequence and the test sequence with OnePose Cap.
8 | #### The app is under brewing🍺 coming soon.
9 |
10 | ### Step 2: Organize the file structure of collected sequences
11 | 1. Export the collected mapping sequence and the test sequence to the PC.
12 | 2. Rename the **annotate** and **test** sequences directories to ``your_obj_name-annotate`` and `your_obj_name-test` respectively and organize the data as the follow structure:
13 | ```
14 | |--- /your/path/to/scanned_data
15 | | |--- your_obj_name
16 | | | |---your_obj_name-annotate
17 | | | |---your_obj_name-test
18 | ```
19 | Refer to the [sample data]() as an example.
20 | 3. Link the collected data to the project directory
21 | ```shell
22 | REPO_ROOT=/path/to/OnePose
23 | ln -s /path/to/scanned_data $REPO_ROOT/data/demo
24 | ```
25 |
26 | Now the data is prepared!
27 |
28 | ### Step 3: Run OnePose with collected data
29 | Download the [pretrained OnePose model](https://drive.google.com/drive/folders/1VjLLjJ9oxjKV5Xy3Aty0uQUVwyEhgtIE?usp=sharing) and move it to `${REPO_ROOT}/data/model/checkpoints/onepose/GATsSPG.ckpt`.
30 |
31 | [Optional] To run OnePose with tracking modeule, pelase install [DeepLM](https://github.com/hjwdzh/DeepLM.git).
32 | Please make sure the sample program in `DeepLM` can be correctly executed to ensure successful installation.
33 |
34 |
35 | Execute the following commands, and a demo video naming `demo_video.mp4` will be saved in the folder of the test sequence.
36 | ```shell
37 | REPO_ROOT=/path/to/OnePose
38 | OBJ_NAME=your_obj_name
39 |
40 | cd $REPO_ROOT
41 | conda activate OnePose
42 |
43 | bash scripts/demo_pipeline.sh $OBJ_NAME
44 |
45 | # [Optional] running OnePose with tracking
46 | export PYTHONPATH=$PYTHONPATH:/path/to/DeepLM/build
47 | export TORCH_USE_RTLD_GLOBAL=YES
48 |
49 | bash scripts/demo_pipeline.sh $OBJ_NAME --WITH_TRACKING
50 |
51 | ```
52 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OnePose: One-Shot Object Pose Estimation without CAD Models
2 | ### [Project Page](https://zju3dv.github.io/onepose) | [Paper](https://arxiv.org/pdf/2205.12257.pdf)
3 |
4 |
5 | > OnePose: One-Shot Object Pose Estimation without CAD Models
6 | > [Jiaming Sun](https://jiamingsun.ml)\*, [Zihao Wang](http://zihaowang.xyz/)\*, [Siyu Zhang](https://derizsy.github.io/)\*, [Xingyi He](https://github.com/hxy-123/), [Hongcheng Zhao](https://github.com/HongchengZhao), [Guofeng Zhang](http://www.cad.zju.edu.cn/home/gfzhang/), [Xiaowei Zhou](https://xzhou.me)
7 | > CVPR 2022
8 |
9 | 
10 |
11 | ## TODO List
12 | - [x] Training and inference code.
13 | - [x] Pipeline to reproduce the evaluation results on the proposed OnePose dataset.
14 | - [x] `OnePose Cap` app is available at the [App Store](https://apps.apple.com/cn/app/onepose-capture/id6447052065?l=en-GB) (iOS only) to capture your own training and test data.
15 | - [x] Demo pipeline for running OnePose with custom-captured data.
16 |
17 | ## Installation
18 |
19 | ```shell
20 | conda env create -f environment.yaml
21 | conda activate onepose
22 | ```
23 | We use [SuperPoint](https://github.com/magicleap/SuperPointPretrainedNetwork) and [SuperGlue](https://github.com/magicleap/SuperPointPretrainedNetwork)
24 | for 2D feature detection and matching in this project.
25 | We can't provide the code directly due its LICENSE requirements, please download the inference code and pretrained models using the following script:
26 | ```shell
27 | REPO_ROOT=/path/to/OnePose
28 | cd $REPO_ROOT
29 | sh ./scripts/prepare_2D_matching_resources.sh
30 | ```
31 |
32 | [COLMAP](https://colmap.github.io/) is used in this project for Structure-from-Motion.
33 | Please refer to the official [instructions](https://colmap.github.io/install.html) for the installation.
34 |
35 | [Optional, WIP] You may optionally try out our web-based 3D visualization tool [Wis3D](https://github.com/zju3dv/Wis3D) for convenient and interactive visualizations of feature matches. We also provide many other cool visualization features in Wis3D, welcome to try it out.
36 |
37 | ```bash
38 | # Working in progress, should be ready very soon, only available on test-pypi now.
39 | pip install -i https://test.pypi.org/simple/ wis3d
40 | ```
41 |
42 | ## Training and Evaluation on OnePose dataset
43 | ### Dataset setup
44 | 1. Download OnePose dataset from [onedrive storage](https://zjueducn-my.sharepoint.com/:f:/g/personal/zihaowang_zju_edu_cn/ElfzHE0sTXxNndx6uDLWlbYB-2zWuLfjNr56WxF11_DwSg?e=GKI0Df) and extract them into `$/your/path/to/onepose_datasets`.
45 | The directory should be organized in the following structure:
46 | ```
47 | |--- /your/path/to/onepose_datasets
48 | | |--- train_data
49 | | |--- val_data
50 | | |--- test_data
51 | | |--- sample_data
52 | ```
53 |
54 | 2. Build the dataset symlinks
55 | ```shell
56 | REPO_ROOT=/path/to/OnePose
57 | ln -s /your/path/to/onepose_datasets $REPO_ROOT/data/onepose_datasets
58 | ```
59 |
60 | 3. Run Structure-from-Motion for the data sequences
61 |
62 | Reconstructed the object point cloud and 2D-3D correspondences are needed for both training and test objects:
63 | ```python
64 | python run.py +preprocess=sfm_spp_spg_train.yaml # for training data
65 | python run.py +preprocess=sfm_spp_spg_test.yaml # for testing data
66 | python run.py +preprocess=sfm_spp_spg_val.yaml # for val data
67 | python run.py +preprocess=sfm_spp_spg_sample.yaml # an example, if you don't want to test the full dataset
68 | ```
69 |
70 | ### Inference on OnePose dataset
71 | 1. Download the pretrain weights [pretrained model](https://drive.google.com/drive/folders/1VjLLjJ9oxjKV5Xy3Aty0uQUVwyEhgtIE?usp=sharing) and move it to `${REPO_ROOT}/data/model/checkpoints/onepose/GATsSPG.ckpt`.
72 |
73 | 2. Inference with category-agnostic 2D object detection.
74 |
75 | When deploying OnePose to a real world system,
76 | an off-the-shelf category-level 2D object detector like [YOLOv5](https://github.com/ultralytics/yolov5) can be used.
77 | However, this could defeat the category-agnostic nature of OnePose.
78 | We can instead use a feature-matching-based pipeline for 2D object detection, which locates the scanned object on the query image through 2D feature matching.
79 | Note that the 2D object detection is only necessary during the initialization.
80 | After the initialization, the 2D bounding box can be obtained from projecting the previously detected 3D bounding box to the current camera frame.
81 | Please refer to the [supplementary material](https://zju3dv.github.io/onepose/files/onepose_supp.pdf) for more details.
82 |
83 | ```python
84 | # Obtaining category-agnostic 2D object detection results first.
85 | # Increasing the `n_ref_view` will improve the detection robustness but with the cost of slowing down the initialization speed.
86 | python feature_matching_object_detector.py +experiment=object_detector.yaml n_ref_view=15
87 |
88 | # Running pose estimation with `object_detect_mode` set to `feature_matching`.
89 | # Note that enabling visualization will slow down the inference.
90 | python inference.py +experiment=test_GATsSPG.yaml object_detect_mode=feature_matching save_wis3d=False
91 | ```
92 |
93 | 3. Running inference with ground-truth 2D bounding boxes
94 |
95 | The following command should reproduce results in the paper, which use 2D boxes projected from 3D boxes as object detection results.
96 |
97 | ```python
98 | # Note that enabling visualization will slow down the inference.
99 | python inference.py +experiment=test_GATsSPG.yaml object_detect_mode=GT_box save_wis3d=False # for testing data
100 | ```
101 |
102 |
103 |
104 | 4. [Optional] Visualize matching and estimated poses with Wis3D. Make sure the flag `save_wis3d` is set as True in testing
105 | and the full images are extracted from `Frames.m4v` by script `scripts/parse_full_img.sh`.
106 | The visualization file will be saved under `cfg.output.vis_dir` directory which is set as `GATsSPG` by default.
107 | Run the following commands for visualization:
108 | ```shell
109 | sh ./scripts/parse_full_img.sh path_to_Frames_m4v # parse full image from m4v file
110 |
111 | cd runs/vis/GATsSPG
112 | wis3d --vis_dir ./ --host localhost --port 11020
113 | ```
114 | This would launch a web service for visualization at port 11020.
115 |
116 |
117 | ### Training the GATs Network
118 | 1. Prepare ground-truth annotations. Merge annotations of training/val data:
119 | ```python
120 | python run.py +preprocess=merge_anno task_name=onepose split=train
121 | python run.py +preprocess=merge_anno task_name=onepose split=val
122 | ```
123 |
124 | 2. Begin training
125 | ```python
126 | python train.py +experiment=train_GATsSPG task_name=onepose exp_name=training_onepose
127 | ```
128 |
129 | All model weights will be saved under `${REPO_ROOT}/data/models/checkpoints/${exp_name}` and logs will be saved under `${REPO_ROOT}/data/logs/${exp_name}`.
130 |
134 |
135 | ## Citation
136 | If you find this code useful for your research, please use the following BibTeX entry.
137 |
138 | ```bibtex
139 | @article{sun2022onepose,
140 | title={{OnePose}: One-Shot Object Pose Estimation without {CAD} Models},
141 | author = {Sun, Jiaming and Wang, Zihao and Zhang, Siyu and He, Xingyi and Zhao, Hongcheng and Zhang, Guofeng and Zhou, Xiaowei},
142 | journal={CVPR},
143 | year={2022},
144 | }
145 | ```
146 |
147 | ## Copyright
148 |
149 | This work is affiliated with ZJU-SenseTime Joint Lab of 3D Vision, and its intellectual property belongs to SenseTime Group Ltd.
150 |
151 | ```
152 | Copyright SenseTime. All Rights Reserved.
153 |
154 | Licensed under the Apache License, Version 2.0 (the "License");
155 | you may not use this file except in compliance with the License.
156 | You may obtain a copy of the License at
157 |
158 | http://www.apache.org/licenses/LICENSE-2.0
159 |
160 | Unless required by applicable law or agreed to in writing, software
161 | distributed under the License is distributed on an "AS IS" BASIS,
162 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
163 | See the License for the specific language governing permissions and
164 | limitations under the License.
165 | ```
166 |
167 | ## Acknowledgement
168 | Part of our code is borrowed from [hloc](https://github.com/cvg/Hierarchical-Localization) and [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork), thanks to their authors for the great works.
169 |
--------------------------------------------------------------------------------
/assets/onepose-github-teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zju3dv/OnePose/d0c313a5c36994658e63eac7ca3a63d7e3573d7b/assets/onepose-github-teaser.gif
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - trainer: null
6 | - model: null
7 | - datamodule: null
8 | - callbacks: null # set this to null if you don't want to use callbacks
9 | - logger: null # set logger here or use command line (e.g. `python train.py logger=wandb`)
10 |
11 | # enable color logging
12 | # - override hydra/hydra_logging: colorlog
13 | # - override hydra/job_logging: colorlog
14 |
15 |
16 | # path to original working directory (that `train.py` was executed from in command line)
17 | # hydra hijacks working directory by changing it to the current log directory,
18 | # so it's useful to have path to original working directory as a special variable
19 | # read more here: https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
20 | work_dir: ${hydra:runtime.cwd}
21 |
22 |
23 | # path to folder with data
24 | data_dir: ${work_dir}/data
25 |
26 |
27 | # pretty print config at the start of the run using Rich library
28 | print_config: True
29 |
30 |
31 | # output paths for hydra logs
32 | # hydra:
33 | # run:
34 | # dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
35 | # sweep:
36 | # dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
37 | # subdir: ${hydra.job.num}
38 |
39 | hydra:
40 | run:
41 | dir: ${work_dir}
--------------------------------------------------------------------------------
/configs/experiment/object_detector.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: inference
4 | task_name: local_feature_object_detector
5 | suffix: ''
6 |
7 | model:
8 | extractor_model_path: ${data_dir}/models/extractors/SuperPoint/superpoint_v1.pth
9 | match_model_path: ${data_dir}/models/matchers/SuperGlue/superglue_outdoor.pth
10 |
11 | network:
12 | detection: superpoint
13 | matching: superglue
14 |
15 | n_ref_view: 15
16 | scan_data_dir: ${data_dir}/onepose_datasets/test_data
17 | sfm_model_dir: ${data_dir}/sfm_model
18 |
19 | input:
20 | data_dirs:
21 | - ${scan_data_dir}/0408-colorbox-box colorbox-4
22 | - ${scan_data_dir}/0409-aptamil-box aptamil-3
23 | - ${scan_data_dir}/0419-cookies2-others cookies2-4
24 | - ${scan_data_dir}/0422-qvduoduo-box qvduoduo-4
25 | - ${scan_data_dir}/0423-oreo-box oreo-4
26 | - ${scan_data_dir}/0424-chocbox-box chocbox-4
27 | - ${scan_data_dir}/0447-nabati-box nabati-5
28 | - ${scan_data_dir}/0450-hlychocpie-box hlychocpie-4
29 | - ${scan_data_dir}/0452-hlymatchapie-box hlymatchapie-4
30 | - ${scan_data_dir}/0455-strawberryoreo-box strawberryoreo-4
31 | - ${scan_data_dir}/0456-chocoreo-box chocoreo-4
32 | - ${scan_data_dir}/0458-hetaocakes-box hetaocakes-4
33 | - ${scan_data_dir}/0459-jzhg-box jzhg-4
34 | - ${scan_data_dir}/0466-mfmilkcake-box mfmilkcake-4
35 | - ${scan_data_dir}/0468-minipuff-box minipuff-4
36 | - ${scan_data_dir}/0469-diycookies-box diycookies-4
37 | - ${scan_data_dir}/0470-eggrolls-box eggrolls-4
38 | - ${scan_data_dir}/0471-hlyormosiapie-box hlyormosiapie-4
39 | - ${scan_data_dir}/0472-chocoreo-bottle chocoreo-4
40 | - ${scan_data_dir}/0473-twgrassjelly1-box twgrassjelly1-4
41 | - ${scan_data_dir}/0474-twgrassjelly2-box twgrassjelly2-4
42 | - ${scan_data_dir}/0476-giraffecup-bottle giraffecup-4
43 | - ${scan_data_dir}/0480-ljcleaner-others ljcleaner-4
44 | - ${scan_data_dir}/0483-ambrosial-box ambrosial-4
45 | - ${scan_data_dir}/0486-sanqitoothpaste-box sanqitoothpaste-4
46 | - ${scan_data_dir}/0487-jindiantoothpaste-box jindiantoothpaste-4
47 | - ${scan_data_dir}/0488-jijiantoothpaste-box jijiantoothpaste-4
48 | - ${scan_data_dir}/0489-taipingcookies-others taipingcookies-4
49 | - ${scan_data_dir}/0490-haochidiancookies-others haochidiancookies-4
50 | - ${scan_data_dir}/0492-tuccookies-box tuccookies-4
51 | - ${scan_data_dir}/0493-haochidianeggroll-box haochidianeggroll-4
52 | - ${scan_data_dir}/0494-qvduoduocookies-box qvduoduocookies-4
53 | - ${scan_data_dir}/0495-fulingstapler-box fulingstapler-4
54 | - ${scan_data_dir}/0496-delistapler-box delistapler-4
55 | - ${scan_data_dir}/0497-delistaplerlarger-box delistaplerlarger-4
56 | - ${scan_data_dir}/0498-yousuanru-box yousuanru-4
57 | - ${scan_data_dir}/0500-chocfranzzi-box chocfranzzi-4
58 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-4
59 | - ${scan_data_dir}/0502-shufujia-box shufujia-5
60 | - ${scan_data_dir}/0503-shufujiawhite-box shufujiawhite-3
61 | - ${scan_data_dir}/0504-lux-box lux-4
62 | - ${scan_data_dir}/0508-yqsl-others yqsl-4
63 | - ${scan_data_dir}/0510-yqslmilk-others yqslmilk-4
64 | - ${scan_data_dir}/0511-policecar-others policecar-4
65 | - ${scan_data_dir}/0517-nationalgeo-box nationalgeo-4
66 | - ${scan_data_dir}/0518-jasmine-box jasmine-4
67 | - ${scan_data_dir}/0519-backpack1-box backpack1-4
68 | - ${scan_data_dir}/0520-lipault-box lipault-4
69 | - ${scan_data_dir}/0521-ranova-box ranova-4
70 | - ${scan_data_dir}/0522-milkbox-box milkbox-4
71 | - ${scan_data_dir}/0523-edibleoil-others edibleoil-4
72 | - ${scan_data_dir}/0525-toygrab-others toygrab-2
73 | - ${scan_data_dir}/0526-toytable-others toytable-3
74 | - ${scan_data_dir}/0527-spalding-others spalding-2
75 | - ${scan_data_dir}/0534-tonkotsuramen-box tonkotsuramen-4
76 | - ${scan_data_dir}/0535-odbmilk-box odbmilk-4
77 | - ${scan_data_dir}/0537-petsnack-box petsnack-4
78 | - ${scan_data_dir}/0539-spamwrapper-others spamwrapper-5
79 | - ${scan_data_dir}/0543-brownhouse-others brownhouse-4
80 | - ${scan_data_dir}/0547-cubebox-box cubebox-4
81 | - ${scan_data_dir}/0548-duck-others duck-4
82 | - ${scan_data_dir}/0550-greenbox-box greenbox-4
83 | - ${scan_data_dir}/0551-milk-others milk-4
84 | - ${scan_data_dir}/0552-mushroom-others mushroom-4
85 | - ${scan_data_dir}/0557-santachoc-others santachoc-4
86 | - ${scan_data_dir}/0558-teddychoc-others teddychoc-4
87 | - ${scan_data_dir}/0559-tissuebox-box tissuebox-4
88 | - ${scan_data_dir}/0560-tofubox-box tofubox-4
89 | - ${scan_data_dir}/0564-biatee-others biatee-4
90 | - ${scan_data_dir}/0565-biscuits-box biscuits-4
91 | - ${scan_data_dir}/0568-cornflakes-box cornflakes-5
92 | - ${scan_data_dir}/0570-kasekuchen-box kasekuchen-4
93 | - ${scan_data_dir}/0577-schoko-box schoko-4
94 | - ${scan_data_dir}/0578-tee-others tee-4
95 | - ${scan_data_dir}/0579-tomatocan-bottle tomatocan-4
96 | - ${scan_data_dir}/0580-xmaxbox-others xmaxbox-4
97 | - ${scan_data_dir}/0582-yogurtlarge-others yogurtlarge-4
98 | - ${scan_data_dir}/0583-yogurtmedium-others yogurtmedium-4
99 | - ${scan_data_dir}/0594-martinBootsLeft-others martinBootsLeft-2
100 | - ${scan_data_dir}/0595-martinBootsRight-others martinBootsRight-4
101 |
102 | sfm_model_dirs:
103 | - ${sfm_model_dir}/0408-colorbox-box
104 | - ${sfm_model_dir}/0409-aptamil-box
105 | - ${sfm_model_dir}/0419-cookies2-others
106 | - ${sfm_model_dir}/0422-qvduoduo-box
107 | - ${sfm_model_dir}/0423-oreo-box
108 | - ${sfm_model_dir}/0424-chocbox-box
109 | - ${sfm_model_dir}/0447-nabati-box
110 | - ${sfm_model_dir}/0450-hlychocpie-box
111 | - ${sfm_model_dir}/0452-hlymatchapie-box
112 | - ${sfm_model_dir}/0455-strawberryoreo-box
113 | - ${sfm_model_dir}/0456-chocoreo-box
114 | - ${sfm_model_dir}/0458-hetaocakes-box
115 | - ${sfm_model_dir}/0459-jzhg-box
116 | - ${sfm_model_dir}/0466-mfmilkcake-box
117 | - ${sfm_model_dir}/0468-minipuff-box
118 | - ${sfm_model_dir}/0469-diycookies-box
119 | - ${sfm_model_dir}/0470-eggrolls-box
120 | - ${sfm_model_dir}/0471-hlyormosiapie-box
121 | - ${sfm_model_dir}/0472-chocoreo-bottle
122 | - ${sfm_model_dir}/0473-twgrassjelly1-box
123 | - ${sfm_model_dir}/0474-twgrassjelly2-box
124 | - ${sfm_model_dir}/0476-giraffecup-bottle
125 | - ${sfm_model_dir}/0480-ljcleaner-others
126 | - ${sfm_model_dir}/0483-ambrosial-box
127 | - ${sfm_model_dir}/0486-sanqitoothpaste-box
128 | - ${sfm_model_dir}/0487-jindiantoothpaste-box
129 | - ${sfm_model_dir}/0488-jijiantoothpaste-box
130 | - ${sfm_model_dir}/0489-taipingcookies-others
131 | - ${sfm_model_dir}/0490-haochidiancookies-others
132 | - ${sfm_model_dir}/0492-tuccookies-box
133 | - ${sfm_model_dir}/0493-haochidianeggroll-box
134 | - ${sfm_model_dir}/0494-qvduoduocookies-box
135 | - ${sfm_model_dir}/0495-fulingstapler-box
136 | - ${sfm_model_dir}/0496-delistapler-box
137 | - ${sfm_model_dir}/0497-delistaplerlarger-box
138 | - ${sfm_model_dir}/0498-yousuanru-box
139 | - ${sfm_model_dir}/0500-chocfranzzi-box
140 | - ${sfm_model_dir}/0501-matchafranzzi-box
141 | - ${sfm_model_dir}/0502-shufujia-box
142 | - ${sfm_model_dir}/0503-shufujiawhite-box
143 | - ${sfm_model_dir}/0504-lux-box
144 | - ${sfm_model_dir}/0508-yqsl-others
145 | - ${sfm_model_dir}/0510-yqslmilk-others
146 | - ${sfm_model_dir}/0511-policecar-others
147 | - ${sfm_model_dir}/0517-nationalgeo-box
148 | - ${sfm_model_dir}/0518-jasmine-box
149 | - ${sfm_model_dir}/0519-backpack1-box
150 | - ${sfm_model_dir}/0520-lipault-box
151 | - ${sfm_model_dir}/0521-ranova-box
152 | - ${sfm_model_dir}/0522-milkbox-box
153 | - ${sfm_model_dir}/0523-edibleoil-others
154 | - ${sfm_model_dir}/0525-toygrab-others
155 | - ${sfm_model_dir}/0526-toytable-others
156 | - ${sfm_model_dir}/0527-spalding-others
157 | - ${sfm_model_dir}/0534-tonkotsuramen-box
158 | - ${sfm_model_dir}/0535-odbmilk-box
159 | - ${sfm_model_dir}/0537-petsnack-box
160 | - ${sfm_model_dir}/0539-spamwrapper-others
161 | - ${sfm_model_dir}/0543-brownhouse-others
162 | - ${sfm_model_dir}/0547-cubebox-box
163 | - ${sfm_model_dir}/0548-duck-others
164 | - ${sfm_model_dir}/0550-greenbox-box
165 | - ${sfm_model_dir}/0551-milk-others
166 | - ${sfm_model_dir}/0552-mushroom-others
167 | - ${sfm_model_dir}/0557-santachoc-others
168 | - ${sfm_model_dir}/0558-teddychoc-others
169 | - ${sfm_model_dir}/0559-tissuebox-box
170 | - ${sfm_model_dir}/0560-tofubox-box
171 | - ${sfm_model_dir}/0564-biatee-others
172 | - ${sfm_model_dir}/0565-biscuits-box
173 | - ${sfm_model_dir}/0568-cornflakes-box
174 | - ${sfm_model_dir}/0570-kasekuchen-box
175 | - ${sfm_model_dir}/0577-schoko-box
176 | - ${sfm_model_dir}/0578-tee-others
177 | - ${sfm_model_dir}/0579-tomatocan-bottle
178 | - ${sfm_model_dir}/0580-xmaxbox-others
179 | - ${sfm_model_dir}/0582-yogurtlarge-others
180 | - ${sfm_model_dir}/0583-yogurtmedium-others
181 | - ${sfm_model_dir}/0594-martinBootsLeft-others
182 | - ${sfm_model_dir}/0595-martinBootsRight-others
--------------------------------------------------------------------------------
/configs/experiment/test_GATsSPG.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: inference
4 | task_name: test_onepose
5 | num_leaf: 8
6 | suffix: ''
7 | save_demo: False
8 | save_wis3d: False
9 | demo_root: ${data_dir}/runs/demo
10 |
11 | model:
12 | onepose_model_path: ${data_dir}/models/checkpoints/onepose/GATsSPG.ckpt
13 | extractor_model_path: ${data_dir}/models/extractors/SuperPoint/superpoint_v1.pth
14 |
15 | network:
16 | detection: superpoint
17 | matching: superglue
18 |
19 | # object_detect_mode: 'GT_box' # ["GT_box", "feature_matching"]
20 | object_detect_mode: 'GT_box' # ["GT_box", "feature_matching"]
21 | max_num_kp3d: 2500
22 | scan_data_dir: ${data_dir}/onepose_datasets/test_data
23 | sfm_model_dir: ${data_dir}/sfm_model
24 |
25 | input:
26 | data_dirs:
27 | - ${scan_data_dir}/0408-colorbox-box colorbox-4
28 | - ${scan_data_dir}/0409-aptamil-box aptamil-3
29 | - ${scan_data_dir}/0419-cookies2-others cookies2-4
30 | - ${scan_data_dir}/0422-qvduoduo-box qvduoduo-4
31 | - ${scan_data_dir}/0423-oreo-box oreo-4
32 | - ${scan_data_dir}/0424-chocbox-box chocbox-4
33 | - ${scan_data_dir}/0447-nabati-box nabati-5
34 | - ${scan_data_dir}/0450-hlychocpie-box hlychocpie-4
35 | - ${scan_data_dir}/0452-hlymatchapie-box hlymatchapie-4
36 | - ${scan_data_dir}/0455-strawberryoreo-box strawberryoreo-4
37 | - ${scan_data_dir}/0456-chocoreo-box chocoreo-4
38 | - ${scan_data_dir}/0458-hetaocakes-box hetaocakes-4
39 | - ${scan_data_dir}/0459-jzhg-box jzhg-4
40 | - ${scan_data_dir}/0466-mfmilkcake-box mfmilkcake-4
41 | - ${scan_data_dir}/0468-minipuff-box minipuff-4
42 | - ${scan_data_dir}/0469-diycookies-box diycookies-4
43 | - ${scan_data_dir}/0470-eggrolls-box eggrolls-4
44 | - ${scan_data_dir}/0471-hlyormosiapie-box hlyormosiapie-4
45 | - ${scan_data_dir}/0472-chocoreo-bottle chocoreo-4
46 | - ${scan_data_dir}/0473-twgrassjelly1-box twgrassjelly1-4
47 | - ${scan_data_dir}/0474-twgrassjelly2-box twgrassjelly2-4
48 | - ${scan_data_dir}/0476-giraffecup-bottle giraffecup-4
49 | - ${scan_data_dir}/0480-ljcleaner-others ljcleaner-4
50 | - ${scan_data_dir}/0483-ambrosial-box ambrosial-4
51 | - ${scan_data_dir}/0486-sanqitoothpaste-box sanqitoothpaste-4
52 | - ${scan_data_dir}/0487-jindiantoothpaste-box jindiantoothpaste-4
53 | - ${scan_data_dir}/0488-jijiantoothpaste-box jijiantoothpaste-4
54 | - ${scan_data_dir}/0489-taipingcookies-others taipingcookies-4
55 | - ${scan_data_dir}/0490-haochidiancookies-others haochidiancookies-4
56 | - ${scan_data_dir}/0492-tuccookies-box tuccookies-4
57 | - ${scan_data_dir}/0493-haochidianeggroll-box haochidianeggroll-4
58 | - ${scan_data_dir}/0494-qvduoduocookies-box qvduoduocookies-4
59 | - ${scan_data_dir}/0495-fulingstapler-box fulingstapler-4
60 | - ${scan_data_dir}/0496-delistapler-box delistapler-4
61 | - ${scan_data_dir}/0497-delistaplerlarger-box delistaplerlarger-4
62 | - ${scan_data_dir}/0498-yousuanru-box yousuanru-4
63 | - ${scan_data_dir}/0500-chocfranzzi-box chocfranzzi-4
64 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-4
65 | - ${scan_data_dir}/0502-shufujia-box shufujia-5
66 | - ${scan_data_dir}/0503-shufujiawhite-box shufujiawhite-3
67 | - ${scan_data_dir}/0504-lux-box lux-4
68 | - ${scan_data_dir}/0508-yqsl-others yqsl-4
69 | - ${scan_data_dir}/0510-yqslmilk-others yqslmilk-4
70 | - ${scan_data_dir}/0511-policecar-others policecar-4
71 | - ${scan_data_dir}/0517-nationalgeo-box nationalgeo-4
72 | - ${scan_data_dir}/0518-jasmine-box jasmine-4
73 | - ${scan_data_dir}/0519-backpack1-box backpack1-4
74 | - ${scan_data_dir}/0520-lipault-box lipault-4
75 | - ${scan_data_dir}/0521-ranova-box ranova-4
76 | - ${scan_data_dir}/0522-milkbox-box milkbox-4
77 | - ${scan_data_dir}/0523-edibleoil-others edibleoil-4
78 | - ${scan_data_dir}/0525-toygrab-others toygrab-2
79 | - ${scan_data_dir}/0526-toytable-others toytable-3
80 | - ${scan_data_dir}/0527-spalding-others spalding-2
81 | - ${scan_data_dir}/0534-tonkotsuramen-box tonkotsuramen-4
82 | - ${scan_data_dir}/0535-odbmilk-box odbmilk-4
83 | - ${scan_data_dir}/0537-petsnack-box petsnack-4
84 | - ${scan_data_dir}/0539-spamwrapper-others spamwrapper-5
85 | - ${scan_data_dir}/0543-brownhouse-others brownhouse-4
86 | - ${scan_data_dir}/0547-cubebox-box cubebox-4
87 | - ${scan_data_dir}/0548-duck-others duck-4
88 | - ${scan_data_dir}/0550-greenbox-box greenbox-4
89 | - ${scan_data_dir}/0551-milk-others milk-4
90 | - ${scan_data_dir}/0552-mushroom-others mushroom-4
91 | - ${scan_data_dir}/0557-santachoc-others santachoc-4
92 | - ${scan_data_dir}/0558-teddychoc-others teddychoc-4
93 | - ${scan_data_dir}/0559-tissuebox-box tissuebox-4
94 | - ${scan_data_dir}/0560-tofubox-box tofubox-4
95 | - ${scan_data_dir}/0564-biatee-others biatee-4
96 | - ${scan_data_dir}/0565-biscuits-box biscuits-4
97 | - ${scan_data_dir}/0568-cornflakes-box cornflakes-5
98 | - ${scan_data_dir}/0570-kasekuchen-box kasekuchen-4
99 | - ${scan_data_dir}/0577-schoko-box schoko-4
100 | - ${scan_data_dir}/0578-tee-others tee-4
101 | - ${scan_data_dir}/0579-tomatocan-bottle tomatocan-4
102 | - ${scan_data_dir}/0580-xmaxbox-others xmaxbox-4
103 | - ${scan_data_dir}/0582-yogurtlarge-others yogurtlarge-4
104 | - ${scan_data_dir}/0583-yogurtmedium-others yogurtmedium-4
105 | - ${scan_data_dir}/0594-martinBootsLeft-others martinBootsLeft-2
106 | - ${scan_data_dir}/0595-martinBootsRight-others martinBootsRight-4
107 |
108 | sfm_model_dirs:
109 | - ${sfm_model_dir}/0408-colorbox-box
110 | - ${sfm_model_dir}/0409-aptamil-box
111 | - ${sfm_model_dir}/0419-cookies2-others
112 | - ${sfm_model_dir}/0422-qvduoduo-box
113 | - ${sfm_model_dir}/0423-oreo-box
114 | - ${sfm_model_dir}/0424-chocbox-box
115 | - ${sfm_model_dir}/0447-nabati-box
116 | - ${sfm_model_dir}/0450-hlychocpie-box
117 | - ${sfm_model_dir}/0452-hlymatchapie-box
118 | - ${sfm_model_dir}/0455-strawberryoreo-box
119 | - ${sfm_model_dir}/0456-chocoreo-box
120 | - ${sfm_model_dir}/0458-hetaocakes-box
121 | - ${sfm_model_dir}/0459-jzhg-box
122 | - ${sfm_model_dir}/0466-mfmilkcake-box
123 | - ${sfm_model_dir}/0468-minipuff-box
124 | - ${sfm_model_dir}/0469-diycookies-box
125 | - ${sfm_model_dir}/0470-eggrolls-box
126 | - ${sfm_model_dir}/0471-hlyormosiapie-box
127 | - ${sfm_model_dir}/0472-chocoreo-bottle
128 | - ${sfm_model_dir}/0473-twgrassjelly1-box
129 | - ${sfm_model_dir}/0474-twgrassjelly2-box
130 | - ${sfm_model_dir}/0476-giraffecup-bottle
131 | - ${sfm_model_dir}/0480-ljcleaner-others
132 | - ${sfm_model_dir}/0483-ambrosial-box
133 | - ${sfm_model_dir}/0486-sanqitoothpaste-box
134 | - ${sfm_model_dir}/0487-jindiantoothpaste-box
135 | - ${sfm_model_dir}/0488-jijiantoothpaste-box
136 | - ${sfm_model_dir}/0489-taipingcookies-others
137 | - ${sfm_model_dir}/0490-haochidiancookies-others
138 | - ${sfm_model_dir}/0492-tuccookies-box
139 | - ${sfm_model_dir}/0493-haochidianeggroll-box
140 | - ${sfm_model_dir}/0494-qvduoduocookies-box
141 | - ${sfm_model_dir}/0495-fulingstapler-box
142 | - ${sfm_model_dir}/0496-delistapler-box
143 | - ${sfm_model_dir}/0497-delistaplerlarger-box
144 | - ${sfm_model_dir}/0498-yousuanru-box
145 | - ${sfm_model_dir}/0500-chocfranzzi-box
146 | - ${sfm_model_dir}/0501-matchafranzzi-box
147 | - ${sfm_model_dir}/0502-shufujia-box
148 | - ${sfm_model_dir}/0503-shufujiawhite-box
149 | - ${sfm_model_dir}/0504-lux-box
150 | - ${sfm_model_dir}/0508-yqsl-others
151 | - ${sfm_model_dir}/0510-yqslmilk-others
152 | - ${sfm_model_dir}/0511-policecar-others
153 | - ${sfm_model_dir}/0517-nationalgeo-box
154 | - ${sfm_model_dir}/0518-jasmine-box
155 | - ${sfm_model_dir}/0519-backpack1-box
156 | - ${sfm_model_dir}/0520-lipault-box
157 | - ${sfm_model_dir}/0521-ranova-box
158 | - ${sfm_model_dir}/0522-milkbox-box
159 | - ${sfm_model_dir}/0523-edibleoil-others
160 | - ${sfm_model_dir}/0525-toygrab-others
161 | - ${sfm_model_dir}/0526-toytable-others
162 | - ${sfm_model_dir}/0527-spalding-others
163 | - ${sfm_model_dir}/0534-tonkotsuramen-box
164 | - ${sfm_model_dir}/0535-odbmilk-box
165 | - ${sfm_model_dir}/0537-petsnack-box
166 | - ${sfm_model_dir}/0539-spamwrapper-others
167 | - ${sfm_model_dir}/0543-brownhouse-others
168 | - ${sfm_model_dir}/0547-cubebox-box
169 | - ${sfm_model_dir}/0548-duck-others
170 | - ${sfm_model_dir}/0550-greenbox-box
171 | - ${sfm_model_dir}/0551-milk-others
172 | - ${sfm_model_dir}/0552-mushroom-others
173 | - ${sfm_model_dir}/0557-santachoc-others
174 | - ${sfm_model_dir}/0558-teddychoc-others
175 | - ${sfm_model_dir}/0559-tissuebox-box
176 | - ${sfm_model_dir}/0560-tofubox-box
177 | - ${sfm_model_dir}/0564-biatee-others
178 | - ${sfm_model_dir}/0565-biscuits-box
179 | - ${sfm_model_dir}/0568-cornflakes-box
180 | - ${sfm_model_dir}/0570-kasekuchen-box
181 | - ${sfm_model_dir}/0577-schoko-box
182 | - ${sfm_model_dir}/0578-tee-others
183 | - ${sfm_model_dir}/0579-tomatocan-bottle
184 | - ${sfm_model_dir}/0580-xmaxbox-others
185 | - ${sfm_model_dir}/0582-yogurtlarge-others
186 | - ${sfm_model_dir}/0583-yogurtmedium-others
187 | - ${sfm_model_dir}/0594-martinBootsLeft-others
188 | - ${sfm_model_dir}/0595-martinBootsRight-others
189 |
190 |
191 | output:
192 | vis_dir: ${work_dir}/runs/vis/GATsSPG
193 | eval_dir: ${work_dir}/runs/eval/GATsSPG
--------------------------------------------------------------------------------
/configs/experiment/test_demo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: inference
4 | task_name: demo
5 | num_leaf: 8
6 | suffix: ''
7 | save_demo: False
8 | save_wis3d: False
9 | use_tracking: False
10 |
11 | model:
12 | onepose_model_path: ${work_dir}/data/models/checkpoints/onepose/GATsSPG.ckpt
13 | extractor_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
14 | match_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth
15 |
16 | network:
17 | detection: superpoint
18 | matching: superglue
19 |
20 | max_num_kp3d: 2500
21 |
22 | input:
23 | data_dirs: null
24 | sfm_model_dirs: null
25 |
26 | output:
27 | vis_dir: ${work_dir}/runs/vis/demo
28 | eval_dir: ${work_dir}/runs/eval/demo
--------------------------------------------------------------------------------
/configs/experiment/test_sample.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: inference
4 | task_name: test_sample
5 | num_leaf: 8
6 | suffix: ''
7 | save_demo: False
8 | save_wis3d: False
9 | demo_root: ${data_dir}/runs/demo
10 |
11 | model:
12 | onepose_model_path: ${data_dir}/models/checkpoints/onepose/GATsSPG.ckpt
13 | extractor_model_path: ${data_dir}/models/extractors/SuperPoint/superpoint_v1.pth
14 |
15 | network:
16 | detection: superpoint
17 | matching: superglue
18 |
19 | object_detect_mode: 'GT_box' # ["GT_box", "feature_matching"]
20 | max_num_kp3d: 2500
21 | scan_data_dir: ${data_dir}/onepose_datasets/sample_data
22 | sfm_model_dir: ${data_dir}/sfm_model
23 |
24 | input:
25 | data_dirs:
26 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-4
27 |
28 | sfm_model_dirs:
29 | - ${sfm_model_dir}/0501-matchafranzzi-box
30 |
31 | output:
32 | vis_dir: ${work_dir}/runs/vis/test_sample
33 | eval_dir: ${work_dir}/runs/eval/test_sample
--------------------------------------------------------------------------------
/configs/experiment/train_GATsSPG.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # to execute this experiment run:
4 | # python train.py +experiment=train_GATsSPG
5 |
6 | defaults:
7 | - override /trainer: null # override trainer to null so it's not loaded from main config defaults...
8 | - override /model: null
9 | - override /datamodule: null
10 | - override /callbacks: null
11 | - override /logger: null
12 |
13 | # we override default configurations with nulls to prevent them from loading at all
14 | # instead we define all modules and their paths directly in this config,
15 | # so everything is stored in one place for more readibility
16 |
17 | seed: 12345
18 |
19 | task_name: null
20 | exp_name: train_onepose
21 | trainer:
22 | _target_: pytorch_lightning.Trainer
23 | gpus:
24 | - 6
25 | min_epochs: 1
26 | max_epochs: 10
27 | gradient_clip_val: 0.5
28 | accumulate_grad_batches: 2
29 | weights_summary: null
30 | num_sanity_val_steps: 2
31 |
32 |
33 | model:
34 | # _target_: src.models.spg_model.LitModelSPG
35 | _target_: src.models.GATsSPG_lightning_model.LitModelGATsSPG
36 | optimizer: adam
37 | lr: 1e-3
38 | weight_decay: 0.
39 | architecture: SuperGlue
40 |
41 | milestones: [5, 10, 15, 20]
42 | gamma: 0.5
43 |
44 | descriptor_dim: 256
45 | keypoints_encoder: [32, 64, 128]
46 | sinkhorn_iterations: 100
47 | match_threshold: 0.2
48 | match_type: 'softmax'
49 | scale_factor: 0.07
50 |
51 | # focal loss
52 | focal_loss_alpha: 0.5
53 | focal_loss_gamma: 2
54 | pos_weights: 0.5
55 | neg_weights: 0.5
56 |
57 | # GATs
58 | include_self: True
59 | with_linear_transform: False
60 | additional: False
61 |
62 | # SuperPoint
63 | spp_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
64 |
65 | # trainer:
66 | # n_val_pairs_to_plot: 5
67 |
68 | datamodule:
69 | _target_: src.datamodules.GATs_spg_datamodule.GATsSPGDataModule
70 | data_dirs: ${data_dir}/sfm_model
71 | anno_dirs: outputs_${model.match_type}/anno
72 | train_anno_file: ${work_dir}/data/cache/${task_name}/train.json
73 | val_anno_file: ${work_dir}/data/cache/${task_name}/val.json
74 | batch_size: 8
75 | num_workers: 16
76 | num_leaf: 8
77 | pin_memory: True
78 | shape2d: 1000
79 | shape3d: 2000
80 | assign_pad_val: 0
81 |
82 | callbacks:
83 | model_checkpoint:
84 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
85 | monitor: "val/loss"
86 | save_top_k: -1
87 | save_last: True
88 | mode: "min"
89 | dirpath: '${data_dir}/models/checkpoints/${exp_name}'
90 | filename: '{epoch}'
91 | lr_monitor:
92 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
93 | logging_interval: 'step'
94 |
95 | logger:
96 | tensorboard:
97 | _target_: pytorch_lightning.loggers.TensorBoardLogger
98 | save_dir: '${data_dir}/logs'
99 | name: ${exp_name}
100 | default_hp_metric: False
101 |
102 | neptune:
103 | tags: ["best_model"]
104 | csv_logger:
105 | save_dir: "."
106 |
107 | hydra:
108 | run:
109 | dir: ${work_dir}
--------------------------------------------------------------------------------
/configs/preprocess/merge_anno.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: merge_anno
4 | task_name: null
5 | split: 'train'
6 |
7 | train:
8 | names:
9 | - 0410-huiyuan-box
10 | - 0413-juliecookies-box
11 | - 0414-babydiapers-others
12 | - 0415-captaincup-others
13 | - 0416-tongyinoodles-others
14 | - 0418-cookies1-others
15 | - 0420-liuliumei-others
16 | - 0421-cannedfish-others
17 | - 0443-wheatbear-others
18 | - 0445-pistonhead-others
19 | - 0448-soyabeancookies-bottle
20 | - 0460-bhdsoyabeancookies-bottle
21 | - 0461-cranberrycookies-bottle
22 | - 0462-ylmilkpowder-bottle
23 | - 0463-camelmilk-bottle
24 | - 0464-mfchoccake-box
25 | - 0465-mfcreamcake-box
26 | - 0477-cutlet-bottle
27 | - 0479-ggbondcutlet-others
28 | - 0484-bigroll-box
29 | - 0499-tiramisufranzzi-box
30 | - 0506-sauerkrautnoodles-others
31 | - 0507-hotsournoodles-others
32 | - 0509-bscola-others
33 | - 0512-ugreenhub-box
34 | - 0513-busbox-box
35 | - 0516-wewarm-box
36 | - 0529-onionnoodles-box
37 | - 0530-trufflenoodles-box
38 | - 0531-whiskware-box
39 | - 0532-delonghi-box
40 | - 0533-shiramyun-box
41 | - 0536-ranovarect-box
42 | - 0542-bueno-box
43 | - 0545-book-others
44 | - 0546-can-bottle
45 | - 0549-footballcan-bottle
46 | - 0556-pinkbox-box
47 | - 0561-yellowbottle-bottle
48 | - 0562-yellowbox-box
49 | - 0563-applejuice-box
50 | - 0566-chillisauce-box
51 | - 0567-coffeebox-box
52 | - 0569-greentea-bottle
53 | - 0571-cakebox-box
54 | - 0572-milkbox-others
55 | - 0573-redchicken-others
56 | - 0574-rubberduck-others
57 | - 0575-saltbottle-bottle
58 |
59 | val:
60 | names:
61 | - 0601-loquat-box
62 | - 0606-tiger-others
63 | - 0611-pikachubowl-others
64 | - 0616-hmbb-others
65 |
66 | network:
67 | detection: superpoint
68 | matching: superglue
69 |
70 | datamodule:
71 | scan_data_dir: ${work_dir}/data/onepose_datasets
72 | data_dir: ${work_dir}/data/sfm_model
73 | out_path: ${work_dir}/data/cache/${task_name}/${split}.json
74 |
75 |
76 | hydra:
77 | run:
78 | dir: ${work_dir}
--------------------------------------------------------------------------------
/configs/preprocess/sfm_spp_spg_demo.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: sfm
4 | work_dir: ${hydra:runtime.cwd}
5 | redo: True
6 |
7 | dataset:
8 | max_num_kp3d: 2500
9 | max_num_kp2d: 1000
10 |
11 | data_dir: null
12 | outputs_dir: ${work_dir}/data/sfm_model/{}
13 |
14 | network:
15 | detection: superpoint
16 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
17 |
18 | matching: superglue
19 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth
20 |
21 | sfm:
22 | down_ratio: 5
23 | covis_num: 10
24 | rotation_thresh: 50
25 |
26 | disable_lightning_logs: True
27 |
28 |
--------------------------------------------------------------------------------
/configs/preprocess/sfm_spp_spg_sample.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: sfm
4 | work_dir: ${hydra:runtime.cwd}
5 | redo: False
6 |
7 | scan_data_dir: ${work_dir}/data/onepose_datasets/sample_data
8 |
9 | dataset:
10 | max_num_kp3d: 2500
11 | max_num_kp2d: 1000
12 |
13 | data_dir:
14 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-1
15 |
16 | outputs_dir: ${work_dir}/data/sfm_model/{}
17 |
18 | network:
19 | detection: superpoint
20 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
21 |
22 | matching: superglue
23 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth
24 |
25 | sfm:
26 | down_ratio: 5
27 | covis_num: 10
28 | rotation_thresh: 50
29 |
30 | disable_lightning_logs: True
--------------------------------------------------------------------------------
/configs/preprocess/sfm_spp_spg_test.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: sfm
4 | work_dir: ${hydra:runtime.cwd}
5 | redo: False
6 |
7 | scan_data_dir: ${work_dir}/data/onepose_datasets/test_data
8 |
9 | dataset:
10 | max_num_kp3d: 2500
11 | max_num_kp2d: 1000
12 |
13 | data_dir:
14 | - ${scan_data_dir}/0408-colorbox-box colorbox-1
15 | - ${scan_data_dir}/0409-aptamil-box aptamil-1
16 | - ${scan_data_dir}/0419-cookies2-others cookies2-1
17 | - ${scan_data_dir}/0422-qvduoduo-box qvduoduo-1
18 | - ${scan_data_dir}/0423-oreo-box oreo-1
19 | - ${scan_data_dir}/0424-chocbox-box chocbox-1
20 | - ${scan_data_dir}/0447-nabati-box nabati-1
21 | - ${scan_data_dir}/0450-hlychocpie-box hlychocpie-1
22 | - ${scan_data_dir}/0452-hlymatchapie-box hlymatchapie-1
23 | - ${scan_data_dir}/0455-strawberryoreo-box strawberryoreo-1
24 | - ${scan_data_dir}/0456-chocoreo-box chocoreo-1
25 | - ${scan_data_dir}/0458-hetaocakes-box hetaocakes-1
26 | - ${scan_data_dir}/0459-jzhg-box jzhg-1
27 | - ${scan_data_dir}/0466-mfmilkcake-box mfmilkcake-1
28 | - ${scan_data_dir}/0468-minipuff-box minipuff-1
29 | - ${scan_data_dir}/0469-diycookies-box diycookies-1
30 | - ${scan_data_dir}/0470-eggrolls-box eggrolls-1
31 | - ${scan_data_dir}/0471-hlyormosiapie-box hlyormosiapie-1
32 | - ${scan_data_dir}/0472-chocoreo-bottle chocoreo-1
33 | - ${scan_data_dir}/0473-twgrassjelly1-box twgrassjelly1-1
34 | - ${scan_data_dir}/0474-twgrassjelly2-box twgrassjelly2-1
35 | - ${scan_data_dir}/0476-giraffecup-bottle giraffecup-1
36 | - ${scan_data_dir}/0480-ljcleaner-others ljcleaner-1
37 | - ${scan_data_dir}/0483-ambrosial-box ambrosial-1
38 | - ${scan_data_dir}/0486-sanqitoothpaste-box sanqitoothpaste-1
39 | - ${scan_data_dir}/0487-jindiantoothpaste-box jindiantoothpaste-1
40 | - ${scan_data_dir}/0488-jijiantoothpaste-box jijiantoothpaste-1
41 | - ${scan_data_dir}/0489-taipingcookies-others taipingcookies-1
42 | - ${scan_data_dir}/0490-haochidiancookies-others haochidiancookies-1
43 | - ${scan_data_dir}/0492-tuccookies-box tuccookies-1
44 | - ${scan_data_dir}/0493-haochidianeggroll-box haochidianeggroll-1
45 | - ${scan_data_dir}/0494-qvduoduocookies-box qvduoduocookies-1
46 | - ${scan_data_dir}/0495-fulingstapler-box fulingstapler-1
47 | - ${scan_data_dir}/0496-delistapler-box delistapler-1
48 | - ${scan_data_dir}/0497-delistaplerlarger-box delistaplerlarger-1
49 | - ${scan_data_dir}/0498-yousuanru-box yousuanru-1
50 | - ${scan_data_dir}/0500-chocfranzzi-box chocfranzzi-1
51 | - ${scan_data_dir}/0501-matchafranzzi-box matchafranzzi-1
52 | - ${scan_data_dir}/0502-shufujia-box shufujia-1
53 | - ${scan_data_dir}/0503-shufujiawhite-box shufujiawhite-1
54 | - ${scan_data_dir}/0504-lux-box lux-1
55 | - ${scan_data_dir}/0508-yqsl-others yqsl-1
56 | - ${scan_data_dir}/0510-yqslmilk-others yqslmilk-1
57 | - ${scan_data_dir}/0511-policecar-others policecar-1
58 | - ${scan_data_dir}/0517-nationalgeo-box nationalgeo-1
59 | - ${scan_data_dir}/0518-jasmine-box jasmine-1
60 | - ${scan_data_dir}/0519-backpack1-box backpack1-1
61 | - ${scan_data_dir}/0520-lipault-box lipault-1
62 | - ${scan_data_dir}/0521-ranova-box ranova-1
63 | - ${scan_data_dir}/0522-milkbox-box milkbox-1
64 | - ${scan_data_dir}/0523-edibleoil-others edibleoil-1
65 | - ${scan_data_dir}/0525-toygrab-others toygrab-1
66 | - ${scan_data_dir}/0526-toytable-others toytable-1
67 | - ${scan_data_dir}/0527-spalding-others spalding-1
68 | - ${scan_data_dir}/0534-tonkotsuramen-box tonkotsuramen-1
69 | - ${scan_data_dir}/0535-odbmilk-box odbmilk-1
70 | - ${scan_data_dir}/0537-petsnack-box petsnack-1
71 | - ${scan_data_dir}/0539-spamwrapper-others spamwrapper-1
72 | - ${scan_data_dir}/0543-brownhouse-others brownhouse-1
73 | - ${scan_data_dir}/0547-cubebox-box cubebox-1
74 | - ${scan_data_dir}/0548-duck-others duck-1
75 | - ${scan_data_dir}/0550-greenbox-box greenbox-1
76 | - ${scan_data_dir}/0551-milk-others milk-1
77 | - ${scan_data_dir}/0552-mushroom-others mushroom-1
78 | - ${scan_data_dir}/0557-santachoc-others santachoc-1
79 | - ${scan_data_dir}/0558-teddychoc-others teddychoc-1
80 | - ${scan_data_dir}/0559-tissuebox-box tissuebox-1
81 | - ${scan_data_dir}/0560-tofubox-box tofubox-1
82 | - ${scan_data_dir}/0564-biatee-others biatee-1
83 | - ${scan_data_dir}/0565-biscuits-box biscuits-1
84 | - ${scan_data_dir}/0568-cornflakes-box cornflakes-1
85 | - ${scan_data_dir}/0570-kasekuchen-box kasekuchen-1
86 | - ${scan_data_dir}/0577-schoko-box schoko-1
87 | - ${scan_data_dir}/0578-tee-others tee-1
88 | - ${scan_data_dir}/0579-tomatocan-bottle tomatocan-1
89 | - ${scan_data_dir}/0580-xmaxbox-others xmaxbox-1
90 | - ${scan_data_dir}/0582-yogurtlarge-others yogurtlarge-1
91 | - ${scan_data_dir}/0583-yogurtmedium-others yogurtmedium-1
92 | - ${scan_data_dir}/0594-martinBootsLeft-others martinBootsLeft-1
93 | - ${scan_data_dir}/0595-martinBootsRight-others martinBootsRight-1
94 |
95 | outputs_dir: ${work_dir}/data/sfm_model/{}
96 |
97 | network:
98 | detection: superpoint
99 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
100 |
101 | matching: superglue
102 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth
103 |
104 | sfm:
105 | down_ratio: 5
106 | covis_num: 10
107 | rotation_thresh: 50
108 |
109 | disable_lightning_logs: True
110 |
111 |
--------------------------------------------------------------------------------
/configs/preprocess/sfm_spp_spg_train.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: sfm
4 | work_dir: ${hydra:runtime.cwd}
5 | redo: False
6 |
7 | scan_data_dir: ${work_dir}/data/onepose_datasets/train_data
8 |
9 | dataset:
10 | max_num_kp3d: 2500
11 | max_num_kp2d: 1000
12 |
13 | data_dir:
14 | - ${scan_data_dir}/0410-huiyuan-box huiyuan-1 huiyuan-2 huiyuan-3
15 | - ${scan_data_dir}/0413-juliecookies-box juliecookies-1 juliecookies-2 juliecookies-3
16 | - ${scan_data_dir}/0414-babydiapers-others babydiapers-1 babydiapers-2 babydiapers-3
17 | - ${scan_data_dir}/0415-captaincup-others captaincup-1 captaincup-2 captaincup-3
18 | - ${scan_data_dir}/0416-tongyinoodles-others tongyinoodles-1 tongyinoodles-2 tongyinoodles-3
19 | - ${scan_data_dir}/0418-cookies1-others cookies1-1 cookies1-2
20 | - ${scan_data_dir}/0420-liuliumei-others liuliumei-1 liuliumei-2 liuliumei-3
21 | - ${scan_data_dir}/0421-cannedfish-others cannedfish-1 cannedfish-2 cannedfish-3
22 | - ${scan_data_dir}/0443-wheatbear-others wheatbear-1 wheatbear-2
23 | - ${scan_data_dir}/0445-pistonhead-others pistonhead-1 pistonhead-2 pistonhead-3
24 | - ${scan_data_dir}/0448-soyabeancookies-bottle soyabeancookies-1 soyabeancookies-2
25 | - ${scan_data_dir}/0460-hbdsoyabeancookies-bottle hbdsoyabeancookies-1 bhdsoyabeancookies-2 bhdsoyabeancookies-3
26 | - ${scan_data_dir}/0461-cranberrycookies-bottle cranberrycookies-1 cranberrycookies-2 cranberrycookies-3
27 | - ${scan_data_dir}/0462-ylmilkpowder-bottle ylmilkpowder-1 ylmilkpowder-2 ylmilkpowder-3
28 | - ${scan_data_dir}/0463-camelmilk-bottle camelmilk-1 camelmilk-2 camelmilk-3
29 | - ${scan_data_dir}/0464-mfchoccake-box mfchoccake-1 mfchoccake-2 mfchoccake-3
30 | - ${scan_data_dir}/0465-mfcreamcake-box mfcreamcake-1 mfcreamcake-2 mfcreamcake-3
31 | - ${scan_data_dir}/0477-cutlet-bottle cutlet-1 cutlet-2 cutlet-3
32 | - ${scan_data_dir}/0479-ggbondcutlet-others ggbondcutlet-1 ggbondcutlet-2 ggbondcutlet-3
33 | - ${scan_data_dir}/0484-bigroll-box bigroll-1 bigroll-2 bigroll-3
34 | - ${scan_data_dir}/0499-tiramisufranzzi-box tiramisufranzzi-1 tiramisufranzzi-2 tiramisufranzzi-3
35 | - ${scan_data_dir}/0506-sauerkrautnoodles-others sauerkrautnoodles-1 sauerkrautnoodles-2 sauerkrautnoodles-3
36 | - ${scan_data_dir}/0507-hotsournoodles-others hotsournoodles-1 hotsournoodles-2 hotsournoodles-3
37 | - ${scan_data_dir}/0509-bscola-others bscola-1 bscola-2 bscola-3
38 | - ${scan_data_dir}/0512-ugreenhub-box ugreenhub-1 ugreenhub-2 ugreenhub-3
39 | - ${scan_data_dir}/0513-busbox-box busbox-1 busbox-2 busbox-3
40 | - ${scan_data_dir}/0516-wewarm-box wewarm-1 wewarm-2 wewarm-3
41 | - ${scan_data_dir}/0529-onionnoodles-box onionnoodles-1 onionnoodles-2
42 | - ${scan_data_dir}/0530-trufflenoodles-box trufflenoodles-1 trufflenoodles-2 trufflenoodles-3
43 | - ${scan_data_dir}/0531-whiskware-box whiskware-1 whiskware-2 whiskware-3
44 | - ${scan_data_dir}/0532-delonghi-box delonghi-1 delonghi-2
45 | - ${scan_data_dir}/0533-shiramyun-box shiramyun-1 shiramyun-2 shiramyun-3
46 | - ${scan_data_dir}/0536-ranovarect-box ranovarect-1 ranovarect-2 ranvorect-3 ranvorect-4
47 | - ${scan_data_dir}/0542-bueno-box bueno-1 bueno-2 bueno-3
48 | - ${scan_data_dir}/0545-book-others book-1 book-2 book-3
49 | - ${scan_data_dir}/0546-can-bottle can-1 can-2 can-3
50 | - ${scan_data_dir}/0549-footballcan-bottle footballcan-1 footballcan-2 footballcan-3
51 | - ${scan_data_dir}/0556-pinkbox-box pinkbox-1 pinkbox-2 pinkbox-3
52 | - ${scan_data_dir}/0561-yellowbottle-bottle yellowbottle-1 yellowbottle-2 yellowbottle-3
53 | - ${scan_data_dir}/0562-yellowbox-box yellowbox-1 yellowbox-2 yellowbox-3
54 | - ${scan_data_dir}/0563-applejuice-box applejuice-1 applejuice-2 applejuice-3
55 | - ${scan_data_dir}/0566-chillisauce-box chillisauce-1 chillisauce-2 chillisauce-3
56 | - ${scan_data_dir}/0567-coffeebox-box coffeebox-1 coffeebox-2 coffeebox-3
57 | - ${scan_data_dir}/0569-greentea-bottle greentea-1 greentea-2 greentea-3
58 | - ${scan_data_dir}/0571-cakebox-box cakebox-1 cakebox-2 cakebox-3
59 | - ${scan_data_dir}/0572-milkbox-others milkbox-1 milkbox-2 milkbox-3
60 | - ${scan_data_dir}/0573-redchicken-others redchicken-1 redchicken-2 redchicken-3
61 | - ${scan_data_dir}/0574-rubberduck-others rubberduck-1 rubberduck-2 rubberduck-3
62 | - ${scan_data_dir}/0575-saltbottle-bottle saltbottle-1 saltbottle-2 satlbottle-3
63 |
64 | outputs_dir: ${work_dir}/data/sfm_model/{}
65 |
66 | network:
67 | detection: superpoint
68 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
69 |
70 | matching: superglue
71 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth
72 |
73 | sfm:
74 | down_ratio: 5
75 | covis_num: 10
76 | rotation_thresh: 50
77 |
78 |
79 | disable_lightning_logs: True
80 |
81 |
--------------------------------------------------------------------------------
/configs/preprocess/sfm_spp_spg_val.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | type: sfm
4 | work_dir: ${hydra:runtime.cwd}
5 | redo: False
6 |
7 | scan_data_dir: ${work_dir}/data/onepose_datasets/val_data
8 |
9 | dataset:
10 | max_num_kp3d: 2500
11 | max_num_kp2d: 1000
12 |
13 | data_dir:
14 | - ${scan_data_dir}/0601-loquat-box loquat-1
15 | - ${scan_data_dir}/0602-aficion-box aficion-1
16 | - ${scan_data_dir}/0603-redbook-box redbook-1
17 | - ${scan_data_dir}/0604-pillbox1-box pillbox1-1
18 | - ${scan_data_dir}/0605-pillbox2-box pillbox2-1
19 | - ${scan_data_dir}/0606-tiger-others tiger-1
20 | - ${scan_data_dir}/0607-admilk-others admilk-1
21 | - ${scan_data_dir}/0608-teacan-others teacan-1
22 | - ${scan_data_dir}/0609-doll-others doll-1
23 | - ${scan_data_dir}/0610-calendar-box calendar-1
24 | - ${scan_data_dir}/0611-pikachubowl-others pikachubowl-1
25 | - ${scan_data_dir}/0612-originaloreo-box originaloreo-1
26 | - ${scan_data_dir}/0613-adidasshoeright-others adidasshoeright-1
27 | - ${scan_data_dir}/0614-darlietoothpaste-box darlietoothpaste-1
28 | - ${scan_data_dir}/0615-nabati-bottle nabati-1
29 | - ${scan_data_dir}/0616-hmbb-others hmbb-1
30 | - ${scan_data_dir}/0617-porcelain-others porcelain-1
31 | - ${scan_data_dir}/0618-yogurt-bottle yogurt-1
32 | - ${scan_data_dir}/0619-newtolmeat-others newtolmeat-1
33 | - ${scan_data_dir}/0620-dinosaurcup-bottle dinosaurcup-1
34 | - ${scan_data_dir}/0621-saltbox-box saltbox-1
35 |
36 | outputs_dir: ${work_dir}/data/sfm_model/{}
37 |
38 | network:
39 | detection: superpoint
40 | detection_model_path: ${work_dir}/data/models/extractors/SuperPoint/superpoint_v1.pth
41 |
42 | matching: superglue
43 | matching_model_path: ${work_dir}/data/models/matchers/SuperGlue/superglue_outdoor.pth
44 |
45 | sfm:
46 | down_ratio: 5
47 | covis_num: 10
48 | rotation_thresh: 50
49 |
50 | disable_lightning_logs: True
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: onepose
2 | channels:
3 | # - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch
4 | - pytorch
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - python=3.7
9 | - pytorch=1.8.0
10 | - torchvision=0.9.1
11 | - cudatoolkit=10.2
12 | - ipython
13 | - tqdm
14 | - matplotlib
15 | - pylint
16 | - conda-forge::jupyterlab
17 | - conda-forge::h5py=3.1.0
18 | - conda-forge::loguru=0.5.3
19 | - conda-forge::scipy
20 | - conda-forge::numba
21 | - conda-forge::ipdb
22 | - conda-forge::albumentations=0.5.1
23 | - pip
24 | - pip:
25 | - -r requirements.txt
--------------------------------------------------------------------------------
/feature_matching_object_detector.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import torch
3 | import hydra
4 | from tqdm import tqdm
5 | import os
6 | import os.path as osp
7 | import natsort
8 |
9 | from loguru import logger
10 | from torch.utils.data import DataLoader
11 | from src.utils import data_utils
12 | from src.utils.model_io import load_network
13 | from src.local_feature_2D_detector import LocalFeatureObjectDetector
14 |
15 | from pytorch_lightning import seed_everything
16 |
17 | seed_everything(12345)
18 |
19 |
20 | def get_default_paths(cfg, data_root, data_dir, sfm_model_dir):
21 | anno_dir = osp.join(
22 | sfm_model_dir, f"outputs_{cfg.network.detection}_{cfg.network.matching}", "anno"
23 | )
24 | avg_anno_3d_path = osp.join(anno_dir, "anno_3d_average.npz")
25 | clt_anno_3d_path = osp.join(anno_dir, "anno_3d_collect.npz")
26 | idxs_path = osp.join(anno_dir, "idxs.npy")
27 | sfm_ws_dir = osp.join(
28 | sfm_model_dir,
29 | f"outputs_{cfg.network.detection}_{cfg.network.matching}",
30 | "sfm_ws",
31 | "model",
32 | )
33 |
34 | img_lists = []
35 | color_dir = osp.join(data_dir, "color_full")
36 | if not osp.exists(color_dir):
37 | logger.info('color_full directory not exists! Try to parse from Frames.m4v')
38 | scan_video_dir = osp.join(data_dir, 'Frames.m4v')
39 | assert osp.exists(scan_video_dir), 'Frames.m4v not found! Run detector fail!'
40 | data_utils.video2img(scan_video_dir, color_dir)
41 | img_lists += glob.glob(color_dir + "/*.png", recursive=True)
42 |
43 | img_lists = natsort.natsorted(img_lists)
44 |
45 | # Save detect results:
46 | detect_img_dir = osp.join(data_dir, "color_det")
47 | if osp.exists(detect_img_dir):
48 | os.system(f"rm -rf {detect_img_dir}")
49 | os.makedirs(detect_img_dir, exist_ok=True)
50 |
51 | detect_K_dir = osp.join(data_dir, "intrin_det")
52 | if osp.exists(detect_K_dir):
53 | os.system(f"rm -rf {detect_K_dir}")
54 | os.makedirs(detect_K_dir, exist_ok=True)
55 |
56 | intrin_full_path = osp.join(data_dir, "intrinsics.txt")
57 | paths = {
58 | "data_root": data_root,
59 | "data_dir": data_dir,
60 | "sfm_model_dir": sfm_model_dir,
61 | "sfm_ws_dir": sfm_ws_dir,
62 | "avg_anno_3d_path": avg_anno_3d_path,
63 | "clt_anno_3d_path": clt_anno_3d_path,
64 | "idxs_path": idxs_path,
65 | "intrin_full_path": intrin_full_path,
66 | "output_detect_img_dir": detect_img_dir,
67 | "output_K_crop_dir": detect_K_dir
68 | }
69 | return img_lists, paths
70 |
71 | def load_2D_matching_model(cfg):
72 |
73 | def load_extractor_model(cfg, model_path):
74 | """Load extractor model(SuperPoint)"""
75 | from src.models.extractors.SuperPoint.superpoint import SuperPoint
76 | from src.sfm.extract_features import confs
77 |
78 | extractor_model = SuperPoint(confs[cfg.network.detection]["conf"])
79 | extractor_model.cuda()
80 | extractor_model.eval()
81 | load_network(extractor_model, model_path, force=True)
82 |
83 | return extractor_model
84 |
85 | def load_2D_matcher(cfg):
86 | """Load matching model(SuperGlue)"""
87 | from src.models.matchers.SuperGlue.superglue import SuperGlue
88 | from src.sfm.match_features import confs
89 |
90 | match_model = SuperGlue(confs[cfg.network.matching]["conf"])
91 | match_model.eval()
92 | load_network(match_model, cfg.model.match_model_path)
93 | return match_model
94 |
95 | extractor_model = load_extractor_model(cfg, cfg.model.extractor_model_path)
96 | matcher = load_2D_matcher(cfg)
97 | return extractor_model, matcher
98 |
99 |
100 | def pack_data(avg_descriptors3d, clt_descriptors, keypoints3d, detection, image_size):
101 | """Prepare data for OnePose inference"""
102 | keypoints2d = torch.Tensor(detection["keypoints"])
103 | descriptors2d = torch.Tensor(detection["descriptors"])
104 |
105 | inp_data = {
106 | "keypoints2d": keypoints2d[None].cuda(), # [1, n1, 2]
107 | "keypoints3d": keypoints3d[None].cuda(), # [1, n2, 3]
108 | "descriptors2d_query": descriptors2d[None].cuda(), # [1, dim, n1]
109 | "descriptors3d_db": avg_descriptors3d[None].cuda(), # [1, dim, n2]
110 | "descriptors2d_db": clt_descriptors[None].cuda(), # [1, dim, n2*num_leaf]
111 | "image_size": image_size,
112 | }
113 |
114 | return inp_data
115 |
116 |
117 | @torch.no_grad()
118 | def inference_core(cfg, data_root, seq_dir, sfm_model_dir):
119 | """Inference & visualize"""
120 | from src.datasets.normalized_dataset import NormalizedDataset
121 | from src.sfm.extract_features import confs
122 |
123 | # Load models and prepare data:
124 | extractor_model, matching_2D_model = load_2D_matching_model(cfg)
125 | img_lists, paths = get_default_paths(cfg, data_root, seq_dir, sfm_model_dir)
126 | K, _ = data_utils.get_K(paths["intrin_full_path"])
127 |
128 | local_feature_obj_detector = LocalFeatureObjectDetector(
129 | extractor_model,
130 | matching_2D_model,
131 | sfm_ws_dir=paths["sfm_ws_dir"],
132 | n_ref_view=cfg.n_ref_view,
133 | output_results=True,
134 | detect_save_dir=paths['output_detect_img_dir'],
135 | K_crop_save_dir=paths['output_K_crop_dir']
136 | )
137 | dataset = NormalizedDataset(
138 | img_lists, confs[cfg.network.detection]["preprocessing"]
139 | )
140 | loader = DataLoader(dataset, num_workers=1)
141 |
142 | # Begin Object detection:
143 | for id, data in enumerate(tqdm(loader)):
144 | img_path = data["path"][0]
145 | inp = data["image"].cuda()
146 |
147 | # Detect object by 2D local feature matching for the first frame:
148 | local_feature_obj_detector.detect(inp, img_path, K)
149 |
150 | def inference(cfg):
151 | data_dirs = cfg.input.data_dirs
152 | sfm_model_dirs = cfg.input.sfm_model_dirs
153 | if isinstance(data_dirs, str) and isinstance(sfm_model_dirs, str):
154 | data_dirs = [data_dirs]
155 | sfm_model_dirs = [sfm_model_dirs]
156 |
157 | for data_dir, sfm_model_dir in tqdm(
158 | zip(data_dirs, sfm_model_dirs), total=len(data_dirs)
159 | ):
160 | splits = data_dir.split(" ")
161 | data_root = splits[0]
162 | for seq_name in splits[1:]:
163 | seq_dir = osp.join(data_root, seq_name)
164 | logger.info(f"Run feature matching object detector for: {seq_dir}")
165 | inference_core(cfg, data_root, seq_dir, sfm_model_dir)
166 |
167 |
168 | @hydra.main(config_path="configs/", config_name="config.yaml")
169 | def main(cfg):
170 | globals()[cfg.type](cfg)
171 |
172 |
173 | if __name__ == "__main__":
174 | main()
175 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import torch
3 | import hydra
4 | from tqdm import tqdm
5 | import os.path as osp
6 | import numpy as np
7 |
8 | from PIL import Image
9 | from loguru import logger
10 | from torch.utils.data import DataLoader
11 | from src.utils import data_utils, path_utils, eval_utils, vis_utils
12 |
13 | from pytorch_lightning import seed_everything
14 | seed_everything(12345)
15 |
16 |
17 | def get_default_paths(cfg, data_root, data_dir, sfm_model_dir):
18 | anno_dir = osp.join(sfm_model_dir, f'outputs_{cfg.network.detection}_{cfg.network.matching}', 'anno')
19 | avg_anno_3d_path = osp.join(anno_dir, 'anno_3d_average.npz')
20 | clt_anno_3d_path = osp.join(anno_dir, 'anno_3d_collect.npz')
21 | idxs_path = osp.join(anno_dir, 'idxs.npy')
22 |
23 | object_detect_mode = cfg.object_detect_mode
24 | logger.info(f"Use {object_detect_mode} as object detector")
25 | if object_detect_mode == 'GT_box':
26 | color_dir = osp.join(data_dir, 'color')
27 | elif object_detect_mode == 'feature_matching':
28 | color_dir = osp.join(data_dir, 'color_det')
29 | assert osp.exists(color_dir), "color_det directory not exists! You need to run local_feature_2D_detector.py for object detection. Please refer to README.md for the instructions"
30 | else:
31 | raise NotImplementedError
32 |
33 | img_lists = []
34 | img_lists += glob.glob(color_dir + '/*.png', recursive=True)
35 |
36 | intrin_full_path = osp.join(data_dir, 'intrinsics.txt')
37 | paths = {
38 | "data_root": data_root,
39 | 'data_dir': data_dir,
40 | 'sfm_model_dir': sfm_model_dir,
41 | 'avg_anno_3d_path': avg_anno_3d_path,
42 | 'clt_anno_3d_path': clt_anno_3d_path,
43 | 'idxs_path': idxs_path,
44 | 'intrin_full_path': intrin_full_path
45 | }
46 | return img_lists, paths
47 |
48 |
49 | def load_model(cfg):
50 | """ Load model """
51 | def load_matching_model(model_path):
52 | """ Load onepose model """
53 | from src.models.GATsSPG_lightning_model import LitModelGATsSPG
54 |
55 | trained_model = LitModelGATsSPG.load_from_checkpoint(checkpoint_path=model_path)
56 | trained_model.cuda()
57 | trained_model.eval()
58 | trained_model.freeze()
59 |
60 | return trained_model
61 |
62 | def load_extractor_model(cfg, model_path):
63 | """ Load extractor model(SuperPoint) """
64 | from src.models.extractors.SuperPoint.superpoint import SuperPoint
65 | from src.sfm.extract_features import confs
66 | from src.utils.model_io import load_network
67 |
68 | extractor_model = SuperPoint(confs[cfg.network.detection]['conf'])
69 | extractor_model.cuda()
70 | extractor_model.eval()
71 | load_network(extractor_model, model_path)
72 |
73 | return extractor_model
74 |
75 | matching_model = load_matching_model(cfg.model.onepose_model_path)
76 | extractor_model = load_extractor_model(cfg, cfg.model.extractor_model_path)
77 | return matching_model, extractor_model
78 |
79 |
80 | def pack_data(avg_descriptors3d, clt_descriptors, keypoints3d, detection, image_size):
81 | """ Prepare data for OnePose inference """
82 | keypoints2d = torch.Tensor(detection['keypoints'])
83 | descriptors2d = torch.Tensor(detection['descriptors'])
84 |
85 | inp_data = {
86 | 'keypoints2d': keypoints2d[None].cuda(), # [1, n1, 2]
87 | 'keypoints3d': keypoints3d[None].cuda(), # [1, n2, 3]
88 | 'descriptors2d_query': descriptors2d[None].cuda(), # [1, dim, n1]
89 | 'descriptors3d_db': avg_descriptors3d[None].cuda(), # [1, dim, n2]
90 | 'descriptors2d_db': clt_descriptors[None].cuda(), # [1, dim, n2*num_leaf]
91 | 'image_size': image_size
92 | }
93 |
94 | return inp_data
95 |
96 |
97 | @torch.no_grad()
98 | def inference_core(cfg, data_root, seq_dir, sfm_model_dir):
99 | """ Inference & visualize"""
100 | from src.datasets.normalized_dataset import NormalizedDataset
101 | from src.sfm.extract_features import confs
102 | from src.evaluators.cmd_evaluator import Evaluator
103 |
104 | matching_model, extractor_model = load_model(cfg)
105 | img_lists, paths = get_default_paths(cfg, data_root, seq_dir, sfm_model_dir)
106 |
107 | dataset = NormalizedDataset(img_lists, confs[cfg.network.detection]['preprocessing'])
108 | loader = DataLoader(dataset, num_workers=1)
109 | evaluator = Evaluator()
110 |
111 | idx = 0
112 | num_leaf = cfg.num_leaf
113 | avg_data = np.load(paths['avg_anno_3d_path'])
114 | clt_data = np.load(paths['clt_anno_3d_path'])
115 | idxs = np.load(paths['idxs_path'])
116 |
117 | keypoints3d = torch.Tensor(clt_data['keypoints3d']).cuda()
118 | num_3d = keypoints3d.shape[0]
119 | # Load average 3D features:
120 | avg_descriptors3d, _ = data_utils.pad_features3d_random(
121 | avg_data['descriptors3d'],
122 | avg_data['scores3d'],
123 | num_3d
124 | )
125 | # Load corresponding 2D features of each 3D point:
126 | clt_descriptors, _ = data_utils.build_features3d_leaves(
127 | clt_data['descriptors3d'],
128 | clt_data['scores3d'],
129 | idxs, num_3d, num_leaf
130 | )
131 |
132 | for data in tqdm(loader):
133 | img_path = data['path'][0]
134 | inp = data['image'].cuda()
135 |
136 | intrin_path = path_utils.get_intrin_path_by_color(img_path, det_type=cfg.object_detect_mode)
137 | K_crop = np.loadtxt(intrin_path)
138 |
139 | # Detect query image keypoints and extract descriptors:
140 | pred_detection = extractor_model(inp)
141 | pred_detection = {k: v[0].cpu().numpy() for k, v in pred_detection.items()}
142 |
143 | # 2D-3D matching by GATsSPG:
144 | inp_data = pack_data(avg_descriptors3d, clt_descriptors,
145 | keypoints3d, pred_detection, data['size'])
146 | pred, _ = matching_model(inp_data)
147 | matches = pred['matches0'].detach().cpu().numpy()
148 | valid = matches > -1
149 | kpts2d = pred_detection['keypoints']
150 | kpts3d = inp_data['keypoints3d'][0].detach().cpu().numpy()
151 | confidence = pred['matching_scores0'].detach().cpu().numpy()
152 | mkpts2d, mkpts3d, mconf = kpts2d[valid], kpts3d[matches[valid]], confidence[valid]
153 |
154 | # Estimate object pose by 2D-3D correspondences:
155 | pose_pred, pose_pred_homo, inliers = eval_utils.ransac_PnP(K_crop, mkpts2d, mkpts3d, scale=1000)
156 |
157 | # Evaluate:
158 | gt_pose_path = path_utils.get_gt_pose_path_by_color(img_path, det_type=cfg.object_detect_mode)
159 | pose_gt = np.loadtxt(gt_pose_path)
160 | evaluator.evaluate(pose_pred, pose_gt)
161 |
162 | # Visualize:
163 | if cfg.save_wis3d:
164 | poses = [pose_gt, pose_pred_homo]
165 | box3d_path = path_utils.get_3d_box_path(data_root)
166 | intrin_full_path = path_utils.get_intrin_full_path(seq_dir)
167 | image_full_path = path_utils.get_img_full_path_by_color(img_path, det_type=cfg.object_detect_mode)
168 |
169 | image_full = vis_utils.vis_reproj(image_full_path, poses, box3d_path, intrin_full_path,
170 | save_demo=cfg.save_demo, demo_root=cfg.demo_root)
171 | mkpts3d_2d = vis_utils.reproj(K_crop, pose_gt, mkpts3d)
172 | image0 = Image.open(img_path).convert('LA')
173 | image1 = image0.copy()
174 | vis_utils.dump_wis3d(idx, cfg, seq_dir, image0, image1, image_full,
175 | mkpts2d, mkpts3d_2d, mconf, inliers)
176 |
177 | idx += 1
178 |
179 | eval_result = evaluator.summarize()
180 | obj_name = sfm_model_dir.split('/')[-1]
181 | seq_name = seq_dir.split('/')[-1]
182 | eval_utils.record_eval_result(cfg.output.eval_dir, obj_name, seq_name, eval_result)
183 |
184 |
185 | def inference(cfg):
186 | data_dirs = cfg.input.data_dirs
187 | sfm_model_dirs = cfg.input.sfm_model_dirs
188 | if isinstance(data_dirs, str) and isinstance(sfm_model_dirs, str):
189 | data_dirs = [data_dirs]
190 | sfm_model_dirs = [sfm_model_dirs]
191 |
192 | for data_dir, sfm_model_dir in tqdm(zip(data_dirs, sfm_model_dirs), total=len(data_dirs)):
193 | splits = data_dir.split(" ")
194 | data_root = splits[0]
195 | for seq_name in splits[1:]:
196 | seq_dir = osp.join(data_root, seq_name)
197 | logger.info(f'Eval {seq_dir}')
198 | inference_core(cfg, data_root, seq_dir, sfm_model_dir)
199 |
200 |
201 | @hydra.main(config_path='configs/', config_name='config.yaml')
202 | def main(cfg):
203 | globals()[cfg.type](cfg)
204 |
205 | if __name__ == "__main__":
206 | main()
207 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.5.10
2 | aiohttp==3.7
3 | aioredis==1.3.1
4 | pydegensac==0.1.2
5 | opencv_python==4.4.0.46
6 | yacs>=0.1.8
7 | einops==0.3.0
8 | kornia==0.4.1
9 | pickle5==0.0.11
10 | timm>=0.3.2
11 | hydra-core==1.1.1
12 | omegaconf==2.1.2
13 | pycocotools==2.0.4
14 | wandb==0.12.17
15 | rich==12.4.4
16 | transforms3d==0.3.1
17 | natsort==8.1.0
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import glob
4 | import hydra
5 |
6 | import os.path as osp
7 | from loguru import logger
8 | from pathlib import Path
9 | from omegaconf import DictConfig
10 |
11 |
12 | def merge_(anno_2d_file, avg_anno_3d_file, collect_anno_3d_file,
13 | idxs_file, img_id, ann_id, images, annotations):
14 | """ To prepare training and test objects, we merge annotations about difference objs"""
15 | with open(anno_2d_file, 'r') as f:
16 | annos_2d = json.load(f)
17 |
18 | for anno_2d in annos_2d:
19 | img_id += 1
20 | info = {
21 | 'id': img_id,
22 | 'img_file': anno_2d['img_file'],
23 | }
24 | images.append(info)
25 |
26 | ann_id += 1
27 | anno = {
28 | 'image_id': img_id,
29 | 'id': ann_id,
30 | 'pose_file': anno_2d['pose_file'],
31 | 'anno2d_file': anno_2d['anno_file'],
32 | 'avg_anno3d_file': avg_anno_3d_file,
33 | 'collect_anno3d_file': collect_anno_3d_file,
34 | 'idxs_file': idxs_file
35 | }
36 | annotations.append(anno)
37 | return img_id, ann_id
38 |
39 |
40 | def merge_anno(cfg):
41 | """ Merge different objects' anno file into one anno file """
42 | anno_dirs = []
43 |
44 | if cfg.split == 'train':
45 | names = cfg.train.names
46 | elif cfg.split == 'val':
47 | names = cfg.val.names
48 |
49 | for name in names:
50 | anno_dir = osp.join(cfg.datamodule.data_dir, name, f'outputs_{cfg.network.detection}_{cfg.network.matching}', 'anno')
51 | anno_dirs.append(anno_dir)
52 |
53 | img_id = 0
54 | ann_id = 0
55 | images = []
56 | annotations = []
57 | for anno_dir in anno_dirs:
58 | logger.info(f'Merging anno dir: {anno_dir}')
59 | anno_2d_file = osp.join(anno_dir, 'anno_2d.json')
60 | avg_anno_3d_file = osp.join(anno_dir, 'anno_3d_average.npz')
61 | collect_anno_3d_file = osp.join(anno_dir, 'anno_3d_collect.npz')
62 | idxs_file = osp.join(anno_dir, 'idxs.npy')
63 |
64 | if not osp.isfile(anno_2d_file) or not osp.isfile(avg_anno_3d_file) or not osp.isfile(collect_anno_3d_file):
65 | logger.info(f'No annotation in: {anno_dir}')
66 | continue
67 |
68 | img_id, ann_id = merge_(anno_2d_file, avg_anno_3d_file, collect_anno_3d_file,
69 | idxs_file, img_id, ann_id, images, annotations)
70 |
71 | logger.info(f'Total num: {len(images)}')
72 | instance = {'images': images, 'annotations': annotations}
73 |
74 | out_dir = osp.dirname(cfg.datamodule.out_path)
75 | Path(out_dir).mkdir(exist_ok=True, parents=True)
76 | with open(cfg.datamodule.out_path, 'w') as f:
77 | json.dump(instance, f)
78 |
79 |
80 | def sfm(cfg):
81 | """ Reconstruct and postprocess sparse object point cloud, and store point cloud features"""
82 | data_dirs = cfg.dataset.data_dir
83 | down_ratio = cfg.sfm.down_ratio
84 | data_dirs = [data_dirs] if isinstance(data_dirs, str) else data_dirs
85 |
86 | for data_dir in data_dirs:
87 | logger.info(f"Processing {data_dir}.")
88 | root_dir, sub_dirs = data_dir.split(' ')[0], data_dir.split(' ')[1:]
89 |
90 | # Parse image directory and downsample images:
91 | img_lists = []
92 | for sub_dir in sub_dirs:
93 | seq_dir = osp.join(root_dir, sub_dir)
94 | img_lists += glob.glob(str(Path(seq_dir)) + '/color/*.png', recursive=True)
95 |
96 | down_img_lists = []
97 | for img_file in img_lists:
98 | index = int(img_file.split('/')[-1].split('.')[0])
99 | if index % down_ratio == 0:
100 | down_img_lists.append(img_file)
101 | img_lists = down_img_lists
102 |
103 | if len(img_lists) == 0:
104 | logger.info(f"No png image in {root_dir}")
105 | continue
106 |
107 | obj_name = root_dir.split('/')[-1]
108 | outputs_dir_root = cfg.dataset.outputs_dir.format(obj_name)
109 |
110 | # Begin SfM and postprocess:
111 | sfm_core(cfg, img_lists, outputs_dir_root)
112 | postprocess(cfg, img_lists, root_dir, outputs_dir_root)
113 |
114 |
115 | def sfm_core(cfg, img_lists, outputs_dir_root):
116 | """ Sparse reconstruction: extract features, match features, triangulation"""
117 | from src.sfm import extract_features, match_features, \
118 | generate_empty, triangulation, pairs_from_poses
119 |
120 | # Construct output directory structure:
121 | outputs_dir = osp.join(outputs_dir_root, 'outputs' + '_' + cfg.network.detection + '_' + cfg.network.matching)
122 | feature_out = osp.join(outputs_dir, f'feats-{cfg.network.detection}.h5')
123 | covis_pairs_out = osp.join(outputs_dir, f'pairs-covis{cfg.sfm.covis_num}.txt')
124 | matches_out = osp.join(outputs_dir, f'matches-{cfg.network.matching}.h5')
125 | empty_dir = osp.join(outputs_dir, 'sfm_empty')
126 | deep_sfm_dir = osp.join(outputs_dir, 'sfm_ws')
127 |
128 | if cfg.redo:
129 | os.system(f'rm -rf {outputs_dir}')
130 | Path(outputs_dir).mkdir(exist_ok=True, parents=True)
131 |
132 | # Extract image features, construct image pairs and then match:
133 | extract_features.main(img_lists, feature_out, cfg)
134 | pairs_from_poses.covis_from_pose(img_lists, covis_pairs_out, cfg.sfm.covis_num, max_rotation=cfg.sfm.rotation_thresh)
135 | match_features.main(cfg, feature_out, covis_pairs_out, matches_out, vis_match=False)
136 |
137 | # Reconstruct 3D point cloud with known image poses:
138 | generate_empty.generate_model(img_lists, empty_dir)
139 | triangulation.main(deep_sfm_dir, empty_dir, outputs_dir, covis_pairs_out, feature_out, matches_out, image_dir=None)
140 |
141 |
142 | def postprocess(cfg, img_lists, root_dir, outputs_dir_root):
143 | """ Filter points and average feature"""
144 | from src.sfm.postprocess import filter_points, feature_process, filter_tkl
145 |
146 | bbox_path = osp.join(root_dir, "box3d_corners.txt")
147 | # Construct output directory structure:
148 | outputs_dir = osp.join(outputs_dir_root, 'outputs' + '_' + cfg.network.detection + '_' + cfg.network.matching)
149 | feature_out = osp.join(outputs_dir, f'feats-{cfg.network.detection}.h5')
150 | deep_sfm_dir = osp.join(outputs_dir, 'sfm_ws')
151 | model_path = osp.join(deep_sfm_dir, 'model')
152 |
153 | # Select feature track length to limit the number of 3D points below the 'max_num_kp3d' threshold:
154 | track_length, points_count_list = filter_tkl.get_tkl(model_path, thres=cfg.dataset.max_num_kp3d, show=False)
155 | filter_tkl.vis_tkl_filtered_pcds(model_path, points_count_list, track_length, outputs_dir) # For visualization only
156 |
157 | # Leverage the selected feature track length threshold and 3D BBox to filter 3D points:
158 | xyzs, points_idxs = filter_points.filter_3d(model_path, track_length, bbox_path)
159 | # Merge 3d points by distance between points
160 | merge_xyzs, merge_idxs = filter_points.merge(xyzs, points_idxs, dist_threshold=1e-3)
161 |
162 | # Save features of the filtered point cloud:
163 | feature_process.get_kpt_ann(cfg, img_lists, feature_out, outputs_dir, merge_idxs, merge_xyzs)
164 |
165 |
166 | @hydra.main(config_path='configs/', config_name='config.yaml')
167 | def main(cfg: DictConfig):
168 | globals()[cfg.type](cfg)
169 |
170 |
171 | if __name__ == "__main__":
172 | main()
--------------------------------------------------------------------------------
/scripts/demo_pipeline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | PROJECT_DIR="$(pwd)"
3 | OBJ_NAME=$1
4 | echo "Current work dir: $PROJECT_DIR"
5 |
6 | echo '-------------------'
7 | echo 'Parse scanned data:'
8 | echo '-------------------'
9 | # Parse scanned annotated & test sequence:
10 | python $PROJECT_DIR/parse_scanned_data.py \
11 | --scanned_object_path \
12 | "$PROJECT_DIR/data/demo/$OBJ_NAME"
13 |
14 | echo '--------------------------------------------------------------'
15 | echo 'Run SfM to reconstruct object point cloud for pose estimation:'
16 | echo '--------------------------------------------------------------'
17 | # Run SfM to reconstruct object sparse point cloud from $OBJ_NAME-annotate sequence:
18 | python $PROJECT_DIR/run.py \
19 | +preprocess="sfm_spp_spg_demo" \
20 | dataset.data_dir="$PROJECT_DIR/data/demo/$OBJ_NAME $OBJ_NAME-annotate" \
21 | dataset.outputs_dir="$PROJECT_DIR/data/demo/$OBJ_NAME/sfm_model" \
22 |
23 | echo "-----------------------------------"
24 | echo "Run inference and output demo video:"
25 | echo "-----------------------------------"
26 |
27 | WITH_TRACKING=False
28 | while [[ "$#" -gt 0 ]]; do
29 | case $1 in
30 | -u|--WITH_TRACKING) WITH_TRACKING=True ;;
31 | esac
32 | shift
33 | done
34 |
35 | # Run inference on $OBJ_NAME-test and output demo video:
36 | python $PROJECT_DIR/inference_demo.py \
37 | +experiment="test_demo" \
38 | input.data_dirs="$PROJECT_DIR/data/demo/$OBJ_NAME $OBJ_NAME-test" \
39 | input.sfm_model_dirs="$PROJECT_DIR/data/demo/$OBJ_NAME/sfm_model" \
40 | use_tracking=${WITH_TRACKING}
41 |
--------------------------------------------------------------------------------
/scripts/parse_full_img.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | PROJECT_DIR="$(pwd)"
3 | VIDEO_PATH=$1
4 |
5 | echo '-------------------'
6 | echo 'Parse full image: '
7 | echo '-------------------'
8 |
9 | # Parse full image from Frames.m4v
10 | python $PROJECT_DIR/video2img.py \
11 | --video_file ${VIDEO_PATH}
12 |
--------------------------------------------------------------------------------
/scripts/prepare_2D_matching_resources.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | REPO_ROOT=$(pwd)
3 | echo "work space:$REPO_ROOT"
4 |
5 | # Download code and pretrained model of SuperPoint:
6 | mkdir -p $REPO_ROOT/data/models/extractors/SuperPoint
7 | # cd $REPO_ROOT/src/models/extractors/SuperPoint
8 | # wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/superpoint.py
9 | cd $REPO_ROOT/data/models/extractors/SuperPoint
10 | wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth
11 |
12 | # Download code and pretrained model of SuperGlue:
13 | mkdir -p $REPO_ROOT/data/models/matchers/SuperGlue
14 | # cd $REPO_ROOT/src/models/matchers/SuperGlue
15 | # wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/superglue.py
16 | cd $REPO_ROOT/data/models/matchers/SuperGlue
17 | wget https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superglue_outdoor.pth
--------------------------------------------------------------------------------
/src/callbacks/custom_callbacks.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import Callback
2 |
3 |
4 | class ExampleCallback(Callback):
5 | def __init__(self):
6 | pass
7 |
8 | def on_init_start(self, trainer):
9 | print("Starting to initialize trainer!")
10 |
11 | def on_init_end(self, trainer):
12 | print("Trainer is initialized now.")
13 |
14 | def on_train_end(self, trainer, pl_module):
15 | print("Do something when training ends.")
16 |
17 |
18 | class UnfreezeModelCallback(Callback):
19 | """
20 | Unfreeze all model parameters after a few epochs.
21 | """
22 |
23 | def __init__(self, wait_epochs=5):
24 | self.wait_epochs = wait_epochs
25 |
26 | def on_epoch_end(self, trainer, pl_module):
27 | if trainer.current_epoch == self.wait_epochs:
28 | for param in pl_module.model.model.parameters():
29 | param.requires_grad = True
30 |
--------------------------------------------------------------------------------
/src/callbacks/wandb_callbacks.py:
--------------------------------------------------------------------------------
1 | # wandb
2 | from pytorch_lightning.loggers import WandbLogger
3 | import wandb
4 |
5 | # pytorch
6 | from pytorch_lightning import Callback
7 | import pytorch_lightning as pl
8 | import torch
9 |
10 | # others
11 | from sklearn.metrics import precision_score, recall_score, f1_score
12 | from typing import List
13 | import glob
14 | import os
15 |
16 |
17 | def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger:
18 | logger = None
19 | for lg in trainer.logger:
20 | if isinstance(lg, WandbLogger):
21 | logger = lg
22 |
23 | if not logger:
24 | raise Exception(
25 | "You're using wandb related callback, "
26 | "but WandbLogger was not found for some reason..."
27 | )
28 |
29 | return logger
30 |
31 |
32 | class UploadCodeToWandbAsArtifact(Callback):
33 | """Upload all *.py files to wandb as an artifact at the beginning of the run."""
34 |
35 | def __init__(self, code_dir: str):
36 | self.code_dir = code_dir
37 |
38 | def on_train_start(self, trainer, pl_module):
39 | logger = get_wandb_logger(trainer=trainer)
40 | experiment = logger.experiment
41 |
42 | code = wandb.Artifact("project-source", type="code")
43 | for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True):
44 | code.add_file(path)
45 |
46 | experiment.use_artifact(code)
47 |
48 |
49 | class UploadCheckpointsToWandbAsArtifact(Callback):
50 | """Upload experiment checkpoints to wandb as an artifact at the end of training."""
51 |
52 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
53 | self.ckpt_dir = ckpt_dir
54 | self.upload_best_only = upload_best_only
55 |
56 | def on_train_end(self, trainer, pl_module):
57 | logger = get_wandb_logger(trainer=trainer)
58 | experiment = logger.experiment
59 |
60 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
61 |
62 | if self.upload_best_only:
63 | ckpts.add_file(trainer.checkpoint_callback.best_model_path)
64 | else:
65 | for path in glob.glob(
66 | os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True
67 | ):
68 | ckpts.add_file(path)
69 |
70 | experiment.use_artifact(ckpts)
71 |
72 |
73 | class WatchModelWithWandb(Callback):
74 | """Make WandbLogger watch model at the beginning of the run."""
75 |
76 | def __init__(self, log: str = "gradients", log_freq: int = 100):
77 | self.log = log
78 | self.log_freq = log_freq
79 |
80 | def on_train_start(self, trainer, pl_module):
81 | logger = get_wandb_logger(trainer=trainer)
82 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
83 |
84 |
85 | class LogF1PrecisionRecallHeatmapToWandb(Callback):
86 | """
87 | Generate f1, precision and recall heatmap from validation step outputs.
88 | Expects validation step to return predictions and targets.
89 | Works only for single label classification!
90 | """
91 |
92 | def __init__(self, class_names: List[str] = None):
93 | self.class_names = class_names
94 | self.preds = []
95 | self.targets = []
96 | self.ready = False
97 |
98 | def on_sanity_check_end(self, trainer, pl_module):
99 | """Start executing this callback only after all validation sanity checks end."""
100 | self.ready = True
101 |
102 | def on_validation_batch_end(
103 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
104 | ):
105 | """Gather data from single batch."""
106 | if self.ready:
107 | preds, targets = outputs["preds"], outputs["targets"]
108 | self.preds.append(preds)
109 | self.targets.append(targets)
110 |
111 | def on_validation_epoch_end(self, trainer, pl_module):
112 | """Generate f1, precision and recall heatmap."""
113 | if self.ready:
114 | logger = get_wandb_logger(trainer=trainer)
115 | experiment = logger.experiment
116 |
117 | self.preds = torch.cat(self.preds).cpu()
118 | self.targets = torch.cat(self.targets).cpu()
119 | f1 = f1_score(self.preds, self.targets, average=None)
120 | r = recall_score(self.preds, self.targets, average=None)
121 | p = precision_score(self.preds, self.targets, average=None)
122 |
123 | experiment.log(
124 | {
125 | f"f1_p_r_heatmap/{trainer.current_epoch}_{experiment.id}": wandb.plots.HeatMap(
126 | x_labels=self.class_names,
127 | y_labels=["f1", "precision", "recall"],
128 | matrix_values=[f1, p, r],
129 | show_text=True,
130 | )
131 | },
132 | commit=False,
133 | )
134 |
135 | self.preds = []
136 | self.targets = []
137 |
138 |
139 | class LogConfusionMatrixToWandb(Callback):
140 | """
141 | Generate Confusion Matrix.
142 | Expects validation step to return predictions and targets.
143 | Works only for single label classification!
144 | """
145 |
146 | def __init__(self, class_names: List[str] = None):
147 | self.class_names = class_names
148 | self.preds = []
149 | self.targets = []
150 | self.ready = False
151 |
152 | def on_sanity_check_end(self, trainer, pl_module):
153 | """Start executing this callback only after all validation sanity checks end."""
154 | self.ready = True
155 |
156 | def on_validation_batch_end(
157 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
158 | ):
159 | """Gather data from single batch."""
160 | if self.ready:
161 | preds, targets = outputs["preds"], outputs["targets"]
162 | self.preds.append(preds)
163 | self.targets.append(targets)
164 |
165 | def on_validation_epoch_end(self, trainer, pl_module):
166 | """Generate confusion matrix."""
167 | if self.ready:
168 | logger = get_wandb_logger(trainer=trainer)
169 | experiment = logger.experiment
170 |
171 | self.preds = torch.cat(self.preds).tolist()
172 | self.targets = torch.cat(self.targets).tolist()
173 |
174 | experiment.log(
175 | {
176 | f"confusion_matrix/{trainer.current_epoch}_{experiment.id}": wandb.plot.confusion_matrix(
177 | preds=self.preds,
178 | y_true=self.targets,
179 | class_names=self.class_names,
180 | )
181 | },
182 | commit=False,
183 | )
184 |
185 | self.preds = []
186 | self.targets = []
187 |
188 |
189 | ''' BUGGED :(
190 | class LogBestMetricScoresToWandb(Callback):
191 | """
192 | Store in wandb:
193 | - max train acc
194 | - min train loss
195 | - max val acc
196 | - min val loss
197 | Useful for comparing runs in table views, as wandb doesn't currently support column aggregation.
198 | """
199 |
200 | def __init__(self):
201 | self.train_loss_best = None
202 | self.train_acc_best = None
203 | self.val_loss_best = None
204 | self.val_acc_best = None
205 | self.ready = False
206 |
207 | def on_sanity_check_end(self, trainer, pl_module):
208 | """Start executing this callback only after all validation sanity checks end."""
209 | self.ready = True
210 |
211 | def on_epoch_end(self, trainer, pl_module):
212 | if self.ready:
213 | logger = get_wandb_logger(trainer=trainer)
214 | experiment = logger.experiment
215 |
216 | metrics = trainer.callback_metrics
217 |
218 | if not self.train_loss_best or metrics["train/loss"] < self.train_loss_best:
219 | self.train_loss_best = metrics["train_loss"]
220 |
221 | if not self.train_acc_best or metrics["train/acc"] > self.train_acc_best:
222 | self.train_acc_best = metrics["train/acc"]
223 |
224 | if not self.val_loss_best or metrics["val/loss"] < self.val_loss_best:
225 | self.val_loss_best = metrics["val/loss"]
226 |
227 | if not self.val_acc_best or metrics["val/acc"] > self.val_acc_best:
228 | self.val_acc_best = metrics["val/acc"]
229 |
230 | experiment.log({"train/loss_best": self.train_loss_best}, commit=False)
231 | experiment.log({"train/acc_best": self.train_acc_best}, commit=False)
232 | experiment.log({"val/loss_best": self.val_loss_best}, commit=False)
233 | experiment.log({"val/acc_best": self.val_acc_best}, commit=False)
234 | '''
235 |
--------------------------------------------------------------------------------
/src/datamodules/GATs_spg_datamodule.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import LightningDataModule
2 | from torch.utils.data.dataloader import DataLoader
3 | from src.datasets.GATs_spg_dataset import GATsSPGDataset
4 |
5 |
6 | class GATsSPGDataModule(LightningDataModule):
7 | def __init__(self, *args, **kwargs):
8 | super().__init__()
9 |
10 | self.train_anno_file = kwargs['train_anno_file']
11 | self.val_anno_file = kwargs['val_anno_file']
12 | self.batch_size = kwargs['batch_size']
13 | self.num_workers = kwargs['num_workers']
14 | self.pin_memory = kwargs['pin_memory']
15 | self.num_leaf = kwargs['num_leaf']
16 | self.shape2d = kwargs['shape2d']
17 | self.shape3d = kwargs['shape3d']
18 | self.assign_pad_val = kwargs['assign_pad_val']
19 |
20 | self.data_train = None
21 | self.data_val = None
22 | self.data_test = None
23 |
24 | # Loader parameters
25 | self.train_loader_params = {
26 | 'batch_size': self.batch_size,
27 | 'shuffle': True,
28 | 'num_workers': self.num_workers,
29 | 'pin_memory': self.pin_memory,
30 | }
31 | self.val_loader_params = {
32 | 'batch_size': 1,
33 | 'shuffle': False,
34 | 'num_workers': self.num_workers,
35 | 'pin_memory': self.pin_memory,
36 | }
37 | self.test_loader_params = {
38 | 'batch_size': 1,
39 | 'shuffle': False,
40 | 'num_workers': self.num_workers,
41 | 'pin_memory': self.pin_memory,
42 | }
43 |
44 | def prepare_data(self):
45 | pass
46 |
47 | def setup(self, stage=None):
48 | """ Load data. Set variable: self.data_train, self.data_val, self.data_test"""
49 | trainset = GATsSPGDataset(
50 | anno_file=self.train_anno_file,
51 | num_leaf=self.num_leaf,
52 | split='train',
53 | shape2d=self.shape2d,
54 | shape3d=self.shape3d,
55 | pad_val=self.assign_pad_val
56 | )
57 | print("=> Read train anno file: ", self.train_anno_file)
58 |
59 | valset = GATsSPGDataset(
60 | anno_file=self.val_anno_file,
61 | num_leaf=self.num_leaf,
62 | split='val',
63 | shape2d=self.shape2d,
64 | shape3d=self.shape3d,
65 | pad_val=self.assign_pad_val,
66 | load_pose_gt=True
67 | )
68 | print("=> Read validation anno file: ", self.val_anno_file)
69 |
70 | self.data_train = trainset
71 | self.data_val = valset
72 | self.data_test = valset
73 |
74 | def train_dataloader(self):
75 | return DataLoader(dataset=self.data_train, **self.train_loader_params)
76 |
77 | def val_dataloader(self):
78 | return DataLoader(dataset=self.data_val, **self.val_loader_params)
79 |
80 | def test_dataloader(self):
81 | return DataLoader(dataset=self.data_test, **self.test_loader_params)
82 |
--------------------------------------------------------------------------------
/src/datasets/GATs_spg_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | try:
3 | import ujson as json
4 | except ImportError:
5 | import json
6 | import torch
7 | import numpy as np
8 |
9 | from pycocotools.coco import COCO
10 | from torch.utils.data import Dataset
11 | from src.utils import data_utils
12 |
13 |
14 | class GATsSPGDataset(Dataset):
15 | def __init__(
16 | self,
17 | anno_file,
18 | num_leaf,
19 | split,
20 | pad=True,
21 | shape2d=1000,
22 | shape3d=2000,
23 | pad_val=0,
24 | load_pose_gt=False,
25 | ):
26 | super(Dataset, self).__init__()
27 |
28 | self.coco = COCO(anno_file)
29 | self.anns = np.array(self.coco.getImgIds())
30 | self.num_leaf = num_leaf
31 |
32 | self.split = split
33 | self.pad = pad
34 | self.shape2d = shape2d
35 | self.shape3d = shape3d
36 | self.pad_val = pad_val
37 | self.load_pose_gt = load_pose_gt
38 |
39 | def get_intrin_by_color_path(self, color_path):
40 | intrin_path = color_path.replace('/color/', '/intrin_ba/').replace(
41 | '.png', '.txt'
42 | )
43 | K_crop = torch.from_numpy(np.loadtxt(intrin_path)) # [3, 3]
44 | return K_crop
45 |
46 | def get_gt_pose_by_color_path(self, color_path):
47 | gt_pose_path = color_path.replace('/color/', '/poses_ba/').replace(
48 | '.png', '.txt'
49 | )
50 | pose_gt = torch.from_numpy(np.loadtxt(gt_pose_path)) # [4, 4]
51 | return pose_gt
52 |
53 | def read_anno2d(self, anno2d_file, height, width, pad=True):
54 | """ Read (and pad) 2d info"""
55 | with open(anno2d_file, 'r') as f:
56 | data = json.load(f)
57 |
58 | keypoints2d = torch.Tensor(data['keypoints2d']) # [n, 2]
59 | descriptors2d = torch.Tensor(data['descriptors2d']) # [dim, n]
60 | scores2d = torch.Tensor(data['scores2d']) # [n, 1]
61 | assign_matrix = torch.Tensor(data['assign_matrix']) # [2, k]
62 |
63 | num_2d_orig = keypoints2d.shape[0]
64 |
65 | if pad:
66 | keypoints2d, descriptors2d, scores2d = data_utils.pad_keypoints2d_random(keypoints2d, descriptors2d, scores2d,
67 | height, width, self.shape2d)
68 | return keypoints2d, descriptors2d, scores2d, assign_matrix, num_2d_orig
69 |
70 | def read_anno3d(self, avg_anno3d_file, clt_anno3d_file, idxs_file, pad=True):
71 | """ Read(and pad) 3d info"""
72 | avg_data = np.load(avg_anno3d_file)
73 | clt_data = np.load(clt_anno3d_file)
74 | idxs = np.load(idxs_file)
75 |
76 | keypoints3d = torch.Tensor(clt_data['keypoints3d']) # [m, 3]
77 | avg_descriptors3d = torch.Tensor(avg_data['descriptors3d']) # [dim, m]
78 | clt_descriptors = torch.Tensor(clt_data['descriptors3d']) # [dim, k]
79 | avg_scores = torch.Tensor(avg_data['scores3d']) # [m, 1]
80 | clt_scores = torch.Tensor(clt_data['scores3d']) # [k, 1]
81 |
82 | num_3d_orig = keypoints3d.shape[0]
83 | if pad:
84 | keypoints3d = data_utils.pad_keypoints3d_random(keypoints3d, self.shape3d)
85 | avg_descriptors3d, avg_scores = data_utils.pad_features3d_random(avg_descriptors3d, avg_scores, self.shape3d)
86 | clt_descriptors, clt_scores = data_utils.build_features3d_leaves(clt_descriptors, clt_scores, idxs,
87 | self.shape3d, num_leaf=self.num_leaf)
88 | return keypoints3d, avg_descriptors3d, avg_scores, clt_descriptors, clt_scores, num_3d_orig
89 |
90 | def read_anno(self, img_id):
91 | """
92 | Read image, 2d info and 3d info.
93 | Pad 2d info and 3d info to a constant size.
94 | """
95 | ann_ids = self.coco.getAnnIds(imgIds=img_id)
96 | anno = self.coco.loadAnns(ann_ids)[0]
97 |
98 | color_path = self.coco.loadImgs(int(img_id))[0]['img_file']
99 | image = cv2.imread(color_path)
100 | height, width, _ = image.shape
101 |
102 | idxs_file = anno['idxs_file']
103 | avg_anno3d_file = anno['avg_anno3d_file']
104 | collect_anno3d_file = anno['collect_anno3d_file']
105 |
106 | # Load 3D points and features:
107 | (
108 | keypoints3d,
109 | avg_descriptors3d,
110 | avg_scores3d,
111 | clt_descriptors2d,
112 | clt_scores2d,
113 | num_3d_orig
114 | ) = self.read_anno3d(avg_anno3d_file, collect_anno3d_file, idxs_file, pad=self.pad)
115 |
116 | if self.split == 'train':
117 | anno2d_file = anno['anno2d_file']
118 | # Load 2D keypoints, features and GT 2D-3D correspondences:
119 | (
120 | keypoints2d,
121 | descriptors2d,
122 | scores2d,
123 | assign_matrix,
124 | num_2d_orig
125 | ) = self.read_anno2d(anno2d_file, height, width, pad=self.pad)
126 |
127 | # Construct GT conf_matrix:
128 | conf_matrix = data_utils.reshape_assign_matrix(
129 | assign_matrix,
130 | num_2d_orig,
131 | num_3d_orig,
132 | self.shape2d,
133 | self.shape3d,
134 | pad=True,
135 | pad_val=self.pad_val
136 | )
137 |
138 | data = {
139 | 'keypoints2d': keypoints2d, # [n1, 2]
140 | 'descriptors2d_query': descriptors2d, # [dim, n1]
141 | }
142 |
143 | elif self.split == 'val':
144 | image_gray = data_utils.read_gray_scale(color_path) / 255.
145 | data = {
146 | 'image': image_gray
147 | }
148 | conf_matrix = torch.Tensor([])
149 |
150 | data.update({
151 | 'keypoints3d': keypoints3d, # [n2, 3]
152 | 'descriptors3d_db': avg_descriptors3d, # [dim, n2]
153 | 'descriptors2d_db': clt_descriptors2d, # [dim, n2 * num_leaf]
154 | 'image_size': torch.Tensor([height, width])
155 | })
156 |
157 | if self.load_pose_gt:
158 | K_crop = self.get_intrin_by_color_path(color_path)
159 | pose_gt = self.get_gt_pose_by_color_path(color_path)
160 | data.update({'query_intrinsic': K_crop, 'query_pose_gt': pose_gt, 'query_image': image})
161 |
162 | return data, conf_matrix
163 |
164 | def __getitem__(self, index):
165 | img_id = self.anns[index]
166 |
167 | data, conf_matrix = self.read_anno(img_id)
168 | return data, conf_matrix
169 |
170 | def __len__(self):
171 | return len(self.anns)
--------------------------------------------------------------------------------
/src/datasets/normalized_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from types import SimpleNamespace
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class NormalizedDataset(Dataset):
9 | """read images(suppose images have been cropped)"""
10 | default_conf = {
11 | 'globs': ['*.jpg', '*.png'],
12 | 'grayscale': True,
13 | }
14 |
15 | def __init__(self, img_lists, conf):
16 | self.img_lists = img_lists
17 | self.conf = SimpleNamespace(**{**self.default_conf, **conf})
18 |
19 | if len(img_lists) == 0:
20 | raise ValueError('Could not find any image.')
21 |
22 | def __getitem__(self, index):
23 | img_path = self.img_lists[index]
24 |
25 | mode = cv2.IMREAD_GRAYSCALE if self.conf.grayscale else cv2.IMREAD_COLOR
26 | image = cv2.imread(img_path, mode)
27 | size = image.shape[:2]
28 |
29 | image = image.astype(np.float32)
30 | if self.conf.grayscale:
31 | image = image[None]
32 | else:
33 | image = image.transpose((2, 0, 1))
34 | image /= 255.
35 |
36 | data = {
37 | 'path': str(img_path),
38 | 'image': image,
39 | 'size': np.array(size),
40 | }
41 | return data
42 |
43 | def __len__(self):
44 | return len(self.img_lists)
--------------------------------------------------------------------------------
/src/evaluators/cmd_evaluator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Evaluator():
4 | def __init__(self):
5 | self.cmd1 = []
6 | self.cmd3 = []
7 | self.cmd5 = []
8 | self.cmd7 = []
9 | self.add = []
10 |
11 | def cm_degree_1_metric(self, pose_pred, pose_target):
12 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_target[:, 3]) * 100
13 | rotation_diff = np.dot(pose_pred[:, :3], pose_target[:, :3].T)
14 | trace = np.trace(rotation_diff)
15 | trace = trace if trace <= 3 else 3
16 | angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.))
17 | self.cmd1.append(translation_distance < 1 and angular_distance < 1)
18 |
19 | def cm_degree_5_metric(self, pose_pred, pose_target):
20 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_target[:, 3]) * 100
21 | rotation_diff = np.dot(pose_pred[:, :3], pose_target[:, :3].T)
22 | trace = np.trace(rotation_diff)
23 | trace = trace if trace <= 3 else 3
24 | angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.))
25 | self.cmd5.append(translation_distance < 5 and angular_distance < 5)
26 |
27 | def cm_degree_3_metric(self, pose_pred, pose_target):
28 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_target[:, 3]) * 100
29 | rotation_diff = np.dot(pose_pred[:, :3], pose_target[:, :3].T)
30 | trace = np.trace(rotation_diff)
31 | trace = trace if trace <= 3 else 3
32 | angular_distance = np.rad2deg(np.arccos((trace - 1.) / 2.))
33 | self.cmd3.append(translation_distance < 3 and angular_distance < 3)
34 |
35 | def evaluate(self, pose_pred, pose_gt):
36 | if pose_pred is None:
37 | self.cmd5.append(False)
38 | self.cmd1.append(False)
39 | self.cmd3.append(False)
40 | self.cmd7.append(False)
41 | else:
42 | if pose_pred.shape == (4, 4):
43 | pose_pred = pose_pred[:3, :4]
44 | if pose_gt.shape == (4, 4):
45 | pose_gt = pose_gt[:3, :4]
46 | self.cm_degree_1_metric(pose_pred, pose_gt)
47 | self.cm_degree_3_metric(pose_pred, pose_gt)
48 | self.cm_degree_5_metric(pose_pred, pose_gt)
49 |
50 | def summarize(self):
51 | cmd1 = np.mean(self.cmd1)
52 | cmd3 = np.mean(self.cmd3)
53 | cmd5 = np.mean(self.cmd5)
54 | print('1 cm 1 degree metric: {}'.format(cmd1))
55 | print('3 cm 3 degree metric: {}'.format(cmd3))
56 | print('5 cm 5 degree metric: {}'.format(cmd5))
57 |
58 | self.cmd1 = []
59 | self.cmd3 = []
60 | self.cmd5 = []
61 | self.cmd7 = []
62 | return {'cmd1': cmd1, 'cmd3': cmd3, 'cmd5': cmd5}
--------------------------------------------------------------------------------
/src/local_feature_2D_detector/__init__.py:
--------------------------------------------------------------------------------
1 | from .local_feature_2D_detector import LocalFeatureObjectDetector
--------------------------------------------------------------------------------
/src/local_feature_2D_detector/local_feature_2D_detector.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import time
3 | import cv2
4 | import torch
5 | import numpy as np
6 | from src.utils.colmap.read_write_model import read_model
7 | from src.utils.data_utils import get_K_crop_resize, get_image_crop_resize
8 | from src.utils.vis_utils import reproj
9 |
10 |
11 | def pack_extract_data(img_path):
12 | image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
13 |
14 | image = image[None] / 255.0
15 | return torch.Tensor(image)
16 |
17 |
18 | def pack_match_data(db_detection, query_detection, db_size, query_size):
19 | data = {}
20 | for k in db_detection.keys():
21 | data[k + "0"] = db_detection[k].__array__()
22 | for k in query_detection.keys():
23 | data[k + "1"] = query_detection[k].__array__()
24 | data = {k: torch.from_numpy(v)[None].float().cuda() for k, v in data.items()}
25 |
26 | data["image0"] = torch.empty(
27 | (
28 | 1,
29 | 1,
30 | )
31 | + tuple(db_size)[::-1]
32 | )
33 | data["image1"] = torch.empty(
34 | (
35 | 1,
36 | 1,
37 | )
38 | + tuple(query_size)[::-1]
39 | )
40 | return data
41 |
42 |
43 | class LocalFeatureObjectDetector():
44 | def __init__(self, extractor, matcher, sfm_ws_dir, n_ref_view=15, output_results=False, detect_save_dir=None, K_crop_save_dir=None):
45 | self.extractor = extractor.cuda()
46 | self.matcher = matcher.cuda()
47 | self.db_dict = self.extract_ref_view_features(sfm_ws_dir, n_ref_view)
48 | self.output_results = output_results
49 | self.detect_save_dir = detect_save_dir
50 | self.K_crop_save_dir = K_crop_save_dir
51 |
52 | def extract_ref_view_features(self, sfm_ws_dir, n_ref_views):
53 | assert osp.exists(sfm_ws_dir), f"SfM work space:{sfm_ws_dir} not exists!"
54 | cameras, images, points3D = read_model(sfm_ws_dir)
55 | idx = 0
56 | sample_gap = len(images) // n_ref_views
57 |
58 | # Prepare reference input to matcher:
59 | db_dict = {} # id: image
60 | for idx in range(1, len(images), sample_gap):
61 | db_img_path = images[idx].name
62 |
63 | db_img = pack_extract_data(db_img_path)
64 |
65 | # Detect DB image keypoints:
66 | db_inp = db_img[None].cuda()
67 | db_detection = self.extractor(db_inp)
68 | db_detection = {
69 | k: v[0].detach().cpu().numpy() for k, v in db_detection.items()
70 | }
71 | db_detection["size"] = np.array(db_img.shape[-2:])
72 | db_dict[idx] = db_detection
73 |
74 | return db_dict
75 |
76 | @torch.no_grad()
77 | def match_worker(self, query):
78 | detect_results_dict = {}
79 | for idx, db in self.db_dict.items():
80 | db_shape = db["size"]
81 | query_shape = query["size"]
82 |
83 | match_data = pack_match_data(db, query, db["size"], query["size"])
84 | match_pred = self.matcher(match_data)
85 | matches = match_pred["matches0"][0].detach().cpu().numpy()
86 | confs = match_pred["matching_scores0"][0].detach().cpu().numpy()
87 | valid = matches > -1
88 |
89 | mkpts0 = db["keypoints"][valid]
90 | mkpts1 = query["keypoints"][matches[valid]]
91 | confs = confs[valid]
92 |
93 | if mkpts0.shape[0] < 6:
94 | affine = None
95 | inliers = np.empty((0))
96 | detect_results_dict[idx] = {
97 | "inliers": inliers,
98 | "bbox": np.array([0, 0, query["size"][0], query["size"][1]]),
99 | }
100 | continue
101 |
102 | # Estimate affine and warp source image:
103 | affine, inliers = cv2.estimateAffinePartial2D(
104 | mkpts0, mkpts1, ransacReprojThreshold=6
105 | )
106 |
107 | # Estimate box:
108 | four_corner = np.array(
109 | [
110 | [0, 0, 1],
111 | [db_shape[1], 0, 1],
112 | [0, db_shape[0], 1],
113 | [db_shape[1], db_shape[0], 1],
114 | ]
115 | ).T # 3*4
116 |
117 | bbox = (affine @ four_corner).T.astype(np.int32) # 4*2
118 |
119 | left_top = np.min(bbox, axis=0)
120 | right_bottom = np.max(bbox, axis=0)
121 |
122 | w, h = right_bottom - left_top
123 | offset_percent = 0.0
124 | x0 = left_top[0] - int(w * offset_percent)
125 | y0 = left_top[1] - int(h * offset_percent)
126 | x1 = right_bottom[0] + int(w * offset_percent)
127 | y1 = right_bottom[1] + int(h * offset_percent)
128 |
129 | detect_results_dict[idx] = {
130 | "inliers": inliers,
131 | "bbox": np.array([x0, y0, x1, y1]),
132 | }
133 | return detect_results_dict
134 |
135 | def detect_by_matching(self, query):
136 | detect_results_dict = self.match_worker(query)
137 |
138 | # Sort multiple bbox candidate and use bbox with maxium inliers:
139 | idx_sorted = [
140 | k
141 | for k, _ in sorted(
142 | detect_results_dict.items(),
143 | reverse=True,
144 | key=lambda item: item[1]["inliers"].shape[0],
145 | )
146 | ]
147 | return detect_results_dict[idx_sorted[0]]["bbox"]
148 |
149 | def robust_crop(self, query_img_path, bbox, K, crop_size=512):
150 | x0, y0 = bbox[0], bbox[1]
151 | x1, y1 = bbox[2], bbox[3]
152 | x_c = (x0 + x1) / 2
153 | y_c = (y0 + y1) / 2
154 |
155 | origin_img = cv2.imread(query_img_path, cv2.IMREAD_GRAYSCALE)
156 | image_crop = origin_img
157 | K_crop, K_crop_homo = get_K_crop_resize(bbox, K, [crop_size, crop_size])
158 | return image_crop, K_crop
159 |
160 | def crop_img_by_bbox(self, query_img_path, bbox, K=None, crop_size=512):
161 | """
162 | Crop image by detect bbox
163 | Input:
164 | query_img_path: str,
165 | bbox: np.ndarray[x0, y0, x1, y1],
166 | K[optional]: 3*3
167 | Output:
168 | image_crop: np.ndarray[crop_size * crop_size],
169 | K_crop[optional]: 3*3
170 | """
171 | x0, y0 = bbox[0], bbox[1]
172 | x1, y1 = bbox[2], bbox[3]
173 | origin_img = cv2.imread(query_img_path, cv2.IMREAD_GRAYSCALE)
174 |
175 | resize_shape = np.array([y1 - y0, x1 - x0])
176 | if K is not None:
177 | K_crop, K_crop_homo = get_K_crop_resize(bbox, K, resize_shape)
178 | image_crop, trans1 = get_image_crop_resize(origin_img, bbox, resize_shape)
179 |
180 | bbox_new = np.array([0, 0, x1 - x0, y1 - y0])
181 | resize_shape = np.array([crop_size, crop_size])
182 | if K is not None:
183 | K_crop, K_crop_homo = get_K_crop_resize(bbox_new, K_crop, resize_shape)
184 | image_crop, trans2 = get_image_crop_resize(image_crop, bbox_new, resize_shape)
185 |
186 | return image_crop, K_crop if K is not None else None
187 |
188 | def save_detection(self, crop_img, query_img_path):
189 | if self.output_results and self.detect_save_dir is not None:
190 | cv2.imwrite(osp.join(self.detect_save_dir, osp.basename(query_img_path)), crop_img)
191 |
192 | def save_K_crop(self, K_crop, query_img_path):
193 | if self.output_results and self.K_crop_save_dir is not None:
194 | np.savetxt(osp.join(self.K_crop_save_dir, osp.splitext(osp.basename(query_img_path))[0] + '.txt'), K_crop) # K_crop: 3*3
195 |
196 | def detect(self, query_img, query_img_path, K, crop_size=512):
197 | """
198 | Detect object by local feature matching and crop image.
199 | Input:
200 | query_image: np.ndarray[1*1*H*W],
201 | query_img_path: str,
202 | K: np.ndarray[3*3], intrinsic matrix of original image
203 | Output:
204 | bounding_box: np.ndarray[x0, y0, x1, y1]
205 | cropped_image: torch.tensor[1 * 1 * crop_size * crop_size] (normalized),
206 | cropped_K: np.ndarray[3*3];
207 | """
208 | if len(query_img.shape) != 4:
209 | query_inp = query_img[None].cuda()
210 | else:
211 | query_inp = query_img.cuda()
212 |
213 | # Extract query image features:
214 | query_inp = self.extractor(query_inp)
215 | query_inp = {k: v[0].detach().cpu().numpy() for k, v in query_inp.items()}
216 | query_inp["size"] = np.array(query_img.shape[-2:])
217 |
218 | # Detect bbox and crop image:
219 | bbox = self.detect_by_matching(
220 | query=query_inp,
221 | )
222 | image_crop, K_crop = self.crop_img_by_bbox(query_img_path, bbox, K, crop_size=crop_size)
223 | self.save_detection(image_crop, query_img_path)
224 | self.save_K_crop(K_crop, query_img_path)
225 |
226 | # To Tensor:
227 | image_crop = image_crop.astype(np.float32) / 255
228 | image_crop_tensor = torch.from_numpy(image_crop)[None][None].cuda()
229 |
230 | return bbox, image_crop_tensor, K_crop
231 |
232 | def previous_pose_detect(self, query_img_path, K, pre_pose, bbox3D_corner, crop_size=512):
233 | """
234 | Detect object by projecting 3D bbox with estimated last frame pose.
235 | Input:
236 | query_image_path: str,
237 | K: np.ndarray[3*3], intrinsic matrix of original image
238 | pre_pose: np.ndarray[3*4] or [4*4], pose of last frame
239 | bbox3D_corner: np.ndarray[8*3], corner coordinate of annotated 3D bbox
240 | Output:
241 | bounding_box: np.ndarray[x0, y0, x1, y1]
242 | cropped_image: torch.tensor[1 * 1 * crop_size * crop_size] (normalized),
243 | cropped_K: np.ndarray[3*3];
244 | """
245 | # Project 3D bbox:
246 | proj_2D_coor = reproj(K, pre_pose, bbox3D_corner)
247 | x0, y0 = np.min(proj_2D_coor, axis=0)
248 | x1, y1 = np.max(proj_2D_coor, axis=0)
249 | bbox = np.array([x0, y0, x1, y1]).astype(np.int32)
250 |
251 | image_crop, K_crop = self.crop_img_by_bbox(query_img_path, bbox, K, crop_size=crop_size)
252 | self.save_detection(image_crop, query_img_path)
253 | self.save_K_crop(K_crop, query_img_path)
254 |
255 | # To Tensor:
256 | image_crop = image_crop.astype(np.float32) / 255
257 | image_crop_tensor = torch.from_numpy(image_crop)[None][None].cuda()
258 |
259 | return bbox, image_crop_tensor, K_crop
--------------------------------------------------------------------------------
/src/losses/focal_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class FocalLoss(nn.Module):
5 |
6 | def __init__(self, alpha=1, gamma=2, neg_weights=0.5, pos_weights=0.5):
7 | super(FocalLoss, self).__init__()
8 | self.alpha = alpha
9 | self.gamma = gamma
10 | self.neg_weights = neg_weights
11 | self.pos_weights = pos_weights
12 |
13 | def forward(self, pred, target):
14 | loss_pos = - self.alpha * torch.pow(1 - pred[target==1], self.gamma) * (pred[target==1]).log()
15 | loss_neg = - (1 - self.alpha) * torch.pow(pred[target==0], self.gamma) * (1 - pred[target==0]).log()
16 |
17 | assert len(loss_pos) != 0 or len(loss_neg) != 0, 'Invalid loss.'
18 | # operate mean operation on an empty list will lead to NaN
19 | if len(loss_pos) == 0:
20 | loss_mean = self.neg_weights * loss_neg.mean()
21 | elif len(loss_neg) == 0:
22 | loss_mean = self.pos_weights * loss_pos.mean()
23 | else:
24 | loss_mean = self.pos_weights * loss_pos.mean() + self.neg_weights * loss_neg.mean()
25 |
26 | return loss_mean
--------------------------------------------------------------------------------
/src/models/GATsSPG_architectures/GATs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class GraphAttentionLayer(nn.Module):
7 | def __init__(
8 | self,
9 | in_features,
10 | out_features,
11 | dropout,
12 | alpha,
13 | concat=True,
14 | include_self=True,
15 | additional=False,
16 | with_linear_transform=True
17 | ):
18 | super(GraphAttentionLayer, self).__init__()
19 | self.dropout = dropout
20 | self.in_features = in_features
21 | self.out_features = out_features
22 | self.alpha = alpha
23 | self.concat = concat
24 |
25 | self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
26 | nn.init.xavier_normal_(self.W.data, gain=1.414)
27 | self.a = nn.Parameter(torch.empty(size=(2 * out_features, 1)))
28 | nn.init.xavier_normal_(self.a.data, gain=1.414)
29 |
30 | self.leakyrelu = nn.LeakyReLU(self.alpha)
31 | self.include_self = include_self
32 | self.with_linear_transform = with_linear_transform
33 | self.additional = additional
34 |
35 | def forward(self, h_2d, h_3d):
36 | b, n1, dim = h_3d.shape
37 | b, n2, dim = h_2d.shape
38 | num_leaf = int(n2 / n1)
39 |
40 | wh_2d = torch.matmul(h_2d, self.W)
41 | wh_3d = torch.matmul(h_3d, self.W)
42 |
43 | e = self._prepare_attentional_mechanism_input(wh_2d, wh_3d, num_leaf, self.include_self)
44 | attention = F.softmax(e, dim=2)
45 |
46 | h_2d = torch.reshape(h_2d, (b, n1, num_leaf, dim))
47 | wh_2d = torch.reshape(wh_2d, (b, n1, num_leaf, dim))
48 | if self.include_self:
49 | wh_2d = torch.cat(
50 | [wh_3d.unsqueeze(-2), wh_2d], dim=-2
51 | ) # [b, N, 1+num_leaf, d_out]
52 | h_2d = torch.cat(
53 | [h_3d.unsqueeze(-2), h_2d], dim=-2
54 | )
55 |
56 | if self.with_linear_transform:
57 | h_prime = torch.einsum('bncd,bncq->bnq', attention, wh_2d)
58 | else:
59 | h_prime = torch.einsum('bncd,bncq->bnq', attention, h_2d)
60 |
61 | if self.additional:
62 | h_prime = h_prime + h_3d
63 | else:
64 | if self.with_linear_transform:
65 | h_prime = torch.einsum('bncd,bncq->bnq', attention, wh_2d) / 2. + wh_3d
66 | else:
67 | h_prime = torch.einsum('bncd,bncq->bnq', attention, h_2d) / 2. + h_3d
68 |
69 | if self.concat:
70 | return F.elu(h_prime)
71 | else:
72 | return h_prime
73 |
74 | def _prepare_attentional_mechanism_input(self, wh_2d, wh_3d, num_leaf, include_self=False):
75 | b, n1, dim = wh_3d.shape
76 | b, n2, dim = wh_2d.shape
77 |
78 | wh_2d_ = torch.matmul(wh_2d, self.a[:self.out_features, :]) # [b, N2, 1]
79 | wh_2d_ = torch.reshape(wh_2d_, (b, n1, num_leaf, -1)) # [b, n1, 6, 1]
80 | wh_3d_ = torch.matmul(wh_3d, self.a[self.out_features:, :]) # [b, N1, 1]
81 |
82 | if include_self:
83 | wh_2d_ = torch.cat(
84 | [wh_3d_.unsqueeze(2), wh_2d_], dim=-2
85 | ) # [b, N1, 1 + num_leaf, 1]
86 |
87 | e = wh_3d_.unsqueeze(2) + wh_2d_
88 | return self.leakyrelu(e)
89 |
--------------------------------------------------------------------------------
/src/models/GATsSPG_architectures/GATs_SuperGlue.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .GATs import GraphAttentionLayer
6 |
7 |
8 | def arange_like(x, dim: int):
9 | return x.new_ones(x.shape[dim]).cumsum(0) - 1
10 |
11 |
12 | def buildAdjMatrix(num_2d, num_3d):
13 | num_leaf = int(num_2d / num_3d)
14 |
15 | adj_matrix = torch.zeros(num_3d, num_2d)
16 | for i in range(num_3d):
17 | adj_matrix[i, num_leaf*i: num_leaf*(i+1)] = 1 / num_leaf
18 | return adj_matrix.cuda()
19 |
20 |
21 | class AttentionalGNN(nn.Module):
22 |
23 | def __init__(
24 | self,
25 | feature_dim: int,
26 | layer_names: list,
27 | include_self: bool,
28 | additional: bool,
29 | with_linear_transform: bool
30 | ):
31 | super().__init__()
32 |
33 | self.layers = nn.ModuleList([
34 | GraphAttentionLayer(
35 | in_features=256,
36 | out_features=256,
37 | dropout=0.6,
38 | alpha=0.2,
39 | concat=True,
40 | include_self=include_self,
41 | additional=additional,
42 | with_linear_transform=with_linear_transform,
43 | ) if i % 3 == 0 else AttentionPropagation(feature_dim, 4)
44 | for i in range(len(layer_names))
45 | ])
46 | self.names = layer_names
47 |
48 | def forward(self, desc2d_query, desc3d_db, desc2d_db):
49 | for layer, name in zip(self.layers, self.names):
50 | if name == 'GATs':
51 | desc2d_db_ = torch.einsum('bdn->bnd', desc2d_db)
52 | desc3d_db_ = torch.einsum('bdn->bnd', desc3d_db)
53 | desc3d_db_ = layer(desc2d_db_, desc3d_db_)
54 | desc3d_db = torch.einsum('bnd->bdn', desc3d_db_)
55 | elif name == 'cross':
56 | layer.attn.prob = []
57 | src0, src1 = desc3d_db, desc2d_query # [b, c, l1], [b, c, l2]
58 | delta0, delta1 = layer(desc2d_query, src0), layer(desc3d_db, src1)
59 | desc2d_query, desc3d_db = (desc2d_query + delta0), (desc3d_db + delta1)
60 | elif name == 'self':
61 | layer.attn.prob = []
62 | src0, src1 = desc2d_query, desc3d_db
63 | delta0, delta1 = layer(desc2d_query, src0), layer(desc3d_db, src1)
64 | desc2d_query, desc3d_db = (desc2d_query + delta0), (desc3d_db + delta1)
65 |
66 | return desc2d_query, desc3d_db
67 |
68 |
69 | def linear_attention(query, key, value):
70 | eps = 1e-6
71 | query = F.elu(query) + 1
72 | key = F.elu(key) + 1
73 |
74 | v_length = value.size(3)
75 | value = value / v_length
76 |
77 | KV = torch.einsum('bdhm,bqhm->bqdh', key, value)
78 | Z = 1 / (torch.einsum('bdhm,bdh->bhm', query, key.sum(3)) + eps)
79 | queried_values = torch.einsum('bdhm,bqdh,bhm->bqhm', query, KV, Z) * v_length
80 | return queried_values.contiguous()
81 |
82 |
83 | class MultiHeadedAttention(nn.Module):
84 | """ Multi-head attention to increase model expressivity"""
85 | def __init__(self, num_heads: int, d_model: int):
86 | super().__init__()
87 | assert d_model % num_heads == 0
88 | self.dim = d_model // num_heads
89 | self.num_heads = num_heads
90 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1)
91 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)])
92 |
93 | def forward(self, query, key, value):
94 | batch_dim = query.size(0)
95 |
96 | query, key, value = [
97 | l(x).view(batch_dim, self.dim, self.num_heads, -1)
98 | for l, x in zip(self.proj, (query, key, value))
99 | ]
100 | x = linear_attention(query, key, value)
101 | return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1))
102 |
103 |
104 | class AttentionPropagation(nn.Module):
105 | def __init__(self, feature_dim: int, num_heads: int):
106 | super().__init__()
107 | self.attn = MultiHeadedAttention(num_heads, feature_dim)
108 | self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim])
109 | nn.init.constant_(self.mlp[-1].bias, 0.)
110 |
111 | def forward(self, x, source):
112 | message = self.attn(x, source, source)
113 | return self.mlp(torch.cat([x, message], dim=1)) # [b, 2c, 1000] / [b, 2c, 2000]
114 |
115 |
116 | def MLP(channels: list, do_bn=True):
117 | """ Multi-layer perceptron"""
118 | n = len(channels)
119 | layers = []
120 | for i in range(1, n):
121 | layers.append(
122 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)
123 | )
124 | if i < n -1:
125 | if do_bn:
126 | layers.append(nn.InstanceNorm1d(channels[i]))
127 | layers.append(nn.ReLU())
128 | return nn.Sequential(*layers)
129 |
130 |
131 | class KeypointEncoder(nn.Module):
132 | """ Joint encoding of visual appearance and location using MLPs """
133 | def __init__(self, inp_dim, feature_dim, layers):
134 | super().__init__()
135 | self.encoder = MLP([inp_dim] + list(layers) + [feature_dim])
136 | nn.init.constant_(self.encoder[-1].bias, 0.0)
137 |
138 | def forward(self, kpts, scores):
139 | inputs = [kpts.transpose(1, 2), scores.unsqueeze(1)]
140 | return self.encoder(torch.cat(inputs, dim=1))
141 |
142 |
143 | class GATsSuperGlue(nn.Module):
144 |
145 | def __init__(self, hparams):
146 | super().__init__()
147 | self.hparams = hparams
148 | self.match_type = hparams['match_type']
149 |
150 | self.kenc_2d = KeypointEncoder(
151 | inp_dim=3,
152 | feature_dim=hparams['descriptor_dim'],
153 | layers=hparams['keypoints_encoder']
154 | )
155 |
156 | self.kenc_3d = KeypointEncoder(
157 | inp_dim=4,
158 | feature_dim=hparams['descriptor_dim'],
159 | layers=hparams['keypoints_encoder']
160 | )
161 |
162 | GNN_layers = ['GATs', 'self', 'cross'] * 4
163 | self.gnn = AttentionalGNN(
164 | feature_dim=hparams['descriptor_dim'],
165 | layer_names=GNN_layers,
166 | include_self=hparams['include_self'],
167 | additional=hparams['additional'],
168 | with_linear_transform=hparams['with_linear_transform']
169 | )
170 | self.final_proj = nn.Conv1d(
171 | in_channels=hparams['descriptor_dim'],
172 | out_channels=hparams['descriptor_dim'],
173 | kernel_size=1,
174 | bias=True
175 | )
176 | bin_score = nn.Parameter(torch.tensor(1.))
177 | self.register_parameter('bin_score', bin_score)
178 |
179 | def forward(self, data):
180 | """
181 | Keys of data:
182 | keypoints2d: [b, n1, 2]
183 | keypoints3d: [b, n2, 3]
184 | descriptors2d_query: [b, dim, n1]
185 | descriptors3d_db: [b, dim, n2]
186 | descriptors2d_db: [b, dim, n2 * num_leaf]
187 | scores2d_query: [b, n1, 1]
188 | scores3d_db: [b, n2, 1]
189 | scores2d_db: [b, n2 * num_leaf, 1]
190 | """
191 | kpts2d, kpts3d = data['keypoints2d'].float(), data['keypoints3d'].float()
192 | desc2d_query = data['descriptors2d_query'].float()
193 | desc3d_db, desc2d_db = data['descriptors3d_db'].float(), data['descriptors2d_db'].float()
194 |
195 | if kpts2d.shape[1] == 0 or kpts3d.shape[1] == 0:
196 | shape0, shape1 = kpts2d.shape[:-1], kpts3d.shape[:-1]
197 | return {
198 | 'matches0': kpts2d.new_full(shape0, -1, dtype=torch.int)[0],
199 | 'matches1': kpts3d.new_full(shape1, -1, dtype=torch.int)[0],
200 | 'matching_scores0': kpts2d.new_zeros(shape0)[0],
201 | 'matching_scores1': kpts3d.new_zeros(shape1)[0],
202 | 'skip_train': True
203 | }
204 |
205 | # Multi-layer Transformer network
206 | desc2d_query, desc3d_db = self.gnn(desc2d_query, desc3d_db, desc2d_db)
207 |
208 | # Final MLP projection
209 | mdesc2d_query, mdesc3d_db = self.final_proj(desc2d_query), self.final_proj(desc3d_db)
210 |
211 | # Normalize mdesc to avoid NaN
212 | mdesc2d_query = F.normalize(mdesc2d_query, p=2, dim=1)
213 | mdesc3d_db = F.normalize(mdesc3d_db, p=2, dim=1)
214 |
215 | # Get the matches with score above "match_threshold"
216 | if self.match_type == "softmax":
217 | scores = torch.einsum('bdn,bdm->bnm', mdesc2d_query, mdesc3d_db) / self.hparams['scale_factor']
218 | conf_matrix = F.softmax(scores, 1) * F.softmax(scores, 2)
219 |
220 | max0, max1 = conf_matrix[:, :, :].max(2), conf_matrix[:, :, :].max(1)
221 | indices0, indices1 = max0.indices, max1.indices
222 | mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
223 | mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
224 | zero = conf_matrix.new_tensor(0)
225 | mscores0 = torch.where(mutual0, max0.values, zero)
226 | mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero)
227 | valid0 = mutual0 & (mscores0 > self.hparams['match_threshold'])
228 | valid1 = mutual1 & valid0.gather(1, indices1)
229 | indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
230 | indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
231 |
232 | pred = {
233 | 'matches0': indices0[0], # use -1 for invalid match
234 | 'matches1': indices1[0], # use -1 for invalid match
235 | 'matching_scores0': mscores0[0],
236 | 'matching_scores1': mscores1[0],
237 | }
238 | else:
239 | raise NotImplementedError
240 |
241 | return pred, conf_matrix
242 |
--------------------------------------------------------------------------------
/src/models/GATsSPG_lightning_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 |
4 | from itertools import chain
5 | from src.models.GATsSPG_architectures.GATs_SuperGlue import GATsSuperGlue
6 | from src.losses.focal_loss import FocalLoss
7 | from src.utils.eval_utils import compute_query_pose_errors, aggregate_metrics
8 | from src.utils.vis_utils import draw_reprojection_pair
9 | from src.utils.comm import gather
10 | from src.models.extractors.SuperPoint.superpoint import SuperPoint
11 | from src.sfm.extract_features import confs
12 | from src.utils.model_io import load_network
13 |
14 |
15 | class LitModelGATsSPG(pl.LightningModule):
16 |
17 | def __init__(self, *args, **kwargs):
18 | super().__init__()
19 |
20 | self.save_hyperparameters()
21 | self.extractor = SuperPoint(confs['superpoint'])
22 | load_network(self.extractor, self.hparams.spp_model_path, force=False)
23 | self.matcher = GATsSuperGlue(hparams=self.hparams)
24 | self.crit = FocalLoss(
25 | alpha=self.hparams.focal_loss_alpha,
26 | gamma=self.hparams.focal_loss_gamma,
27 | neg_weights=self.hparams.neg_weights,
28 | pos_weights=self.hparams.pos_weights
29 | )
30 | self.n_vals_plot = 10
31 |
32 | self.train_loss_hist = []
33 | self.val_loss_hist = []
34 | self.save_flag = True
35 |
36 | def forward(self, x):
37 | return self.matcher(x)
38 |
39 | def training_step(self, batch, batch_idx):
40 | self.save_flag = False
41 | data, conf_matrix_gt = batch
42 | preds, conf_matrix_pred = self.matcher(data)
43 |
44 | loss_mean = self.crit(conf_matrix_pred, conf_matrix_gt)
45 | if (
46 | self.trainer.global_rank == 0
47 | and self.global_step % self.trainer.log_every_n_steps == 0
48 | ):
49 | self.logger.experiment.add_scalar('train/loss', loss_mean, self.global_step)
50 |
51 | return {'loss': loss_mean, 'preds': preds}
52 |
53 | def validation_step(self, batch, batch_idx):
54 | data, _ = batch
55 | extraction = self.extractor(data['image'])
56 | data.update({
57 | 'keypoints2d': extraction['keypoints'][0][None],
58 | 'descriptors2d_query': extraction['descriptors'][0][None],
59 | })
60 | preds, conf_matrix_pred = self.matcher(data)
61 |
62 | pose_pred, val_results = compute_query_pose_errors(data, preds)
63 |
64 | # Visualize match:
65 | val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
66 | figures = {'evaluation': []}
67 | if batch_idx % val_plot_interval == 0:
68 | figures = draw_reprojection_pair(data, val_results, visual_color_type='conf')
69 |
70 | loss_mean = 0
71 | self.log('val/loss', loss_mean, on_step=False, on_epoch=True, prog_bar=False)
72 | del data
73 | return {'figures': figures, 'metrics': val_results}
74 |
75 | def test_step(self, batch, batch_idx):
76 | pass
77 |
78 | def training_epoch_end(self, outputs):
79 | avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
80 | if self.trainer.global_rank == 0:
81 | self.logger.experiment.add_scalar(
82 | 'train/avg_loss_on_epoch', avg_loss, global_step=self.current_epoch
83 | )
84 |
85 | def validation_epoch_end(self, outputs):
86 | self.val_loss_hist.append(self.trainer.callback_metrics['val/loss'])
87 | self.log('val/loss_best', min(self.val_loss_hist), prog_bar=False)
88 |
89 | multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
90 | for valset_idx, outputs in enumerate(multi_outputs):
91 | cur_epoch = self.trainer.current_epoch
92 | if not self.trainer.resume_from_checkpoint and self.trainer.sanity_checking:
93 | cur_epoch = -1
94 |
95 | def flattenList(x): return list(chain(*x))
96 |
97 | # Log val metrics: dict of list, numpy
98 | _metrics = [o['metrics'] for o in outputs]
99 | metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
100 |
101 | # Log figures:
102 | _figures = [o['figures'] for o in outputs]
103 | figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
104 |
105 | # NOTE: tensorboard records only on rank 0
106 | if self.trainer.global_rank == 0:
107 | val_metrics_4tb = aggregate_metrics(metrics)
108 | for k, v in val_metrics_4tb.items():
109 | self.logger.experiment.add_scalar(f'metrics_{valset_idx}/{k}', v, global_step=cur_epoch)
110 |
111 | for k, v in figures.items():
112 | for plot_idx, fig in enumerate(v):
113 | self.logger.experiment.add_figure(
114 | f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True
115 | )
116 |
117 | def configure_optimizers(self):
118 | if self.hparams.optimizer == 'adam':
119 | optimizer = torch.optim.Adam(
120 | self.parameters(),
121 | lr=self.hparams.lr,
122 | weight_decay=self.hparams.weight_decay
123 | )
124 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
125 | milestones=self.hparams.milestones,
126 | gamma=self.hparams.gamma)
127 | return [optimizer], [lr_scheduler]
128 | else:
129 | raise Exception("Invalid optimizer name.")
--------------------------------------------------------------------------------
/src/models/extractors/SuperPoint/superpoint.py:
--------------------------------------------------------------------------------
1 | # %BANNER_BEGIN%
2 | # ---------------------------------------------------------------------
3 | # %COPYRIGHT_BEGIN%
4 | #
5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
6 | #
7 | # Unpublished Copyright (c) 2020
8 | # Magic Leap, Inc., All Rights Reserved.
9 | #
10 | # NOTICE: All information contained herein is, and remains the property
11 | # of COMPANY. The intellectual and technical concepts contained herein
12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign
13 | # Patents, patents in process, and are protected by trade secret or
14 | # copyright law. Dissemination of this information or reproduction of
15 | # this material is strictly forbidden unless prior written permission is
16 | # obtained from COMPANY. Access to the source code contained herein is
17 | # hereby forbidden to anyone except current COMPANY employees, managers
18 | # or contractors who have executed Confidentiality and Non-disclosure
19 | # agreements explicitly covering such access.
20 | #
21 | # The copyright notice above does not evidence any actual or intended
22 | # publication or disclosure of this source code, which includes
23 | # information that is confidential and/or proprietary, and is a trade
24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
32 | #
33 | # %COPYRIGHT_END%
34 | # ----------------------------------------------------------------------
35 | # %AUTHORS_BEGIN%
36 | #
37 | # Originating Authors: Paul-Edouard Sarlin
38 | #
39 | # %AUTHORS_END%
40 | # --------------------------------------------------------------------*/
41 | # %BANNER_END%
42 |
43 | from pathlib import Path
44 | import torch
45 | from torch import nn
46 |
47 | def simple_nms(scores, nms_radius: int):
48 | """ Fast Non-maximum suppression to remove nearby points """
49 | assert(nms_radius >= 0)
50 |
51 | def max_pool(x):
52 | return torch.nn.functional.max_pool2d(
53 | x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
54 |
55 | zeros = torch.zeros_like(scores)
56 | max_mask = scores == max_pool(scores)
57 | for _ in range(2):
58 | supp_mask = max_pool(max_mask.float()) > 0
59 | supp_scores = torch.where(supp_mask, zeros, scores)
60 | new_max_mask = supp_scores == max_pool(supp_scores)
61 | max_mask = max_mask | (new_max_mask & (~supp_mask))
62 | return torch.where(max_mask, scores, zeros)
63 |
64 |
65 | def remove_borders(keypoints, scores, border: int, height: int, width: int):
66 | """ Removes keypoints too close to the border """
67 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
68 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
69 | mask = mask_h & mask_w
70 | return keypoints[mask], scores[mask]
71 |
72 |
73 | def top_k_keypoints(keypoints, scores, k: int):
74 | if k >= len(keypoints):
75 | return keypoints, scores
76 | scores, indices = torch.topk(scores, k, dim=0)
77 | return keypoints[indices], scores
78 |
79 |
80 | def sample_descriptors(keypoints, descriptors, s: int = 8):
81 | """ Interpolate descriptors at keypoint locations """
82 | b, c, h, w = descriptors.shape
83 | keypoints = keypoints - s / 2 + 0.5
84 | keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
85 | ).to(keypoints)[None]
86 | keypoints = keypoints*2 - 1 # normalize to (-1, 1)
87 | args = {'align_corners': True} if int(torch.__version__[2]) > 2 else {}
88 | descriptors = torch.nn.functional.grid_sample(
89 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
90 | descriptors = torch.nn.functional.normalize(
91 | descriptors.reshape(b, c, -1), p=2, dim=1)
92 | return descriptors
93 |
94 |
95 | class SuperPoint(nn.Module):
96 | """SuperPoint Convolutional Detector and Descriptor
97 |
98 | SuperPoint: Self-Supervised Interest Point Detection and
99 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
100 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
101 |
102 | """
103 | default_config = {
104 | 'descriptor_dim': 256,
105 | 'nms_radius': 4,
106 | 'keypoint_threshold': 0.005,
107 | 'max_keypoints': -1,
108 | 'remove_borders': 4,
109 | }
110 |
111 | def __init__(self, config):
112 | super().__init__()
113 | self.config = {**self.default_config, **config}
114 |
115 | self.relu = nn.ReLU(inplace=True)
116 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
117 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
118 |
119 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
120 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
121 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
122 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
123 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
124 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
125 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
126 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
127 |
128 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
129 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
130 |
131 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
132 | self.convDb = nn.Conv2d(
133 | c5, self.config['descriptor_dim'],
134 | kernel_size=1, stride=1, padding=0)
135 |
136 | mk = self.config['max_keypoints']
137 | if mk == 0 or mk < -1:
138 | raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
139 |
140 | def forward(self, inp):
141 | """ Compute keypoints, scores, descriptors for image """
142 | # Shared Encoder
143 | x = self.relu(self.conv1a(inp))
144 | x = self.relu(self.conv1b(x))
145 | x = self.pool(x)
146 | x = self.relu(self.conv2a(x))
147 | x = self.relu(self.conv2b(x))
148 | x = self.pool(x)
149 | x = self.relu(self.conv3a(x))
150 | x = self.relu(self.conv3b(x))
151 | x = self.pool(x)
152 | x = self.relu(self.conv4a(x))
153 | x = self.relu(self.conv4b(x))
154 |
155 | # Compute the dense keypoint scores
156 | cPa = self.relu(self.convPa(x))
157 | scores = self.convPb(cPa)
158 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
159 | b, _, h, w = scores.shape
160 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
161 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
162 | scores = simple_nms(scores, self.config['nms_radius'])
163 |
164 | # Extract keypoints
165 | keypoints = [
166 | torch.nonzero(s > self.config['keypoint_threshold'])
167 | for s in scores]
168 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
169 |
170 | # Discard keypoints near the image borders
171 | keypoints, scores = list(zip(*[
172 | remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
173 | for k, s in zip(keypoints, scores)]))
174 |
175 | # Keep the k keypoints with highest score
176 | if self.config['max_keypoints'] >= 0:
177 | keypoints, scores = list(zip(*[
178 | top_k_keypoints(k, s, self.config['max_keypoints'])
179 | for k, s in zip(keypoints, scores)]))
180 |
181 | # Convert (h, w) to (x, y)
182 | keypoints = [torch.flip(k, [1]).float() for k in keypoints]
183 |
184 | # Compute the dense descriptors
185 | cDa = self.relu(self.convDa(x))
186 | descriptors = self.convDb(cDa)
187 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
188 |
189 | # Extract descriptors
190 | descriptors = [sample_descriptors(k[None], d[None], 8)[0]
191 | for k, d in zip(keypoints, descriptors)]
192 |
193 | return {
194 | 'keypoints': keypoints,
195 | 'scores': scores,
196 | 'descriptors': descriptors,
197 | }
198 |
--------------------------------------------------------------------------------
/src/models/matchers/nn/nearest_neighbour.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def find_nn(sim, ratio_thresh, distance_thresh):
6 | sim_nn, ind_nn = sim.topk(2 if ratio_thresh else 1, dim=-1, largest=True)
7 | dist_nn = 2 * (1 - sim_nn)
8 | mask = torch.ones(ind_nn.shape[:-1], dtype=torch.bool, device=sim.device)
9 | if ratio_thresh:
10 | mask = mask & (dist_nn[..., 0] <= (ratio_thresh ** 2) * dist_nn[..., 1])
11 | if distance_thresh:
12 | mask = mask & (dist_nn[..., 0] <= distance_thresh ** 2)
13 | matches = torch.where(mask, ind_nn[..., 0], ind_nn.new_tensor(-1))
14 | scores = torch.where(mask, (sim_nn[..., 0] + 1) / 2, sim_nn.new_tensor(0))
15 | return matches, scores
16 |
17 |
18 | def mutual_check(m0, m1):
19 | inds0 = torch.arange(m0.shape[-1], device=m0.device)
20 | loop = torch.gather(m1, -1, torch.where(m0 > -1, m0, m0.new_tensor(0)))
21 | ok = (m0 > -1) & (inds0 == loop)
22 | m0_new = torch.where(ok, m0, m0.new_tensor(-1))
23 | return m0_new
24 |
25 |
26 | class NearestNeighbour(nn.Module):
27 | default_conf = {
28 | 'match_threshold': None,
29 | 'ratio_threshold': None,
30 | 'distance_threshold': None,
31 | 'do_mutual_check': True
32 | }
33 |
34 | def __init__(self, conf=None):
35 | super().__init__()
36 | if conf:
37 | self.conf = conf = {**self.default_conf, **conf}
38 | else:
39 | self.conf = self.default_conf
40 |
41 | def forward(self, data):
42 | sim = torch.einsum('bdn, bdm->bnm', data['descriptors0'], data['descriptors1'])
43 | matches0, scores0 = find_nn(
44 | sim, self.conf['ratio_threshold'], self.conf['distance_threshold']
45 | )
46 |
47 | matches1, scores1 = find_nn(
48 | sim.transpose(1, 2), self.conf['ratio_threshold'], self.conf['distance_threshold']
49 | )
50 |
51 | if self.conf['do_mutual_check']:
52 | matches1, scores1 = find_nn(
53 | sim.transpose(1, 2), self.conf['ratio_threshold'],
54 | self.conf['distance_threshold']
55 | )
56 | matches0 = mutual_check(matches0, matches1)
57 |
58 | return {
59 | 'matches0': matches0,
60 | 'matches1': matches1,
61 | 'matching_scores0': scores0,
62 | 'matching_scores1': scores1
63 | }
--------------------------------------------------------------------------------
/src/sfm/extract_features.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import tqdm
3 | import torch
4 | import logging
5 |
6 | from torch.utils.data import DataLoader
7 |
8 | confs = {
9 | 'superpoint': {
10 | 'output': 'feats-spp',
11 | 'model': {
12 | 'name': 'spp_det',
13 | },
14 | 'preprocessing': {
15 | 'grayscale': True,
16 | 'resize_h': 512,
17 | 'resize_w': 512
18 | },
19 | 'conf': {
20 | 'descriptor_dim': 256,
21 | 'nms_radius': 3,
22 | 'max_keypoints': 4096,
23 | 'keypoints_threshold': 0.6
24 | }
25 | }
26 | }
27 |
28 |
29 | @torch.no_grad()
30 | def spp(img_lists, feature_out, cfg):
31 | """extract keypoints info by superpoint"""
32 | from src.utils.model_io import load_network
33 | from src.models.extractors.SuperPoint.superpoint import SuperPoint as spp_det
34 | from src.datasets.normalized_dataset import NormalizedDataset
35 |
36 | conf = confs[cfg.network.detection]
37 | model = spp_det(conf['conf']).cuda()
38 | model.eval()
39 | load_network(model, cfg.network.detection_model_path, force=True)
40 |
41 | dataset = NormalizedDataset(img_lists, conf['preprocessing'])
42 | loader = DataLoader(dataset, num_workers=1)
43 |
44 | feature_file = h5py.File(feature_out, 'w')
45 | logging.info(f'Exporting features to {feature_out}')
46 | for data in tqdm.tqdm(loader):
47 | inp = data['image'].cuda()
48 | pred = model(inp)
49 |
50 | pred = {k: v[0].cpu().numpy() for k, v in pred.items()}
51 | pred['image_size'] = data['size'][0].numpy()
52 |
53 | grp = feature_file.create_group(data['path'][0])
54 | for k, v in pred.items():
55 | grp.create_dataset(k, data=v)
56 |
57 | del pred
58 |
59 | feature_file.close()
60 | logging.info('Finishing exporting features.')
61 |
62 |
63 | def main(img_lists, feature_out, cfg):
64 | if cfg.network.detection == 'superpoint':
65 | spp(img_lists, feature_out, cfg)
66 | else:
67 | raise NotImplementedError
--------------------------------------------------------------------------------
/src/sfm/generate_empty.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import logging
3 | import os.path as osp
4 | import numpy as np
5 |
6 | from pathlib import Path
7 | from src.utils import path_utils
8 | from src.utils.colmap.read_write_model import Camera, Image, Point3D
9 | from src.utils.colmap.read_write_model import rotmat2qvec
10 | from src.utils.colmap.read_write_model import write_model
11 |
12 |
13 | def get_pose_from_txt(img_index, pose_dir):
14 | """ Read 4x4 transformation matrix from txt """
15 | pose_file = osp.join(pose_dir, '{}.txt'.format(img_index))
16 | pose = np.loadtxt(pose_file)
17 |
18 | tvec = pose[:3, 3].reshape(3, )
19 | qvec = rotmat2qvec(pose[:3, :3]).reshape(4, )
20 | return pose, tvec, qvec
21 |
22 |
23 | def get_intrin_from_txt(img_index, intrin_dir):
24 | """ Read 3x3 intrinsic matrix from txt """
25 | intrin_file = osp.join(intrin_dir, '{}.txt'.format(img_index))
26 | intrin = np.loadtxt(intrin_file)
27 |
28 | return intrin
29 |
30 |
31 | def import_data(img_lists, do_ba=False):
32 | """ Import intrinsics and camera pose info """
33 | points3D_out = {}
34 | images_out = {}
35 | cameras_out = {}
36 |
37 | def compare(img_name):
38 | key = img_name.split('/')[-1]
39 | return int(key.split('.')[0])
40 | img_lists.sort(key=compare)
41 |
42 | key, img_id, camera_id = 0, 0, 0
43 | xys_ = np.zeros((0, 2), float)
44 | point3D_ids_ = np.full(0, -1, int) # will be filled after triangulation
45 |
46 | # import data
47 | for img_path in img_lists:
48 | key += 1
49 | img_id += 1
50 | camera_id += 1
51 |
52 | img_name = img_path.split('/')[-1]
53 | base_dir = osp.dirname(osp.dirname(img_path))
54 | img_index = int(img_name.split('.')[0])
55 |
56 | # read pose
57 | pose_dir = path_utils.get_gt_pose_dir(base_dir)
58 | _, tvec, qvec = get_pose_from_txt(img_index, pose_dir)
59 |
60 | # read intrinsic
61 | intrin_dir = path_utils.get_intrin_dir(base_dir)
62 | K = get_intrin_from_txt(img_index, intrin_dir)
63 | fx, fy, cx, cy = K[0][0], K[1][1], K[0, 2], K[1, 2]
64 |
65 | image = cv2.imread(img_path)
66 | h, w, _ = image.shape
67 |
68 | image = Image(
69 | id=img_id,
70 | qvec=qvec,
71 | tvec=tvec,
72 | camera_id=camera_id,
73 | name=img_path,
74 | xys=xys_,
75 | point3D_ids=point3D_ids_
76 | )
77 |
78 | camera = Camera(
79 | id=camera_id,
80 | model='PINHOLE',
81 | width=w,
82 | height=h,
83 | params=np.array([fx, fy, cx, cy])
84 | )
85 |
86 | images_out[key] = image
87 | cameras_out[key] = camera
88 |
89 | return cameras_out, images_out, points3D_out
90 |
91 |
92 | def generate_model(img_lists, empty_dir, do_ba=False):
93 | """ Write intrinsics and camera poses into COLMAP format model"""
94 | logging.info('Generate empty model...')
95 | model = import_data(img_lists, do_ba)
96 |
97 | logging.info(f'Writing the COLMAP model to {empty_dir}')
98 | Path(empty_dir).mkdir(exist_ok=True, parents=True)
99 | write_model(*model, path=str(empty_dir), ext='.bin')
100 | logging.info('Finishing writing model.')
101 |
--------------------------------------------------------------------------------
/src/sfm/global_ba.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import subprocess
4 | import os.path as osp
5 |
6 | from pathlib import Path
7 |
8 |
9 | def run_bundle_adjuster(deep_sfm_dir, ba_dir, colmap_path):
10 | logging.info("Running the bundle adjuster.")
11 |
12 | deep_sfm_model_dir = osp.join(deep_sfm_dir, 'model')
13 | cmd = [
14 | str(colmap_path), 'bundle_adjuster',
15 | '--input_path', str(deep_sfm_model_dir),
16 | '--output_path', str(ba_dir),
17 | '--BundleAdjustment.max_num_iterations', '150',
18 | '--BundleAdjustment.max_linear_solver_iterations', '500',
19 | '--BundleAdjustment.function_tolerance', '0',
20 | '--BundleAdjustment.gradient_tolerance', '0',
21 | '--BundleAdjustment.parameter_tolerance', '0',
22 | '--BundleAdjustment.refine_focal_length', '0',
23 | '--BundleAdjustment.refine_principal_point', '0',
24 | '--BundleAdjustment.refine_extra_params', '0',
25 | '--BundleAdjustment.refine_extrinsics', '1'
26 | ]
27 | logging.info(' '.join(cmd))
28 |
29 | ret = subprocess.call(cmd)
30 | if ret != 0:
31 | logging.warning('Problem with point_triangulator, existing.')
32 | exit(ret)
33 |
34 |
35 | def main(deep_sfm_dir, ba_dir, colmap_path='colmap'):
36 | assert Path(deep_sfm_dir).exists(), deep_sfm_dir
37 |
38 | Path(ba_dir).mkdir(parents=True, exist_ok=True)
39 | run_bundle_adjuster(deep_sfm_dir, ba_dir, colmap_path)
--------------------------------------------------------------------------------
/src/sfm/match_features.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import torch
3 | import logging
4 | import tqdm
5 |
6 | import os.path as osp
7 |
8 | confs = {
9 | 'superglue': {
10 | 'output': 'matches-spg',
11 | 'conf': {
12 | 'descriptor_dim': 256,
13 | 'weights': 'outdoor',
14 | 'match_threshold': 0.7
15 | }
16 | }
17 | }
18 |
19 |
20 | def names_to_pair(name0, name1):
21 | return '_'.join((name0.replace('/', '-'), name1.replace('/', '-')))
22 |
23 |
24 | @torch.no_grad()
25 | def spg(cfg, feature_path, covis_pairs, matches_out, vis_match=False):
26 | """Match features by SuperGlue"""
27 | from src.models.matchers.SuperGlue.superglue import SuperGlue as spg_matcher
28 | from src.utils.model_io import load_network
29 | from src.utils.vis_utils import vis_match_pairs
30 |
31 | assert osp.exists(feature_path), feature_path
32 | feature_file = h5py.File(feature_path, 'r')
33 | logging.info(f'Exporting matches to {matches_out}')
34 |
35 | with open(covis_pairs, 'r') as f:
36 | pair_list = f.read().rstrip('\n').split('\n')
37 |
38 | # load superglue model
39 | conf = confs[cfg.network.matching]['conf']
40 | model = spg_matcher(conf).cuda()
41 | model.eval()
42 | load_network(model, cfg.network.matching_model_path, force=True)
43 |
44 | # match features by superglue
45 | match_file = h5py.File(matches_out, 'w')
46 | matched = set()
47 | for pair in tqdm.tqdm(pair_list):
48 | name0, name1 = pair.split(' ')
49 | pair = names_to_pair(name0, name1)
50 |
51 | if len({(name0, name1), (name1, name0)} & matched) \
52 | or pair in match_file:
53 | continue
54 |
55 | data = {}
56 | feats0, feats1 = feature_file[name0], feature_file[name1]
57 | for k in feats0.keys():
58 | data[k+'0'] = feats0[k].__array__()
59 | for k in feats1.keys():
60 | data[k+'1'] = feats1[k].__array__()
61 | data = {k: torch.from_numpy(v)[None].float().cuda() for k, v in data.items()}
62 |
63 | data['image0'] = torch.empty((1, 1, ) + tuple(feats0['image_size'])[::-1])
64 | data['image1'] = torch.empty((1, 1, ) + tuple(feats1['image_size'])[::-1])
65 | pred = model(data)
66 |
67 | grp = match_file.create_group(pair)
68 | matches0 = pred['matches0'][0].cpu().short().numpy()
69 | grp.create_dataset('matches0', data=matches0)
70 |
71 | matches1 = pred['matches1'][0].cpu().short().numpy()
72 | grp.create_dataset('matches1', data=matches1)
73 |
74 | if 'matching_scores0' in pred:
75 | scores = pred['matching_scores0'][0].cpu().half().numpy()
76 | grp.create_dataset('matching_scores0', data=scores)
77 |
78 | if 'matching_scores1' in pred:
79 | scores = pred['matching_scores1'][0].cpu().half().numpy()
80 | grp.create_dataset('matching_scores1', data=scores)
81 |
82 | matched |= {(name0, name1), (name1, name0)}
83 |
84 | if vis_match:
85 | vis_match_pairs(pred, feats0, feats1, name0, name1)
86 |
87 | match_file.close()
88 | logging.info('Finishing exporting matches.')
89 |
90 |
91 | def main(cfg, feature_out, covis_pairs_out, matches_out, vis_match=False):
92 | if cfg.network.matching == 'superglue':
93 | spg(cfg, feature_out, covis_pairs_out, matches_out, vis_match)
94 | else:
95 | raise NotImplementedError
--------------------------------------------------------------------------------
/src/sfm/pairs_from_poses.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.spatial.distance as distance
3 | from src.utils import path_utils
4 |
5 |
6 | def get_pairswise_distances(pose_files):
7 | Rs = []
8 | ts = []
9 |
10 | seqs_ids = {}
11 | for i in range(len(pose_files)):
12 | pose_file = pose_files[i]
13 | seq_name = pose_file.split('/')[-3]
14 | if seq_name not in seqs_ids.keys():
15 | seqs_ids[seq_name] = [i]
16 | else:
17 | seqs_ids[seq_name].append(i)
18 |
19 | for pose_file in pose_files:
20 | pose = np.loadtxt(pose_file)
21 | R = pose[:3, :3]
22 | t = pose[:3, 3:]
23 | Rs.append(R)
24 | ts.append(t)
25 |
26 | Rs = np.stack(Rs, axis=0)
27 | ts = np.stack(ts, axis=0)
28 |
29 | Rs = Rs.transpose(0, 2, 1) # [n, 3, 3]
30 | ts = -(Rs @ ts)[:, :, 0] # [n, 3, 3] @ [n, 3, 1]
31 |
32 | dist = distance.squareform(distance.pdist(ts))
33 | trace = np.einsum('nji,mji->mn', Rs, Rs, optimize=True)
34 | dR = np.clip((trace - 1) / 2, -1., 1.)
35 | dR = np.rad2deg(np.abs(np.arccos(dR)))
36 |
37 | return dist, dR, seqs_ids
38 |
39 |
40 | def covis_from_pose(img_lists, covis_pairs_out, num_matched, max_rotation, do_ba=False):
41 | pose_lists = [path_utils.get_gt_pose_path_by_color(color_path) for color_path in img_lists]
42 | dist, dR, seqs_ids = get_pairswise_distances(pose_lists)
43 |
44 | min_rotation = 10
45 | valid = dR > min_rotation
46 | np.fill_diagonal(valid, False)
47 | dist = np.where(valid, dist, np.inf)
48 |
49 | pairs = []
50 | num_matched_per_seq = num_matched // len(seqs_ids.keys())
51 | for i in range(len(img_lists)):
52 | dist_i = dist[i]
53 | for seq_id in seqs_ids:
54 | ids = np.array(seqs_ids[seq_id])
55 | try:
56 | idx = np.argpartition(dist_i[ids], num_matched_per_seq * 2)[: num_matched_per_seq:2]
57 | except:
58 | idx = np.argpartition(dist_i[ids], dist_i.shape[0]-1)
59 | idx = ids[idx]
60 | idx = idx[np.argsort(dist_i[idx])]
61 | idx = idx[valid[i][idx]]
62 |
63 | for j in idx:
64 | name0 = img_lists[i]
65 | name1 = img_lists[j]
66 |
67 | pairs.append((name0, name1))
68 |
69 | with open(covis_pairs_out, 'w') as f:
70 | f.write('\n'.join(' '.join([i, j]) for i, j in pairs))
71 |
--------------------------------------------------------------------------------
/src/sfm/postprocess/filter_points.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | import os.path as osp
5 | from src.utils.colmap import read_write_model
6 |
7 |
8 | def filter_by_track_length(points3D, track_length):
9 | """
10 | Filter 3d points by track length.
11 | Return new pcds and corresponding point ids in origin pcds.
12 | """
13 | idxs_3d = list(points3D.keys())
14 | idxs_3d.sort()
15 | xyzs = np.empty(shape=[0, 3])
16 | points_idxs = np.empty(shape=[0], dtype=int)
17 | for i in range(len(idxs_3d)):
18 | idx_3d = idxs_3d[i]
19 | if len(points3D[idx_3d].point2D_idxs) < track_length:
20 | continue
21 | xyz = points3D[idx_3d].xyz.reshape(1, -1)
22 | xyzs = np.append(xyzs, xyz, axis=0)
23 | points_idxs = np.append(points_idxs, idx_3d)
24 |
25 | return xyzs, points_idxs
26 |
27 |
28 | def filter_by_3d_box(points, points_idxs, box_path):
29 | """ Filter 3d points by 3d box."""
30 | corner_in_cano = np.loadtxt(box_path)
31 |
32 | assert points.shape[1] == 3, "Input pcds must have shape (n, 3)"
33 | if not isinstance(points, torch.Tensor):
34 | points = torch.as_tensor(points, dtype=torch.float32)
35 | if not isinstance(corner_in_cano, torch.Tensor):
36 | corner_in_cano = torch.as_tensor(corner_in_cano, dtype=torch.float32)
37 |
38 | def filter_(bbox_3d, points):
39 | """
40 | @param bbox_3d: corners (8, 3)
41 | @param points: (n, 3)
42 | """
43 | v45 = bbox_3d[5] - bbox_3d[4]
44 | v40 = bbox_3d[0] - bbox_3d[4]
45 | v47 = bbox_3d[7] - bbox_3d[4]
46 |
47 | points = points - bbox_3d[4]
48 | m0 = torch.matmul(points, v45)
49 | m1 = torch.matmul(points, v40)
50 | m2 = torch.matmul(points, v47)
51 |
52 | cs = []
53 | for m, v in zip([m0, m1, m2], [v45, v40, v47]):
54 | c0 = 0 < m
55 | c1 = m < torch.matmul(v, v)
56 | c = c0 & c1
57 | cs.append(c)
58 | cs = cs[0] & cs[1] & cs[2]
59 | passed_inds = torch.nonzero(cs).squeeze(1)
60 | num_passed = torch.sum(cs)
61 | return num_passed, passed_inds, cs
62 |
63 | num_passed, passed_inds, keeps = filter_(corner_in_cano, points)
64 |
65 | xyzs_filtered = np.empty(shape=(0, 3), dtype=np.float32)
66 | for i in range(int(num_passed)):
67 | ind = passed_inds[i]
68 | xyzs_filtered = np.append(xyzs_filtered, points[ind, None], axis=0)
69 |
70 | filtered_xyzs = points[passed_inds]
71 | passed_inds = points_idxs[passed_inds]
72 | return filtered_xyzs, passed_inds
73 |
74 |
75 | def filter_3d(model_path, track_length, box_path):
76 | """ Filter 3d points by tracke length and 3d box """
77 | points_model_path = osp.join(model_path, 'points3D.bin')
78 | points3D = read_write_model.read_points3d_binary(points_model_path)
79 |
80 | xyzs, points_idxs = filter_by_track_length(points3D, track_length)
81 | xyzs, points_idxs = filter_by_3d_box(xyzs, points_idxs, box_path)
82 |
83 | return xyzs, points_idxs
84 |
85 |
86 | def merge(xyzs, points_idxs, dist_threshold=1e-3):
87 | """
88 | Merge points which are close to others. ({[x1, y1], [x2, y2], ...} => [mean(x_i), mean(y_i)])
89 | """
90 | from scipy.spatial.distance import pdist, squareform
91 |
92 | if not isinstance(xyzs, np.ndarray):
93 | xyzs = np.array(xyzs)
94 |
95 | dist = pdist(xyzs, 'euclidean')
96 | distance_matrix = squareform(dist)
97 | close_than_thresh = distance_matrix < dist_threshold
98 |
99 | ret_points_count = 0 # num of return points
100 | ret_points = np.empty(shape=[0, 3]) # pcds after merge
101 | ret_idxs = {} # {new_point_idx: points idxs in Points3D}
102 |
103 | points3D_idx_record = [] # points that have been merged
104 | for j in range(distance_matrix.shape[0]):
105 | idxs = close_than_thresh[j]
106 |
107 | if np.isin(points_idxs[idxs], points3D_idx_record).any():
108 | continue
109 |
110 | points = np.mean(xyzs[idxs], axis=0) # new point
111 | ret_points = np.append(ret_points, points.reshape(1, 3), axis=0)
112 | ret_idxs[ret_points_count] = points_idxs[idxs]
113 | ret_points_count += 1
114 |
115 | points3D_idx_record = points3D_idx_record + points_idxs[idxs].tolist()
116 |
117 | return ret_points, ret_idxs
118 |
119 |
120 |
--------------------------------------------------------------------------------
/src/sfm/postprocess/filter_tkl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import os.path as osp
4 | import matplotlib.pyplot as plt
5 |
6 | from src.utils.colmap.read_write_model import write_model
7 |
8 |
9 | def get_points_count(points3D, show=False):
10 | """ Count track length for each point """
11 | points_count_list = [] # track length for each point
12 | for points_id, points_item in points3D.items():
13 | points_count = len(points_item.point2D_idxs)
14 | points_count_list.append(points_count)
15 |
16 | count_dict = dict() # {track length: num of 3d points}
17 | for count in points_count_list:
18 | if count not in count_dict.keys():
19 | count_dict[count] = 0
20 | count_dict[count] += 1
21 | counts = list(count_dict.keys())
22 | counts.sort()
23 |
24 | count_list_ordered = []
25 | for count in counts:
26 | count_list_ordered.append(count_dict[count])
27 |
28 | if show:
29 | plt.plot(counts, count_list_ordered)
30 | plt.show()
31 |
32 | return count_dict, points_count_list
33 |
34 |
35 | def get_tkl(model_path, thres, show=False):
36 | """ Get the track length value which can limit the number of 3d points below thres"""
37 | from src.utils.colmap.read_write_model import read_model
38 |
39 | cameras, images, points3D = read_model(model_path, ext='.bin')
40 | count_dict, points_count_list = get_points_count(points3D, show)
41 |
42 | ret_points = len(points3D.keys())
43 | count_keys = np.array(list(count_dict.keys()))
44 | count_keys.sort()
45 |
46 | for key in count_keys:
47 | ret_points -= count_dict[key]
48 | if ret_points <= thres:
49 | track_length = key
50 | break
51 |
52 | return track_length, points_count_list
53 |
54 |
55 | def vis_tkl_filtered_pcds(model_path, points_count_list, track_length, output_path):
56 | """
57 | Given a track length value, filter 3d points.
58 | Output filtered pcds for visualization.
59 | """
60 | from src.utils.colmap.read_write_model import read_model
61 |
62 | cameras, images, points3D = read_model(model_path, ext='.bin')
63 |
64 | invalid_points = np.where(np.array(points_count_list) < track_length)[0]
65 | point_ids = []
66 | for k, v in points3D.items():
67 | point_ids.append(k)
68 |
69 | invalid_points_ids = []
70 | for invalid_count in invalid_points:
71 | points3D.pop(point_ids[invalid_count])
72 | invalid_points_ids.append(point_ids[invalid_count])
73 |
74 | output_path = osp.join(output_path, 'tkl_model')
75 | output_file_path = osp.join(output_path, 'tl-{}.ply'.format(track_length))
76 | if not osp.exists(output_path):
77 | os.makedirs(output_path)
78 |
79 | write_model(cameras, images, points3D, output_path, '.bin')
80 | os.system(f'colmap model_converter --input_path {output_path} --output_path {output_file_path} --output_type PLY')
81 | return output_file_path
--------------------------------------------------------------------------------
/src/sfm/triangulation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import h5py
3 | import logging
4 | import tqdm
5 | import subprocess
6 | import os.path as osp
7 | import numpy as np
8 |
9 | from pathlib import Path
10 | from src.utils.colmap.read_write_model import CAMERA_MODEL_NAMES, Image, read_cameras_binary, read_images_binary
11 | from src.utils.colmap.database import COLMAPDatabase
12 |
13 |
14 | def names_to_pair(name0, name1):
15 | return '_'.join((name0.replace('/', '-'), name1.replace('/', '-')))
16 |
17 |
18 | def geometric_verification(colmap_path, database_path, pairs_path):
19 | """ Geometric verfication """
20 | logging.info('Performing geometric verification of the matches...')
21 | cmd = [
22 | str(colmap_path), 'matches_importer',
23 | '--database_path', str(database_path),
24 | '--match_list_path', str(pairs_path),
25 | '--match_type', 'pairs'
26 | ]
27 | ret = subprocess.call(cmd)
28 | if ret != 0:
29 | logging.warning('Problem with matches_importer, existing.')
30 | exit(ret)
31 |
32 |
33 | def create_db_from_model(empty_model, database_path):
34 | """ Create COLMAP database file from empty COLMAP binary file. """
35 | if database_path.exists():
36 | logging.warning('Database already exists.')
37 |
38 | cameras = read_cameras_binary(str(empty_model / 'cameras.bin'))
39 | images = read_images_binary(str(empty_model / 'images.bin'))
40 |
41 | db = COLMAPDatabase.connect(database_path)
42 | db.create_tables()
43 |
44 | for i, camera in cameras.items():
45 | model_id = CAMERA_MODEL_NAMES[camera.model].model_id
46 | db.add_camera(model_id, camera.width, camera.height, camera.params,
47 | camera_id=i, prior_focal_length=True)
48 |
49 | for i, image in images.items():
50 | db.add_image(image.name, image.camera_id, image_id=i)
51 |
52 | db.commit()
53 | db.close()
54 | return {image.name: i for i, image in images.items()}
55 |
56 |
57 | def import_features(image_ids, database_path, feature_path):
58 | """ Import keypoints info into COLMAP database. """
59 | logging.info("Importing features into the database...")
60 | feature_file = h5py.File(str(feature_path), 'r')
61 | db = COLMAPDatabase.connect(database_path)
62 |
63 | for image_name, image_id in tqdm.tqdm(image_ids.items()):
64 | keypoints = feature_file[image_name]['keypoints'].__array__()
65 | keypoints += 0.5
66 | db.add_keypoints(image_id, keypoints)
67 |
68 | feature_file.close()
69 | db.commit()
70 | db.close()
71 |
72 |
73 | def import_matches(image_ids, database_path, pairs_path, matches_path, feature_path,
74 | min_match_score=None, skip_geometric_verification=False):
75 | """ Import matches info into COLMAP database. """
76 | logging.info("Importing matches into the database...")
77 |
78 | with open(str(pairs_path), 'r') as f:
79 | pairs = [p.split(' ') for p in f.read().split('\n')]
80 |
81 | match_file = h5py.File(str(matches_path), 'r')
82 | db = COLMAPDatabase.connect(database_path)
83 |
84 | matched = set()
85 | for name0, name1 in tqdm.tqdm(pairs):
86 | id0, id1 = image_ids[name0], image_ids[name1]
87 | if len({(id0, id1), (id1, id0)} & matched) > 0:
88 | continue
89 |
90 | pair = names_to_pair(name0, name1)
91 | if pair not in match_file:
92 | raise ValueError(
93 | f'Could not find pair {(name0, name1)}... '
94 | 'Maybe you matched with a different list of pairs? '
95 | f'Reverse in file: {names_to_pair(name0, name1) in match_file}.'
96 | )
97 |
98 | matches = match_file[pair]['matches0'].__array__()
99 | valid = matches > -1
100 | if min_match_score:
101 | scores = match_file[pair]['matching_scores0'].__array__()
102 | valid = valid & (scores > min_match_score)
103 |
104 | matches = np.stack([np.where(valid)[0], matches[valid]], -1)
105 |
106 | db.add_matches(id0, id1, matches)
107 | matched |= {(id0, id1), (id1, id0)}
108 |
109 | if skip_geometric_verification:
110 | db.add_two_view_geometry(id0, id1, matches)
111 |
112 | match_file.close()
113 | db.commit()
114 | db.close()
115 |
116 |
117 | def run_triangulation(colmap_path, model_path, database_path, image_dir, empty_model):
118 | """ run triangulation on given database """
119 | logging.info('Running the triangulation...')
120 |
121 | cmd = [
122 | str(colmap_path), 'point_triangulator',
123 | '--database_path', str(database_path),
124 | '--image_path', str(image_dir),
125 | '--input_path', str(empty_model),
126 | '--output_path', str(model_path),
127 | '--Mapper.ba_refine_focal_length', '0',
128 | '--Mapper.ba_refine_principal_point', '0',
129 | '--Mapper.ba_refine_extra_params', '0'
130 | ]
131 | logging.info(' '.join(cmd))
132 | ret = subprocess.call(cmd)
133 | if ret != 0:
134 | logging.warning('Problem with point_triangulator, existing.')
135 | exit(ret)
136 |
137 | stats_raw = subprocess.check_output(
138 | [str(colmap_path), 'model_analyzer', '--path', model_path]
139 | )
140 | stats_raw = stats_raw.decode().split('\n')
141 | stats = dict()
142 | for stat in stats_raw:
143 | if stat.startswith('Register images'):
144 | stats['num_reg_images'] = int(stat.split()[-1])
145 | elif stat.startswith('Points'):
146 | stats['num_sparse_points'] = int(stat.split()[-1])
147 | elif stat.startswith('Observation'):
148 | stats['num_observations'] = int(stat.split()[-1])
149 | elif stat.startswith('Mean track length'):
150 | stats['mean_track_length'] = float(stat.split()[-1])
151 | elif stat.startswith('Mean observation per image'):
152 | stats['num_observations_per_image'] = float(stat.split()[-1])
153 | elif stat.startswith('Mean reprojection error'):
154 | stats['mean_reproj_error'] = float(stat.split()[-1][:-2])
155 | return stats
156 |
157 |
158 | def main(sfm_dir, empty_sfm_model, outputs_dir, pairs, features, matches, \
159 | colmap_path='colmap', skip_geometric_verification=False, min_match_score=None, image_dir=None):
160 | """
161 | Import keypoints, matches.
162 | Given keypoints and matches, reconstruct sparse model from given camera poses.
163 | """
164 | assert Path(empty_sfm_model).exists(), empty_sfm_model
165 | assert Path(features).exists(), features
166 | assert Path(pairs).exists(), pairs
167 | assert Path(matches).exists(), matches
168 |
169 | Path(sfm_dir).mkdir(parents=True, exist_ok=True)
170 | database = osp.join(sfm_dir, 'database.db')
171 | model = osp.join(sfm_dir, 'model')
172 | Path(model).mkdir(exist_ok=True)
173 |
174 | image_ids = create_db_from_model(Path(empty_sfm_model), Path(database))
175 | import_features(image_ids, database, features)
176 | import_matches(image_ids, database, pairs, matches, features,
177 | min_match_score, skip_geometric_verification)
178 |
179 | if not skip_geometric_verification:
180 | geometric_verification(colmap_path, database, pairs)
181 |
182 | if not image_dir:
183 | image_dir = '/'
184 | stats = run_triangulation(colmap_path, model, database, image_dir, empty_sfm_model)
185 | os.system(f'colmap model_converter --input_path {model} --output_path {outputs_dir}/model.ply --output_type PLY')
--------------------------------------------------------------------------------
/src/tracker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zju3dv/OnePose/d0c313a5c36994658e63eac7ca3a63d7e3573d7b/src/tracker/__init__.py
--------------------------------------------------------------------------------
/src/tracker/tracking_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class Timer(object):
6 | def __init__(self):
7 | self.time_dict = dict()
8 | self.res_dict = dict()
9 | self.stash_dict = dict()
10 |
11 | def set(self, label, value):
12 | if label not in self.stash_dict.keys():
13 | self.stash_dict[label] = []
14 | self.stash_dict[label].append(value)
15 |
16 | def tick(self, tick_name):
17 | import time
18 | self.time_dict[tick_name] = time.time() * 1000.0
19 | return self.time_dict[tick_name]
20 |
21 | def tock(self, tock_name, pop=False):
22 | if tock_name not in self.time_dict:
23 | return 0.0
24 | else:
25 | import time
26 | t2 = time.time() * 1000.0
27 | t1 = self.time_dict[tock_name]
28 | self.res_dict[tock_name] = t2 - t1
29 | if pop:
30 | self.time_dict.pop(tock_name)
31 | return self.res_dict[tock_name]
32 |
33 | def stash(self):
34 | for k, v in self.res_dict.items():
35 | if k not in self.stash_dict.keys():
36 | self.stash_dict[k] = []
37 | self.stash_dict[k].append(v)
38 |
39 | def report_stash(self):
40 | res_dict = dict()
41 | for k, v in self.stash_dict.items():
42 | res_dict[k] = np.mean(v)
43 | return res_dict
44 |
45 | def report(self):
46 | return self.res_dict
47 |
48 |
49 | def compute_epipolar_error(kpts0, kpts1, T_0to1, K0, K1):
50 | def to_homogeneous(points):
51 | return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
52 |
53 | kpts0 = (kpts0 - K0[[0, 1], [2, 2]][None]) / K0[[0, 1], [0, 1]][None]
54 | kpts1 = (kpts1 - K1[[0, 1], [2, 2]][None]) / K1[[0, 1], [0, 1]][None]
55 | kpts0 = to_homogeneous(kpts0)
56 | kpts1 = to_homogeneous(kpts1)
57 |
58 | t0, t1, t2 = T_0to1[:3, 3]
59 | t_skew = np.array([
60 | [0, -t2, t1],
61 | [t2, 0, -t0],
62 | [-t1, t0, 0]
63 | ])
64 | E = t_skew @ T_0to1[:3, :3]
65 |
66 | Ep0 = kpts0 @ E.T # N x 3
67 | p1Ep0 = np.sum(kpts1 * Ep0, -1) # N
68 | Etp1 = kpts1 @ E # N x 3
69 | d = p1Ep0 ** 2 * (1.0 / (Ep0[:, 0] ** 2 + Ep0[:, 1] ** 2)
70 | + 1.0 / (Etp1[:, 0] ** 2 + Etp1[:, 1] ** 2))
71 | return d
72 |
73 |
74 | def project(xyz, K, RT, need_depth=False):
75 | """
76 | xyz: [N, 3]
77 | K: [3, 3]
78 | RT: [3, 4]
79 | """
80 | xyz = np.dot(xyz, RT[:, :3].T)
81 | xyz += RT[:, 3:].T
82 | depth = xyz[:, 2:].flatten()
83 | xyz = np.dot(xyz, K.T)
84 | xy = xyz[:, :2] / xyz[:, 2:]
85 | if need_depth:
86 | return xy, depth
87 | else:
88 | return xy
89 |
90 |
91 | def AngleAxisRotatePoint(angleAxis, pt):
92 | theta2 = (angleAxis * angleAxis).sum(dim=1)
93 |
94 | mask = (theta2 > 0).float()
95 |
96 | theta = torch.sqrt(theta2 + (1 - mask))
97 |
98 | mask = mask.reshape((mask.shape[0], 1))
99 | mask = torch.cat([mask, mask, mask], dim=1)
100 |
101 | costheta = torch.cos(theta)
102 | sintheta = torch.sin(theta)
103 | thetaInverse = 1.0 / theta
104 |
105 | w0 = angleAxis[:, 0] * thetaInverse
106 | w1 = angleAxis[:, 1] * thetaInverse
107 | w2 = angleAxis[:, 2] * thetaInverse
108 |
109 | wCrossPt0 = w1 * pt[:, 2] - w2 * pt[:, 1]
110 | wCrossPt1 = w2 * pt[:, 0] - w0 * pt[:, 2]
111 | wCrossPt2 = w0 * pt[:, 1] - w1 * pt[:, 0]
112 |
113 | tmp = (w0 * pt[:, 0] + w1 * pt[:, 1] + w2 * pt[:, 2]) * (1.0 - costheta)
114 |
115 | r0 = pt[:, 0] * costheta + wCrossPt0 * sintheta + w0 * tmp
116 | r1 = pt[:, 1] * costheta + wCrossPt1 * sintheta + w1 * tmp
117 | r2 = pt[:, 2] * costheta + wCrossPt2 * sintheta + w2 * tmp
118 |
119 | r0 = r0.reshape((r0.shape[0], 1))
120 | r1 = r1.reshape((r1.shape[0], 1))
121 | r2 = r2.reshape((r2.shape[0], 1))
122 |
123 | res1 = torch.cat([r0, r1, r2], dim=1)
124 |
125 | wCrossPt0 = angleAxis[:, 1] * pt[:, 2] - angleAxis[:, 2] * pt[:, 1]
126 | wCrossPt1 = angleAxis[:, 2] * pt[:, 0] - angleAxis[:, 0] * pt[:, 2]
127 | wCrossPt2 = angleAxis[:, 0] * pt[:, 1] - angleAxis[:, 1] * pt[:, 0]
128 |
129 | r00 = pt[:, 0] + wCrossPt0
130 | r01 = pt[:, 1] + wCrossPt1
131 | r02 = pt[:, 2] + wCrossPt2
132 |
133 | r00 = r00.reshape((r00.shape[0], 1))
134 | r01 = r01.reshape((r01.shape[0], 1))
135 | r02 = r02.reshape((r02.shape[0], 1))
136 |
137 | res2 = torch.cat([r00, r01, r02], dim=1)
138 |
139 | return res1 * mask + res2 * (1 - mask)
140 |
141 |
142 | def SnavelyReprojectionErrorV2(points_ob, cameras_ob, features):
143 | if (len(points_ob.shape) == 3):
144 | points_ob = points_ob[:,0,:]
145 | cameras_ob = cameras_ob[:,0,:]
146 | focals = features[:, 2]
147 | l1 = features[:, 3]
148 | l2 = features[:, 4]
149 |
150 | # camera[0,1,2] are the angle-axis rotation.
151 | p = AngleAxisRotatePoint(cameras_ob[:, :3], points_ob)
152 | p = p + cameras_ob[:, 3:6]
153 |
154 | xp = p[:,0] / p[:,2]
155 | yp = p[:,1] / p[:,2]
156 |
157 | # predicted_x, predicted_y = DistortV2(xp, yp, cameras_ob, cam_K)
158 |
159 | predicted_x = focals * xp + l1
160 | predicted_y = focals * yp + l2
161 |
162 | residual_0 = predicted_x - features[:, 0]
163 | residual_1 = predicted_y - features[:, 1]
164 |
165 | residual_0 = residual_0.reshape((residual_0.shape[0], 1))
166 | residual_1 = residual_1.reshape((residual_1.shape[0], 1))
167 |
168 | #return torch.sqrt(residual_0**2 + residual_1 ** 2)
169 | return torch.cat([residual_0, residual_1], dim=1)
170 |
171 |
172 | def put_text(img, inform_text, color=None):
173 | import cv2
174 | fontScale = 1
175 | if color is None:
176 | color = (255, 0, 0)
177 | org = (50, 50)
178 | font = cv2.FONT_HERSHEY_SIMPLEX
179 | thickness = 2
180 | img = cv2.putText(img, inform_text, org, font,
181 | fontScale, color, thickness, cv2.LINE_AA)
182 | return img
183 |
184 |
185 | def draw_kpt2d(image, kpt2d, color=(0, 0, 255), radius=2, thikness=1):
186 | import cv2
187 | for coord in kpt2d:
188 | cv2.circle(image, (int(coord[0]), int(coord[1])), radius, color, thikness, 1)
189 | # cv2.circle(image, (int(coord[0]), int(coord[1])), 7, color, 1, 1)
190 | return image
191 |
192 |
193 | class MovieWriter:
194 | def __init__(self):
195 | self.video_out_path = ''
196 | self.movie_cap = None
197 | self.id = 0
198 |
199 | def start(self):
200 | if self.movie_cap is not None:
201 | self.movie_cap.release()
202 |
203 | def write(self, im_bgr, video_out_path, text_info=[], fps=20):
204 | import cv2
205 | if self.movie_cap is None:
206 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
207 | length = fps
208 | self.video_out_path = video_out_path
209 | self.movie_cap = cv2.VideoWriter(video_out_path, fourcc,
210 | length, (im_bgr.shape[1], im_bgr.shape[0]))
211 |
212 | if len(text_info) > 0:
213 | self.put_text(im_bgr, text_info[0], color=text_info[1])
214 | self.movie_cap.write(im_bgr)
215 | self.id += 1
216 |
217 | def put_text(self, img, inform_text, color=None):
218 | import cv2
219 | fontScale = 1
220 | if color is None:
221 | color = (255, 0, 0)
222 | org = (200, 200)
223 | font = cv2.FONT_HERSHEY_SIMPLEX
224 | thickness = 2
225 | img = cv2.putText(img, inform_text, org, font,
226 | fontScale, color, thickness, cv2.LINE_AA)
227 | return img
228 |
229 | def end(self):
230 | if self.movie_cap is not None:
231 | self.movie_cap.release()
232 | self.movie_cap = None
233 | print(f"Output frames:{self.id} to {self.video_out_path}")
234 | self.id = 0
235 |
--------------------------------------------------------------------------------
/src/utils/comm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | [Copied from detectron2]
4 | This file contains primitives for multi-gpu communication.
5 | This is useful when doing distributed training.
6 | """
7 |
8 | import functools
9 | import logging
10 | import numpy as np
11 | import pickle
12 | import torch
13 | import torch.distributed as dist
14 |
15 | _LOCAL_PROCESS_GROUP = None
16 | """
17 | A torch process group which only includes processes that on the same machine as the current process.
18 | This variable is set when processes are spawned by `launch()` in "engine/launch.py".
19 | """
20 |
21 |
22 | def get_world_size() -> int:
23 | if not dist.is_available():
24 | return 1
25 | if not dist.is_initialized():
26 | return 1
27 | return dist.get_world_size()
28 |
29 |
30 | def get_rank() -> int:
31 | if not dist.is_available():
32 | return 0
33 | if not dist.is_initialized():
34 | return 0
35 | return dist.get_rank()
36 |
37 |
38 | def get_local_rank() -> int:
39 | """
40 | Returns:
41 | The rank of the current process within the local (per-machine) process group.
42 | """
43 | if not dist.is_available():
44 | return 0
45 | if not dist.is_initialized():
46 | return 0
47 | assert _LOCAL_PROCESS_GROUP is not None
48 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
49 |
50 |
51 | def get_local_size() -> int:
52 | """
53 | Returns:
54 | The size of the per-machine process group,
55 | i.e. the number of processes per machine.
56 | """
57 | if not dist.is_available():
58 | return 1
59 | if not dist.is_initialized():
60 | return 1
61 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
62 |
63 |
64 | def is_main_process() -> bool:
65 | return get_rank() == 0
66 |
67 |
68 | def synchronize():
69 | """
70 | Helper function to synchronize (barrier) among all processes when
71 | using distributed training
72 | """
73 | if not dist.is_available():
74 | return
75 | if not dist.is_initialized():
76 | return
77 | world_size = dist.get_world_size()
78 | if world_size == 1:
79 | return
80 | dist.barrier()
81 |
82 |
83 | @functools.lru_cache()
84 | def _get_global_gloo_group():
85 | """
86 | Return a process group based on gloo backend, containing all the ranks
87 | The result is cached.
88 | """
89 | if dist.get_backend() == "nccl":
90 | return dist.new_group(backend="gloo")
91 | else:
92 | return dist.group.WORLD
93 |
94 |
95 | def _serialize_to_tensor(data, group):
96 | backend = dist.get_backend(group)
97 | assert backend in ["gloo", "nccl"]
98 | device = torch.device("cpu" if backend == "gloo" else "cuda")
99 |
100 | buffer = pickle.dumps(data)
101 | if len(buffer) > 1024 ** 3:
102 | logger = logging.getLogger(__name__)
103 | logger.warning(
104 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
105 | get_rank(), len(buffer) / (1024 ** 3), device
106 | )
107 | )
108 | storage = torch.ByteStorage.from_buffer(buffer)
109 | tensor = torch.ByteTensor(storage).to(device=device)
110 | return tensor
111 |
112 |
113 | def _pad_to_largest_tensor(tensor, group):
114 | """
115 | Returns:
116 | list[int]: size of the tensor, on each rank
117 | Tensor: padded tensor that has the max size
118 | """
119 | world_size = dist.get_world_size(group=group)
120 | assert (
121 | world_size >= 1
122 | ), "comm.gather/all_gather must be called from ranks within the given group!"
123 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
124 | size_list = [
125 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
126 | ]
127 | dist.all_gather(size_list, local_size, group=group)
128 |
129 | size_list = [int(size.item()) for size in size_list]
130 |
131 | max_size = max(size_list)
132 |
133 | # we pad the tensor because torch all_gather does not support
134 | # gathering tensors of different shapes
135 | if local_size != max_size:
136 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
137 | tensor = torch.cat((tensor, padding), dim=0)
138 | return size_list, tensor
139 |
140 |
141 | def all_gather(data, group=None):
142 | """
143 | Run all_gather on arbitrary picklable data (not necessarily tensors).
144 | Args:
145 | data: any picklable object
146 | group: a torch process group. By default, will use a group which
147 | contains all ranks on gloo backend.
148 | Returns:
149 | list[data]: list of data gathered from each rank
150 | """
151 | if get_world_size() == 1:
152 | return [data]
153 | if group is None:
154 | group = _get_global_gloo_group()
155 | if dist.get_world_size(group) == 1:
156 | return [data]
157 |
158 | tensor = _serialize_to_tensor(data, group)
159 |
160 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
161 | max_size = max(size_list)
162 |
163 | # receiving Tensor from all ranks
164 | tensor_list = [
165 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
166 | ]
167 | dist.all_gather(tensor_list, tensor, group=group)
168 |
169 | data_list = []
170 | for size, tensor in zip(size_list, tensor_list):
171 | buffer = tensor.cpu().numpy().tobytes()[:size]
172 | data_list.append(pickle.loads(buffer))
173 |
174 | return data_list
175 |
176 |
177 | def gather(data, dst=0, group=None):
178 | """
179 | Run gather on arbitrary picklable data (not necessarily tensors).
180 | Args:
181 | data: any picklable object
182 | dst (int): destination rank
183 | group: a torch process group. By default, will use a group which
184 | contains all ranks on gloo backend.
185 | Returns:
186 | list[data]: on dst, a list of data gathered from each rank. Otherwise,
187 | an empty list.
188 | """
189 | if get_world_size() == 1:
190 | return [data]
191 | if group is None:
192 | group = _get_global_gloo_group()
193 | if dist.get_world_size(group=group) == 1:
194 | return [data]
195 | rank = dist.get_rank(group=group)
196 |
197 | tensor = _serialize_to_tensor(data, group)
198 | size_list, tensor = _pad_to_largest_tensor(tensor, group)
199 |
200 | # receiving Tensor from all ranks
201 | if rank == dst:
202 | max_size = max(size_list)
203 | tensor_list = [
204 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
205 | ]
206 | dist.gather(tensor, tensor_list, dst=dst, group=group)
207 |
208 | data_list = []
209 | for size, tensor in zip(size_list, tensor_list):
210 | buffer = tensor.cpu().numpy().tobytes()[:size]
211 | data_list.append(pickle.loads(buffer))
212 | return data_list
213 | else:
214 | dist.gather(tensor, [], dst=dst, group=group)
215 | return []
216 |
217 |
218 | def shared_random_seed():
219 | """
220 | Returns:
221 | int: a random number that is the same across all workers.
222 | If workers need a shared RNG, they can use this shared seed to
223 | create one.
224 | All workers must call this function, otherwise it will deadlock.
225 | """
226 | ints = np.random.randint(2 ** 31)
227 | all_ints = all_gather(ints)
228 | return all_ints[0]
229 |
230 |
231 | def reduce_dict(input_dict, average=True):
232 | """
233 | Reduce the values in the dictionary from all processes so that process with rank
234 | 0 has the reduced results.
235 | Args:
236 | input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
237 | average (bool): whether to do average or sum
238 | Returns:
239 | a dict with the same keys as input_dict, after reduction.
240 | """
241 | world_size = get_world_size()
242 | if world_size < 2:
243 | return input_dict
244 | with torch.no_grad():
245 | names = []
246 | values = []
247 | # sort the keys so that they are consistent across processes
248 | for k in sorted(input_dict.keys()):
249 | names.append(k)
250 | values.append(input_dict[k])
251 | values = torch.stack(values, dim=0)
252 | dist.reduce(values, dst=0)
253 | if dist.get_rank() == 0 and average:
254 | # only main process gets accumulated, so only divide by
255 | # world_size in this case
256 | values /= world_size
257 | reduced_dict = {k: v for k, v in zip(names, values)}
258 | return reduced_dict
259 |
--------------------------------------------------------------------------------
/src/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os.path as osp
4 |
5 | from pathlib import Path
6 |
7 | def record_eval_result(out_dir, obj_name, seq_name, eval_result):
8 | Path(out_dir).mkdir(exist_ok=True, parents=True)
9 |
10 | out_file = osp.join(out_dir, obj_name + seq_name + '.txt')
11 | f = open(out_file, 'w')
12 | for k, v in eval_result.items():
13 | f.write(f'{k}: {v}\n')
14 |
15 | f.close()
16 |
17 |
18 | def ransac_PnP(K, pts_2d, pts_3d, scale=1):
19 | """ solve pnp """
20 | dist_coeffs = np.zeros(shape=[8, 1], dtype='float64')
21 |
22 | pts_2d = np.ascontiguousarray(pts_2d.astype(np.float64))
23 | pts_3d = np.ascontiguousarray(pts_3d.astype(np.float64))
24 | K = K.astype(np.float64)
25 |
26 | pts_3d *= scale
27 | try:
28 | _, rvec, tvec, inliers = cv2.solvePnPRansac(pts_3d, pts_2d, K, dist_coeffs, reprojectionError=5,
29 | iterationsCount=10000, flags=cv2.SOLVEPNP_EPNP)
30 |
31 | rotation = cv2.Rodrigues(rvec)[0]
32 |
33 | tvec /= scale
34 | pose = np.concatenate([rotation, tvec], axis=-1)
35 | pose_homo = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
36 |
37 | inliers = [] if inliers is None else inliers
38 |
39 | return pose, pose_homo, inliers
40 | except cv2.error:
41 | print("CV ERROR")
42 | return np.eye(4)[:3], np.eye(4), []
43 |
44 |
45 | def query_pose_error(pose_pred, pose_gt):
46 | """
47 | Input:
48 | ---------
49 | pose_pred: np.array 3*4 or 4*4
50 | pose_gt: np.array 3*4 or 4*4
51 | """
52 | # Dim check:
53 | if pose_pred.shape[0] == 4:
54 | pose_pred = pose_pred[:3]
55 | if pose_gt.shape[0] == 4:
56 | pose_gt = pose_gt[:3]
57 |
58 | translation_distance = np.linalg.norm(pose_pred[:, 3] - pose_gt[:, 3]) * 100
59 | rotation_diff = np.dot(pose_pred[:, :3], pose_gt[:, :3].T)
60 | trace = np.trace(rotation_diff)
61 | trace = trace if trace <= 3 else 3
62 | angular_distance = np.rad2deg(np.arccos((trace - 1.0) / 2.0))
63 | return angular_distance, translation_distance
64 |
65 |
66 | def compute_query_pose_errors(data, preds):
67 | query_pose_gt = data['query_pose_gt'][0].cpu().numpy()
68 | query_K = data['query_intrinsic'][0].cpu().numpy()
69 | query_kpts2d = data['keypoints2d'][0].cpu().numpy()
70 | query_kpts3d = data['keypoints3d'][0].cpu().numpy()
71 |
72 | matches0 = preds['matches0'].cpu().numpy()
73 | confidence = preds['matching_scores0'].cpu().numpy()
74 | valid = matches0 > -1
75 | mkpts2d = query_kpts2d[valid]
76 | mkpts3d = query_kpts3d[matches0[valid]]
77 | mconf = confidence[valid]
78 |
79 | pose_pred = []
80 | val_results = {'R_errs': [], 't_errs': [], 'inliers': []}
81 |
82 | query_pose_pred, query_pose_pred_homo, inliers = ransac_PnP(
83 | query_K,
84 | mkpts2d,
85 | mkpts3d
86 | )
87 | pose_pred.append(query_pose_pred_homo)
88 |
89 | if query_pose_pred is None:
90 | val_results['R_errs'].append(np.inf)
91 | val_results['t_errs'].append(np.inf)
92 | val_results['inliers'].append(np.array([])).astype(np.bool)
93 | else:
94 | R_err, t_err = query_pose_error(query_pose_pred, query_pose_gt)
95 | val_results['R_errs'].append(R_err)
96 | val_results['t_errs'].append(t_err)
97 | val_results['inliers'].append(inliers)
98 |
99 | pose_pred = np.stack(pose_pred)
100 |
101 | val_results.update({'mkpts2d': mkpts2d, 'mkpts3d': mkpts3d, 'mconf': mconf})
102 | return pose_pred, val_results
103 |
104 |
105 | def aggregate_metrics(metrics, thres=[1, 3, 5]):
106 | """ Aggregate metrics for the whole dataset:
107 | (This method should be called once per dataset)
108 | 1. AUC of the pose error (angular) at the threshold [5, 10, 20]
109 | 2. Mean matching precision at the threshold 5e-4
110 | """
111 | R_errs = metrics['R_errs']
112 | t_errs = metrics['t_errs']
113 |
114 | degree_distance_metric = {}
115 | for threshold in thres:
116 | degree_distance_metric[f'{threshold}cm@{threshold}degree'] = np.mean(
117 | (np.array(R_errs) < threshold) & (np.array(t_errs) < threshold)
118 | )
119 |
120 | return degree_distance_metric
--------------------------------------------------------------------------------
/src/utils/model_io.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from collections import OrderedDict
4 |
5 |
6 | def load_model(net, optim, scheduler, recorder, model_dir, resume=True, epoch=-1):
7 | if not resume:
8 | os.system('rm -rf {}'.format(model_dir))
9 |
10 | if not os.path.exists(model_dir):
11 | return 0
12 |
13 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)]
14 | if len(pths) == 0:
15 | return 0
16 | if epoch == -1:
17 | pth = max(pths)
18 | else:
19 | pth = epoch
20 | print('Load model: {}'.format(os.path.join(model_dir, '{}.pth'.format(pth))))
21 | pretrained_model = torch.load(os.path.join(model_dir, '{}.pth'.format(pth)))
22 | net.load_state_dict(pretrained_model['net'])
23 | optim.load_state_dict(pretrained_model['optim'])
24 | scheduler.load_state_dict(pretrained_model['scheduler'])
25 | recorder.load_state_dict(pretrained_model['recorder'])
26 | return pretrained_model['epoch'] + 1
27 |
28 |
29 | def save_model(net, optim, scheduler, recorder, epoch, model_dir):
30 | os.system('mkdir -p {}'.format(model_dir))
31 | torch.save({
32 | 'net': net.state_dict(),
33 | 'optim': optim.state_dict(),
34 | 'scheduler': scheduler.state_dict(),
35 | 'recorder': recorder.state_dict(),
36 | 'epoch': epoch
37 | }, os.path.join(model_dir, '{}.pth'.format(epoch)))
38 |
39 | # remove previous pretrained model if the number of models is too big
40 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir)]
41 | if len(pths) <= 200:
42 | return
43 | os.system('rm {}'.format(os.path.join(model_dir, '{}.pth'.format(min(pths)))))
44 |
45 |
46 | def load_network_ckpt(net, ckpt_path):
47 | pretrained_model = torch.load(ckpt_path, torch.device('cpu'))
48 | pretrained_model = pretrained_model['state_dict']
49 |
50 | pretrained_model = remove_net_layer(pretrained_model, 'detector')
51 | pretrained_model = remove_net_prefix(pretrained_model, 'superglue.')
52 |
53 | print('=> load weights: ', ckpt_path)
54 | net.load_state_dict(pretrained_model)
55 | return None
56 |
57 |
58 | def load_network(net, model_dir, resume=True, epoch=-1, strict=True, force=False):
59 | """
60 | Load latest network-weights from dir or path
61 | """
62 | if not resume:
63 | return 0
64 |
65 | if not os.path.exists(model_dir):
66 | if force:
67 | raise NotImplementedError
68 | else:
69 | print('pretrained model does not exist')
70 | return 0
71 |
72 | if os.path.isdir(model_dir):
73 | pths = [int(pth.split('.')[0]) for pth in os.listdir(model_dir) if 'pth' in pth]
74 | if len(pths) == 0:
75 | return 0
76 | if epoch == -1:
77 | pth = max(pths)
78 | else:
79 | pth = epoch
80 | model_path = os.path.join(model_dir, '{}.pth'.format(pth))
81 | else:
82 | model_path = model_dir
83 |
84 | print('=> load weights: ', model_path)
85 | pretrained_model = torch.load(model_path, torch.device("cpu"))
86 | if 'net' in pretrained_model.keys():
87 | net.load_state_dict(pretrained_model['net'], strict=strict)
88 | else:
89 | net.load_state_dict(pretrained_model, strict=strict)
90 | return pretrained_model.get('epoch', 0) + 1
91 |
92 |
93 | def remove_net_prefix(net, prefix):
94 | net_ = OrderedDict()
95 | for k in net.keys():
96 | if k.startswith(prefix):
97 | net_[k[len(prefix):]] = net[k]
98 | else:
99 | net_[k] = net[k]
100 | return net_
101 |
102 |
103 | def add_net_prefix(net, prefix):
104 | net_ = OrderedDict()
105 | for k in net.keys():
106 | net_[prefix + k] = net[k]
107 | return net_
108 |
109 |
110 | def replace_net_prefix(net, orig_prefix, prefix):
111 | net_ = OrderedDict()
112 | for k in net.keys():
113 | if k.startswith(orig_prefix):
114 | net_[prefix + k[len(orig_prefix):]] = net[k]
115 | else:
116 | net_[k] = net[k]
117 | return net_
118 |
119 |
120 | def remove_net_layer(net, layers):
121 | keys = list(net.keys())
122 | for k in keys:
123 | for layer in layers:
124 | if k.startswith(layer):
125 | del net[k]
126 | return net
127 |
128 |
129 | def to_cuda(data):
130 | if type(data).__name__ == "Tensor":
131 | data = data.cuda()
132 | elif type(data).__name__ == 'list':
133 | data = [d.cuda() for d in data]
134 | elif type(data).__name__ == 'dict':
135 | data = {k: v.cuda() for k, v in data.items()}
136 | else:
137 | raise NotImplementedError
138 | return data
139 |
--------------------------------------------------------------------------------
/src/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 |
4 |
5 | """
6 | For each object, we store in the following directory format:
7 |
8 | data_root:
9 | - box3d_corners.txt
10 | - seq1_root
11 | - intrinsics.txt
12 | - color/
13 | - color_det[optional]/
14 | - poses_ba/
15 | - intrin_ba/
16 | - intrin_det[optional]/
17 | - ......
18 | - seq2_root
19 | - ......
20 | """
21 |
22 | def get_gt_pose_path_by_color(color_path, det_type='GT_box'):
23 | if det_type == "GT_box":
24 | return color_path.replace("/color/", "/poses_ba/").replace(
25 | ".png", ".txt"
26 | )
27 | elif det_type == 'feature_matching':
28 | return color_path.replace("/color_det/", "/poses_ba/").replace(
29 | ".png", ".txt"
30 | )
31 | else:
32 | raise NotImplementedError
33 |
34 | def get_img_full_path_by_color(color_path, det_type='GT_box'):
35 | if det_type == "GT_box":
36 | return color_path.replace("/color/", "/color_full/")
37 | elif det_type == 'feature_matching':
38 | return color_path.replace("/color_det/", "/color_full/")
39 | else:
40 | raise NotImplementedError
41 |
42 | def get_intrin_path_by_color(color_path, det_type='GT_box'):
43 | if det_type == "GT_box":
44 | return color_path.replace("/color/", "/intrin_ba/").replace(
45 | ".png", ".txt"
46 | )
47 | elif det_type == 'feature_matching':
48 | return color_path.replace("/color_det/", "/intrin_det/").replace(
49 | ".png", ".txt"
50 | )
51 | else:
52 | raise NotImplementedError
53 |
54 | def get_intrin_dir(seq_root):
55 | return osp.join(seq_root, "intrin_ba")
56 |
57 | def get_gt_pose_dir(seq_root):
58 | return osp.join(seq_root, "poses_ba")
59 |
60 | def get_intrin_full_path(seq_root):
61 | return osp.join(seq_root, "intrinsics.txt")
62 |
63 | def get_3d_box_path(data_root):
64 | return osp.join(data_root, "box3d_corners.txt")
65 |
66 |
--------------------------------------------------------------------------------
/src/utils/template_utils.py:
--------------------------------------------------------------------------------
1 | # pytorch lightning imports
2 | import pytorch_lightning as pl
3 |
4 | # hydra imports
5 | from omegaconf import DictConfig, OmegaConf
6 | from hydra.utils import get_original_cwd, to_absolute_path
7 |
8 | # loggers
9 | import wandb
10 | from pytorch_lightning.loggers.wandb import WandbLogger
11 |
12 | # from pytorch_lightning.loggers.neptune import NeptuneLogger
13 | # from pytorch_lightning.loggers.comet import CometLogger
14 | # from pytorch_lightning.loggers.mlflow import MLFlowLogger
15 | # from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
16 |
17 | # rich imports
18 | from rich import print
19 | from rich.syntax import Syntax
20 | from rich.tree import Tree
21 |
22 | # normal imports
23 | from typing import List
24 |
25 |
26 | def print_config(config: DictConfig):
27 | """Prints content of Hydra config using Rich library.
28 |
29 | Args:
30 | config (DictConfig): [description]
31 | """
32 |
33 | # TODO print main config path and experiment config path
34 | # directory = to_absolute_path("configs/config.yaml")
35 | # print(f"Main config path: [link file://{directory}]{directory}")
36 |
37 | style = "dim"
38 |
39 | tree = Tree(f":gear: FULL HYDRA CONFIG", style=style, guide_style=style)
40 |
41 | trainer = OmegaConf.to_yaml(config["trainer"], resolve=True)
42 | trainer_branch = tree.add("Trainer", style=style, guide_style=style)
43 | trainer_branch.add(Syntax(trainer, "yaml"))
44 |
45 | model = OmegaConf.to_yaml(config["model"], resolve=True)
46 | model_branch = tree.add("Model", style=style, guide_style=style)
47 | model_branch.add(Syntax(model, "yaml"))
48 |
49 | datamodule = OmegaConf.to_yaml(config["datamodule"], resolve=True)
50 | datamodule_branch = tree.add("Datamodule", style=style, guide_style=style)
51 | datamodule_branch.add(Syntax(datamodule, "yaml"))
52 |
53 | callbacks_branch = tree.add("Callbacks", style=style, guide_style=style)
54 | if "callbacks" in config:
55 | for cb_name, cb_conf in config["callbacks"].items():
56 | cb = callbacks_branch.add(cb_name, style=style, guide_style=style)
57 | cb.add(Syntax(OmegaConf.to_yaml(cb_conf, resolve=True), "yaml"))
58 | else:
59 | callbacks_branch.add("None")
60 |
61 | logger_branch = tree.add("Logger", style=style, guide_style=style)
62 | if "logger" in config:
63 | for lg_name, lg_conf in config["logger"].items():
64 | lg = logger_branch.add(lg_name, style=style, guide_style=style)
65 | lg.add(Syntax(OmegaConf.to_yaml(lg_conf, resolve=True), "yaml"))
66 | else:
67 | logger_branch.add("None")
68 |
69 | seed = config.get("seed", "None")
70 | seed_branch = tree.add(f"Seed", style=style, guide_style=style)
71 | seed_branch.add(str(seed) + "\n")
72 |
73 | print(tree)
74 |
75 |
76 | def log_hparams_to_all_loggers(
77 | config: DictConfig,
78 | model: pl.LightningModule,
79 | datamodule: pl.LightningDataModule,
80 | trainer: pl.Trainer,
81 | callbacks: List[pl.Callback],
82 | logger: List[pl.loggers.LightningLoggerBase],
83 | ):
84 | """This method controls which parameters from Hydra config are saved by Lightning loggers.
85 |
86 | Args:
87 | config (DictConfig): [description]
88 | model (pl.LightningModule): [description]
89 | datamodule (pl.LightningDataModule): [description]
90 | trainer (pl.Trainer): [description]
91 | callbacks (List[pl.Callback]): [description]
92 | logger (List[pl.loggers.LightningLoggerBase]): [description]
93 | """
94 |
95 | hparams = {}
96 |
97 | # save all params of model, datamodule and trainer
98 | hparams.update(config["model"])
99 | hparams.update(config["datamodule"])
100 | hparams.update(config["trainer"])
101 | hparams.pop("_target_")
102 |
103 | # save seed
104 | hparams["seed"] = config.get("seed", "None")
105 |
106 | # save targets
107 | hparams["_class_model"] = config["model"]["_target_"]
108 | hparams["_class_datamodule"] = config["datamodule"]["_target_"]
109 |
110 | # save sizes of each dataset
111 | if hasattr(datamodule, "data_train") and datamodule.data_train:
112 | hparams["train_size"] = len(datamodule.data_train)
113 | if hasattr(datamodule, "data_val") and datamodule.data_val:
114 | hparams["val_size"] = len(datamodule.data_val)
115 | if hasattr(datamodule, "data_test") and datamodule.data_test:
116 | hparams["test_size"] = len(datamodule.data_test)
117 |
118 | # save number of model parameters
119 | hparams["#params_total"] = sum(p.numel() for p in model.parameters())
120 | hparams["#params_trainable"] = sum(
121 | p.numel() for p in model.parameters() if p.requires_grad
122 | )
123 | hparams["#params_not_trainable"] = sum(
124 | p.numel() for p in model.parameters() if not p.requires_grad
125 | )
126 |
127 | # send hparams to all loggers
128 | for lg in logger:
129 | lg.log_hyperparams(hparams)
130 |
131 |
132 | def finish(
133 | config: DictConfig,
134 | model: pl.LightningModule,
135 | datamodule: pl.LightningDataModule,
136 | trainer: pl.Trainer,
137 | callbacks: List[pl.Callback],
138 | logger: List[pl.loggers.LightningLoggerBase],
139 | ):
140 | """Makes sure everything closed properly.
141 |
142 | Args:
143 | config (DictConfig): [description]
144 | model (pl.LightningModule): [description]
145 | datamodule (pl.LightningDataModule): [description]
146 | trainer (pl.Trainer): [description]
147 | callbacks (List[pl.Callback]): [description]
148 | logger (List[pl.loggers.LightningLoggerBase]): [description]
149 | """
150 |
151 | # without this sweeps with wandb logger might crash!
152 | for lg in logger:
153 | if isinstance(lg, WandbLogger):
154 | wandb.finish()
155 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning import LightningModule, Callback, Trainer
2 | from pytorch_lightning import seed_everything
3 | from pytorch_lightning.loggers import LightningLoggerBase
4 |
5 | import hydra
6 | from omegaconf import DictConfig
7 | from typing import List
8 | from src.utils import template_utils as utils
9 |
10 | import warnings
11 | warnings.filterwarnings('ignore')
12 |
13 |
14 | def train(config: DictConfig):
15 | if config['print_config']:
16 | utils.print_config(config)
17 |
18 | if "seed" in config:
19 | seed_everything(config['seed'])
20 |
21 | # Init PyTorch Lightning model ⚡
22 | model: LightningModule = hydra.utils.instantiate(config['model'])
23 |
24 | # Init PyTorch Lightning datamodule ⚡
25 | datamodule: LightningModule = hydra.utils.instantiate(config['datamodule'])
26 | datamodule.setup()
27 |
28 | # Init PyTorch Lightning callbacks ⚡
29 | callbacks: List[Callback] = []
30 | if "callbacks" in config:
31 | for _, cb_conf in config['callbacks'].items():
32 | if "_target_" in cb_conf:
33 | callbacks.append(hydra.utils.instantiate(cb_conf))
34 |
35 | # Init PyTorch Lightning loggers ⚡
36 | logger: List[LightningLoggerBase] = []
37 | if "logger" in config:
38 | for _, lg_conf in config['logger'].items():
39 | if "_target_" in lg_conf:
40 | logger.append(hydra.utils.instantiate(lg_conf))
41 |
42 | # Init PyTorch Lightning trainer ⚡
43 | trainer: Trainer = hydra.utils.instantiate(
44 | config['trainer'], callbacks=callbacks, logger=logger
45 | )
46 |
47 | # Send some parameters from config to all lightning loggers
48 | utils.log_hparams_to_all_loggers(
49 | config=config,
50 | model=model,
51 | datamodule=datamodule,
52 | trainer=trainer,
53 | callbacks=callbacks,
54 | logger=logger
55 | )
56 |
57 | # Train the model
58 | trainer.fit(model=model, datamodule=datamodule)
59 |
60 | # Evaluate model on test set after training
61 | # trainer.test()
62 |
63 | # Make sure everything closed properly
64 | utils.finish(
65 | config=config,
66 | model=model,
67 | datamodule=datamodule,
68 | trainer=trainer,
69 | callbacks=callbacks,
70 | logger=logger
71 | )
72 |
73 | # Return best achieved metric score for optuna
74 | optimized_metric = config.get("optimized_metric", None)
75 | if optimized_metric:
76 | return trainer.callback_metrics[optimized_metric]
77 |
78 |
79 | @hydra.main(config_path="configs/", config_name="config.yaml")
80 | def main(config: DictConfig):
81 | return train(config)
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
--------------------------------------------------------------------------------
/video2img.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from argparse import ArgumentParser
3 | from pathlib import Path
4 | from src.utils.data_utils import video2img
5 |
6 |
7 | def main():
8 | parser = ArgumentParser()
9 | parser.add_argument("--input", required=True, help="The video file or directory to be parsed")
10 | parser.add_argument("--downsample", default=1, type=int)
11 | args = parser.parse_args()
12 |
13 | input = args.input
14 |
15 | if osp.isdir(input): # in case of directory which contains video file
16 | video_file = osp.join(input, 'Frames.m4v')
17 | else: # in case of video file
18 | video_file = input
19 | assert osp.exists(video_file), "Please input an valid video file!"
20 |
21 | data_dir = osp.dirname(video_file)
22 | out_dir = osp.join(data_dir, 'color_full')
23 | Path(out_dir).mkdir(exist_ok=True, parents=True)
24 |
25 | video2img(video_file, out_dir, args.downsample)
26 |
27 |
28 | if __name__ == "__main__":
29 | main()
--------------------------------------------------------------------------------