├── LICENSE
├── README.md
├── benchmark
├── config.py
├── mapfree.py
├── metrics.py
├── reprojection.py
├── test_metrics.py
└── utils.py
├── config
├── MicKey
│ ├── curriculum_learning.yaml
│ ├── curriculum_learning_warm_up.yaml
│ ├── overlap_score.yaml
│ └── overlap_score_warm_up.yaml
├── datasets
│ └── mapfree.yaml
└── default.py
├── data
└── toy_example
│ ├── im0.jpg
│ ├── im1.jpg
│ └── intrinsics.txt
├── demo_inference.py
├── lib
├── benchmarks
│ ├── reprojection.py
│ └── utils.py
├── datasets
│ ├── datamodules.py
│ ├── mapfree.py
│ ├── sampler.py
│ └── utils.py
├── models
│ ├── MicKey
│ │ ├── compute_pose.py
│ │ ├── model.py
│ │ └── modules
│ │ │ ├── DINO_modules
│ │ │ ├── dinov2.py
│ │ │ └── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── attention.py
│ │ │ │ ├── block.py
│ │ │ │ ├── dino_head.py
│ │ │ │ ├── drop_path.py
│ │ │ │ ├── layer_scale.py
│ │ │ │ ├── mlp.py
│ │ │ │ ├── patch_embed.py
│ │ │ │ └── swiglu_ffn.py
│ │ │ ├── att_layers
│ │ │ ├── attention.py
│ │ │ ├── transformer.py
│ │ │ └── transformer_utils.py
│ │ │ ├── compute_correspondences.py
│ │ │ ├── loss
│ │ │ ├── loss_class.py
│ │ │ ├── loss_utils.py
│ │ │ └── solvers.py
│ │ │ ├── mickey_extractor.py
│ │ │ └── utils
│ │ │ ├── extractor_utils.py
│ │ │ ├── feature_matcher.py
│ │ │ ├── probabilisticProcrustes.py
│ │ │ └── training_utils.py
│ └── builder.py
└── utils
│ ├── data.py
│ ├── metrics.py
│ └── visualization.py
├── resources
├── environment.yml
└── teaser_mickey.png
├── submission.py
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright © Niantic, Inc. 2024. Patent Pending.
2 |
3 | All rights reserved.
4 |
5 |
6 |
7 | =======================================================================================
8 |
9 |
10 |
11 | This Software is licensed under the terms of the following Matching 2D Images in 3D:
12 | Metric Relative Pose from Metric Correspondences license which allows for non-commercial
13 | use only. For any other use of the software not covered by the terms of this license,
14 | please contact partnerships@nianticlabs.com
15 |
16 |
17 |
18 | =======================================================================================
19 |
20 |
21 |
22 | Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences License
23 |
24 |
25 | This Agreement is made by and between the Licensor and the Licensee as
26 | defined and identified below.
27 |
28 |
29 | 1. Definitions.
30 |
31 | In this Agreement (“the Agreement”) the following words shall have the
32 | following meanings:
33 |
34 | "Authors" shall mean A. Barroso-Laguna, S. Munukutla, V. Prisacariu, E. Brachmann
35 | "Licensee" Shall mean the person or organization agreeing to use the
36 | Software in accordance with these terms and conditions.
37 | "Licensor" shall mean Niantic Inc., a company organized and existing under
38 | the laws of Delaware, whose principal place of business is at 1 Ferry Building,
39 | Suite 200, San Francisco, 94111.
40 | "Software" shall mean the Matching 2D Images in 3D: Metric Relative Pose
41 | from Metric Correspondences Software uploaded by Licensor to the GitHub
42 | repository at https://github.com/nianticlabs/mickey
43 | on April 10th 2024 in source code or object code form and any
44 | accompanying documentation as well as any modifications or additions uploaded
45 | to the same GitHub repository by Licensor.
46 |
47 |
48 | 2. License.
49 |
50 | 2.1 The Licensor has all necessary rights to grant a license under: (i)
51 | copyright and rights in the nature of copyright subsisting in the Software; and
52 | (ii) certain patent rights resulting from a patent application(s) filed by the
53 | Licensor in the United States and/or other jurisdictions in connection with the
54 | Software. The Licensor grants the Licensee for the duration of this Agreement,
55 | a free of charge, non-sublicenseable, non-exclusive, non-transferable copyright
56 | and patent license (in consequence of said patent application(s)) to use the
57 | Software for non-commercial purpose only, including teaching and research at
58 | educational institutions and research at not-for-profit research institutions
59 | in accordance with the provisions of this Agreement. Non-commercial use
60 | expressly excludes any profit-making or commercial activities, including without
61 | limitation sale, license, manufacture or development of commercial products, use in
62 | commercially-sponsored research, use at a laboratory or other facility owned or
63 | controlled (whether in whole or in part) by a commercial entity, provision of
64 | consulting service, use for or on behalf of any commercial entity, use in
65 | research where a commercial party obtains rights to research results or any
66 | other benefit, and use of the code in any models, model weights or code
67 | resulting from such procedure in any commercial product. Notwithstanding the
68 | foregoing restrictions, you can use this code for publishing comparison results
69 | for academic papers, including retraining on your own data. Any use of the
70 | Software for any purpose other than pursuant to the license grant set forth
71 | above shall automatically terminate this License.
72 |
73 |
74 | 2.2 The Licensee is permitted to make modifications to the Software
75 | provided that any distribution of such modifications is in accordance with
76 | Clause 3.
77 |
78 | 2.3 Except as expressly permitted by this Agreement and save to the
79 | extent and in the circumstances expressly required to be permitted by law, the
80 | Licensee is not permitted to rent, lease, sell, offer to sell, or loan the
81 | Software or its associated documentation.
82 |
83 |
84 | 3. Redistribution and modifications
85 |
86 | 3.1 The Licensee may reproduce and distribute copies of the Software, with
87 | or without modifications, in source format only and only to this same GitHub
88 | repository , and provided that any and every distribution is accompanied by an
89 | unmodified copy of this License and that the following copyright notice is
90 | always displayed in an obvious manner: Copyright © Niantic, Inc. 2018. All
91 | rights reserved.
92 |
93 |
94 | 3.2 In the case where the Software has been modified, any distribution must
95 | include prominent notices indicating which files have been changed.
96 |
97 | 3.3 The Licensee shall cause any work that it distributes or publishes,
98 | that in whole or in part contains or is derived from the Software or any part
99 | thereof (“Work based on the Software”), to be licensed as a whole at no charge
100 | to all third parties entitled to a license to the Software under the terms of
101 | this License and on the same terms provided in this License.
102 |
103 |
104 | 4. Duration.
105 |
106 | This Agreement is effective until the Licensee terminates it by destroying
107 | the Software, any Work based on the Software, and its documentation together
108 | with all copies. It will also terminate automatically if the Licensee fails to
109 | abide by its terms. Upon automatic termination the Licensee agrees to destroy
110 | all copies of the Software, Work based on the Software, and its documentation.
111 |
112 |
113 | 5. Disclaimer of Warranties.
114 |
115 | The Software is provided as is. To the maximum extent permitted by law,
116 | Licensor provides no warranties or conditions of any kind, either express or
117 | implied, including without limitation, any warranties or condition of title,
118 | non-infringement or fitness for a particular purpose.
119 |
120 |
121 | 6. LIMITATION OF LIABILITY.
122 |
123 | IN NO EVENT SHALL THE LICENSOR AND/OR AUTHORS BE LIABLE FOR ANY DIRECT,
124 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING
125 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
126 | DATA OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
127 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
128 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
129 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
130 |
131 |
132 | 7. Indemnity.
133 |
134 | The Licensee shall indemnify the Licensor and/or Authors against all third
135 | party claims that may be asserted against or suffered by the Licensor and/or
136 | Authors and which relate to use of the Software by the Licensee.
137 |
138 |
139 | 8. Intellectual Property.
140 |
141 | 8.1 As between the Licensee and Licensor, copyright and all other
142 | intellectual property rights subsisting in or in connection with the Software
143 | and supporting information shall remain at all times the property of the
144 | Licensor. The Licensee shall acquire no rights in any such material except as
145 | expressly provided in this Agreement.
146 |
147 | 8.2 No permission is granted to use the trademarks or product names of the
148 | Licensor except as required for reasonable and customary use in describing the
149 | origin of the Software and for the purposes of abiding by the terms of Clause
150 | 3.1.
151 |
152 | 8.3 The Licensee shall promptly notify the Licensor of any improvement or
153 | new use of the Software (“Improvements”) in sufficient detail for Licensor to
154 | evaluate the Improvements. The Licensee hereby grants the Licensor and its
155 | affiliates a non-exclusive, fully paid-up, royalty-free, irrevocable and
156 | perpetual license to all Improvements for non-commercial academic research and
157 | teaching purposes upon creation of such improvements.
158 |
159 | 8.4 The Licensee grants an exclusive first option to the Licensor to be
160 | exercised by the Licensor within three (3) years of the date of notification of
161 | an Improvement under Clause 8.3 to use any the Improvement for commercial
162 | purposes on terms to be negotiated and agreed by Licensee and Licensor in good
163 | faith within a period of six (6) months from the date of exercise of the said
164 | option (including without limitation any royalty share in net income from such
165 | commercialization payable to the Licensee, as the case may be).
166 |
167 |
168 | 9. Acknowledgements.
169 |
170 | The Licensee shall acknowledge the Authors and use of the Software in the
171 | publication of any work that uses, or results that are achieved through, the
172 | use of the Software. The following citation shall be included in the
173 | acknowledgement: “Matching 2D Images in 3D: Metric Relative Pose
174 | from Metric Correspondences", by A. Barroso-Laguna, S. Munukutla,
175 | V. Prisacariu, E. Brachmann, CVPR 2024.
176 |
177 |
178 | 10. Governing Law.
179 |
180 | This Agreement shall be governed by, construed and interpreted in
181 | accordance with English law and the parties submit to the exclusive
182 | jurisdiction of the English courts.
183 |
184 |
185 | 11. Termination.
186 |
187 | Upon termination of this Agreement, the licenses granted hereunder will
188 | terminate and Sections 5, 6, 7, 8, 9, 10 and 11 shall survive any termination
189 | of this Agreement.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
15 |
16 | This is the reference implementation of the paper **"Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences"** presented at **CVPR 2024**.
17 |
18 | The paper introduces **M**etr**ic Key**points (MicKey), a feature detection pipeline that regresses keypoint positions in camera space.
19 | MicKey presents a differentiable approach to establish metric correspondences via descriptor matching. From the metric correspondences, MicKey recovers metric relative poses.
20 | MicKey is trained in an end-to-end fashion using differentiable pose optimization and requires only image pairs and their ground truth relative poses for supervision.
21 |
22 |
23 |
24 |
25 |
26 | ## Setup
27 |
28 | Assuming a fresh [Anaconda](https://www.anaconda.com/download/) distribution, you can install dependencies with:
29 | ```shell
30 | conda env create -f resources/environment.yml
31 | conda activate mickey
32 | ```
33 | We ran our experiments with PyTorch 2.0.1, CUDA 11.6, Python 3.8.17 and Debian GNU/Linux 11.
34 |
35 | ## Evaluating MicKey
36 | MicKey aims at addressing the problem of instant Augmented Reality (AR) introduced in the [Map-free benchmark](https://research.nianticlabs.com/mapfree-reloc-benchmark).
37 | In the Map-free set up, instead of building 3D maps from hundreds of images and scale calibrations, they propose to use only one photo of a scene as the map.
38 | The Map-free benchmark then evaluates how accurate is the estimated metric relative pose between the reference image (the map)
39 | and the query image (the user).
40 |
41 | ### Download Map-free dataset
42 | You can find the Map-free dataset in [their project page](https://research.nianticlabs.com/mapfree-reloc-benchmark/dataset).
43 | Extract the test.zip file into `data/mapfree`. Optionally, if you want to train MicKey, also download train and val zip files.
44 |
45 | ### Pre-trained Models
46 | We provide two [MicKey models](https://storage.googleapis.com/niantic-lon-static/research/mickey/assets/mickey_weights.zip).
47 | * _mickey.ckpt_: These are the default weights for MicKey, without using the overlapping scores provides in Map-free dataset and following the curriculum learning strategy described in the paper.
48 | * _mickey_sc.ckpt_: These are the weights when training MicKey using the min and max overlapping scores defined in Map-free.
49 |
50 | Extract mickey_weights.zip into `weights/`. In the zip file, we also provide the default configuration needed to run the evaluation.
51 |
52 | ### Run the submission script
53 | Similar to Map-free code base, we provide a [submission script](submission.py) to generate submission files:
54 |
55 | ```shell
56 | python submission.py --config path/to/config --checkpoint path/to/checkpoint --o results/your_method
57 | ```
58 | The resulting file `results/your_method/submission.zip` can be uploaded to the Map-free [online benchmark website](https://research.nianticlabs.com/mapfree-reloc-benchmark) and compared against existing methods in the [leaderboard](https://research.nianticlabs.com/mapfree-reloc-benchmark/leaderboard).
59 |
60 | ### Run the local evaluation
61 | The Map-free benchmark does not provide ground-truth poses for the test set. But we can still evaluate our method locally on the validation set.
62 | ```shell
63 | python submission.py --config path/to/config --checkpoint path/to/checkpoint --o results/your_method --split val
64 | ```
65 | and evaluate it as:
66 | ```shell
67 | python -m benchmark.mapfree --submission_path results/your_method/submission.zip --split val
68 | ```
69 |
70 | ### Download MicKey correspondences and depth files
71 | We provide the depth maps and correspondences computed by MicKey.
72 | - Download [MicKey depth maps](https://storage.googleapis.com/niantic-lon-static/research/map-free-reloc/assets/mickey_depths.tar.gz).
73 | - Download [MicKey correspondences](https://storage.googleapis.com/niantic-lon-static/research/map-free-reloc/assets/mickey_correspondences.zip).
74 | - Extract the contents of both files to `data/mapfree`
75 |
76 | Refer to the [Map-free benchmark](https://github.com/nianticlabs/map-free-reloc/tree/main?tab=readme-ov-file#feature-matching--scale-from-depth-baselines) to learn how to load precomputed correspondes and depth maps in their feature matching pipeline.
77 |
78 | ### Running MicKey in custom images
79 | We provide a [demo script](demo_inference.py) to run the relative pose estimation pipeline on custom image pairs.
80 | As an example, we store in `data/toy_example` two images with their respective intrinsics.
81 | The script computes their metric relative pose and saves the corresponding depth and keypoint score maps in the image folder.
82 | Run the demo script as:
83 | ```shell
84 | python demo_inference.py --im_path_ref data/toy_example/im0.jpg \
85 | --im_path_dst data/toy_example/im1.jpg \
86 | --intrinsics data/toy_example/intrinsics.txt \
87 | --checkpoint path/to/checkpoint \
88 | --config path/to/config
89 | ```
90 |
91 | To generate the 3D assets as in [MicKey's webpage](https://nianticlabs.github.io/mickey/), you can turn on the
92 | `--generate_3D_vis` flag. This will generate a rendered image with the input images, their computed 3D camera positions,
93 | and the set of 3D point inliers.
94 |
95 | ## Training MicKey
96 | Besides the test scripts, we also provide the training code to train MicKey.
97 |
98 | We provide two default configurations in `config/MicKey/`:
99 | * _curriculum_learning.yaml_: This configuration follows the curriculum learning approach detailed in the paper.
100 | It hence does not use any image overlapping information but only relative ground truth poses during training.
101 | * _overlap_score.yaml_: This configuration relies on the image overlapping information to only choose solvable image pairs during training.
102 |
103 | Besides the two default configurations, we also provide a configuration to speed up their training.
104 | These configurations use low-resolution images and do not add the null hypothesis (refer to Section 3.1.4 for details).
105 | We recommend initializing MicKey with these configurations and then
106 | fine-tuning the network with the default ones (which use high-resolution images and the null hypothesis).
107 | They can be found under `config/MicKey/`:
108 | * _curriculum_learning_warm_up.yaml_
109 | * _overlap_score_warm_up.yaml_
110 |
111 | To train MicKey default model, use:
112 | ```shell
113 | python train.py --config config/MicKey/curriculum_learning.yaml \
114 | --dataset_config config/datasets/mapfree.yaml \
115 | --experiment experiment_name \
116 | --path_weights path/to/checkpoint/folder
117 | ```
118 | Resume training from a checkpoint by adding `--resume {path_to_checkpoint}`.
119 |
120 | The top models, according to the validation loss, the VCRE metric, and the pose AUC score, are saved during training.
121 | Tensorboard results and checkpoints are saved into the folder `dir/to/weights/experiment_name`.
122 |
123 | Note that by default, the configuration is set to use 4 GPUs.
124 | You can reduce the expected number of GPUs in the config file (e.g., _NUM_GPUS: 1_).
125 |
126 | ## Changelog
127 | - 19 September 2024: Added vectorized RANSAC and warm up configurations.
128 | - 13 August 2024: Added visualization code.
129 | - 7 June 2024: Added precomputed depth maps and keypoint correspondences.
130 |
131 | ## BibTeX
132 | If you use this code in your research, please consider citing our paper:
133 |
134 | ```bibtex
135 | @inproceedings{barroso2024mickey,
136 | title={Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences},
137 | author={Barroso-Laguna, Axel and Munukutla, Sowmya and Prisacariu, Victor and Brachmann, Eric},
138 | booktitle={CVPR},
139 | year={2024}
140 | }
141 | ```
142 |
143 | ## License
144 | Copyright © Niantic, Inc. 2024. Patent Pending. All rights reserved. This code is for non-commercial use. Please see the [license](LICENSE) file for terms.
145 |
146 | ## Acknowledgements
147 | We use parts of code from different repositories. We thank the authors and maintainers of the following repositories.
148 | - [Map-free](https://research.nianticlabs.com/mapfree-reloc-benchmark)
149 | - [ACE](https://github.com/nianticlabs/ace)
150 | - [ACE0](https://github.com/nianticlabs/acezero)
151 | - [DUSt3R](https://github.com/naver/dust3r)
152 | - [RoMa](https://github.com/Parskatt/RoMa)
153 | - [DINOv2](https://github.com/facebookresearch/dinov2)
154 | - [LoFTR](https://github.com/zju3dv/LoFTR)
155 | - [DPT](https://github.com/isl-org/DPT)
156 | - [ExtremeRotation](https://github.com/RuojinCai/ExtremeRotation_code)
157 |
158 |
--------------------------------------------------------------------------------
/benchmark/config.py:
--------------------------------------------------------------------------------
1 | # translation and rotation thresholds [meters, degrees]
2 | # used to compute Precision and AUC considering Pose Error
3 | t_threshold = 0.25
4 | R_threshold = 5
5 |
6 | # reprojection (VCRE) threshold [pixels]
7 | # used to compute Precision and AUC considering VCRE
8 | vcre_threshold = 90
9 |
--------------------------------------------------------------------------------
/benchmark/mapfree.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | from pathlib import Path
4 | from zipfile import ZipFile
5 | from io import TextIOWrapper
6 | import json
7 | import logging
8 | import numpy as np
9 |
10 | from benchmark.utils import load_poses, subsample_poses, load_K, precision_recall
11 | from benchmark.metrics import MetricManager, Inputs
12 | import benchmark.config as config
13 | from config.default import cfg
14 |
15 | def plot_perfect_curve(P):
16 | total_bins = 1000
17 | prec_values = []
18 | ratio_values = []
19 | for i in range(total_bins):
20 | ratio_tmp = i/total_bins
21 | value = min(1, P / ratio_tmp)
22 | prec_values.append(value)
23 | ratio_values.append(ratio_tmp)
24 | return prec_values, ratio_values
25 |
26 | def compute_scene_metrics(dataset_path: Path, submission_zip: ZipFile, scene: str):
27 | metric_manager = MetricManager()
28 |
29 | # load intrinsics and poses
30 | try:
31 | K, W, H = load_K(dataset_path / scene / 'intrinsics.txt')
32 | with (dataset_path / scene / 'poses.txt').open('r', encoding='utf-8') as gt_poses_file:
33 | gt_poses = load_poses(gt_poses_file, load_confidence=False)
34 | except FileNotFoundError as e:
35 | logging.error(f'Could not find ground-truth dataset files: {e}')
36 | raise
37 | else:
38 | logging.info(
39 | f'Loaded ground-truth intrinsics and poses for scene {scene}')
40 |
41 | # try to load estimated poses from submission
42 | try:
43 | with submission_zip.open(f'pose_{scene}.txt') as estimated_poses_file:
44 | estimated_poses_file_wrapper = TextIOWrapper(
45 | estimated_poses_file, encoding='utf-8')
46 | estimated_poses = load_poses(
47 | estimated_poses_file_wrapper, load_confidence=True)
48 | except KeyError as e:
49 | logging.warning(
50 | f'Submission does not have estimates for scene {scene}.')
51 | return dict(), len(gt_poses)
52 | except UnicodeDecodeError as e:
53 | logging.error('Unsupported file encoding: please use UTF-8')
54 | raise
55 | else:
56 | logging.info(f'Loaded estimated poses for scene {scene}')
57 |
58 | # The val/test set is subsampled by a factor of 5
59 | gt_poses = subsample_poses(gt_poses, subsample=5)
60 |
61 | # failures encode how many frames did not have an estimate
62 | # e.g. user/method did not provide an estimate for that frame
63 | # it's different from when an estimate is provided with low confidence!
64 | failures = 0
65 |
66 | # Results encoded as dict
67 | # key: metric name; value: list of values (one per frame).
68 | # e.g. results['t_err'] = [1.2, 0.3, 0.5, ...]
69 | results = defaultdict(list)
70 |
71 | # compute metrics per frame
72 | for frame_num, (q_gt, t_gt, _) in gt_poses.items():
73 | if frame_num not in estimated_poses:
74 | failures += 1
75 | continue
76 |
77 | q_est, t_est, confidence = estimated_poses[frame_num]
78 | inputs = Inputs(q_gt=q_gt, t_gt=t_gt, q_est=q_est, t_est=t_est,
79 | confidence=confidence, K=K[frame_num], W=W, H=H)
80 | metric_manager(inputs, results)
81 |
82 | return results, failures
83 |
84 |
85 | def aggregate_results(all_results, all_failures):
86 | # aggregate metrics
87 | median_metrics = defaultdict(list)
88 | all_metrics = defaultdict(list)
89 | for scene_results in all_results.values():
90 | for metric, values in scene_results.items():
91 | median_metrics[metric].append(np.median(values))
92 | all_metrics[metric].extend(values)
93 | all_metrics = {k: np.array(v) for k, v in all_metrics.items()}
94 | assert all([v.ndim == 1 for v in all_metrics.values()]
95 | ), 'invalid metrics shape'
96 |
97 | # compute avg median metrics
98 | avg_median_metrics = {metric: np.mean(
99 | values) for metric, values in median_metrics.items()}
100 |
101 | # compute precision/AUC for pose error and reprojection errors
102 | accepted_poses = (all_metrics['trans_err'] < config.t_threshold) * \
103 | (all_metrics['rot_err'] < config.R_threshold)
104 | accepted_vcre = all_metrics['reproj_err'] < config.vcre_threshold
105 | total_samples = len(next(iter(all_metrics.values()))) + all_failures
106 |
107 | prec_pose = np.sum(accepted_poses) / total_samples
108 | prec_vcre = np.sum(accepted_vcre) / total_samples
109 |
110 | # compute AUC for pose and VCRE
111 | pose_prec_values, pose_recall_values, auc_pose = precision_recall(
112 | inliers=all_metrics['confidence'], tp=accepted_poses, failures=all_failures)
113 | vcre_prec_values, vcre_recall_values, auc_vcre = precision_recall(
114 | inliers=all_metrics['confidence'], tp=accepted_vcre, failures=all_failures)
115 |
116 | curves_data = {}
117 | curves_data['vcre_prec_values'], curves_data['vcre_recall_values'] = vcre_prec_values, vcre_recall_values
118 | curves_data['pose_prec_values'], curves_data['pose_recall_values'] = pose_prec_values, pose_recall_values
119 |
120 | # output metrics
121 | output_metrics = dict()
122 | output_metrics['Average Median Translation Error'] = avg_median_metrics['trans_err']
123 | output_metrics['Average Median Rotation Error'] = avg_median_metrics['rot_err']
124 | output_metrics['Average Median Reprojection Error'] = avg_median_metrics['reproj_err']
125 | output_metrics[f'Precision @ Pose Error < ({config.t_threshold*100}cm, {config.R_threshold}deg)'] = prec_pose
126 | output_metrics[f'AUC @ Pose Error < ({config.t_threshold*100}cm, {config.R_threshold}deg)'] = auc_pose
127 | output_metrics[f'Precision @ VCRE < {config.vcre_threshold}px'] = prec_vcre
128 | output_metrics[f'AUC @ VCRE < {config.vcre_threshold}px'] = auc_vcre
129 | output_metrics[f'Estimates for % of frames'] = len(all_metrics['trans_err']) / total_samples
130 | return output_metrics, curves_data
131 |
132 |
133 | def count_unexpected_scenes(scenes: tuple, submission_zip: ZipFile):
134 | submission_scenes = [fname[5:-4]
135 | for fname in submission_zip.namelist() if fname.startswith("pose_")]
136 | return len(set(submission_scenes) - set(scenes))
137 |
138 | def main(args):
139 | dataset_path = args.dataset_path / args.split
140 | scenes = tuple(f.name for f in dataset_path.iterdir() if f.is_dir())
141 |
142 | try:
143 | submission_zip = ZipFile(args.submission_path, 'r')
144 | except FileNotFoundError as e:
145 | logging.error(f'Could not find ZIP file in path {args.submission_path}')
146 | return
147 |
148 | all_results = dict()
149 | all_failures = 0
150 | for scene in scenes:
151 | metrics, failures = compute_scene_metrics(
152 | dataset_path, submission_zip, scene)
153 | all_results[scene] = metrics
154 | all_failures += failures
155 |
156 | if all_failures > 0:
157 | logging.warning(
158 | f'Submission is missing pose estimates for {all_failures} frames')
159 |
160 | unexpected_scene_count = count_unexpected_scenes(scenes, submission_zip)
161 | if unexpected_scene_count > 0:
162 | logging.warning(
163 | f'Submission contains estimates for {unexpected_scene_count} scenes outside the {args.split} set')
164 |
165 | if all((len(metrics) == 0 for metrics in all_results.values())):
166 | logging.error(
167 | f'Submission does not have any valid pose estimates')
168 | return
169 |
170 | output_metrics, curves_data = aggregate_results(all_results, all_failures)
171 | output_json = json.dumps(output_metrics, indent=2)
172 | print(output_json)
173 |
174 |
175 | if __name__ == '__main__':
176 | parser = argparse.ArgumentParser(
177 | 'eval', description='Evaluate submissions for the MapFree dataset benchmark')
178 | parser.add_argument('--submission_path', type=Path, default='',
179 | help='Path to the submission ZIP file')
180 | parser.add_argument('--split', choices=('val', 'test'), default='test',
181 | help='Dataset split to use for evaluation. Default: test')
182 | parser.add_argument('--log', choices=('warning', 'info', 'error'),
183 | default='warning', help='Logging level. Default: warning')
184 | parser.add_argument('--dataset_path', type=Path, default=None,
185 | help='Path to the dataset folder')
186 |
187 | args = parser.parse_args()
188 |
189 | if args.dataset_path is None:
190 | cfg.merge_from_file('config/datasets/mapfree.yaml')
191 | args.dataset_path = Path(cfg.DATASET.DATA_ROOT)
192 |
193 | logging.basicConfig(level=args.log.upper())
194 | try:
195 | main(args)
196 | except Exception:
197 | logging.error("Unexpected behaviour. Exiting.")
198 |
199 |
--------------------------------------------------------------------------------
/benchmark/metrics.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable
3 |
4 | import numpy as np
5 |
6 | from benchmark.reprojection import reprojection_error
7 | from benchmark.utils import VARIANTS_ANGLE_SIN, quat_angle_error
8 |
9 |
10 | @dataclass
11 | class Inputs:
12 | q_gt: np.array
13 | t_gt: np.array
14 | q_est: np.array
15 | t_est: np.array
16 | confidence: float
17 | K: np.array
18 | W: int
19 | H: int
20 |
21 | def __post_init__(self):
22 | assert self.q_gt.shape == (4,), 'invalid gt quaternion shape'
23 | assert self.t_gt.shape == (3,), 'invalid gt translation shape'
24 | assert self.q_est.shape == (4,), 'invalid estimated quaternion shape'
25 | assert self.t_est.shape == (3,), 'invalid estimated translation shape'
26 | assert self.confidence >= 0, 'confidence must be non negative'
27 | assert self.K.shape == (3, 3), 'invalid K shape'
28 | assert self.W > 0, 'invalid image width'
29 | assert self.H > 0, 'invalid image height'
30 |
31 |
32 | class MyDict(dict):
33 | def register(self, fn) -> Callable:
34 | """Registers a function within dict(fn_name -> fn_ref).
35 | This is used to evaluate all registered metrics in MetricManager.__call__()"""
36 | self[fn.__name__] = fn
37 | return fn
38 |
39 |
40 | class MetricManager:
41 | _metrics = MyDict()
42 |
43 | def __call__(self, inputs: Inputs, results: dict) -> None:
44 | for metric, metric_fn in self._metrics.items():
45 | results[metric].append(metric_fn(inputs))
46 |
47 | @staticmethod
48 | @_metrics.register
49 | def trans_err(inputs: Inputs) -> np.float64:
50 | return np.linalg.norm(inputs.t_est - inputs.t_gt)
51 |
52 | @staticmethod
53 | @_metrics.register
54 | def rot_err(inputs: Inputs, variant: str = VARIANTS_ANGLE_SIN) -> np.float64:
55 | return quat_angle_error(label=inputs.q_est, pred=inputs.q_gt, variant=variant)[0, 0]
56 |
57 | @staticmethod
58 | @_metrics.register
59 | def reproj_err(inputs: Inputs) -> float:
60 | return reprojection_error(
61 | q_est=inputs.q_est, t_est=inputs.t_est, q_gt=inputs.q_gt, t_gt=inputs.t_gt, K=inputs.K,
62 | W=inputs.W, H=inputs.H)
63 |
64 | @staticmethod
65 | @_metrics.register
66 | def confidence(inputs: Inputs) -> float:
67 | return inputs.confidence
68 |
--------------------------------------------------------------------------------
/benchmark/reprojection.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import numpy as np
4 | from transforms3d.quaternions import quat2mat
5 |
6 |
7 | def project(pts: np.ndarray, K: np.ndarray, img_size: List[int] or Tuple[int] = None) -> np.ndarray:
8 | """Projects 3D points to image plane.
9 |
10 | Args:
11 | - pts [N, 3/4]: points in camera coordinates (homogeneous or non-homogeneous)
12 | - K [3, 3]: intrinsic matrix
13 | - img_size (width, height): optional, clamp projection to image borders
14 | Outputs:
15 | - uv [N, 2]: coordinates of projected points
16 | """
17 |
18 | assert len(pts.shape) == 2, 'incorrect number of dimensions'
19 | assert pts.shape[1] in [3, 4], 'invalid dimension size'
20 | assert K.shape == (3, 3), 'incorrect intrinsic shape'
21 |
22 | uv_h = (K @ pts[:, :3].T).T
23 | uv = uv_h[:, :2] / uv_h[:, -1:]
24 |
25 | if img_size is not None:
26 | uv[:, 0] = np.clip(uv[:, 0], 0, img_size[0])
27 | uv[:, 1] = np.clip(uv[:, 1], 0, img_size[1])
28 |
29 | return uv
30 |
31 |
32 | def get_grid_multipleheight() -> np.ndarray:
33 | # create grid of points
34 | ar_grid_step = 0.3
35 | ar_grid_num_x = 7
36 | ar_grid_num_y = 4
37 | ar_grid_num_z = 7
38 | ar_grid_z_offset = 1.8
39 | ar_grid_y_offset = 0
40 |
41 | ar_grid_x_pos = np.arange(0, ar_grid_num_x)-(ar_grid_num_x-1)/2
42 | ar_grid_x_pos *= ar_grid_step
43 |
44 | ar_grid_y_pos = np.arange(0, ar_grid_num_y)-(ar_grid_num_y-1)/2
45 | ar_grid_y_pos *= ar_grid_step
46 | ar_grid_y_pos += ar_grid_y_offset
47 |
48 | ar_grid_z_pos = np.arange(0, ar_grid_num_z).astype(float)
49 | ar_grid_z_pos *= ar_grid_step
50 | ar_grid_z_pos += ar_grid_z_offset
51 |
52 | xx, yy, zz = np.meshgrid(ar_grid_x_pos, ar_grid_y_pos, ar_grid_z_pos)
53 | ones = np.ones(xx.shape[0]*xx.shape[1]*xx.shape[2])
54 | eye_coords = np.concatenate([c.reshape(-1, 1)
55 | for c in (xx, yy, zz, ones)], axis=-1)
56 | return eye_coords
57 |
58 |
59 | # global variable, avoids creating it again
60 | eye_coords_glob = get_grid_multipleheight()
61 |
62 |
63 | def reprojection_error(
64 | q_est: np.ndarray, t_est: np.ndarray, q_gt: np.ndarray, t_gt: np.ndarray, K: np.ndarray,
65 | W: int, H: int) -> float:
66 | eye_coords = eye_coords_glob
67 |
68 | # obtain ground-truth position of projected points
69 | uv_gt = project(eye_coords, K, (W, H))
70 |
71 | # residual transformation
72 | cam2w_est = np.eye(4)
73 | cam2w_est[:3, :3] = quat2mat(q_est)
74 | cam2w_est[:3, -1] = t_est
75 | cam2w_gt = np.eye(4)
76 | cam2w_gt[:3, :3] = quat2mat(q_gt)
77 | cam2w_gt[:3, -1] = t_gt
78 |
79 | # residual reprojection
80 | eyes_residual = (np.linalg.inv(cam2w_est) @ cam2w_gt @ eye_coords.T).T
81 | uv_pred = project(eyes_residual, K, (W, H))
82 |
83 | # get reprojection error
84 | repr_err = np.linalg.norm(uv_gt - uv_pred, ord=2, axis=1)
85 | mean_repr_err = float(repr_err.mean().item())
86 | return mean_repr_err
87 |
--------------------------------------------------------------------------------
/benchmark/test_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from transforms3d.euler import euler2quat
4 | from transforms3d.quaternions import axangle2quat, qmult, quat2mat, rotate_vector
5 |
6 | from benchmark.metrics import Inputs, MetricManager
7 | from benchmark.reprojection import project
8 | from benchmark.utils import VARIANTS_ANGLE_COS, VARIANTS_ANGLE_SIN
9 |
10 |
11 | def createInput(q_gt=None, t_gt=None, q_est=None, t_est=None, confidence=None, K=None, W=None, H=None):
12 | q_gt = np.zeros(4) if q_gt is None else q_gt
13 | t_gt = np.zeros(3) if t_gt is None else t_gt
14 | q_est = np.zeros(4) if q_est is None else q_est
15 | t_est = np.zeros(3) if t_est is None else t_est
16 | confidence = 0. if confidence is None else confidence
17 | K = np.eye(3) if K is None else K
18 | H = 1 if H is None else H
19 | W = 1 if W is None else W
20 | return Inputs(q_gt=q_gt, t_gt=t_gt, q_est=q_est, t_est=t_est, confidence=confidence, K=K, W=W, H=H)
21 |
22 |
23 | def randomQuat():
24 | angles = np.random.uniform(0, 2*np.pi, 3)
25 | q = euler2quat(*angles)
26 | return q
27 |
28 |
29 | class TestMetrics:
30 | @pytest.mark.parametrize('run_number', range(50))
31 | def test_t_err_tinvariance(self, run_number: int) -> None:
32 | """Computes the translation error given an initial translation and displacement of this
33 | translation. The translation error must be equal to the norm of the displacement."""
34 | mean, var = 5, 10
35 | t0 = np.random.normal(mean, var, (3,))
36 | displacement = np.random.normal(mean, var, (3,))
37 |
38 | i = createInput(t_gt=t0, t_est=t0+displacement)
39 | trans_err = MetricManager.trans_err(i)
40 | assert np.isclose(trans_err, np.linalg.norm(displacement))
41 |
42 | @pytest.mark.parametrize('run_number', range(50))
43 | def test_trans_err_rinvariance(self, run_number: int) -> None:
44 | """Computes the translation error given estimated and gt vectors.
45 | The translation error must be the same for a rotated version of those vectors
46 | (same random rotation)"""
47 | mean, var = 5, 10
48 | t0 = np.random.normal(mean, var, (3,))
49 | t1 = np.random.normal(mean, var, (3,))
50 | q = randomQuat()
51 |
52 | i = createInput(t_gt=t0, t_est=t1)
53 | trans_err = MetricManager.trans_err(i)
54 |
55 | ir = createInput(t_gt=rotate_vector(t0, q), t_est=rotate_vector(t1, q))
56 | trans_err_r = MetricManager.trans_err(ir)
57 |
58 | assert np.isclose(trans_err, trans_err_r)
59 |
60 | @pytest.mark.parametrize('run_number', range(50))
61 | @pytest.mark.parametrize('dtype', (np.float64, np.float32))
62 | def test_rot_err_raxis(self, run_number: int, dtype: type) -> None:
63 | """Test rotation error for rotations around a random axis.
64 |
65 | Note: We create GT as high precision, and only downcast when calling rot_err.
66 | """
67 | q = randomQuat().astype(np.float64)
68 |
69 | axis = np.random.uniform(low=-1, high=1, size=3).astype(np.float64)
70 | angle = np.float64(np.random.uniform(low=-np.pi, high=np.pi))
71 | qres = axangle2quat(vector=axis, theta=angle, is_normalized=False).astype(np.float64)
72 |
73 | i = createInput(q_gt=q.astype(dtype), q_est=qmult(q, qres).astype(dtype))
74 | rot_err = MetricManager.rot_err(i)
75 | assert isinstance(rot_err, np.float64)
76 | rot_err_expected = np.abs(np.degrees(angle))
77 | # if we add up errors, we want them to be positive
78 | assert 0. <= rot_err
79 | rtol = 1.e-5 # numpy default
80 | atol = 1.e-8 # numpy default
81 | if isinstance(dtype, np.float32):
82 | atol = 1.e-7 # 1/50 test might fail at 1.e-8
83 | assert np.isclose(rot_err, rot_err_expected, rtol=rtol, atol=atol)
84 |
85 | @pytest.mark.parametrize('run_number', range(50))
86 | def test_r_err_mat(self, run_number: int) -> None:
87 | q0 = randomQuat()
88 | q1 = randomQuat()
89 |
90 | i = createInput(q_gt=q0, q_est=q1)
91 | rot_err = MetricManager.rot_err(i)
92 |
93 | R0 = quat2mat(q0)
94 | R1 = quat2mat(q1)
95 | Rres = R1 @ R0.T
96 | theta = (np.trace(Rres) - 1)/2
97 | theta = np.clip(theta, -1, 1)
98 | angle = np.degrees(np.arccos(theta))
99 |
100 | assert np.isclose(angle, rot_err)
101 |
102 | def test_reproj_error_identity(self):
103 | """Test that reprojection error is zero if poses match"""
104 | q = randomQuat()
105 | t = np.random.normal(0, 10, (3,))
106 | i = createInput(q_gt=q, t_gt=t, q_est=q, t_est=t)
107 |
108 | reproj_err = MetricManager.reproj_err(i)
109 | assert np.isclose(reproj_err, 0)
110 |
111 | @pytest.mark.parametrize('run_number', range(10))
112 | @pytest.mark.parametrize('variant', (VARIANTS_ANGLE_SIN,))
113 | @pytest.mark.parametrize('dtype', (np.float64,))
114 | def test_r_err_small(self, run_number: int, variant: str, dtype: type) -> None:
115 | """Test rotation error for small angle differences.
116 |
117 | Note: We create GT as high precision, and only downcast when calling rot_err.
118 | """
119 | scales_failed = []
120 | for scale in np.logspace(start=-1, stop=-9, num=9, base=10, dtype=dtype):
121 | q = randomQuat().astype(np.float64)
122 | angle = np.float64(np.random.uniform(low=-np.pi, high=np.pi)) * scale
123 | assert isinstance(angle, np.float64)
124 | axis = np.random.uniform(low=-1., high=1., size=3).astype(np.float64)
125 | assert axis.dtype == np.float64
126 | qres = axangle2quat(vector=axis, theta=angle, is_normalized=False).astype(np.float64)
127 | assert qres.dtype == np.float64
128 |
129 | i = createInput(q_gt=q.astype(dtype), q_est=qmult(q, qres).astype(dtype))
130 |
131 | # We expect the error to always be np.float64 for highest acc.
132 | rot_err = MetricManager.rot_err(i, variant=variant)
133 | assert isinstance(rot_err, np.float64)
134 | rot_err_expected = np.abs(np.degrees(angle))
135 | assert isinstance(rot_err_expected, type(rot_err))
136 |
137 | # if we add up errors, we want them to be positive
138 | assert 0. <= rot_err
139 |
140 | # check accuracy for one magnitude higher tolerance than the angle
141 | tol = 0.1 * scale
142 | # need to be more permissive for lower precision
143 | if dtype == np.float32:
144 | tol = 1.e3 * scale
145 |
146 | # cast to dtype for checking
147 | rot_err = rot_err.astype(dtype)
148 | rot_err_expected = rot_err_expected.astype(dtype)
149 |
150 | if variant == VARIANTS_ANGLE_SIN:
151 | assert np.isclose(rot_err, rot_err_expected, rtol=tol, atol=tol)
152 | elif variant == VARIANTS_ANGLE_COS:
153 | if not np.isclose(rot_err, rot_err_expected, rtol=tol, atol=tol):
154 | print(f"[variant '{variant}'] raises an error for\n"
155 | f"\trot_err: {rot_err}"
156 | f"\trot_err_expected: {rot_err_expected}"
157 | f"\trtol: {tol}"
158 | f"\tatol: {tol}")
159 | scales_failed.append(scale)
160 | if len(scales_failed):
161 | pytest.fail(f"Variant {variant} failed at scales {scales_failed}")
162 |
163 |
164 | def test_projection() -> None:
165 | xyz = np.array(((10, 20, 30), (10, 30, 50), (-20, -15, 5),
166 | (-20, -50, 10)), dtype=np.float32)
167 | K = np.eye(3)
168 |
169 | uv = np.array(((1/3, 2/3), (1/5, 3/5), (-4, -3),
170 | (-2, -5)), dtype=np.float32)
171 | assert np.allclose(uv, project(xyz, K))
172 |
173 | uv = np.array(((1/3, 2/3), (1/5, 3/5), (0, 0), (0, 0)), dtype=np.float32)
174 | assert np.allclose(uv, project(xyz, K, img_size=(5, 5)))
175 |
--------------------------------------------------------------------------------
/benchmark/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import typing
3 | import logging
4 |
5 | import numpy as np
6 | from transforms3d.quaternions import qinverse, rotate_vector, qmult
7 |
8 | VARIANTS_ANGLE_SIN = 'sin'
9 | VARIANTS_ANGLE_COS = 'cos'
10 |
11 |
12 | def convert_world2cam_to_cam2world(q, t):
13 | qinv = qinverse(q)
14 | tinv = -rotate_vector(t, qinv)
15 | return qinv, tinv
16 |
17 |
18 | def load_poses(file: typing.IO, load_confidence: bool = False):
19 | """Load poses from text file and converts them to cam2world convention (t is the camera center in world coordinates)
20 |
21 | The text file encodes world2cam poses with the format:
22 | imgpath qw qx qy qz tx ty tz [confidence]
23 | where qw qx qy qz is the quaternion encoding rotation,
24 | and tx ty tz is the translation vector,
25 | and confidence is a float encoding confidence, for estimated poses
26 | """
27 |
28 | expected_parts = 9 if load_confidence else 8
29 |
30 | poses = dict()
31 | for line_number, line in enumerate(file.readlines()):
32 | parts = tuple(line.strip().split(' '))
33 |
34 | # if 'tensor' in parts[-1]:
35 | # print('ERROR: confidence is a tensor')
36 | # parts = list(parts)
37 | # parts[-1] = parts[-1].split('[')[-1].split(']')[0]
38 | if len(parts) != expected_parts:
39 | logging.warning(
40 | f'Invalid number of fields in file {file.name} line {line_number}.'
41 | f' Expected {expected_parts}, received {len(parts)}. Ignoring line.')
42 | continue
43 |
44 | try:
45 | name = parts[0]
46 | if '#' in name:
47 | logging.info(f'Ignoring comment line in {file.name} line {line_number}')
48 | continue
49 | frame_num = int(name[-9:-4])
50 | except ValueError:
51 | logging.warning(
52 | f'Invalid frame number in file {file.name} line {line_number}.'
53 | f' Expected formatting "seq1/frame_00000.jpg". Ignoring line.')
54 | continue
55 |
56 | try:
57 | parts_float = tuple(map(float, parts[1:]))
58 | if any(np.isnan(v) or np.isinf(v) for v in parts_float):
59 | raise ValueError()
60 | qw, qx, qy, qz, tx, ty, tz = parts_float[:7]
61 | confidence = parts_float[7] if load_confidence else None
62 | except ValueError:
63 | logging.warning(
64 | f'Error parsing pose in file {file.name} line {line_number}. Ignoring line.')
65 | continue
66 |
67 | q = np.array((qw, qx, qy, qz), dtype=np.float64)
68 | t = np.array((tx, ty, tz), dtype=np.float64)
69 |
70 | if np.isclose(np.linalg.norm(q), 0):
71 | logging.warning(
72 | f'Error parsing pose in file {file.name} line {line_number}. '
73 | 'Quaternion must have non-zero norm. Ignoring line.')
74 | continue
75 |
76 | q, t = convert_world2cam_to_cam2world(q, t)
77 | poses[frame_num] = (q, t, confidence)
78 | return poses
79 |
80 |
81 | def subsample_poses(poses: dict, subsample: int = 1):
82 | return {k: v for i, (k, v) in enumerate(poses.items()) if i % subsample == 0}
83 |
84 |
85 | def load_K(file_path: Path):
86 | K = dict()
87 | with file_path.open('r', encoding='utf-8') as f:
88 | for line in f.readlines():
89 | if '#' in line:
90 | continue
91 | line = line.strip().split(' ')
92 |
93 | frame_num = int(line[0][-9:-4])
94 | fx, fy, cx, cy, W, H = map(float, line[1:])
95 | K[frame_num] = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
96 | return K, W, H
97 |
98 |
99 | def quat_angle_error(label, pred, variant=VARIANTS_ANGLE_SIN) -> np.ndarray:
100 | assert label.shape == (4,)
101 | assert pred.shape == (4,)
102 | assert variant in (VARIANTS_ANGLE_SIN, VARIANTS_ANGLE_COS), \
103 | f"Need variant to be in ({VARIANTS_ANGLE_SIN}, {VARIANTS_ANGLE_COS})"
104 |
105 | if len(label.shape) == 1:
106 | label = np.expand_dims(label, axis=0)
107 | if len(label.shape) != 2 or label.shape[0] != 1 or label.shape[1] != 4:
108 | raise RuntimeError(f"Unexpected shape of label: {label.shape}, expected: (1, 4)")
109 |
110 | if len(pred.shape) == 1:
111 | pred = np.expand_dims(pred, axis=0)
112 | if len(pred.shape) != 2 or pred.shape[0] != 1 or pred.shape[1] != 4:
113 | raise RuntimeError(f"Unexpected shape of pred: {pred.shape}, expected: (1, 4)")
114 |
115 | label = label.astype(np.float64)
116 | pred = pred.astype(np.float64)
117 |
118 | q1 = pred / np.linalg.norm(pred, axis=1, keepdims=True)
119 | q2 = label / np.linalg.norm(label, axis=1, keepdims=True)
120 | if variant == VARIANTS_ANGLE_COS:
121 | d = np.abs(np.sum(np.multiply(q1, q2), axis=1, keepdims=True))
122 | d = np.clip(d, a_min=-1, a_max=1)
123 | angle = 2. * np.degrees(np.arccos(d))
124 | elif variant == VARIANTS_ANGLE_SIN:
125 | if q1.shape[0] != 1 or q2.shape[0] != 1:
126 | raise NotImplementedError(f"Multiple angles is todo")
127 | # https://www.researchgate.net/post/How_do_I_calculate_the_smallest_angle_between_two_quaternions/5d6ed4a84f3a3e1ed3656616/citation/download
128 | sine = qmult(q1[0], qinverse(q2[0])) # note: takes first element in 2D array
129 | # 114.59 = 2. * 180. / pi
130 | angle = np.arcsin(np.linalg.norm(sine[1:], keepdims=True)) * 114.59155902616465
131 | angle = np.expand_dims(angle, axis=0)
132 |
133 | return angle.astype(np.float64)
134 |
135 |
136 | def precision_recall(inliers, tp, failures):
137 | """
138 | Computes Precision/Recall plot for a set of poses given inliers (confidence) and wether the
139 | estimated pose error (whatever it may be) is within a threshold.
140 | Each point in the plot is obtained by choosing a threshold for inliers (i.e. inlier_thr).
141 | Recall measures how many images have inliers >= inlier_thr
142 | Precision measures how many images that have inliers >= inlier_thr have
143 | estimated pose error <= pose_threshold (measured by counting tps)
144 | Where pose_threshold is (trans_thr[m], rot_thr[deg])
145 |
146 | Inputs:
147 | - inliers [N]
148 | - terr [N]
149 | - rerr [N]
150 | - failures (int)
151 | - pose_threshold (tuple float)
152 | Output
153 | - precision [N]
154 | - recall [N]
155 | - average_precision (scalar)
156 | """
157 |
158 | assert len(inliers) == len(tp), 'unequal shapes'
159 |
160 | # sort by inliers (descending order)
161 | inliers = np.array(inliers)
162 | sort_idx = np.argsort(inliers)[::-1]
163 | inliers = inliers[sort_idx]
164 | tp = np.array(tp).reshape(-1)[sort_idx]
165 |
166 | # get idxs where inliers change (avoid tied up values)
167 | distinct_value_indices = np.where(np.diff(inliers))[0]
168 | threshold_idxs = np.r_[distinct_value_indices, inliers.size - 1]
169 |
170 | # compute prec/recall
171 | N = inliers.shape[0]
172 | rec = np.arange(N, dtype=np.float32) + 1
173 | cum_tp = np.cumsum(tp)
174 | prec = cum_tp[threshold_idxs] / rec[threshold_idxs]
175 | rec = rec[threshold_idxs] / (float(N) + float(failures))
176 |
177 | # invert order and ensures (prec=1, rec=0) point
178 | last_ind = rec.searchsorted(rec[-1])
179 | sl = slice(last_ind, None, -1)
180 | prec = np.r_[prec[sl], 1]
181 | rec = np.r_[rec[sl], 0]
182 |
183 | # compute average precision (AUC) as the weighted average of precisions
184 | average_precision = np.abs(np.sum(np.diff(rec) * np.array(prec)[:-1]))
185 |
186 | return prec, rec, average_precision
187 |
--------------------------------------------------------------------------------
/config/MicKey/curriculum_learning.yaml:
--------------------------------------------------------------------------------
1 |
2 | MODEL: 'MicKey'
3 | DEBUG: False
4 | MICKEY:
5 | DINOV2:
6 | DOWN_FACTOR: 14
7 | CHANNEL_DIM: 1024
8 | FLOAT16: True
9 |
10 | KP_HEADS:
11 | BLOCKS_DIM: [512, 256, 128, 64]
12 | BN: True
13 | USE_SOFTMAX: True
14 | USE_DEPTHSIGMOID: False
15 | MAX_DEPTH: 60
16 | POS_ENCODING: True
17 |
18 | DSC_HEAD:
19 | LAST_DIM: 128
20 | BLOCKS_DIM: [512, 256, 128]
21 | BN: True
22 | NORM_DSC: True
23 | POS_ENCODING: True
24 |
25 | FEATURE_MATCHER:
26 | TYPE: 'DualSoftmax'
27 | DUAL_SOFTMAX:
28 | TEMPERATURE: 0.1
29 | USE_DUSTBIN: True
30 | SINKHORN:
31 | NUM_IT: 10
32 | DUSTBIN_SCORE_INIT: 1.
33 | USE_TRANSFORMER: False
34 |
35 | TRAINING:
36 | NUM_GPUS: 4
37 | BATCH_SIZE: 8 # BS for each dataloader (in every GPU)
38 | NUM_WORKERS: 8
39 | SAMPLER: 'scene_balance'
40 | N_SAMPLES_SCENE: 100
41 | SAMPLE_WITH_REPLACEMENT: True
42 | LR: 1e-4
43 | LOG_INTERVAL: 50
44 | VAL_INTERVAL: 0.5
45 | VAL_BATCHES: 100
46 | EPOCHS: 100
47 |
48 | DATASET:
49 | HEIGHT: 720
50 | WIDTH: 540
51 |
52 | MIN_OVERLAP_SCORE: 0.0 # [train only] discard data with overlap_score < min_overlap_score
53 | MAX_OVERLAP_SCORE: 1.0 # [train only] discard data with overlap_score < min_overlap_score
54 |
55 | LOSS_CLASS:
56 |
57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR
58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values.
59 |
60 | POSE_ERR:
61 | MAX_LOSS_VALUE: 1.5
62 | MAX_LOSS_SOFTVALUE: 0.8
63 | VCRE:
64 | MAX_LOSS_VALUE: 90
65 | MAX_LOSS_SOFTVALUE: 0.8
66 |
67 | GENERATE_HYPOTHESES:
68 | SCORE_TEMPERATURE: 20
69 | IT_MATCHES: 20
70 | IT_RANSAC: 20
71 | INLIER_3D_TH: 0.3
72 | INLIER_REF_TH: 0.15
73 | NUM_REF_STEPS: 4
74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability
75 |
76 | NULL_HYPOTHESIS:
77 | ADD_NULL_HYPOTHESIS: True
78 | TH_OUTLIERS: 0.35
79 |
80 | CURRICULUM_LEARNING:
81 | TRAIN_CURRICULUM: True # It indicates if MicKey should be trained with curriculum learning
82 | TRAIN_WITH_TOPK: True # It indicates if MicKey should be trained only with top image pairs
83 | TOPK_INIT: 30
84 | TOPK: 80
85 |
86 | SAMPLER:
87 | NUM_SAMPLES_MATCHES: 512
88 |
89 | PROCRUSTES:
90 | IT_MATCHES: 20
91 | IT_RANSAC: 100
92 | NUM_SAMPLED_MATCHES: 2048
93 | NUM_CORR_3D_3D: 3
94 | NUM_REFINEMENTS: 4
95 | TH_INLIER: 0.15
96 | TH_SOFT_INLIER: 0.3
97 |
98 |
--------------------------------------------------------------------------------
/config/MicKey/curriculum_learning_warm_up.yaml:
--------------------------------------------------------------------------------
1 |
2 | MODEL: 'MicKey'
3 | DEBUG: False
4 | MICKEY:
5 | DINOV2:
6 | DOWN_FACTOR: 14
7 | CHANNEL_DIM: 1024
8 | FLOAT16: True
9 |
10 | KP_HEADS:
11 | BLOCKS_DIM: [512, 256, 128, 64]
12 | BN: True
13 | USE_SOFTMAX: True
14 | USE_DEPTHSIGMOID: False
15 | MAX_DEPTH: 60
16 | POS_ENCODING: True
17 |
18 | DSC_HEAD:
19 | LAST_DIM: 128
20 | BLOCKS_DIM: [512, 256, 128]
21 | BN: True
22 | NORM_DSC: True
23 | POS_ENCODING: True
24 |
25 | FEATURE_MATCHER:
26 | TYPE: 'DualSoftmax'
27 | DUAL_SOFTMAX:
28 | TEMPERATURE: 0.1
29 | USE_DUSTBIN: True
30 | SINKHORN:
31 | NUM_IT: 10
32 | DUSTBIN_SCORE_INIT: 1.
33 | USE_TRANSFORMER: False
34 |
35 | TRAINING:
36 | NUM_GPUS: 4
37 | BATCH_SIZE: 24 # BS for each dataloader (in every GPU)
38 | NUM_WORKERS: 8
39 | SAMPLER: 'scene_balance'
40 | N_SAMPLES_SCENE: 100
41 | SAMPLE_WITH_REPLACEMENT: True
42 | LR: 1e-4
43 | LOG_INTERVAL: 50
44 | VAL_INTERVAL: 0.5
45 | VAL_BATCHES: 100
46 | EPOCHS: 100
47 |
48 | DATASET:
49 | HEIGHT: 480
50 | WIDTH: 360
51 |
52 | MIN_OVERLAP_SCORE: 0.0 # [train only] discard data with overlap_score < min_overlap_score
53 | MAX_OVERLAP_SCORE: 1.0 # [train only] discard data with overlap_score < min_overlap_score
54 |
55 | LOSS_CLASS:
56 |
57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR
58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values.
59 |
60 | POSE_ERR:
61 | MAX_LOSS_VALUE: 1.5
62 | MAX_LOSS_SOFTVALUE: 0.8
63 | VCRE:
64 | MAX_LOSS_VALUE: 90
65 | MAX_LOSS_SOFTVALUE: 0.8
66 |
67 | GENERATE_HYPOTHESES:
68 | SCORE_TEMPERATURE: 20
69 | IT_MATCHES: 20
70 | IT_RANSAC: 20
71 | INLIER_3D_TH: 0.3
72 | INLIER_REF_TH: 0.15
73 | NUM_REF_STEPS: 4
74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability
75 |
76 | NULL_HYPOTHESIS:
77 | ADD_NULL_HYPOTHESIS: False
78 | TH_OUTLIERS: 0.35
79 |
80 | CURRICULUM_LEARNING:
81 | TRAIN_CURRICULUM: True # It indicates if MicKey should be trained with curriculum learning
82 | TRAIN_WITH_TOPK: True # It indicates if MicKey should be trained only with top image pairs
83 | TOPK_INIT: 30
84 | TOPK: 80
85 |
86 | SAMPLER:
87 | NUM_SAMPLES_MATCHES: 64
88 |
89 | PROCRUSTES:
90 | IT_MATCHES: 20
91 | IT_RANSAC: 100
92 | NUM_SAMPLED_MATCHES: 2048
93 | NUM_CORR_3D_3D: 3
94 | NUM_REFINEMENTS: 4
95 | TH_INLIER: 0.15
96 | TH_SOFT_INLIER: 0.3
97 |
98 |
--------------------------------------------------------------------------------
/config/MicKey/overlap_score.yaml:
--------------------------------------------------------------------------------
1 |
2 | MODEL: 'MicKey'
3 | DEBUG: False
4 | MICKEY:
5 | DINOV2:
6 | DOWN_FACTOR: 14
7 | CHANNEL_DIM: 1024
8 | FLOAT16: True
9 |
10 | KP_HEADS:
11 | BLOCKS_DIM: [512, 256, 128, 64]
12 | BN: True
13 | USE_SOFTMAX: True
14 | USE_DEPTHSIGMOID: False
15 | MAX_DEPTH: 60
16 | POS_ENCODING: True
17 |
18 | DSC_HEAD:
19 | LAST_DIM: 128
20 | BLOCKS_DIM: [512, 256, 128]
21 | BN: True
22 | NORM_DSC: True
23 | POS_ENCODING: True
24 |
25 | FEATURE_MATCHER:
26 | TYPE: 'DualSoftmax'
27 | DUAL_SOFTMAX:
28 | TEMPERATURE: 0.1
29 | USE_DUSTBIN: True
30 | SINKHORN:
31 | NUM_IT: 10
32 | DUSTBIN_SCORE_INIT: 1.
33 | USE_TRANSFORMER: False
34 |
35 | TRAINING:
36 | NUM_GPUS: 4
37 | BATCH_SIZE: 8 # BS for each dataloader (in every GPU)
38 | NUM_WORKERS: 8
39 | SAMPLER: 'scene_balance'
40 | N_SAMPLES_SCENE: 100
41 | SAMPLE_WITH_REPLACEMENT: True
42 | LR: 1e-4
43 | LOG_INTERVAL: 50
44 | VAL_INTERVAL: 0.5
45 | VAL_BATCHES: 100
46 | EPOCHS: 100
47 |
48 | DATASET:
49 | HEIGHT: 720
50 | WIDTH: 540
51 |
52 | MIN_OVERLAP_SCORE: 0.4 # [train only] discard data with overlap_score < min_overlap_score
53 | MAX_OVERLAP_SCORE: 0.8 # [train only] discard data with overlap_score < min_overlap_score
54 |
55 | LOSS_CLASS:
56 |
57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR
58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values.
59 |
60 | POSE_ERR:
61 | MAX_LOSS_VALUE: 1.5
62 | MAX_LOSS_SOFTVALUE: 0.8
63 | VCRE:
64 | MAX_LOSS_VALUE: 90
65 | MAX_LOSS_SOFTVALUE: 0.8
66 |
67 | GENERATE_HYPOTHESES:
68 | SCORE_TEMPERATURE: 20
69 | IT_MATCHES: 20
70 | IT_RANSAC: 20
71 | INLIER_3D_TH: 0.3
72 | INLIER_REF_TH: 0.15
73 | NUM_REF_STEPS: 4
74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability
75 |
76 | NULL_HYPOTHESIS:
77 | ADD_NULL_HYPOTHESIS: True
78 | TH_OUTLIERS: 0.35
79 |
80 | CURRICULUM_LEARNING:
81 | TRAIN_CURRICULUM: False # It indicates if MicKey should be trained with curriculum learning
82 | TRAIN_WITH_TOPK: False # It indicates if MicKey should be trained only with top image pairs
83 | TOPK_INIT: 30
84 | TOPK: 80
85 |
86 | SAMPLER:
87 | NUM_SAMPLES_MATCHES: 512
88 |
89 | PROCRUSTES:
90 | IT_MATCHES: 20
91 | IT_RANSAC: 100
92 | NUM_SAMPLED_MATCHES: 2048
93 | NUM_CORR_3D_3D: 3
94 | NUM_REFINEMENTS: 4
95 | TH_INLIER: 0.15
96 | TH_SOFT_INLIER: 0.3
--------------------------------------------------------------------------------
/config/MicKey/overlap_score_warm_up.yaml:
--------------------------------------------------------------------------------
1 |
2 | MODEL: 'MicKey'
3 | DEBUG: False
4 | MICKEY:
5 | DINOV2:
6 | DOWN_FACTOR: 14
7 | CHANNEL_DIM: 1024
8 | FLOAT16: True
9 |
10 | KP_HEADS:
11 | BLOCKS_DIM: [512, 256, 128, 64]
12 | BN: True
13 | USE_SOFTMAX: True
14 | USE_DEPTHSIGMOID: False
15 | MAX_DEPTH: 60
16 | POS_ENCODING: True
17 |
18 | DSC_HEAD:
19 | LAST_DIM: 128
20 | BLOCKS_DIM: [512, 256, 128]
21 | BN: True
22 | NORM_DSC: True
23 | POS_ENCODING: True
24 |
25 | FEATURE_MATCHER:
26 | TYPE: 'DualSoftmax'
27 | DUAL_SOFTMAX:
28 | TEMPERATURE: 0.1
29 | USE_DUSTBIN: True
30 | SINKHORN:
31 | NUM_IT: 10
32 | DUSTBIN_SCORE_INIT: 1.
33 | USE_TRANSFORMER: False
34 |
35 | TRAINING:
36 | NUM_GPUS: 4
37 | BATCH_SIZE: 24 # BS for each dataloader (in every GPU)
38 | NUM_WORKERS: 8
39 | SAMPLER: 'scene_balance'
40 | N_SAMPLES_SCENE: 100
41 | SAMPLE_WITH_REPLACEMENT: True
42 | LR: 1e-4
43 | LOG_INTERVAL: 50
44 | VAL_INTERVAL: 0.5
45 | VAL_BATCHES: 100
46 | EPOCHS: 100
47 |
48 | DATASET:
49 | HEIGHT: 480
50 | WIDTH: 360
51 |
52 | MIN_OVERLAP_SCORE: 0.4 # [train only] discard data with overlap_score < min_overlap_score
53 | MAX_OVERLAP_SCORE: 0.8 # [train only] discard data with overlap_score < min_overlap_score
54 |
55 | LOSS_CLASS:
56 |
57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR
58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values.
59 |
60 | POSE_ERR:
61 | MAX_LOSS_VALUE: 1.5
62 | MAX_LOSS_SOFTVALUE: 0.8
63 | VCRE:
64 | MAX_LOSS_VALUE: 90
65 | MAX_LOSS_SOFTVALUE: 0.8
66 |
67 | GENERATE_HYPOTHESES:
68 | SCORE_TEMPERATURE: 20
69 | IT_MATCHES: 20
70 | IT_RANSAC: 20
71 | INLIER_3D_TH: 0.3
72 | INLIER_REF_TH: 0.15
73 | NUM_REF_STEPS: 4
74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability
75 |
76 | NULL_HYPOTHESIS:
77 | ADD_NULL_HYPOTHESIS: False
78 | TH_OUTLIERS: 0.35
79 |
80 | CURRICULUM_LEARNING:
81 | TRAIN_CURRICULUM: False # It indicates if MicKey should be trained with curriculum learning
82 | TRAIN_WITH_TOPK: False # It indicates if MicKey should be trained only with top image pairs
83 | TOPK_INIT: 30
84 | TOPK: 80
85 |
86 | SAMPLER:
87 | NUM_SAMPLES_MATCHES: 64
88 |
89 | PROCRUSTES:
90 | IT_MATCHES: 20
91 | IT_RANSAC: 100
92 | NUM_SAMPLED_MATCHES: 2048
93 | NUM_CORR_3D_3D: 3
94 | NUM_REFINEMENTS: 4
95 | TH_INLIER: 0.15
96 | TH_SOFT_INLIER: 0.3
--------------------------------------------------------------------------------
/config/datasets/mapfree.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | DATA_SOURCE: 'MapFree'
3 | DATA_ROOT: 'data/'
4 | SCENES: None # should be a list [] or None. If none, use all scenes.
5 | AUGMENTATION_TYPE: None
6 | HEIGHT: 720
7 | WIDTH: 540
8 | MIN_OVERLAP_SCORE: 0.2 # [train only] discard data with overlap_score < min_overlap_score
9 | MAX_OVERLAP_SCORE: 0.7 # [train only] discard data with overlap_score < min_overlap_score
10 | SEED: 66
--------------------------------------------------------------------------------
/config/default.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _CN = CN()
4 |
5 | ############## Model ##############
6 | _CN.MODEL = None # options: ['MicKey']
7 | _CN.DEBUG = False
8 |
9 | # MicKey configuration
10 | _CN.MICKEY = CN()
11 |
12 | _CN.MICKEY.DINOV2 = CN()
13 | _CN.MICKEY.DINOV2.DOWN_FACTOR = None
14 | _CN.MICKEY.DINOV2.CHANNEL_DIM = None
15 | _CN.MICKEY.DINOV2.FLOAT16 = None
16 |
17 | _CN.MICKEY.KP_HEADS = CN()
18 | _CN.MICKEY.KP_HEADS.BLOCKS_DIM = None
19 | _CN.MICKEY.KP_HEADS.BN = None
20 | _CN.MICKEY.KP_HEADS.USE_SOFTMAX = None
21 | _CN.MICKEY.KP_HEADS.USE_DEPTHSIGMOID = None
22 | _CN.MICKEY.KP_HEADS.MAX_DEPTH = None
23 | _CN.MICKEY.KP_HEADS.POS_ENCODING = None
24 |
25 | _CN.MICKEY.DSC_HEAD = CN()
26 | _CN.MICKEY.DSC_HEAD.LAST_DIM = None
27 | _CN.MICKEY.DSC_HEAD.BLOCKS_DIM = None
28 | _CN.MICKEY.DSC_HEAD.BN = None
29 | _CN.MICKEY.DSC_HEAD.NORM_DSC = None
30 | _CN.MICKEY.DSC_HEAD.POS_ENCODING = None
31 |
32 |
33 | _CN.FEATURE_MATCHER = CN()
34 | _CN.FEATURE_MATCHER.TYPE = None
35 | _CN.FEATURE_MATCHER.DUAL_SOFTMAX = CN()
36 | _CN.FEATURE_MATCHER.DUAL_SOFTMAX.TEMPERATURE = None
37 | _CN.FEATURE_MATCHER.DUAL_SOFTMAX.USE_DUSTBIN = None
38 | _CN.FEATURE_MATCHER.SINKHORN = CN()
39 | _CN.FEATURE_MATCHER.SINKHORN.NUM_IT = None
40 | _CN.FEATURE_MATCHER.SINKHORN.DUSTBIN_SCORE_INIT = None
41 | _CN.FEATURE_MATCHER.USE_TRANSFORMER = None
42 | _CN.FEATURE_MATCHER.TOP_KEYPOINTS = False
43 |
44 | # LOSS_CLASS
45 | _CN.LOSS_CLASS = CN()
46 | _CN.LOSS_CLASS.LOSS_FUNCTION = None
47 | _CN.LOSS_CLASS.SOFT_CLIPPING = None
48 |
49 | _CN.LOSS_CLASS.POSE_ERR = CN()
50 | _CN.LOSS_CLASS.POSE_ERR.MAX_LOSS_VALUE = None
51 | _CN.LOSS_CLASS.POSE_ERR.MAX_LOSS_SOFTVALUE = None
52 |
53 | _CN.LOSS_CLASS.VCRE = CN()
54 | _CN.LOSS_CLASS.VCRE.MAX_LOSS_VALUE = None
55 | _CN.LOSS_CLASS.VCRE.MAX_LOSS_SOFTVALUE = None
56 |
57 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES = CN()
58 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.SCORE_TEMPERATURE = None
59 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.IT_MATCHES = None
60 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.IT_RANSAC = None
61 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.INLIER_3D_TH = None
62 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.INLIER_REF_TH = None
63 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.NUM_REF_STEPS = None
64 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.NUM_CORR_3d3d = None
65 |
66 | _CN.LOSS_CLASS.CURRICULUM_LEARNING = CN()
67 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_CURRICULUM = None
68 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_WITH_TOPK = None
69 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TOPK_INIT = None
70 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TOPK = None
71 |
72 | _CN.LOSS_CLASS.NULL_HYPOTHESIS = CN()
73 | _CN.LOSS_CLASS.NULL_HYPOTHESIS.ADD_NULL_HYPOTHESIS = None
74 | _CN.LOSS_CLASS.NULL_HYPOTHESIS.TH_OUTLIERS = None
75 |
76 | _CN.LOSS_CLASS.SAMPLER = CN()
77 | _CN.LOSS_CLASS.SAMPLER.NUM_SAMPLES_MATCHES = None
78 |
79 |
80 | # Procrustes RANSAC options
81 | _CN.PROCRUSTES = CN()
82 | _CN.PROCRUSTES.IT_MATCHES = None
83 | _CN.PROCRUSTES.IT_RANSAC = None
84 | _CN.PROCRUSTES.NUM_SAMPLED_MATCHES = None
85 | _CN.PROCRUSTES.NUM_CORR_3D_3D = None
86 | _CN.PROCRUSTES.NUM_REFINEMENTS = None
87 | _CN.PROCRUSTES.TH_INLIER = None
88 | _CN.PROCRUSTES.TH_SOFT_INLIER = None
89 |
90 |
91 |
92 |
93 | # Training Procrustes RANSAC options
94 | _CN.PROCRUSTES_TRAINING = CN()
95 | _CN.PROCRUSTES_TRAINING.MAX_CORR_DIST = None
96 | _CN.PROCRUSTES_TRAINING.REFINE = False #refine pose with ICP
97 |
98 |
99 | ############## Dataset ##############
100 | _CN.DATASET = CN()
101 | # 1. data config
102 | _CN.DATASET.DATA_SOURCE = None # options: ['ScanNet', '7Scenes', 'MapFree']
103 | _CN.DATASET.SCENES = None # scenes to use (for 7Scenes/MapFree); should be a list []; If none, use all scenes.
104 | _CN.DATASET.DATA_ROOT = None # path to dataset folder
105 | _CN.DATASET.SEED = None # SEED for dataset generation
106 | _CN.DATASET.NPZ_ROOT = None # path to npz files containing pairs of frame indices per sample
107 | _CN.DATASET.MIN_OVERLAP_SCORE = None # discard data with overlap_score < min_overlap_score
108 | _CN.DATASET.MAX_OVERLAP_SCORE = None # discard data with overlap_score > max_overlap_score
109 | _CN.DATASET.CONSECUTIVE_PAIRS = None # options: [None, 'colorjitter']
110 | _CN.DATASET.FRAME_RATE = None # options: [None, 'colorjitter']
111 | _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'colorjitter']
112 | _CN.DATASET.BLACK_WHITE = False # if true, transform images to black & white
113 | _CN.DATASET.PAIRS_TXT = CN() # Path to text file defining the train/val/test pairs (7Scenes)
114 | _CN.DATASET.PAIRS_TXT.TRAIN = None
115 | _CN.DATASET.PAIRS_TXT.VAL = None
116 | _CN.DATASET.PAIRS_TXT.TEST = None
117 | _CN.DATASET.PAIRS_TXT.ONE_NN = False # If true, keeps only reference image w/ highest similarity to each query
118 | _CN.DATASET.HEIGHT = None
119 | _CN.DATASET.WIDTH = None
120 |
121 | ############# TRAINING #############
122 | _CN.TRAINING = CN()
123 | # Data Loader settings
124 | _CN.TRAINING.BATCH_SIZE = None
125 | _CN.TRAINING.NUM_WORKERS = None
126 | _CN.TRAINING.NUM_GPUS = None
127 | _CN.TRAINING.SAMPLER = None # options: ['random', 'scene_balance']
128 | _CN.TRAINING.N_SAMPLES_SCENE = None # if 'scene_balance' sampler, the number of samples to get per scene
129 | _CN.TRAINING.SAMPLE_WITH_REPLACEMENT = None # if 'scene_balance' sampler, whether to sample with replacement
130 |
131 | # Training settings
132 | _CN.TRAINING.LR = None
133 | _CN.TRAINING.LR_STEP_INTERVAL = None
134 | _CN.TRAINING.LR_STEP_GAMMA = None # multiplicative factor of LR every LR_STEP_ITERATIONS
135 | _CN.TRAINING.VAL_INTERVAL = None
136 | _CN.TRAINING.VAL_BATCHES = None
137 | _CN.TRAINING.LOG_INTERVAL = None
138 | _CN.TRAINING.EPOCHS = None
139 | _CN.TRAINING.GRAD_CLIP = 0. # Indicates the L2 norm at which to clip the gradient. Disabled if 0
140 |
141 | cfg = _CN
--------------------------------------------------------------------------------
/data/toy_example/im0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nianticlabs/mickey/2391be8a35491e7b43481c069f5dab65030839b9/data/toy_example/im0.jpg
--------------------------------------------------------------------------------
/data/toy_example/im1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nianticlabs/mickey/2391be8a35491e7b43481c069f5dab65030839b9/data/toy_example/im1.jpg
--------------------------------------------------------------------------------
/data/toy_example/intrinsics.txt:
--------------------------------------------------------------------------------
1 | im0.jpg 549.7018 549.7018 268.6665 351.8357 540 720
2 | im1.jpg 549.0616 549.0616 268.8559 351.8485 540 720
--------------------------------------------------------------------------------
/demo_inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from lib.models.builder import build_model
4 | from lib.datasets.utils import correct_intrinsic_scale
5 | from lib.utils.visualization import prepare_score_map, colorize_depth, get_render, create_point_cloud_from_inliers
6 | from config.default import cfg
7 | import numpy as np
8 | from pathlib import Path
9 | import cv2
10 |
11 |
12 | def read_color_image(path, resize):
13 | """
14 | Args:
15 | resize (tuple): align image to depthmap, in (w, h).
16 | Returns:
17 | image (torch.tensor): (3, h, w)
18 | """
19 | # read and resize image
20 | cv_type = cv2.IMREAD_COLOR
21 | image = cv2.imread(str(path), cv_type)
22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23 | if resize is not None:
24 | image = cv2.resize(image, resize)
25 |
26 | # (h, w, 3) -> (3, h, w) and normalized
27 | image = torch.from_numpy(image).float().permute(2, 0, 1) / 255
28 |
29 | return image.unsqueeze(0)
30 |
31 | def read_intrinsics(path_intrinsics, resize):
32 | Ks = {}
33 | with Path(path_intrinsics).open('r') as f:
34 | for line in f.readlines():
35 | if '#' in line:
36 | continue
37 |
38 | line = line.strip().split(' ')
39 | img_name = line[0]
40 | fx, fy, cx, cy, W, H = map(float, line[1:])
41 |
42 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
43 | if resize is not None:
44 | K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H).numpy()
45 | Ks[img_name] = K
46 | return Ks
47 |
48 | def generate_3d_vis(data, n_matches, root_dir, batch_id,
49 | use_3d_color_coded=True, color_src_frame=[223, 71, 28], color_dst_frame=[83, 154, 218],
50 | add_dst_lines=True, add_ref_lines=True, add_ref_pts=True, add_points=True,
51 | size_box=0.03, size_2d=0.015, cam_size=1., max_conf_th=0.8, add_confidence=True,
52 | angle_y=0, angle_x=-25, cam_offset_x=0.1, cam_offset_y=0.0, cam_offset_z=-2):
53 |
54 | print('Generating 3D visualization image...')
55 |
56 | # Generate point cloud from inlier matches
57 | point_cloud = create_point_cloud_from_inliers(data, batch_id, use_3d_color_coded)
58 |
59 | # Prepare the data:
60 | data_np = {}
61 | data_np['K_color0'] = data['K_color0'].detach().cpu().numpy()
62 | data_np['K_color1'] = data['K_color1'].detach().cpu().numpy()
63 | data_np['image0'] = 255 * data['image0'].permute(0, 2, 3, 1).detach().cpu().numpy()
64 | data_np['image1'] = 255 * data['image1'].permute(0, 2, 3, 1).detach().cpu().numpy()
65 |
66 | R = data['R'][batch_id][np.newaxis].detach().cpu().numpy()
67 | t = data['t'][batch_id].detach().cpu().numpy().reshape(-1)
68 | P = np.eye(4)
69 | P[:3, :3] = R
70 | P[:3, 3] = t
71 |
72 | # Render the image with camera and 3D points
73 | frame = get_render(P, data, batch_id, point_cloud, color_src_frame, color_dst_frame,
74 | angle_y, angle_x, cam_offset_x, cam_offset_y, cam_offset_z, cam_size, size_box, size_2d,
75 | add_ref_lines, add_dst_lines, add_ref_pts, add_points, n_matches, max_conf_th, add_confidence)
76 |
77 | cv2.imwrite(root_dir + '/3d_vis.png', cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB))
78 |
79 | def run_demo_inference(args):
80 |
81 | # Select device
82 | use_cuda = torch.cuda.is_available()
83 | device = torch.device('cuda:0' if use_cuda else 'cpu')
84 |
85 | print('Preparing data...')
86 |
87 | # Prepare config file
88 | cfg.merge_from_file(args.config)
89 |
90 | # Prepare the model
91 | model = build_model(cfg, checkpoint=args.checkpoint)
92 |
93 | # Load demo images
94 | im0 = read_color_image(args.im_path_ref, args.resize).to(device)
95 | im1 = read_color_image(args.im_path_dst, args.resize).to(device)
96 |
97 | # Load intrinsics
98 | K = read_intrinsics(args.intrinsics, args.resize)
99 |
100 | # Prepare data for MicKey
101 | batch_id = 0
102 | im0_name = args.im_path_ref.split('/')[-1]
103 | im1_name = args.im_path_dst.split('/')[-1]
104 | data = {}
105 | data['image0'] = im0
106 | data['image1'] = im1
107 | data['K_color0'] = torch.from_numpy(K[im0_name]).unsqueeze(0).to(device)
108 | data['K_color1'] = torch.from_numpy(K[im1_name]).unsqueeze(0).to(device)
109 |
110 | # Run inference
111 | print('Running MicKey relative pose estimation...')
112 | model(data, return_inliers=args.generate_3D_vis)
113 |
114 | # Pose, inliers and score are stored in:
115 | # data['R'] = R
116 | # data['t'] = t
117 | # data['inliers'] = inliers
118 |
119 | print('Saving depth and score maps in image directory ...')
120 | depth0_map = colorize_depth(data['depth0_map'][batch_id], invalid_mask=(data['depth0_map'][batch_id] < 0.001).cpu()[0])
121 | depth1_map = colorize_depth(data['depth1_map'][batch_id], invalid_mask=(data['depth1_map'][batch_id] < 0.001).cpu()[0])
122 | score0_map = prepare_score_map(data['scr0'][batch_id], data['image0'][batch_id], temperature=0.5)
123 | score1_map = prepare_score_map(data['scr1'][batch_id], data['image1'][batch_id], temperature=0.5)
124 |
125 | ext_im0 = args.im_path_ref.split('.')[-1]
126 | ext_im1 = args.im_path_dst.split('.')[-1]
127 |
128 | cv2.imwrite(args.im_path_ref.replace(ext_im0, 'score.jpg'), score0_map)
129 | cv2.imwrite(args.im_path_dst.replace(ext_im1, 'score.jpg'), score1_map)
130 |
131 | cv2.imwrite(args.im_path_ref.replace(ext_im0, 'depth.jpg'), depth0_map)
132 | cv2.imwrite(args.im_path_dst.replace(ext_im1, 'depth.jpg'), depth1_map)
133 |
134 | if args.generate_3D_vis:
135 | # We use the maximum possible number of inliers to draw the confidence
136 | n_matches = model.e2e_Procrustes.num_samples_matches
137 | dir_name = '/'.join(args.im_path_ref.split('/')[:-1])
138 | generate_3d_vis(data, n_matches, dir_name, batch_id)
139 |
140 | if __name__ == '__main__':
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument('--im_path_ref', help='path to reference image', default='data/toy_example/im0.jpg')
143 | parser.add_argument('--im_path_dst', help='path to destination image', default='data/toy_example/im1.jpg')
144 | parser.add_argument('--intrinsics', help='path to intrinsics file', default='data/toy_example/intrinsics.txt')
145 | parser.add_argument('--resize', nargs=2, type=int, help='resize applied to the image and intrinsics (w, h)', default=None)
146 | parser.add_argument('--config', help='path to config file', default='weights/mickey_weights/config.yaml')
147 | parser.add_argument('--checkpoint', help='path to model checkpoint',
148 | default='weights/mickey_weights/mickey.ckpt')
149 | parser.add_argument('--generate_3D_vis', help='Set to True to generate a 3D visualisation of the output poses',
150 | default=False)
151 | args = parser.parse_args()
152 |
153 | run_demo_inference(args)
154 |
155 |
--------------------------------------------------------------------------------
/lib/benchmarks/reprojection.py:
--------------------------------------------------------------------------------
1 | # Code from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 |
3 | from typing import List, Tuple
4 | import numpy as np
5 | from transforms3d.quaternions import quat2mat
6 |
7 | def project(pts: np.ndarray, K: np.ndarray, img_size: List[int] or Tuple[int] = None) -> np.ndarray:
8 | """Projects 3D points to image plane.
9 |
10 | Args:
11 | - pts [N, 3/4]: points in camera coordinates (homogeneous or non-homogeneous)
12 | - K [3, 3]: intrinsic matrix
13 | - img_size (width, height): optional, clamp projection to image borders
14 | Outputs:
15 | - uv [N, 2]: coordinates of projected points
16 | """
17 |
18 | assert len(pts.shape) == 2, 'incorrect number of dimensions'
19 | assert pts.shape[1] in [3, 4], 'invalid dimension size'
20 | assert K.shape == (3, 3), 'incorrect intrinsic shape'
21 |
22 | uv_h = (K @ pts[:, :3].T).T
23 | uv = uv_h[:, :2] / uv_h[:, -1:]
24 |
25 | if img_size is not None:
26 | uv[:, 0] = np.clip(uv[:, 0], 0, img_size[0])
27 | uv[:, 1] = np.clip(uv[:, 1], 0, img_size[1])
28 |
29 | return uv
30 |
31 |
32 | def get_grid_multipleheight() -> np.ndarray:
33 | # create grid of points
34 | ar_grid_step = 0.3
35 | ar_grid_num_x = 7
36 | ar_grid_num_y = 4
37 | ar_grid_num_z = 7
38 | ar_grid_z_offset = 1.8
39 | ar_grid_y_offset = 0
40 |
41 | ar_grid_x_pos = np.arange(0, ar_grid_num_x)-(ar_grid_num_x-1)/2
42 | ar_grid_x_pos *= ar_grid_step
43 |
44 | ar_grid_y_pos = np.arange(0, ar_grid_num_y)-(ar_grid_num_y-1)/2
45 | ar_grid_y_pos *= ar_grid_step
46 | ar_grid_y_pos += ar_grid_y_offset
47 |
48 | ar_grid_z_pos = np.arange(0, ar_grid_num_z).astype(float)
49 | ar_grid_z_pos *= ar_grid_step
50 | ar_grid_z_pos += ar_grid_z_offset
51 |
52 | xx, yy, zz = np.meshgrid(ar_grid_x_pos, ar_grid_y_pos, ar_grid_z_pos)
53 | ones = np.ones(xx.shape[0]*xx.shape[1]*xx.shape[2])
54 | eye_coords = np.concatenate([c.reshape(-1, 1)
55 | for c in (xx, yy, zz, ones)], axis=-1)
56 | return eye_coords
57 |
58 |
59 | # global variable, avoids creating it again
60 | eye_coords_glob = get_grid_multipleheight()
61 |
62 |
63 | def reprojection_error(
64 | q_est: np.ndarray, t_est: np.ndarray, q_gt: np.ndarray, t_gt: np.ndarray, K: np.ndarray,
65 | W: int, H: int) -> float:
66 | eye_coords = eye_coords_glob
67 |
68 | # obtain ground-truth position of projected points
69 | uv_gt = project(eye_coords, K, (W, H))
70 |
71 | # residual transformation
72 | cam2w_est = np.eye(4)
73 | cam2w_est[:3, :3] = quat2mat(q_est)
74 | cam2w_est[:3, -1] = t_est
75 | cam2w_gt = np.eye(4)
76 | cam2w_gt[:3, :3] = quat2mat(q_gt)
77 | cam2w_gt[:3, -1] = t_gt
78 |
79 | # residual reprojection
80 | eyes_residual = (np.linalg.inv(cam2w_est) @ cam2w_gt @ eye_coords.T).T
81 | uv_pred = project(eyes_residual, K, (W, H))
82 |
83 | # get reprojection error
84 | repr_err = np.linalg.norm(uv_gt - uv_pred, ord=2, axis=1)
85 | mean_repr_err = float(repr_err.mean().item())
86 | return mean_repr_err
87 |
--------------------------------------------------------------------------------
/lib/benchmarks/utils.py:
--------------------------------------------------------------------------------
1 | # Code from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 |
3 | from pathlib import Path
4 | import typing
5 | import logging
6 |
7 | import numpy as np
8 | from transforms3d.quaternions import qinverse, rotate_vector, qmult
9 |
10 | VARIANTS_ANGLE_SIN = 'sin'
11 | VARIANTS_ANGLE_COS = 'cos'
12 |
13 |
14 | def convert_world2cam_to_cam2world(q, t):
15 | qinv = qinverse(q)
16 | tinv = -rotate_vector(t, qinv)
17 | return qinv, tinv
18 |
19 |
20 | def load_poses(file: typing.IO, load_confidence: bool = False):
21 | """Load poses from text file and converts them to cam2world convention (t is the camera center in world coordinates)
22 |
23 | The text file encodes world2cam poses with the format:
24 | imgpath qw qx qy qz tx ty tz [confidence]
25 | where qw qx qy qz is the quaternion encoding rotation,
26 | and tx ty tz is the translation vector,
27 | and confidence is a float encoding confidence, for estimated poses
28 | """
29 |
30 | expected_parts = 9 if load_confidence else 8
31 |
32 | poses = dict()
33 | for line_number, line in enumerate(file.readlines()):
34 | parts = tuple(line.strip().split(' '))
35 |
36 | # if 'tensor' in parts[-1]:
37 | # print('ERROR: confidence is a tensor')
38 | # parts = list(parts)
39 | # parts[-1] = parts[-1].split('[')[-1].split(']')[0]
40 | if len(parts) != expected_parts:
41 | logging.warning(
42 | f'Invalid number of fields in file {file.name} line {line_number}.'
43 | f' Expected {expected_parts}, received {len(parts)}. Ignoring line.')
44 | continue
45 |
46 | try:
47 | name = parts[0]
48 | if '#' in name:
49 | logging.info(f'Ignoring comment line in {file.name} line {line_number}')
50 | continue
51 | frame_num = int(name[-9:-4])
52 | except ValueError:
53 | logging.warning(
54 | f'Invalid frame number in file {file.name} line {line_number}.'
55 | f' Expected formatting "seq1/frame_00000.jpg". Ignoring line.')
56 | continue
57 |
58 | try:
59 | parts_float = tuple(map(float, parts[1:]))
60 | if any(np.isnan(v) or np.isinf(v) for v in parts_float):
61 | raise ValueError()
62 | qw, qx, qy, qz, tx, ty, tz = parts_float[:7]
63 | confidence = parts_float[7] if load_confidence else None
64 | except ValueError:
65 | logging.warning(
66 | f'Error parsing pose in file {file.name} line {line_number}. Ignoring line.')
67 | continue
68 |
69 | q = np.array((qw, qx, qy, qz), dtype=np.float64)
70 | t = np.array((tx, ty, tz), dtype=np.float64)
71 |
72 | if np.isclose(np.linalg.norm(q), 0):
73 | logging.warning(
74 | f'Error parsing pose in file {file.name} line {line_number}. '
75 | 'Quaternion must have non-zero norm. Ignoring line.')
76 | continue
77 |
78 | q, t = convert_world2cam_to_cam2world(q, t)
79 | poses[frame_num] = (q, t, confidence)
80 | return poses
81 |
82 |
83 | def subsample_poses(poses: dict, subsample: int = 1):
84 | return {k: v for i, (k, v) in enumerate(poses.items()) if i % subsample == 0}
85 |
86 |
87 | def load_K(file_path: Path):
88 | K = dict()
89 | with file_path.open('r', encoding='utf-8') as f:
90 | for line in f.readlines():
91 | if '#' in line:
92 | continue
93 | line = line.strip().split(' ')
94 |
95 | frame_num = int(line[0][-9:-4])
96 | fx, fy, cx, cy, W, H = map(float, line[1:])
97 | K[frame_num] = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
98 | return K, W, H
99 |
100 |
101 | def quat_angle_error(label, pred, variant=VARIANTS_ANGLE_SIN) -> np.ndarray:
102 | assert label.shape == (4,)
103 | assert pred.shape == (4,)
104 | assert variant in (VARIANTS_ANGLE_SIN, VARIANTS_ANGLE_COS), \
105 | f"Need variant to be in ({VARIANTS_ANGLE_SIN}, {VARIANTS_ANGLE_COS})"
106 |
107 | if len(label.shape) == 1:
108 | label = np.expand_dims(label, axis=0)
109 | if len(label.shape) != 2 or label.shape[0] != 1 or label.shape[1] != 4:
110 | raise RuntimeError(f"Unexpected shape of label: {label.shape}, expected: (1, 4)")
111 |
112 | if len(pred.shape) == 1:
113 | pred = np.expand_dims(pred, axis=0)
114 | if len(pred.shape) != 2 or pred.shape[0] != 1 or pred.shape[1] != 4:
115 | raise RuntimeError(f"Unexpected shape of pred: {pred.shape}, expected: (1, 4)")
116 |
117 | label = label.astype(np.float64)
118 | pred = pred.astype(np.float64)
119 |
120 | q1 = pred / np.linalg.norm(pred, axis=1, keepdims=True)
121 | q2 = label / np.linalg.norm(label, axis=1, keepdims=True)
122 | if variant == VARIANTS_ANGLE_COS:
123 | d = np.abs(np.sum(np.multiply(q1, q2), axis=1, keepdims=True))
124 | d = np.clip(d, a_min=-1, a_max=1)
125 | angle = 2. * np.degrees(np.arccos(d))
126 | elif variant == VARIANTS_ANGLE_SIN:
127 | if q1.shape[0] != 1 or q2.shape[0] != 1:
128 | raise NotImplementedError(f"Multiple angles is todo")
129 | # https://www.researchgate.net/post/How_do_I_calculate_the_smallest_angle_between_two_quaternions/5d6ed4a84f3a3e1ed3656616/citation/download
130 | sine = qmult(q1[0], qinverse(q2[0])) # note: takes first element in 2D array
131 | # 114.59 = 2. * 180. / pi
132 | angle = np.arcsin(np.linalg.norm(sine[1:], keepdims=True)) * 114.59155902616465
133 | angle = np.expand_dims(angle, axis=0)
134 |
135 | return angle.astype(np.float64)
136 |
137 |
138 | def precision_recall(inliers, tp, failures):
139 | """
140 | Computes Precision/Recall plot for a set of poses given inliers (confidence) and wether the
141 | estimated pose error (whatever it may be) is within a threshold.
142 | Each point in the plot is obtained by choosing a threshold for inliers (i.e. inlier_thr).
143 | Recall measures how many images have inliers >= inlier_thr
144 | Precision measures how many images that have inliers >= inlier_thr have
145 | estimated pose error <= pose_threshold (measured by counting tps)
146 | Where pose_threshold is (trans_thr[m], rot_thr[deg])
147 |
148 | Inputs:
149 | - inliers [N]
150 | - terr [N]
151 | - rerr [N]
152 | - failures (int)
153 | - pose_threshold (tuple float)
154 | Output
155 | - precision [N]
156 | - recall [N]
157 | - average_precision (scalar)
158 | """
159 |
160 | assert len(inliers) == len(tp), 'unequal shapes'
161 |
162 | # sort by inliers (descending order)
163 | inliers = np.array(inliers)
164 | sort_idx = np.argsort(inliers)[::-1]
165 | inliers = inliers[sort_idx]
166 | tp = np.array(tp).reshape(-1)[sort_idx]
167 |
168 | # get idxs where inliers change (avoid tied up values)
169 | distinct_value_indices = np.where(np.diff(inliers))[0]
170 | threshold_idxs = np.r_[distinct_value_indices, inliers.size - 1]
171 |
172 | # compute prec/recall
173 | N = inliers.shape[0]
174 | rec = np.arange(N, dtype=np.float32) + 1
175 | cum_tp = np.cumsum(tp)
176 | prec = cum_tp[threshold_idxs] / rec[threshold_idxs]
177 | rec = rec[threshold_idxs] / (float(N) + float(failures))
178 |
179 | # invert order and ensures (prec=1, rec=0) point
180 | last_ind = rec.searchsorted(rec[-1])
181 | sl = slice(last_ind, None, -1)
182 | prec = np.r_[prec[sl], 1]
183 | rec = np.r_[rec[sl], 0]
184 |
185 | # compute average precision (AUC) as the weighted average of precisions
186 | average_precision = np.abs(np.sum(np.diff(rec) * np.array(prec)[:-1]))
187 |
188 | return prec, rec, average_precision
189 |
--------------------------------------------------------------------------------
/lib/datasets/datamodules.py:
--------------------------------------------------------------------------------
1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 |
3 | import torch.utils as utils
4 | from torchvision.transforms import ColorJitter, Grayscale
5 | import pytorch_lightning as pl
6 |
7 | from torch.utils.data import DataLoader
8 | from lib.datasets.sampler import RandomConcatSampler
9 | from lib.datasets.mapfree import MapFreeDataset
10 |
11 |
12 | class DataModule(pl.LightningDataModule):
13 | def __init__(self, cfg, drop_last_val=True):
14 | super().__init__()
15 | self.cfg = cfg
16 | self.drop_last_val = drop_last_val
17 |
18 | datasets = {'MapFree': MapFreeDataset}
19 |
20 | assert cfg.DATASET.DATA_SOURCE in datasets.keys(), 'invalid DATA_SOURCE, this dataset is not implemented'
21 | self.dataset_type = datasets[cfg.DATASET.DATA_SOURCE]
22 |
23 | def get_sampler(self, dataset, reset_epoch=False):
24 | if self.cfg.TRAINING.SAMPLER == 'scene_balance':
25 | sampler = RandomConcatSampler(dataset,
26 | self.cfg.TRAINING.N_SAMPLES_SCENE,
27 | self.cfg.TRAINING.SAMPLE_WITH_REPLACEMENT,
28 | shuffle=True,
29 | reset_on_iter=reset_epoch
30 | )
31 | else:
32 | sampler = None
33 | return sampler
34 |
35 | def train_dataloader(self):
36 | transforms = ColorJitter() if self.cfg.DATASET.AUGMENTATION_TYPE == 'colorjitter' else None
37 | transforms = Grayscale(
38 | num_output_channels=3) if self.cfg.DATASET.BLACK_WHITE else transforms
39 |
40 | dataset = self.dataset_type(self.cfg, 'train', transforms=transforms)
41 | sampler = self.get_sampler(dataset)
42 |
43 | dataloader = utils.data.DataLoader(dataset,
44 | batch_size=self.cfg.TRAINING.BATCH_SIZE,
45 | num_workers=self.cfg.TRAINING.NUM_WORKERS,
46 | sampler=sampler
47 | )
48 | return dataloader
49 |
50 | def val_dataloader(self):
51 | dataset = self.dataset_type(self.cfg, 'val')
52 | dataloader = utils.data.DataLoader(dataset,
53 | batch_size=self.cfg.TRAINING.BATCH_SIZE,
54 | num_workers=self.cfg.TRAINING.NUM_WORKERS,
55 | sampler=None,
56 | drop_last=self.drop_last_val
57 | )
58 | return dataloader
59 |
60 | def test_dataloader(self):
61 | dataset = self.dataset_type(self.cfg, 'test')
62 | dataloader = utils.data.DataLoader(dataset,
63 | batch_size=self.cfg.TRAINING.BATCH_SIZE,
64 | num_workers=self.cfg.TRAINING.NUM_WORKERS,
65 | shuffle=False,
66 | drop_last=self.drop_last_val)
67 | return dataloader
68 |
69 |
70 | class DataModuleTraining(pl.LightningDataModule):
71 | def __init__(self, cfg):
72 | super().__init__()
73 | self.cfg = cfg
74 | self.seed = cfg.DATASET.SEED
75 |
76 | datasets = {'MapFree': MapFreeDataset}
77 |
78 | assert cfg.DATASET.DATA_SOURCE in datasets.keys(), 'invalid DATA_SOURCE, this dataset is not implemented'
79 | self.dataset_type = datasets[cfg.DATASET.DATA_SOURCE]
80 |
81 | def get_sampler(self, dataset, reset_epoch=False, seed=66):
82 | if self.cfg.TRAINING.SAMPLER == 'scene_balance':
83 | sampler = RandomConcatSampler(dataset,
84 | self.cfg.TRAINING.N_SAMPLES_SCENE,
85 | self.cfg.TRAINING.SAMPLE_WITH_REPLACEMENT,
86 | shuffle=True,
87 | reset_on_iter=reset_epoch,
88 | seed=seed)
89 | else:
90 | sampler = None
91 | return sampler
92 |
93 | def train_dataloader(self):
94 | transforms = ColorJitter() if self.cfg.DATASET.AUGMENTATION_TYPE == 'colorjitter' else None
95 | transforms = Grayscale(
96 | num_output_channels=3) if self.cfg.DATASET.BLACK_WHITE else transforms
97 |
98 | dataset = self.dataset_type(self.cfg, 'train', transforms=transforms)
99 | sampler = self.get_sampler(dataset, seed=self.seed)
100 | dataloader = DataLoader(dataset,
101 | batch_size=self.cfg.TRAINING.BATCH_SIZE,
102 | num_workers=self.cfg.TRAINING.NUM_WORKERS,
103 | sampler=sampler)
104 | return dataloader
105 |
106 | def val_dataloader(self):
107 | dataset = self.dataset_type(self.cfg, 'val')
108 | sampler = self.get_sampler(dataset, reset_epoch=True)
109 | # sampler = None
110 | dataloader = DataLoader(dataset,
111 | batch_size=self.cfg.TRAINING.BATCH_SIZE,
112 | num_workers=self.cfg.TRAINING.NUM_WORKERS,
113 | sampler=sampler,
114 | drop_last=True)
115 | return dataloader
116 |
117 | def test_dataloader(self):
118 | dataset = self.dataset_type(self.cfg, 'test')
119 | dataloader = DataLoader(dataset,
120 | batch_size=self.cfg.TRAINING.BATCH_SIZE,
121 | num_workers=self.cfg.TRAINING.NUM_WORKERS,
122 | shuffle=False)
123 | return dataloader
--------------------------------------------------------------------------------
/lib/datasets/mapfree.py:
--------------------------------------------------------------------------------
1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 |
3 | from pathlib import Path
4 | import torch
5 | import torch.utils.data as data
6 | import numpy as np
7 | from transforms3d.quaternions import qinverse, qmult, rotate_vector, quat2mat
8 | from lib.datasets.utils import read_color_image, read_depth_image, correct_intrinsic_scale
9 |
10 | class MapFreeScene(data.Dataset):
11 | def __init__(
12 | self, scene_root, resize, sample_factor=1, overlap_limits=None, transforms=None,
13 | test_scene=False):
14 | super().__init__()
15 |
16 | self.scene_root = Path(scene_root)
17 | self.resize = resize
18 | self.sample_factor = sample_factor
19 | self.transforms = transforms
20 | self.test_scene = test_scene
21 |
22 | # load absolute poses
23 | self.poses = self.read_poses(self.scene_root)
24 |
25 | # read intrinsics
26 | self.K, self.K_ori = self.read_intrinsics(self.scene_root, resize)
27 |
28 | # load pairs
29 | self.pairs = self.load_pairs(self.scene_root, overlap_limits, self.sample_factor)
30 |
31 | @staticmethod
32 | def read_intrinsics(scene_root: Path, resize=None):
33 | Ks = {}
34 | K_ori = {}
35 | with (scene_root / 'intrinsics.txt').open('r') as f:
36 | for line in f.readlines():
37 | if '#' in line:
38 | continue
39 |
40 | line = line.strip().split(' ')
41 | img_name = line[0]
42 | fx, fy, cx, cy, W, H = map(float, line[1:])
43 |
44 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
45 | K_ori[img_name] = K
46 | if resize is not None:
47 | K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H)
48 | Ks[img_name] = K
49 | return Ks, K_ori
50 |
51 | @staticmethod
52 | def read_poses(scene_root: Path):
53 | """
54 | Returns a dictionary that maps: img_path -> (q, t) where
55 | np.array q = (qw, qx qy qz) quaternion encoding rotation matrix;
56 | np.array t = (tx ty tz) translation vector;
57 | (q, t) encodes absolute pose (world-to-camera), i.e. X_c = R(q) X_W + t
58 | """
59 | poses = {}
60 | with (scene_root / 'poses.txt').open('r') as f:
61 | for line in f.readlines():
62 | if '#' in line:
63 | continue
64 |
65 | line = line.strip().split(' ')
66 | img_name = line[0]
67 | qt = np.array(list(map(float, line[1:])))
68 | poses[img_name] = (qt[:4], qt[4:])
69 | return poses
70 |
71 | def load_pairs(self, scene_root: Path, overlap_limits: tuple = None, sample_factor: int = 1):
72 | """
73 | For training scenes, filter pairs of frames based on overlap (pre-computed in overlaps.npz)
74 | For test/val scenes, pairs are formed between keyframe and every other sample_factor query frames.
75 | If sample_factor == 1, all query frames are used. Note: sample_factor applicable only to test/val
76 | Returns:
77 | pairs: nd.array [Npairs, 4], where each column represents seaA, imA, seqB, imB, respectively
78 | """
79 | pairs = self.load_pairs_overlap(scene_root, overlap_limits, sample_factor)
80 |
81 | return pairs
82 |
83 | def load_pairs_overlap(self, scene_root: Path, overlap_limits: tuple = None, sample_factor: int = 1):
84 | overlaps_path = scene_root / 'overlaps.npz'
85 |
86 | if overlaps_path.exists():
87 | f = np.load(overlaps_path, allow_pickle=True)
88 | idxs, overlaps = f['idxs'], f['overlaps']
89 | if overlap_limits is not None:
90 | min_overlap, max_overlap = overlap_limits
91 | mask = (overlaps > min_overlap) * (overlaps < max_overlap)
92 | idxs = idxs[mask]
93 | return idxs.copy()
94 | else:
95 | idxs = np.zeros((len(self.poses) - 1, 4), dtype=np.uint16)
96 | idxs[:, 2] = 1
97 | idxs[:, 3] = np.array([int(fn[-9:-4])
98 | for fn in self.poses.keys() if 'seq0' not in fn], dtype=np.uint16)
99 | return idxs[::sample_factor]
100 |
101 | def get_pair_path(self, pair):
102 | seqA, imgA, seqB, imgB = pair
103 | return (f'seq{seqA}/frame_{imgA:05}.jpg', f'seq{seqB}/frame_{imgB:05}.jpg')
104 |
105 | def __len__(self):
106 | return len(self.pairs)
107 |
108 | def __getitem__(self, index):
109 | # image paths (relative to scene_root)
110 | im1_path, im2_path = self.get_pair_path(self.pairs[index])
111 |
112 | # load color images
113 | image1 = read_color_image(self.scene_root / im1_path,
114 | self.resize, augment_fn=self.transforms)
115 | image2 = read_color_image(self.scene_root / im2_path,
116 | self.resize, augment_fn=self.transforms)
117 |
118 | # get absolute pose of im0 and im1
119 | if self.test_scene:
120 | t1, t2, c1, c2 = np.zeros([3]), np.zeros([3]), np.zeros([3]), np.zeros([3])
121 | q1, q2 = np.zeros([4]), np.zeros([4])
122 | T = np.zeros([4, 4])
123 | else:
124 | # quaternion and translation vector that transforms World-to-Cam
125 | q1, t1 = self.poses[im1_path]
126 | # quaternion and translation vector that transforms World-to-Cam
127 | q2, t2 = self.poses[im2_path]
128 | c1 = rotate_vector(-t1, qinverse(q1)) # center of camera 1 in world coordinates)
129 | c2 = rotate_vector(-t2, qinverse(q2)) # center of camera 2 in world coordinates)
130 |
131 | # get 4 x 4 relative pose transformation matrix (from im1 to im2)
132 | # for val set, q1,t1 is the identity pose, so the relative pose matches the absolute pose
133 | q12 = qmult(q2, qinverse(q1))
134 | t12 = t2 - rotate_vector(t1, q12)
135 | T = np.eye(4, dtype=np.float32)
136 | T[:3, :3] = quat2mat(q12)
137 | T[:3, -1] = t12
138 |
139 | T = torch.from_numpy(T)
140 |
141 | data = {
142 | 'image0': image1, # (3, h, w)
143 | 'image1': image2,
144 | 'T_0to1': T, # (4, 4) # relative pose
145 | 'abs_q_0': q1,
146 | 'abs_c_0': c1,
147 | 'abs_q_1': q2,
148 | 'abs_c_1': c2,
149 | 'K_color0': self.K[im1_path], # (3, 3)
150 | 'Kori_color0': self.K_ori[im1_path], # (3, 3)
151 | 'K_color1': self.K[im2_path], # (3, 3)
152 | 'Kori_color1': self.K_ori[im2_path], # (3, 3)
153 | 'dataset_name': 'Mapfree',
154 | 'scene_id': self.scene_root.stem,
155 | 'scene_root': str(self.scene_root),
156 | 'pair_id': index*self.sample_factor,
157 | 'pair_names': (im1_path, im2_path),
158 | }
159 |
160 | return data
161 |
162 |
163 | class MapFreeDataset(data.ConcatDataset):
164 | def __init__(self, cfg, mode, transforms=None):
165 | assert mode in ['train', 'val', 'test'], 'Invalid dataset mode'
166 |
167 | data_root = Path(cfg.DATASET.DATA_ROOT) / mode
168 | resize = (cfg.DATASET.WIDTH, cfg.DATASET.HEIGHT)
169 |
170 | if mode=='test':
171 | test_scene = True
172 | else:
173 | test_scene = False
174 |
175 | overlap_limits = (cfg.DATASET.MIN_OVERLAP_SCORE, cfg.DATASET.MAX_OVERLAP_SCORE)
176 | sample_factor = {'train': 1, 'val': 5, 'test': 5}[mode]
177 |
178 | scenes = cfg.DATASET.SCENES
179 | if scenes is None:
180 | # Locate all scenes of the current dataset
181 | scenes = [s.name for s in data_root.iterdir() if s.is_dir()]
182 |
183 | if cfg.DEBUG:
184 | if mode=='train':
185 | scenes = scenes[:30]
186 | elif mode=='val':
187 | scenes = scenes[:10]
188 |
189 | # Init dataset objects for each scene
190 | data_srcs = [
191 | MapFreeScene(
192 | data_root / scene, resize, sample_factor, overlap_limits, transforms,
193 | test_scene) for scene in scenes]
194 | super().__init__(data_srcs)
195 |
--------------------------------------------------------------------------------
/lib/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 | # From https://github.com/zju3dv/LoFTR/blob/261baf641cb9ada07dd9746e420ada7fe8a03152/src/datasets/sampler.py
3 | import torch
4 | from torch.utils.data import Sampler, ConcatDataset
5 |
6 | class RandomConcatSampler(Sampler):
7 | """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset
8 | in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement.
9 | However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase.
10 |
11 | For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not.
12 | Args:
13 | shuffle (bool): shuffle the random sampled indices across all sub-datsets.
14 | repeat (int): repeatedly use the sampled indices multiple times for training.
15 | [arXiv:1902.05509, arXiv:1901.09335]
16 | NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples)
17 | NOTE: This sampler behaves differently with DistributedSampler.
18 | It assume the dataset is splitted across ranks instead of replicated.
19 | TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs.
20 | ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373
21 | """
22 |
23 | def __init__(self,
24 | data_source: ConcatDataset,
25 | n_samples_per_subset: int,
26 | subset_replacement: bool = True,
27 | shuffle: bool = True,
28 | repeat: int = 1,
29 | seed: int = 66,
30 | reset_on_iter: bool = False):
31 | if not isinstance(data_source, ConcatDataset):
32 | raise TypeError("data_source should be torch.utils.data.ConcatDataset")
33 |
34 | self.data_source = data_source
35 | self.n_subset = len(self.data_source.datasets)
36 | self.n_samples_per_subset = n_samples_per_subset
37 | self.n_samples = self.n_subset * self.n_samples_per_subset * repeat
38 | self.subset_replacement = subset_replacement
39 | self.repeat = repeat
40 | self.shuffle = shuffle
41 | self.seed = seed
42 | self.reset_on_iter = reset_on_iter # If true, recreate random seed to that samples are the same every epoch
43 | self.generator = torch.manual_seed(self.seed)
44 | assert self.repeat >= 1
45 |
46 | def __len__(self):
47 | return self.n_samples
48 |
49 | def __iter__(self):
50 | if self.reset_on_iter:
51 | self.generator = torch.manual_seed(self.seed)
52 |
53 | indices = []
54 | # sample from each sub-dataset
55 | for d_idx in range(self.n_subset):
56 | low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1]
57 | high = self.data_source.cumulative_sizes[d_idx]
58 | if self.subset_replacement:
59 | rand_tensor = torch.randint(low, high, (self.n_samples_per_subset,),
60 | generator=self.generator, dtype=torch.int64)
61 | else: # sample without replacement
62 | len_subset = len(self.data_source.datasets[d_idx])
63 | rand_tensor = torch.randperm(len_subset, generator=self.generator) + low
64 | if len_subset >= self.n_samples_per_subset:
65 | rand_tensor = rand_tensor[:self.n_samples_per_subset]
66 | else: # padding with replacement
67 | rand_tensor_replacement = torch.randint(
68 | low, high, (self.n_samples_per_subset - len_subset,),
69 | generator=self.generator, dtype=torch.int64)
70 | rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement])
71 | indices.append(rand_tensor)
72 | indices = torch.cat(indices)
73 | if self.shuffle: # shuffle the sampled dataset (from multiple subsets)
74 | rand_tensor = torch.randperm(len(indices), generator=self.generator)
75 | indices = indices[rand_tensor]
76 |
77 | # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling)
78 | if self.repeat > 1:
79 | repeat_indices = [indices.clone() for _ in range(self.repeat - 1)]
80 | if self.shuffle:
81 | def _choice(x): return x[torch.randperm(len(x), generator=self.generator)]
82 | repeat_indices = map(_choice, repeat_indices)
83 | indices = torch.cat([indices, *repeat_indices], 0)
84 |
85 | assert indices.shape[0] == self.n_samples
86 | return iter(indices.tolist())
87 |
--------------------------------------------------------------------------------
/lib/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from numpy.linalg import inv
7 |
8 | def imread(path, augment_fn=None):
9 | cv_type = cv2.IMREAD_COLOR
10 | image = cv2.imread(str(path), cv_type)
11 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
12 |
13 | if augment_fn is not None:
14 | image = cv2.imread(str(path), cv2.IMREAD_COLOR)
15 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
16 | image = augment_fn(image)
17 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
18 | return image # (h, w, 3)
19 |
20 |
21 | def get_resized_wh(w, h, resize=None):
22 | if resize is not None: # resize the longer edge
23 | scale = resize / max(h, w)
24 | w_new, h_new = int(round(w * scale)), int(round(h * scale))
25 | else:
26 | w_new, h_new = w, h
27 | return w_new, h_new
28 |
29 |
30 | def get_divisible_wh(w, h, df=None):
31 | if df is not None:
32 | w_new, h_new = map(lambda x: int(x // df * df), [w, h])
33 | else:
34 | w_new, h_new = w, h
35 | return w_new, h_new
36 |
37 |
38 | def pad_bottom_right(inp, pad_size, df, ret_mask=False, border=3):
39 | assert isinstance(pad_size, int) and pad_size >= max(
40 | inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
41 | mask = None
42 | if inp.ndim == 2:
43 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
44 | padded[:inp.shape[0], :inp.shape[1]] = inp
45 | if ret_mask:
46 | mask = np.zeros((pad_size, pad_size), dtype=bool)
47 | mask[:inp.shape[0], :inp.shape[1]] = True
48 | elif inp.ndim == 3:
49 | padded = np.zeros((pad_size, pad_size, inp.shape[2]), dtype=inp.dtype)
50 | padded[:inp.shape[0], :inp.shape[1], :] = inp
51 | if ret_mask:
52 |
53 | mask = np.zeros((1, pad_size//df, pad_size//df))
54 | mask[:, :inp.shape[0]//df-border, :inp.shape[1]//df-border] = 1
55 |
56 | else:
57 | raise NotImplementedError()
58 | return padded, mask
59 |
60 |
61 | def read_color_image(path, resize=(640, 480), augment_fn=None):
62 | """
63 | Args:
64 | resize (tuple): align image to depthmap, in (w, h).
65 | augment_fn (callable, optional): augments images with pre-defined visual effects
66 | Returns:
67 | image (torch.tensor): (3, h, w)
68 | """
69 | # read and resize image
70 | image = imread(path, None)
71 | image = cv2.resize(image, resize)
72 |
73 | # (h, w, 3) -> (3, h, w) and normalized
74 | image = torch.from_numpy(image).float().permute(2, 0, 1) / 255
75 | if augment_fn:
76 | image = augment_fn(image)
77 | return image
78 |
79 |
80 | def read_depth_image(path):
81 | depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
82 | depth = depth / 1000
83 | depth = torch.from_numpy(depth).float() # (h, w)
84 | return depth
85 |
86 | def correct_intrinsic_scale(K, scale_x, scale_y):
87 | '''Given an intrinsic matrix (3x3) and two scale factors, returns the new intrinsic matrix corresponding to
88 | the new coordinates x' = scale_x * x; y' = scale_y * y
89 | Source: https://dsp.stackexchange.com/questions/6055/how-does-resizing-an-image-affect-the-intrinsic-camera-matrix
90 | '''
91 |
92 | transform = torch.eye(3)
93 | transform[0, 0] = scale_x
94 | transform[0, 2] = scale_x / 2 - 0.5
95 | transform[1, 1] = scale_y
96 | transform[1, 2] = scale_y / 2 - 0.5
97 | Kprime = transform @ K
98 |
99 | return Kprime
100 |
101 | def define_sampling_grid(im_size, feats_downsample=4, step=1):
102 | """
103 | Auxiliary function to generate the sampling grid from the feature map
104 | Args:
105 | im_size: original image size that goes into the network
106 | feats_downsample: rescaling factor that happens within the architecture due to downsampling steps
107 | Output:
108 | indexes_mat: dense grid sampling indexes, size: (im_size/feats_downsample, im_size/feats_downsample)
109 | """
110 |
111 | feats_size = int(im_size/feats_downsample)
112 | grid_size = int(im_size/feats_downsample/step)
113 |
114 | indexes = np.asarray(range(0, feats_size, step))[:grid_size]
115 | indexes_x = indexes.reshape((1, len(indexes), 1))
116 | indexes_y = indexes.reshape((len(indexes), 1, 1))
117 |
118 | indexes_x = np.tile(indexes_x, [len(indexes), 1, 1])
119 | indexes_y = np.tile(indexes_y, [1, len(indexes), 1])
120 |
121 | indexes_mat = np.concatenate([indexes_x, indexes_y], axis=-1)
122 | indexes_mat = indexes_mat.reshape((grid_size*grid_size, 2))
123 |
124 | return indexes_mat
125 |
126 |
127 |
--------------------------------------------------------------------------------
/lib/models/MicKey/compute_pose.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 |
3 | from lib.models.MicKey.modules.compute_correspondences import ComputeCorrespondences
4 | from lib.models.MicKey.modules.utils.probabilisticProcrustes import e2eProbabilisticProcrustesSolver
5 |
6 | class MickeyRelativePose(pl.LightningModule):
7 | # Compute the metric relative pose between two input images, with given intrinsics (for the pose solver).
8 |
9 | def __init__(self, cfg):
10 | super().__init__()
11 |
12 | # Define MicKey architecture and matching module:
13 | self.compute_matches = ComputeCorrespondences(cfg)
14 |
15 | # Metric solver
16 | self.e2e_Procrustes = e2eProbabilisticProcrustesSolver(cfg)
17 |
18 | self.is_eval_model(True)
19 |
20 | def forward(self, data, return_inliers=False):
21 |
22 | self.compute_matches(data)
23 | data['final_scores'] = data['scores'] * data['kp_scores']
24 |
25 | if return_inliers:
26 | # Returns inliers list:
27 | R, t, inliers, inliers_list = self.e2e_Procrustes.estimate_pose_vectorized(data, return_inliers=True)
28 | data['inliers_list'] = inliers_list
29 | else:
30 | # If the inlier list is not needed:
31 | R, t, inliers = self.e2e_Procrustes.estimate_pose_vectorized(data, return_inliers=False)
32 |
33 | data['R'] = R
34 | data['t'] = t
35 | data['inliers'] = inliers
36 |
37 | return R, t
38 |
39 | def on_load_checkpoint(self, checkpoint):
40 | # This function avoids loading DINOv2 which are not sotred in Mickey's checkpoint.
41 | # This saves memory during training, since DINOv2 is frozen and not updated there is no need to store
42 | # the weights in every checkpoint.
43 |
44 | # Recover DINOv2 features from pretrained weights.
45 | for param_tensor in self.compute_matches.state_dict():
46 | if 'dinov2'in param_tensor:
47 | checkpoint['state_dict']['compute_matches.'+param_tensor] = \
48 | self.compute_matches.state_dict()[param_tensor]
49 |
50 | def is_eval_model(self, is_eval):
51 | if is_eval:
52 | self.compute_matches.extractor.depth_head.eval()
53 | self.compute_matches.extractor.det_offset.eval()
54 | self.compute_matches.extractor.dsc_head.eval()
55 | self.compute_matches.extractor.det_head.eval()
56 | else:
57 | self.compute_matches.extractor.depth_head.train()
58 | self.compute_matches.extractor.det_offset.train()
59 | self.compute_matches.extractor.dsc_head.train()
60 | self.compute_matches.extractor.det_head.train()
61 |
--------------------------------------------------------------------------------
/lib/models/MicKey/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 |
4 | from lib.models.MicKey.modules.loss.loss_class import MetricPoseLoss
5 | from lib.models.MicKey.modules.compute_correspondences import ComputeCorrespondences
6 | from lib.models.MicKey.modules.utils.training_utils import log_image_matches, debug_reward_matches_log, vis_inliers
7 | from lib.models.MicKey.modules.utils.probabilisticProcrustes import e2eProbabilisticProcrustesSolver
8 |
9 | from lib.utils.metrics import pose_error_torch, vcre_torch
10 | from lib.benchmarks.utils import precision_recall
11 |
12 | class MicKeyTrainingModel(pl.LightningModule):
13 | def __init__(self, cfg):
14 | super().__init__()
15 |
16 | # Store MicKey's configuration
17 | self.cfg = cfg
18 |
19 | # Define MicKey architecture and matching module:
20 | self.compute_matches = ComputeCorrespondences(cfg)
21 | self.is_eval_model(False)
22 |
23 | # Loss function class
24 | self.loss_fn = MetricPoseLoss(cfg)
25 |
26 | # Metric solvers
27 | self.e2e_Procrustes = e2eProbabilisticProcrustesSolver(cfg)
28 |
29 | # Logger parameters
30 | self.counter_batch = 0
31 | self.log_store_ims = True
32 | self.log_max_ims = 5
33 | self.log_im_counter_train = 0
34 | self.log_im_counter_val = 0
35 | self.log_interval = cfg.TRAINING.LOG_INTERVAL
36 |
37 | # Define curriculum learning parameters:
38 | self.curriculum_learning = cfg.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_CURRICULUM
39 | self.topK = cfg.LOSS_CLASS.CURRICULUM_LEARNING.TOPK_INIT
40 | self.topK_max = cfg.LOSS_CLASS.CURRICULUM_LEARNING.TOPK
41 |
42 | # Lightning configurations
43 | self.automatic_optimization = False # This property activates manual optimization.
44 | self.multi_gpu = True
45 | self.validation_step_outputs = []
46 | # torch.autograd.set_detect_anomaly(True)
47 |
48 | def forward(self, data):
49 | self.compute_matches(data)
50 |
51 | def training_step(self, batch, batch_idx):
52 |
53 | self(batch)
54 | self.prepare_batch_for_loss(batch, batch_idx)
55 |
56 | avg_loss, outputs, probs_grad, num_its = self.loss_fn(batch)
57 |
58 | training_step_ok = self.backward_step(batch, outputs, probs_grad, avg_loss, num_its)
59 | self.tensorboard_log_step(batch, avg_loss, outputs, probs_grad, training_step_ok)
60 |
61 | def on_train_epoch_end(self):
62 | if self.curriculum_learning:
63 | self.topK = min(self.topK_max, self.topK + 5)
64 | self.loss_fn.topK = self.topK
65 |
66 | def validation_step(self, batch, batch_idx):
67 |
68 | self.is_eval_model(True)
69 | self(batch)
70 | self.prepare_batch_for_loss(batch, batch_idx)
71 |
72 | # validation metrics
73 | avg_loss, outputs, probs_grad, num_its = self.loss_fn(batch)
74 | outputs['loss'] = avg_loss
75 |
76 | # Metric pose evaluation
77 | R_ours, t_m_ours, inliers_ours = self.e2e_Procrustes.estimate_pose_vectorized(batch)
78 | outputs_metric_ours = pose_error_torch(R_ours, t_m_ours, batch['T_0to1'], reduce=None)
79 | outputs['metric_ours_t_err_ang'] = outputs_metric_ours['t_err_ang']
80 | outputs['metric_ours_t_err_euc'] = outputs_metric_ours['t_err_euc']
81 | outputs['metric_ours_R_err'] = outputs_metric_ours['R_err']
82 | outputs['metric_inliers'] = inliers_ours
83 |
84 | outputs_vcre_ours = vcre_torch(R_ours, t_m_ours, batch['T_0to1'], batch['Kori_color0'], reduce=None)
85 | outputs['metric_ours_vcre'] = outputs_vcre_ours['repr_err']
86 |
87 | self.validation_step_outputs.append(outputs)
88 |
89 | return outputs
90 |
91 | def backward_step(self, batch, outputs, probs_grad, avg_loss, num_its):
92 | opt = self.optimizers()
93 |
94 | # update model
95 | opt.zero_grad()
96 |
97 | if num_its == 0:
98 | print('No valid hypotheses were generated')
99 | return False
100 |
101 | # Generate gradients for learning keypoint offsets
102 | avg_loss.backward()
103 |
104 | invalid_probs = torch.isnan(probs_grad[0]).any()
105 | invalid_kps0 = (torch.isnan(outputs['kps0'].grad).any() or torch.isinf(outputs['kps0'].grad).any())
106 | invalid_kps1 = (torch.isnan(outputs['kps1'].grad).any() or torch.isinf(outputs['kps1'].grad).any())
107 | invalid_depth0 = (torch.isnan(outputs['depth0'].grad).any() or torch.isinf(outputs['depth0'].grad).any())
108 | invalid_depth1 = (torch.isnan(outputs['depth1'].grad).any() or torch.isinf(outputs['depth1'].grad).any())
109 |
110 | if invalid_probs:
111 | print('Found NaN/Inf in probs!')
112 | return False
113 |
114 | if invalid_depth0 or invalid_depth1:
115 | print('Found NaN/Inf in depth0/depth1 gradients!')
116 | return False
117 |
118 | if batch['kps0'].requires_grad:
119 |
120 | if invalid_kps0 or invalid_kps1:
121 | print('Found NaN/Inf in kps0/kps1 gradients!')
122 | return False
123 |
124 | torch.autograd.backward((torch.log(batch['final_scores'] + 1e-16),
125 | batch['kps0'], batch['kps1'], batch['depth_kp0'], batch['depth_kp1']),
126 | (probs_grad[0], outputs['kps0'].grad, outputs['kps1'].grad,
127 | outputs['depth0'].grad, outputs['depth1'].grad))
128 | elif batch['depth_kp0'].requires_grad:
129 | torch.autograd.backward((torch.log(batch['final_scores'] + 1e-16),
130 | batch['depth_kp0'], batch['depth_kp1']),
131 | (probs_grad[0], outputs['depth0'].grad, outputs['depth1'].grad))
132 | else:
133 | torch.autograd.backward((torch.log(batch['final_scores'] + 1e-16)),
134 | (probs_grad[0]))
135 |
136 | # add gradient clipping after backward to avoid gradient exploding
137 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5)
138 |
139 | # check if the gradients of the training parameters contain nan values
140 | nans = sum([torch.isnan(param.grad).any() for param in list(self.parameters()) if param.grad is not None])
141 | if nans != 0:
142 | print("parameter gradients includes {} nan values".format(nans))
143 | return False
144 |
145 | opt.step()
146 |
147 | return True
148 |
149 | def tensorboard_log_step(self, batch, avg_loss, outputs, probs_grad, training_step_ok):
150 |
151 | self.log('train/loss', avg_loss.detach())
152 | self.log('train/loss_rot', outputs['avg_loss_rot'].detach())
153 | self.log('train/loss_trans', outputs['avg_loss_trans'].detach())
154 |
155 | if self.log_store_ims:
156 | if self.counter_batch % self.log_interval == 0:
157 | self.counter_batch = 0
158 |
159 | # If training with curriculum learning, not all image pairs have valid gradients
160 | # ensure selecting one image pair for logging that has reward information
161 | batch_id = torch.where(outputs['mask_topk'] == 1.)[0][0].item()
162 |
163 | # Metric pose evaluation
164 | R_ours, t_m_ours, inliers_ours, inliers_list_ours = self.e2e_Procrustes.estimate_pose_vectorized(batch, return_inliers=True)
165 | outputs_metric_ours = pose_error_torch(R_ours, t_m_ours, batch['T_0to1'], reduce=None)
166 | self.log('train_metric_pose/ours_t_err_ang', outputs_metric_ours['t_err_ang'].mean().detach())
167 | self.log('train_metric_pose/ours_t_err_euc', outputs_metric_ours['t_err_euc'].mean().detach())
168 | self.log('train_metric_pose/ours_R_err', outputs_metric_ours['R_err'].mean().detach())
169 |
170 | outputs_vcre_ours = vcre_torch(R_ours, t_m_ours, batch['T_0to1'], batch['Kori_color0'], reduce=None)
171 | self.log('train_vcre/repr_err', outputs_vcre_ours['repr_err'].mean().detach())
172 |
173 | im_inliers = vis_inliers(inliers_list_ours, batch, batch_i=batch_id)
174 |
175 | im_matches, sc_map0, sc_map1, depth_map0, depth_map1 = log_image_matches(self.compute_matches.matcher,
176 | batch, train_depth=True,
177 | batch_i=batch_id)
178 |
179 | tensorboard = self.logger.experiment
180 | tensorboard.add_image('training_matching/best_inliers', im_inliers, global_step=self.log_im_counter_train)
181 | tensorboard.add_image('training_matching/best_matches_desc', im_matches, global_step=self.log_im_counter_train)
182 | tensorboard.add_image('training_scores/map0', sc_map0, global_step=self.log_im_counter_train)
183 | tensorboard.add_image('training_scores/map1', sc_map1, global_step=self.log_im_counter_train)
184 | tensorboard.add_image('training_depth/map0', depth_map0[0], global_step=self.log_im_counter_train)
185 | tensorboard.add_image('training_depth/map1', depth_map1[0], global_step=self.log_im_counter_train)
186 | if training_step_ok:
187 | try:
188 | im_rewards, rew_kp0, rew_kp1 = debug_reward_matches_log(batch, probs_grad, batch_i=batch_id)
189 | tensorboard.add_image('training_rewards/pair0', im_rewards, global_step=self.log_im_counter_train)
190 | except ValueError:
191 | print('[WARNING]: Failed to log reward image. Selected image is not in topK image pairs. ')
192 |
193 | self.log_im_counter_train += 1
194 |
195 | torch.cuda.empty_cache()
196 | self.counter_batch += 1
197 |
198 | def prepare_batch_for_loss(self, batch, batch_idx):
199 |
200 | batch['batch_idx'] = batch_idx
201 | batch['final_scores'] = batch['scores'] * batch['kp_scores']
202 |
203 | return batch
204 |
205 | def on_validation_epoch_end(self):
206 |
207 | # aggregates metrics/losses from all validation steps
208 | aggregated = {}
209 | for key in self.validation_step_outputs[0].keys():
210 | aggregated[key] = torch.stack([x[key] for x in self.validation_step_outputs])
211 |
212 | # compute stats
213 | mean_R_loss = aggregated['avg_loss_rot'].mean()
214 | mean_t_loss = aggregated['avg_loss_trans'].mean()
215 | mean_loss = aggregated['loss'].mean()
216 |
217 | # Metric stats:
218 | metric_ours_t_err_ang = aggregated['metric_ours_t_err_ang'].mean()
219 | metric_ours_t_err_euc = aggregated['metric_ours_t_err_euc'].mean()
220 | metric_ours_R_err = aggregated['metric_ours_R_err'].mean()
221 |
222 | metric_ours_vcre = aggregated['metric_ours_vcre'].mean()
223 |
224 | # compute precision/AUC for pose error
225 | t_threshold = 0.25
226 | R_threshold = 5
227 | accepted_poses_ours = (aggregated['metric_ours_t_err_euc'].view(-1) < t_threshold) * \
228 | (aggregated['metric_ours_R_err'].view(-1) < R_threshold)
229 |
230 | inliers = aggregated['metric_inliers'].view(-1).detach().cpu().numpy()
231 |
232 | prec_pose_ours = accepted_poses_ours.sum()/len(accepted_poses_ours)
233 |
234 | _, _, auc_pose = precision_recall(inliers=inliers, tp=accepted_poses_ours.detach().cpu().numpy(), failures=0)
235 |
236 | # compute precision/AUC for pose error
237 | t_threshold = 0.5
238 | R_threshold = 10
239 | accepted_poses_ours = (aggregated['metric_ours_t_err_euc'].view(-1) < t_threshold) * \
240 | (aggregated['metric_ours_R_err'].view(-1) < R_threshold)
241 |
242 | inliers = aggregated['metric_inliers'].view(-1).detach().cpu().numpy()
243 |
244 | prec_pose_ours_10 = accepted_poses_ours.sum() / len(accepted_poses_ours)
245 |
246 | _, _, auc_pose_10 = precision_recall(inliers=inliers, tp=accepted_poses_ours.detach().cpu().numpy(), failures=0)
247 |
248 |
249 | # compute precision/AUC for reprojection errors
250 | px_threshold = 90
251 | accepted_vcre_ours = aggregated['metric_ours_vcre'].view(-1) < px_threshold
252 |
253 | prec_vcre_ours = accepted_vcre_ours.sum()/len(accepted_vcre_ours)
254 |
255 | _, _, auc_vcre = precision_recall(inliers=inliers, tp=accepted_vcre_ours.detach().cpu().numpy(), failures=0)
256 |
257 | # log stats
258 | self.log('val_loss/loss_R', mean_R_loss, sync_dist=self.multi_gpu)
259 | self.log('val_loss/loss_t', mean_t_loss, sync_dist=self.multi_gpu)
260 | self.log('val_loss/loss', mean_loss, sync_dist=self.multi_gpu)
261 |
262 | self.log('val_metric_pose/ours_t_err_ang', metric_ours_t_err_ang, sync_dist=self.multi_gpu)
263 | self.log('val_metric_pose/ours_t_err_euc', metric_ours_t_err_euc, sync_dist=self.multi_gpu)
264 | self.log('val_metric_pose/ours_R_err', metric_ours_R_err, sync_dist=self.multi_gpu)
265 |
266 | self.log('val_vcre/auc_vcre', auc_vcre, sync_dist=self.multi_gpu)
267 | self.log('val_vcre/prec_vcre_ours', prec_vcre_ours, sync_dist=self.multi_gpu)
268 | self.log('val_vcre/metric_ours_vcre', metric_ours_vcre, sync_dist=self.multi_gpu)
269 |
270 | self.log('val_AUC_pose/prec_pose_ours', prec_pose_ours, sync_dist=self.multi_gpu)
271 | self.log('val_AUC_pose/auc_pose', torch.tensor(auc_pose), sync_dist=self.multi_gpu)
272 |
273 | self.log('val_AUC_pose/prec_pose_ours_10', prec_pose_ours_10, sync_dist=self.multi_gpu)
274 | self.log('val_AUC_pose/auc_pose_10', torch.tensor(auc_pose_10), sync_dist=self.multi_gpu)
275 |
276 | self.validation_step_outputs.clear() # free memory
277 |
278 | self.is_eval_model(False)
279 |
280 | return mean_loss
281 |
282 | def configure_optimizers(self):
283 | tcfg = self.cfg.TRAINING
284 | opt = torch.optim.Adam(self.parameters(), lr=tcfg.LR, eps=1e-6)
285 | if tcfg.LR_STEP_INTERVAL:
286 | scheduler = torch.optim.lr_scheduler.StepLR(
287 | opt, tcfg.LR_STEP_INTERVAL, tcfg.LR_STEP_GAMMA)
288 | return {'optimizer': opt, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}}
289 | return opt
290 |
291 | def on_save_checkpoint(self, checkpoint):
292 | # As DINOv2 is pre-trained (and no finetuned, avoid saving its weights (it should help the memory).
293 | dinov2_keys = []
294 | for key in checkpoint['state_dict'].keys():
295 | if 'dinov2' in key:
296 | dinov2_keys.append(key)
297 | for key in dinov2_keys:
298 | del checkpoint['state_dict'][key]
299 |
300 | def on_load_checkpoint(self, checkpoint):
301 |
302 | # Recover DINOv2 features from pretrained weights.
303 | for param_tensor in self.compute_matches.state_dict():
304 | if 'dinov2'in param_tensor:
305 | checkpoint['state_dict']['compute_matches.'+param_tensor] = \
306 | self.compute_matches.state_dict()[param_tensor]
307 |
308 | def is_eval_model(self, is_eval):
309 | if is_eval:
310 | self.compute_matches.extractor.depth_head.eval()
311 | self.compute_matches.extractor.det_offset.eval()
312 | self.compute_matches.extractor.dsc_head.eval()
313 | self.compute_matches.extractor.det_head.eval()
314 | else:
315 | self.compute_matches.extractor.depth_head.train()
316 | self.compute_matches.extractor.det_offset.train()
317 | self.compute_matches.extractor.dsc_head.train()
318 | self.compute_matches.extractor.det_head.train()
319 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/dinov2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | from functools import partial
12 | import math
13 | import logging
14 | from typing import Sequence, Tuple, Union, Callable
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.utils.checkpoint
19 | from torch.nn.init import trunc_normal_
20 |
21 | from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
22 |
23 |
24 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
25 | if not depth_first and include_root:
26 | fn(module=module, name=name)
27 | for child_name, child_module in module.named_children():
28 | child_name = ".".join((name, child_name)) if name else child_name
29 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
30 | if depth_first and include_root:
31 | fn(module=module, name=name)
32 | return module
33 |
34 |
35 | class BlockChunk(nn.ModuleList):
36 | def forward(self, x):
37 | for b in self:
38 | x = b(x)
39 | return x
40 |
41 |
42 | class DinoVisionTransformer(nn.Module):
43 | def __init__(
44 | self,
45 | img_size=224,
46 | patch_size=16,
47 | in_chans=3,
48 | embed_dim=768,
49 | depth=12,
50 | num_heads=12,
51 | mlp_ratio=4.0,
52 | qkv_bias=True,
53 | ffn_bias=True,
54 | proj_bias=True,
55 | drop_path_rate=0.0,
56 | drop_path_uniform=False,
57 | init_values=None, # for layerscale: None or 0 => no layerscale
58 | embed_layer=PatchEmbed,
59 | act_layer=nn.GELU,
60 | block_fn=Block,
61 | ffn_layer="mlp",
62 | block_chunks=1,
63 | ):
64 | """
65 | Args:
66 | img_size (int, tuple): input image size
67 | patch_size (int, tuple): patch size
68 | in_chans (int): number of input channels
69 | embed_dim (int): embedding dimension
70 | depth (int): depth of transformer
71 | num_heads (int): number of attention heads
72 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
73 | qkv_bias (bool): enable bias for qkv if True
74 | proj_bias (bool): enable bias for proj in attn if True
75 | ffn_bias (bool): enable bias for ffn if True
76 | drop_path_rate (float): stochastic depth rate
77 | drop_path_uniform (bool): apply uniform drop rate across blocks
78 | weight_init (str): weight init scheme
79 | init_values (float): layer-scale init values
80 | embed_layer (nn.Module): patch embedding layer
81 | act_layer (nn.Module): MLP activation layer
82 | block_fn (nn.Module): transformer block class
83 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
84 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
85 | """
86 | super().__init__()
87 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
88 |
89 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
90 | self.num_tokens = 1
91 | self.n_blocks = depth
92 | self.num_heads = num_heads
93 | self.patch_size = patch_size
94 |
95 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
96 | num_patches = self.patch_embed.num_patches
97 |
98 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
99 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
100 |
101 | if drop_path_uniform is True:
102 | dpr = [drop_path_rate] * depth
103 | else:
104 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
105 |
106 | if ffn_layer == "mlp":
107 | ffn_layer = Mlp
108 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
109 | ffn_layer = SwiGLUFFNFused
110 | elif ffn_layer == "identity":
111 |
112 | def f(*args, **kwargs):
113 | return nn.Identity()
114 |
115 | ffn_layer = f
116 | else:
117 | raise NotImplementedError
118 |
119 | blocks_list = [
120 | block_fn(
121 | dim=embed_dim,
122 | num_heads=num_heads,
123 | mlp_ratio=mlp_ratio,
124 | qkv_bias=qkv_bias,
125 | proj_bias=proj_bias,
126 | ffn_bias=ffn_bias,
127 | drop_path=dpr[i],
128 | norm_layer=norm_layer,
129 | act_layer=act_layer,
130 | ffn_layer=ffn_layer,
131 | init_values=init_values,
132 | )
133 | for i in range(depth)
134 | ]
135 | if block_chunks > 0:
136 | self.chunked_blocks = True
137 | chunked_blocks = []
138 | chunksize = depth // block_chunks
139 | for i in range(0, depth, chunksize):
140 | # this is to keep the block index consistent if we chunk the block list
141 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i: i + chunksize])
142 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
143 | else:
144 | self.chunked_blocks = False
145 | self.blocks = nn.ModuleList(blocks_list)
146 |
147 | self.norm = norm_layer(embed_dim)
148 | self.head = nn.Identity()
149 |
150 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
151 |
152 | self.init_weights()
153 | for param in self.parameters():
154 | param.requires_grad = False
155 |
156 | @property
157 | def device(self):
158 | return self.cls_token.device
159 |
160 | def init_weights(self):
161 | trunc_normal_(self.pos_embed, std=0.02)
162 | nn.init.normal_(self.cls_token, std=1e-6)
163 | named_apply(init_weights_vit_timm, self)
164 |
165 | def interpolate_pos_encoding(self, x, w, h):
166 | previous_dtype = x.dtype
167 | npatch = x.shape[1] - 1
168 | N = self.pos_embed.shape[1] - 1
169 | if npatch == N and w == h:
170 | return self.pos_embed
171 | pos_embed = self.pos_embed.float()
172 | class_pos_embed = pos_embed[:, 0]
173 | patch_pos_embed = pos_embed[:, 1:]
174 | dim = x.shape[-1]
175 | w0 = w // self.patch_size
176 | h0 = h // self.patch_size
177 | # we add a small number to avoid floating point error in the interpolation
178 | # see discussion at https://github.com/facebookresearch/dino/issues/8
179 | w0, h0 = w0 + 0.1, h0 + 0.1
180 |
181 | patch_pos_embed = nn.functional.interpolate(
182 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
183 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
184 | mode="bicubic",
185 | )
186 |
187 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
188 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
189 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
190 |
191 | def prepare_tokens_with_masks(self, x, masks=None):
192 | B, nc, w, h = x.shape
193 | x = self.patch_embed(x)
194 | if masks is not None:
195 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
196 |
197 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
198 | x = x + self.interpolate_pos_encoding(x, w, h)
199 |
200 | return x
201 |
202 | def forward_features_list(self, x_list, masks_list):
203 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
204 | for blk in self.blocks:
205 | x = blk(x)
206 |
207 | all_x = x
208 | output = []
209 | for x, masks in zip(all_x, masks_list):
210 | x_norm = self.norm(x)
211 | output.append(
212 | {
213 | "x_norm_clstoken": x_norm[:, 0],
214 | "x_norm_patchtokens": x_norm[:, 1:],
215 | "x_prenorm": x,
216 | "masks": masks,
217 | }
218 | )
219 | return output
220 |
221 | def forward_features(self, x, masks=None):
222 | if isinstance(x, list):
223 | return self.forward_features_list(x, masks)
224 |
225 | x = self.prepare_tokens_with_masks(x, masks)
226 |
227 | for blk in self.blocks:
228 | x = blk(x)
229 |
230 | x_norm = self.norm(x)
231 | return {
232 | "x_norm_clstoken": x_norm[:, 0],
233 | "x_norm_patchtokens": x_norm[:, 1:],
234 | "x_prenorm": x,
235 | "masks": masks,
236 | }
237 |
238 | def _get_intermediate_layers_not_chunked(self, x, n=1):
239 | x = self.prepare_tokens_with_masks(x)
240 | # If n is an int, take the n last blocks. If it's a list, take them
241 | output, total_block_len = [], len(self.blocks)
242 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
243 | for i, blk in enumerate(self.blocks):
244 | x = blk(x)
245 | if i in blocks_to_take:
246 | output.append(x)
247 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
248 | return output
249 |
250 | def _get_intermediate_layers_chunked(self, x, n=1):
251 | x = self.prepare_tokens_with_masks(x)
252 | output, i, total_block_len = [], 0, len(self.blocks[-1])
253 | # If n is an int, take the n last blocks. If it's a list, take them
254 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
255 | for block_chunk in self.blocks:
256 | for blk in block_chunk[i:]: # Passing the nn.Identity()
257 | x = blk(x)
258 | if i in blocks_to_take:
259 | output.append(x)
260 | i += 1
261 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
262 | return output
263 |
264 | def get_intermediate_layers(
265 | self,
266 | x: torch.Tensor,
267 | n: Union[int, Sequence] = 1, # Layers or n last layers to take
268 | reshape: bool = False,
269 | return_class_token: bool = False,
270 | norm=True,
271 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
272 | if self.chunked_blocks:
273 | outputs = self._get_intermediate_layers_chunked(x, n)
274 | else:
275 | outputs = self._get_intermediate_layers_not_chunked(x, n)
276 | if norm:
277 | outputs = [self.norm(out) for out in outputs]
278 | class_tokens = [out[:, 0] for out in outputs]
279 | outputs = [out[:, 1:] for out in outputs]
280 | if reshape:
281 | B, _, w, h = x.shape
282 | outputs = [
283 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
284 | for out in outputs
285 | ]
286 | if return_class_token:
287 | return tuple(zip(outputs, class_tokens))
288 | return tuple(outputs)
289 |
290 | def forward(self, *args, is_training=False, **kwargs):
291 | ret = self.forward_features(*args, **kwargs)
292 | if is_training:
293 | return ret
294 | else:
295 | return self.head(ret["x_norm_clstoken"])
296 |
297 |
298 | def init_weights_vit_timm(module: nn.Module, name: str = ""):
299 | """ViT weight initialization, original timm impl (for reproducibility)"""
300 | if isinstance(module, nn.Linear):
301 | trunc_normal_(module.weight, std=0.02)
302 | if module.bias is not None:
303 | nn.init.zeros_(module.bias)
304 |
305 |
306 | def vit_small(patch_size=16, **kwargs):
307 | model = DinoVisionTransformer(
308 | patch_size=patch_size,
309 | embed_dim=384,
310 | depth=12,
311 | num_heads=6,
312 | mlp_ratio=4,
313 | block_fn=partial(Block, attn_class=MemEffAttention),
314 | **kwargs,
315 | )
316 | return model
317 |
318 |
319 | def vit_base(patch_size=16, **kwargs):
320 | model = DinoVisionTransformer(
321 | patch_size=patch_size,
322 | embed_dim=768,
323 | depth=12,
324 | num_heads=12,
325 | mlp_ratio=4,
326 | block_fn=partial(Block, attn_class=MemEffAttention),
327 | **kwargs,
328 | )
329 | return model
330 |
331 |
332 | def vit_large(patch_size=16, **kwargs):
333 | model = DinoVisionTransformer(
334 | patch_size=patch_size,
335 | embed_dim=1024,
336 | depth=24,
337 | num_heads=16,
338 | mlp_ratio=4,
339 | block_fn=partial(Block, attn_class=MemEffAttention),
340 | **kwargs,
341 | )
342 | return model
343 |
344 |
345 | def vit_giant2(patch_size=16, **kwargs):
346 | """
347 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
348 | """
349 | model = DinoVisionTransformer(
350 | patch_size=patch_size,
351 | embed_dim=1536,
352 | depth=40,
353 | num_heads=24,
354 | mlp_ratio=4,
355 | block_fn=partial(Block, attn_class=MemEffAttention),
356 | **kwargs,
357 | )
358 | return model
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .dino_head import DINOHead
8 | from .mlp import Mlp
9 | from .patch_embed import PatchEmbed
10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
11 | from .block import NestedTensorBlock
12 | from .attention import MemEffAttention
13 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10 |
11 | import logging
12 |
13 | from torch import Tensor
14 | from torch import nn
15 |
16 |
17 | logger = logging.getLogger("dinov2")
18 |
19 |
20 | try:
21 | from xformers.ops import memory_efficient_attention, unbind, fmha
22 |
23 | XFORMERS_AVAILABLE = True
24 | except ImportError:
25 | logger.warning("xFormers not available")
26 | XFORMERS_AVAILABLE = False
27 |
28 |
29 | class Attention(nn.Module):
30 | def __init__(
31 | self,
32 | dim: int,
33 | num_heads: int = 8,
34 | qkv_bias: bool = False,
35 | proj_bias: bool = True,
36 | attn_drop: float = 0.0,
37 | proj_drop: float = 0.0,
38 | ) -> None:
39 | super().__init__()
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim**-0.5
43 |
44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45 | self.attn_drop = nn.Dropout(attn_drop)
46 | self.proj = nn.Linear(dim, dim, bias=proj_bias)
47 | self.proj_drop = nn.Dropout(proj_drop)
48 |
49 | def forward(self, x: Tensor) -> Tensor:
50 | B, N, C = x.shape
51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52 |
53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54 | attn = q @ k.transpose(-2, -1)
55 |
56 | attn = attn.softmax(dim=-1)
57 | attn = self.attn_drop(attn)
58 |
59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60 | x = self.proj(x)
61 | x = self.proj_drop(x)
62 | return x
63 |
64 |
65 | class MemEffAttention(Attention):
66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67 | if not XFORMERS_AVAILABLE:
68 | assert attn_bias is None, "xFormers is required for nested tensors usage"
69 | return super().forward(x)
70 |
71 | B, N, C = x.shape
72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73 |
74 | q, k, v = unbind(qkv, 2)
75 |
76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77 | x = x.reshape([B, N, C])
78 |
79 | x = self.proj(x)
80 | x = self.proj_drop(x)
81 | return x
82 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | import logging
12 | from typing import Callable, List, Any, Tuple, Dict
13 |
14 | import torch
15 | from torch import nn, Tensor
16 |
17 | from .attention import Attention, MemEffAttention
18 | from .drop_path import DropPath
19 | from .layer_scale import LayerScale
20 | from .mlp import Mlp
21 |
22 |
23 | logger = logging.getLogger("dinov2")
24 |
25 |
26 | try:
27 | from xformers.ops import fmha
28 | from xformers.ops import scaled_index_add, index_select_cat
29 |
30 | XFORMERS_AVAILABLE = True
31 | except ImportError:
32 | logger.warning("xFormers not available")
33 | XFORMERS_AVAILABLE = False
34 |
35 |
36 | class Block(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | num_heads: int,
41 | mlp_ratio: float = 4.0,
42 | qkv_bias: bool = False,
43 | proj_bias: bool = True,
44 | ffn_bias: bool = True,
45 | drop: float = 0.0,
46 | attn_drop: float = 0.0,
47 | init_values=None,
48 | drop_path: float = 0.0,
49 | act_layer: Callable[..., nn.Module] = nn.GELU,
50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51 | attn_class: Callable[..., nn.Module] = Attention,
52 | ffn_layer: Callable[..., nn.Module] = Mlp,
53 | ) -> None:
54 | super().__init__()
55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56 | self.norm1 = norm_layer(dim)
57 | self.attn = attn_class(
58 | dim,
59 | num_heads=num_heads,
60 | qkv_bias=qkv_bias,
61 | proj_bias=proj_bias,
62 | attn_drop=attn_drop,
63 | proj_drop=drop,
64 | )
65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67 |
68 | self.norm2 = norm_layer(dim)
69 | mlp_hidden_dim = int(dim * mlp_ratio)
70 | self.mlp = ffn_layer(
71 | in_features=dim,
72 | hidden_features=mlp_hidden_dim,
73 | act_layer=act_layer,
74 | drop=drop,
75 | bias=ffn_bias,
76 | )
77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79 |
80 | self.sample_drop_ratio = drop_path
81 |
82 | def forward(self, x: Tensor) -> Tensor:
83 | def attn_residual_func(x: Tensor) -> Tensor:
84 | return self.ls1(self.attn(self.norm1(x)))
85 |
86 | def ffn_residual_func(x: Tensor) -> Tensor:
87 | return self.ls2(self.mlp(self.norm2(x)))
88 |
89 | if self.training and self.sample_drop_ratio > 0.1:
90 | # the overhead is compensated only for a drop path rate larger than 0.1
91 | x = drop_add_residual_stochastic_depth(
92 | x,
93 | residual_func=attn_residual_func,
94 | sample_drop_ratio=self.sample_drop_ratio,
95 | )
96 | x = drop_add_residual_stochastic_depth(
97 | x,
98 | residual_func=ffn_residual_func,
99 | sample_drop_ratio=self.sample_drop_ratio,
100 | )
101 | elif self.training and self.sample_drop_ratio > 0.0:
102 | x = x + self.drop_path1(attn_residual_func(x))
103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104 | else:
105 | x = x + attn_residual_func(x)
106 | x = x + ffn_residual_func(x)
107 | return x
108 |
109 |
110 | def drop_add_residual_stochastic_depth(
111 | x: Tensor,
112 | residual_func: Callable[[Tensor], Tensor],
113 | sample_drop_ratio: float = 0.0,
114 | ) -> Tensor:
115 | # 1) extract subset using permutation
116 | b, n, d = x.shape
117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119 | x_subset = x[brange]
120 |
121 | # 2) apply residual_func to get residual
122 | residual = residual_func(x_subset)
123 |
124 | x_flat = x.flatten(1)
125 | residual = residual.flatten(1)
126 |
127 | residual_scale_factor = b / sample_subset_size
128 |
129 | # 3) add the residual
130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131 | return x_plus_residual.view_as(x)
132 |
133 |
134 | def get_branges_scales(x, sample_drop_ratio=0.0):
135 | b, n, d = x.shape
136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138 | residual_scale_factor = b / sample_subset_size
139 | return brange, residual_scale_factor
140 |
141 |
142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143 | if scaling_vector is None:
144 | x_flat = x.flatten(1)
145 | residual = residual.flatten(1)
146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147 | else:
148 | x_plus_residual = scaled_index_add(
149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150 | )
151 | return x_plus_residual
152 |
153 |
154 | attn_bias_cache: Dict[Tuple, Any] = {}
155 |
156 |
157 | def get_attn_bias_and_cat(x_list, branges=None):
158 | """
159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache
160 | """
161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163 | if all_shapes not in attn_bias_cache.keys():
164 | seqlens = []
165 | for b, x in zip(batch_sizes, x_list):
166 | for _ in range(b):
167 | seqlens.append(x.shape[1])
168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169 | attn_bias._batch_sizes = batch_sizes
170 | attn_bias_cache[all_shapes] = attn_bias
171 |
172 | if branges is not None:
173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174 | else:
175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176 | cat_tensors = torch.cat(tensors_bs1, dim=1)
177 |
178 | return attn_bias_cache[all_shapes], cat_tensors
179 |
180 |
181 | def drop_add_residual_stochastic_depth_list(
182 | x_list: List[Tensor],
183 | residual_func: Callable[[Tensor, Any], Tensor],
184 | sample_drop_ratio: float = 0.0,
185 | scaling_vector=None,
186 | ) -> Tensor:
187 | # 1) generate random set of indices for dropping samples in the batch
188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189 | branges = [s[0] for s in branges_scales]
190 | residual_scale_factors = [s[1] for s in branges_scales]
191 |
192 | # 2) get attention bias and index+concat the tensors
193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194 |
195 | # 3) apply residual_func to get residual, and split the result
196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197 |
198 | outputs = []
199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201 | return outputs
202 |
203 |
204 | class NestedTensorBlock(Block):
205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206 | """
207 | x_list contains a list of tensors to nest together and run
208 | """
209 | assert isinstance(self.attn, MemEffAttention)
210 |
211 | if self.training and self.sample_drop_ratio > 0.0:
212 |
213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214 | return self.attn(self.norm1(x), attn_bias=attn_bias)
215 |
216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217 | return self.mlp(self.norm2(x))
218 |
219 | x_list = drop_add_residual_stochastic_depth_list(
220 | x_list,
221 | residual_func=attn_residual_func,
222 | sample_drop_ratio=self.sample_drop_ratio,
223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224 | )
225 | x_list = drop_add_residual_stochastic_depth_list(
226 | x_list,
227 | residual_func=ffn_residual_func,
228 | sample_drop_ratio=self.sample_drop_ratio,
229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230 | )
231 | return x_list
232 | else:
233 |
234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236 |
237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238 | return self.ls2(self.mlp(self.norm2(x)))
239 |
240 | attn_bias, x = get_attn_bias_and_cat(x_list)
241 | x = x + attn_residual_func(x, attn_bias=attn_bias)
242 | x = x + ffn_residual_func(x)
243 | return attn_bias.split(x)
244 |
245 | def forward(self, x_or_x_list):
246 | if isinstance(x_or_x_list, Tensor):
247 | return super().forward(x_or_x_list)
248 | elif isinstance(x_or_x_list, list):
249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250 | return self.forward_nested(x_or_x_list)
251 | else:
252 | raise AssertionError
253 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/dino_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn.init import trunc_normal_
10 | from torch.nn.utils import weight_norm
11 |
12 |
13 | class DINOHead(nn.Module):
14 | def __init__(
15 | self,
16 | in_dim,
17 | out_dim,
18 | use_bn=False,
19 | nlayers=3,
20 | hidden_dim=2048,
21 | bottleneck_dim=256,
22 | mlp_bias=True,
23 | ):
24 | super().__init__()
25 | nlayers = max(nlayers, 1)
26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
27 | self.apply(self._init_weights)
28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
29 | self.last_layer.weight_g.data.fill_(1)
30 |
31 | def _init_weights(self, m):
32 | if isinstance(m, nn.Linear):
33 | trunc_normal_(m.weight, std=0.02)
34 | if isinstance(m, nn.Linear) and m.bias is not None:
35 | nn.init.constant_(m.bias, 0)
36 |
37 | def forward(self, x):
38 | x = self.mlp(x)
39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12
40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
41 | x = self.last_layer(x)
42 | return x
43 |
44 |
45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
46 | if nlayers == 1:
47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias)
48 | else:
49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
50 | if use_bn:
51 | layers.append(nn.BatchNorm1d(hidden_dim))
52 | layers.append(nn.GELU())
53 | for _ in range(nlayers - 2):
54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
55 | if use_bn:
56 | layers.append(nn.BatchNorm1d(hidden_dim))
57 | layers.append(nn.GELU())
58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
59 | return nn.Sequential(*layers)
60 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/drop_path.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10 |
11 |
12 | from torch import nn
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16 | if drop_prob == 0.0 or not training:
17 | return x
18 | keep_prob = 1 - drop_prob
19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21 | if keep_prob > 0.0:
22 | random_tensor.div_(keep_prob)
23 | output = x * random_tensor
24 | return output
25 |
26 |
27 | class DropPath(nn.Module):
28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29 |
30 | def __init__(self, drop_prob=None):
31 | super(DropPath, self).__init__()
32 | self.drop_prob = drop_prob
33 |
34 | def forward(self, x):
35 | return drop_path(x, self.drop_prob, self.training)
36 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/layer_scale.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8 |
9 | from typing import Union
10 |
11 | import torch
12 | from torch import Tensor
13 | from torch import nn
14 |
15 |
16 | class LayerScale(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | init_values: Union[float, Tensor] = 1e-5,
21 | inplace: bool = False,
22 | ) -> None:
23 | super().__init__()
24 | self.inplace = inplace
25 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
29 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10 |
11 |
12 | from typing import Callable, Optional
13 |
14 | from torch import Tensor, nn
15 |
16 |
17 | class Mlp(nn.Module):
18 | def __init__(
19 | self,
20 | in_features: int,
21 | hidden_features: Optional[int] = None,
22 | out_features: Optional[int] = None,
23 | act_layer: Callable[..., nn.Module] = nn.GELU,
24 | drop: float = 0.0,
25 | bias: bool = True,
26 | ) -> None:
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x: Tensor) -> Tensor:
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # References:
8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10 |
11 | from typing import Callable, Optional, Tuple, Union
12 |
13 | from torch import Tensor
14 | import torch.nn as nn
15 |
16 |
17 | def make_2tuple(x):
18 | if isinstance(x, tuple):
19 | assert len(x) == 2
20 | return x
21 |
22 | assert isinstance(x, int)
23 | return (x, x)
24 |
25 |
26 | class PatchEmbed(nn.Module):
27 | """
28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29 |
30 | Args:
31 | img_size: Image size.
32 | patch_size: Patch token size.
33 | in_chans: Number of input image channels.
34 | embed_dim: Number of linear projection output channels.
35 | norm_layer: Normalization layer.
36 | """
37 |
38 | def __init__(
39 | self,
40 | img_size: Union[int, Tuple[int, int]] = 224,
41 | patch_size: Union[int, Tuple[int, int]] = 16,
42 | in_chans: int = 3,
43 | embed_dim: int = 768,
44 | norm_layer: Optional[Callable] = None,
45 | flatten_embedding: bool = True,
46 | ) -> None:
47 | super().__init__()
48 |
49 | image_HW = make_2tuple(img_size)
50 | patch_HW = make_2tuple(patch_size)
51 | patch_grid_size = (
52 | image_HW[0] // patch_HW[0],
53 | image_HW[1] // patch_HW[1],
54 | )
55 |
56 | self.img_size = image_HW
57 | self.patch_size = patch_HW
58 | self.patches_resolution = patch_grid_size
59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60 |
61 | self.in_chans = in_chans
62 | self.embed_dim = embed_dim
63 |
64 | self.flatten_embedding = flatten_embedding
65 |
66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68 |
69 | def forward(self, x: Tensor) -> Tensor:
70 | _, _, H, W = x.shape
71 | patch_H, patch_W = self.patch_size
72 |
73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75 |
76 | x = self.proj(x) # B C H W
77 | H, W = x.size(2), x.size(3)
78 | x = x.flatten(2).transpose(1, 2) # B HW C
79 | x = self.norm(x)
80 | if not self.flatten_embedding:
81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82 | return x
83 |
84 | def flops(self) -> float:
85 | Ho, Wo = self.patches_resolution
86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87 | if self.norm is not None:
88 | flops += Ho * Wo * self.embed_dim
89 | return flops
90 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/DINO_modules/layers/swiglu_ffn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Callable, Optional
8 |
9 | from torch import Tensor, nn
10 | import torch.nn.functional as F
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | act_layer: Callable[..., nn.Module] = None,
20 | drop: float = 0.0,
21 | bias: bool = True,
22 | ) -> None:
23 | super().__init__()
24 | out_features = out_features or in_features
25 | hidden_features = hidden_features or in_features
26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28 |
29 | def forward(self, x: Tensor) -> Tensor:
30 | x12 = self.w12(x)
31 | x1, x2 = x12.chunk(2, dim=-1)
32 | hidden = F.silu(x1) * x2
33 | return self.w3(hidden)
34 |
35 |
36 | try:
37 | from xformers.ops import SwiGLU
38 |
39 | XFORMERS_AVAILABLE = True
40 | except ImportError:
41 | SwiGLU = SwiGLUFFN
42 | XFORMERS_AVAILABLE = False
43 |
44 |
45 | class SwiGLUFFNFused(SwiGLU):
46 | def __init__(
47 | self,
48 | in_features: int,
49 | hidden_features: Optional[int] = None,
50 | out_features: Optional[int] = None,
51 | act_layer: Callable[..., nn.Module] = None,
52 | drop: float = 0.0,
53 | bias: bool = True,
54 | ) -> None:
55 | out_features = out_features or in_features
56 | hidden_features = hidden_features or in_features
57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58 | super().__init__(
59 | in_features=in_features,
60 | hidden_features=hidden_features,
61 | out_features=out_features,
62 | bias=bias,
63 | )
64 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/att_layers/attention.py:
--------------------------------------------------------------------------------
1 | """
2 | Some functions are borrowed from LoFTR: Detector-Free Local
3 | Feature Matching with Transformers (https://github.com/zju3dv/LoFTR) and modified here.
4 | If using this code, please consider citing LoFTR.
5 | """
6 |
7 | import torch
8 | from torch.nn import Module, Dropout
9 | import torch.nn.functional as F
10 |
11 | def elu_feature_map(x):
12 | return torch.nn.functional.elu(x) + 1
13 |
14 | class Attention(Module):
15 | def __init__(self, eps=1e-6, use_dropout=False, attention='', attention_dropout=0.1):
16 | super().__init__()
17 | self.feature_map = elu_feature_map
18 | self.eps = eps
19 | self.use_dropout = use_dropout
20 | self.dropout = Dropout(attention_dropout)
21 | self.attention = attention
22 |
23 | def forward_full(self, queries, keys, values):
24 | """ Multi-head scaled dot-product attention, a.k.a full attention.
25 | Args:
26 | queries: [N, L, H, D]
27 | keys: [N, S, H, D]
28 | values: [N, S, H, D]
29 | Returns:
30 | queried_values: (N, L, H, D)
31 | """
32 |
33 | # Compute the unnormalized attention
34 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
35 |
36 | # Compute the attention and the weighted average
37 | softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
38 | A = torch.softmax(softmax_temp * QK, dim=2)
39 | if self.use_dropout:
40 | A = self.dropout(A)
41 |
42 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
43 |
44 | return queried_values.contiguous()
45 |
46 | def forward_linear(self, queries, keys, values):
47 | """ Multi-Head linear attention proposed in "Transformers are RNNs"
48 | Args:
49 | queries: [N, L, H, D]
50 | keys: [N, S, H, D]
51 | values: [N, S, H, D]
52 | Returns:
53 | queried_values: (N, L, H, D)
54 | """
55 | Q = self.feature_map(queries)
56 | K = self.feature_map(keys)
57 |
58 | v_length = values.size(1)
59 | values = values / v_length # prevent fp16 overflow
60 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
61 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
62 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
63 |
64 | return queried_values.contiguous()
65 |
66 | def forward_flash(self, q, k, v):
67 | args = [x.half().contiguous() for x in [q, k, v]]
68 | v = F.scaled_dot_product_attention(*args).to(q.dtype)
69 | return v.contiguous()
70 |
71 | def forward(self, queries, keys, values):
72 | """ Multi-Head linear attention proposed in "Transformers are RNNs"
73 | Args:
74 | queries: [N, L, H, D]
75 | keys: [N, S, H, D]
76 | values: [N, S, H, D]
77 | Returns:
78 | x: queried values (N, L, H, D)
79 | """
80 | if self.attention == 'linear':
81 | x = self.forward_linear(queries, keys, values)
82 | elif self.attention == 'flash':
83 | x = self.forward_flash(queries, keys, values)
84 | else:
85 | x = self.forward_full(queries, keys, values)
86 | return x
87 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/att_layers/transformer.py:
--------------------------------------------------------------------------------
1 |
2 | import copy
3 | import math
4 | import torch
5 | import torch.nn as nn
6 | from einops.einops import rearrange
7 | from lib.models.MicKey.modules.att_layers.transformer_utils import EncoderLayer
8 | import torch.nn.functional as F
9 |
10 | class PositionEncodingSine(nn.Module):
11 | """
12 | This is a sinusoidal position encoding that generalized to 2-dimensional images
13 | """
14 |
15 | def __init__(self, d_model, max_shape=(256, 256)):
16 | """
17 | Args:
18 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
19 | temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
20 | the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
21 | on the final performance. For now, we keep both impls for backward compatability.
22 | We will remove the buggy impl after re-training all variants of our released models.
23 | """
24 | super().__init__()
25 |
26 | pe = torch.zeros((d_model, *max_shape))
27 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
28 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
29 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
30 | div_term = div_term[:, None, None] # [C//4, 1, 1]
31 | pe[0::4, :, :] = torch.sin(x_position * div_term)
32 | pe[1::4, :, :] = torch.cos(x_position * div_term)
33 | pe[2::4, :, :] = torch.sin(y_position * div_term)
34 | pe[3::4, :, :] = torch.cos(y_position * div_term)
35 |
36 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
37 |
38 | def forward(self, x):
39 | """
40 | Args:
41 | x: [N, C, H, W]
42 | """
43 | return x + self.pe[:, :, :x.size(2), :x.size(3)]
44 |
45 | class Transformer_self_att(nn.Module):
46 | """This class implement self attention transformer module.
47 | Arguments:
48 | d_model: Feature dimension after feature extractor (default: 1024d).
49 | aggregator_conf: Configuration dictionary containing the parameters for the transformer module.
50 | """
51 |
52 | def __init__(self, d_model, num_layers, add_posEnc=False):
53 | super(Transformer_self_att, self).__init__()
54 |
55 | # Define the transformer parameters
56 | self.d_model = d_model
57 |
58 | # TODO: Expose parameters to config file
59 | layer_names = ['self'] * num_layers
60 | attention = 'linear'
61 | self.nheads = 8
62 | self.layer_names = layer_names
63 | encoder_layer = EncoderLayer(d_model, self.nheads, attention)
64 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
65 | self._reset_parameters()
66 | self.add_posEnc = add_posEnc
67 | self.posEnc = PositionEncodingSine(d_model)
68 |
69 | def _reset_parameters(self):
70 | for p in self.parameters():
71 | if p.dim() > 1:
72 | nn.init.xavier_uniform_(p)
73 |
74 |
75 | def forward(self, feats):
76 | """
77 | Runs the common self and cross-attention module.
78 | Args:
79 | feats_a: Features from image A (source) ([N, d_model, im_size/down_factor, im_size/down_factor]).
80 | feats_b: Features from image B (destination) ([N, d_model, im_size/down_factor, im_size/down_factor]).
81 | Output:
82 | feats_a: Self and cross-attended features corresponding to image A (source)
83 | ([N, d_model, im_size/down_factor, im_size/down_factor])
84 | feats_b: Self and cross-attended features corresponding to image B (destination)
85 | ([N, d_model, im_size/down_factor, im_size/down_factor]).
86 | """
87 |
88 | assert self.d_model == feats.size(1), "The feature size and transformer must be equal"
89 |
90 | b, c, h, w = feats.size()
91 |
92 | if self.add_posEnc:
93 | feats = self.posEnc(feats)
94 |
95 | feats = rearrange(feats, 'n c h w -> n (h w) c')
96 |
97 | # Apply linear self attention to feats
98 | for layer, name in zip(self.layers, self.layer_names):
99 | feats = layer(feats, feats)
100 |
101 | feats = feats.transpose(2, 1).reshape((b, c, h, w))
102 |
103 | return feats
104 |
105 | class Transformer_att(nn.Module):
106 | """This class implement self attention transformer module.
107 | Arguments:
108 | d_model: Feature dimension after feature extractor (default: 1024d).
109 | aggregator_conf: Configuration dictionary containing the parameters for the transformer module.
110 | """
111 |
112 | def __init__(self, d_model, num_layers, add_posEnc=False):
113 | super(Transformer_att, self).__init__()
114 |
115 | # Define the transformer parameters
116 | self.d_model = d_model
117 |
118 | # TODO: Expose parameters to config file
119 | layer_names = ['self', 'cross'] * num_layers
120 | attention = 'linear'
121 | self.nheads = 8
122 | self.layer_names = layer_names
123 | encoder_layer = EncoderLayer(d_model, self.nheads, attention)
124 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
125 | self._reset_parameters()
126 | self.add_posEnc = add_posEnc
127 | self.posEnc = PositionEncodingSine(d_model)
128 |
129 | def _reset_parameters(self):
130 | for p in self.parameters():
131 | if p.dim() > 1:
132 | nn.init.xavier_uniform_(p)
133 |
134 |
135 | def forward(self, feats0, feats1):
136 | """
137 | Runs the common self and cross-attention module.
138 | Args:
139 | feats_a: Features from image A (source) ([N, d_model, im_size/down_factor, im_size/down_factor]).
140 | feats_b: Features from image B (destination) ([N, d_model, im_size/down_factor, im_size/down_factor]).
141 | Output:
142 | feats_a: Self and cross-attended features corresponding to image A (source)
143 | ([N, d_model, im_size/down_factor, im_size/down_factor])
144 | feats_b: Self and cross-attended features corresponding to image B (destination)
145 | ([N, d_model, im_size/down_factor, im_size/down_factor]).
146 | """
147 |
148 | assert self.d_model == feats0.size(1), "The feature size and transformer must be equal"
149 |
150 | b, c, h, w = feats0.size()
151 |
152 | if self.add_posEnc:
153 | feats0 = self.posEnc(feats0)
154 | feats1 = self.posEnc(feats1)
155 |
156 | feats0 = rearrange(feats0, 'n c h w -> n (h w) c')
157 | feats1 = rearrange(feats1, 'n c h w -> n (h w) c')
158 |
159 | # Apply linear self attention to feats
160 | for layer, name in zip(self.layers, self.layer_names):
161 | if name == 'self':
162 | feats0 = layer(feats0, feats0)
163 | feats1 = layer(feats1, feats1)
164 | elif name == 'cross':
165 | feats0, feats1 = layer(feats0, feats1), layer(feats1, feats0)
166 | else:
167 | raise KeyError
168 |
169 | feats0 = feats0.transpose(2, 1).reshape((b, c, h, w))
170 | feats1 = feats1.transpose(2, 1).reshape((b, c, h, w))
171 |
172 | return feats0, feats1
173 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/att_layers/transformer_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from lib.models.MicKey.modules.att_layers.attention import Attention
5 |
6 | class EncoderLayer(nn.Module):
7 | """
8 | Transformer encoder layer containing the linear self and cross-attention, and the epipolar attention.
9 | Arguments:
10 | d_model: Feature dimension of the input feature maps (default: 128d).
11 | nhead: Number of heads in the multi-head attention.
12 | attention: Type of attention for the common transformer block. Options: linear, full.
13 | """
14 | def __init__(self, d_model, nhead, attention='linear'):
15 | super(EncoderLayer, self).__init__()
16 |
17 | # Transformer encoder layer parameters
18 | self.dim = d_model // nhead
19 | self.nhead = nhead
20 |
21 | # multi-head attention definition
22 | self.q_proj = nn.Linear(d_model, d_model, bias=False)
23 | self.k_proj = nn.Linear(d_model, d_model, bias=False)
24 | self.v_proj = nn.Linear(d_model, d_model, bias=False)
25 | # full_att = False if attention == 'linear' else True
26 | self.attention = Attention(attention=attention)
27 | self.merge = nn.Linear(d_model, d_model, bias=False)
28 |
29 | # feed-forward network
30 | self.mlp = nn.Sequential(
31 | nn.Linear(d_model*2, d_model*2, bias=False),
32 | nn.ReLU(True),
33 | nn.Linear(d_model*2, d_model, bias=False),
34 | )
35 |
36 | # norm and dropout
37 | self.norm1 = nn.LayerNorm(d_model)
38 | self.norm2 = nn.LayerNorm(d_model)
39 |
40 | def forward(self, x, source):
41 | """
42 | Args:
43 | x (torch.Tensor): [N, L, C] (L = im_size/down_factor ** 2)
44 | source (torch.Tensor): [N, S, C]
45 | if is_epi_att:
46 | S = (im_size/down_factor/step_grid) ** 2 * sampling_dim
47 | else:
48 | S = im_size/down_factor ** 2
49 | is_epi_att (bool): Indicates whether it applies epipolar cross-attention
50 | """
51 | bs = x.size(0)
52 | query, key, value = x, source, source
53 |
54 | # multi-head attention
55 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
56 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
57 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
58 | message = self.attention(query, key, value) # [N, L, (H, D)]
59 | message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C]
60 | message = self.norm1(message)
61 |
62 | # feed-forward network
63 | message = self.mlp(torch.cat([x, message], dim=2))
64 | message = self.norm2(message)
65 |
66 | return x + message
67 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/compute_correspondences.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lib.models.MicKey.modules.mickey_extractor import MicKey_Extractor
4 | from lib.models.MicKey.modules.utils.feature_matcher import featureMatcher
5 |
6 | class ComputeCorrespondences(nn.Module):
7 | def __init__(self, cfg):
8 | super().__init__()
9 |
10 | # Feature extractor
11 | self.extractor = MicKey_Extractor(cfg['MICKEY'])
12 |
13 | self.dsc_dim = cfg['MICKEY']['DSC_HEAD']['LAST_DIM']
14 |
15 | # Feature matcher
16 | self.matcher = featureMatcher(cfg['FEATURE_MATCHER'])
17 |
18 | self.down_factor = cfg['MICKEY']['DINOV2']['DOWN_FACTOR']
19 |
20 | def get_abs_kpts_coordinates(self, kpts):
21 |
22 | B, C, H, W = kpts.shape
23 |
24 | # Compute offset for every kp grid
25 | x_abs_pos = torch.arange(W).view(1, 1, W).tile([B, H, 1]).to(kpts.device)
26 | y_abs_pos = torch.arange(H).view(1, H, 1).tile([B, 1, W]).to(kpts.device)
27 | abs_pos = torch.concat([x_abs_pos.unsqueeze(1), y_abs_pos.unsqueeze(1)], dim=1)
28 |
29 | kpts_abs_pos = (kpts + abs_pos) * self.down_factor
30 |
31 | return kpts_abs_pos
32 |
33 | def prepare_kpts_dsc(self, kpt, depth, scr, dsc):
34 |
35 | B, _, H, W = kpt.shape
36 | num_kpts = (H * W)
37 |
38 | kpt = kpt.view(B, 2, num_kpts)
39 | depth = depth.view(B, 1, num_kpts)
40 | scr = scr.view(B, 1, num_kpts)
41 | dsc = dsc.view(B, self.dsc_dim, num_kpts)
42 |
43 | return kpt, depth, scr, dsc
44 |
45 | # Independent method to only combine matching and keypoint scores during training
46 | def kp_matrix_scores(self, sc0, sc1):
47 |
48 | # matrix with "probability" of sampling a correspondence based on keypoint scores only
49 | scores = torch.matmul(sc0.transpose(2, 1).contiguous(), sc1)
50 | return scores
51 |
52 | def forward(self, data):
53 |
54 | # Compute detection and descriptor maps
55 | im0 = data['image0']
56 | im1 = data['image1']
57 |
58 | # Extract independently features from im0 and im1
59 | kps0, depth0, scr0, dsc0 = self.extractor(im0)
60 | kps1, depth1, scr1, dsc1 = self.extractor(im1)
61 |
62 | kps0 = self.get_abs_kpts_coordinates(kps0)
63 | kps1 = self.get_abs_kpts_coordinates(kps1)
64 |
65 | # Log shape for logging purposes
66 | _, _, H_kp0, W_kp0 = kps0.shape
67 | _, _, H_kp1, W_kp1 = kps1.shape
68 | data['kps0_shape'] = [H_kp0, W_kp0]
69 | data['kps1_shape'] = [H_kp1, W_kp1]
70 | data['depth0_map'] = depth0
71 | data['depth1_map'] = depth1
72 | data['down_factor'] = self.down_factor
73 |
74 | # Reshape kpts and descriptors to [B, num_kpts, dim]
75 | kps0, depth0, scr0, dsc0 = self.prepare_kpts_dsc(kps0, depth0, scr0, dsc0)
76 | kps1, depth1, scr1, dsc1 = self.prepare_kpts_dsc(kps1, depth1, scr1, dsc1)
77 |
78 | # get correspondences
79 | scores = self.matcher(kps0, dsc0, kps1, dsc1)
80 |
81 | data['kps0'] = kps0
82 | data['depth_kp0'] = depth0
83 | data['scr0'] = scr0
84 | data['kps1'] = kps1
85 | data['depth_kp1'] = depth1
86 | data['scr1'] = scr1
87 | data['scores'] = scores
88 | data['dsc0'] = dsc0
89 | data['dsc1'] = dsc1
90 | data['kp_scores'] = self.kp_matrix_scores(scr0, scr1)
91 |
92 | return kps0, dsc0, kps1, dsc1
93 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/loss/loss_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from lib.utils.metrics import vcre_loss
4 |
5 | def compute_angular_error(R, t, Rgt_i, tgt_i):
6 | loss_rot, rot_err = rot_angle_loss(R, Rgt_i)
7 | loss_trans, t_err = trans_ang_loss(t, tgt_i)
8 |
9 | max_loss, _ = torch.max(torch.cat((loss_rot, loss_trans), dim=-1), dim=-1)
10 | return max_loss, loss_rot, loss_trans
11 |
12 | def compute_angular_error_weighted(R, t, Rgt_i, tgt_i, weights_t):
13 | loss_rot, rot_err = rot_angle_loss(R, Rgt_i)
14 | loss_trans, t_err = trans_ang_loss(t, tgt_i)
15 |
16 | max_loss, _ = torch.max(torch.cat((loss_rot, loss_trans * weights_t), dim=-1), dim=-1)
17 | return max_loss, loss_rot, loss_trans
18 |
19 | def ess_sq_euclidean_error(E, Egt):
20 |
21 | B = E.shape[0]
22 | E_norm = E/E[:, 2, 2].view(B, 1, 1)
23 | Egt_norm = Egt/Egt[:, 2, 2].view(B, 1, 1)
24 | return torch.pow(E_norm-Egt_norm, 2).view(B, -1).sum(1)
25 |
26 | def compute_pose_loss(R, t, Rgt_i, tgt_i, K0=None, K1=None, soft_clipping=True):
27 | loss_rot, rot_err = rot_angle_loss(R, Rgt_i)
28 | loss_trans = trans_l1_loss(t, tgt_i)
29 |
30 | if soft_clipping:
31 | loss_trans_soft = torch.tanh(loss_trans/0.9) # xm ~ ?
32 | loss_rot_soft = torch.tanh(loss_rot/0.9) # xrads=xdeg ~ ?
33 |
34 | loss = loss_rot_soft + loss_trans_soft
35 | else:
36 | loss = loss_rot + loss_trans
37 |
38 | return loss, loss_rot, loss_trans
39 |
40 | def compute_vcre_loss(R, t, Rgt, tgt, K0, K1, soft_clipping=True):
41 |
42 | B = R.shape[0]
43 | Tgt = torch.zeros((B, 4, 4)).float().to(R.device)
44 | Tgt[:, :3, :3] = Rgt
45 | Tgt[:, :3, 3:] = tgt.transpose(2, 1)
46 |
47 | # Inv pose:
48 | R_inv = R.transpose(2, 1)
49 | t_inv = (-1 * R_inv @ t.transpose(2, 1)).transpose(2, 1)
50 | Tgt_inv = torch.zeros((B, 4, 4)).float().to(R.device)
51 | Rgt_inv = Rgt.transpose(2, 1)
52 | tgt_inv = (-1 * Rgt_inv @ tgt.transpose(2, 1)).transpose(2, 1)
53 | Tgt_inv[:, :3, :3] = Rgt_inv
54 | Tgt_inv[:, :3, 3:] = tgt_inv.transpose(2, 1)
55 |
56 | loss0 = vcre_loss(R, t, Tgt, K0)
57 | loss1 = vcre_loss(R_inv, t_inv, Tgt_inv, K1)
58 |
59 | loss = (loss1 + loss0) / 2.
60 | if soft_clipping:
61 | loss = torch.tanh(loss / 80)
62 |
63 | loss_rot, rot_err = rot_angle_loss(R, Rgt)
64 | loss_trans = trans_l1_loss(t, tgt)
65 |
66 | return loss, loss_rot, loss_trans
67 |
68 | def trans_ang_loss(t, tgt):
69 | """Computes L1 loss for translation vector ANGULAR error
70 | Input:
71 | t - estimated translation vector [B, 1, 3]
72 | tgt - ground-truth translation vector [B, 1, 3]
73 | Output: translation_loss
74 | """
75 |
76 | scale_t = torch.linalg.norm(t, dim=-1)
77 | scale_tgt = torch.linalg.norm(tgt, dim=-1)
78 |
79 | cosine = (t @ tgt.transpose(1, 2)).squeeze(-1) / (scale_t * scale_tgt + 1e-6)
80 | cosine = torch.clip(cosine, -0.99999, 0.99999) # handle numerical errors and NaNs
81 | t_ang_err = torch.acos(cosine)
82 | t_ang_err = torch.minimum(t_ang_err, math.pi - t_ang_err)
83 | return torch.abs(t_ang_err - torch.zeros_like(t_ang_err)), t_ang_err
84 |
85 | def trans_l1_loss(t, tgt):
86 | """Computes L1 loss for translation vector
87 | Input:
88 | t - estimated translation vector [B, 1, 3]
89 | tgt - ground-truth translation vector [B, 1, 3]
90 | Output: translation_loss
91 | """
92 |
93 | return torch.abs(t - tgt).sum(-1)
94 |
95 | def rot_angle_loss(R, Rgt):
96 | """
97 | Computes rotation loss using L1 error of residual rotation angle [radians]
98 | Input:
99 | R - estimated rotation matrix [B, 3, 3]
100 | Rgt - groundtruth rotation matrix [B, 3, 3]
101 | Output: rotation_loss
102 | """
103 |
104 | residual = R.transpose(1, 2) @ Rgt
105 | trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
106 | cosine = (trace - 1) / 2
107 | cosine = torch.clip(cosine, -0.99999, 0.99999) # handle numerical errors and NaNs
108 | R_err = torch.acos(cosine)
109 | loss = torch.abs(R_err - torch.zeros_like(R_err)).unsqueeze(-1)
110 | return loss, R_err
111 |
112 |
113 | def to_homogeneous_torch_batched(u_xys: torch.Tensor):
114 | batch_size, _, num_pts = u_xys.shape
115 | ones = torch.ones((batch_size, 1, num_pts)).float().to(u_xys.device)
116 | u_xyhs = torch.concat([u_xys, ones], dim=1)
117 | return u_xyhs
118 |
119 |
120 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/loss/solvers.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def weighted_procrustes(A, B, w=None, use_weights=True, use_mask=False, eps=1e-16, check_rank=True):
4 | """
5 | X: torch tensor B x N x 3
6 | Y: torch tensor B x N x 3
7 | w: torch tensor B x N
8 | """
9 | # https://ieeexplore.ieee.org/document/88573
10 | # https://github.com/chrischoy/DeepGlobalRegistration/blob/master/core/registration.py#L160
11 | # Refer to Mapfree procrustes
12 | assert len(A) == len(B)
13 | if use_weights:
14 | W1 = torch.abs(w).sum(1, keepdim=True)
15 | w_norm = (w / (W1 + eps)).unsqueeze(-1)
16 | a_mean = (w_norm * A).sum(1, keepdim=True)
17 | b_mean = (w_norm * B).sum(1, keepdim=True)
18 | # print('Possible ERROR:')
19 | # print('check this: Sxy = (Y - muy).t().mm(w_norm * (X - mux)).cpu().double(). Repo DeepGlobalRegistration')
20 |
21 | A_c = A - a_mean
22 | B_c = B - b_mean
23 |
24 | if use_mask:
25 | # Covariance matrix
26 | H = A_c.transpose(1, 2) @ (w.unsqueeze(-1) * B_c)
27 | else:
28 | # Covariance matrix
29 | H = A_c.transpose(1, 2) @ (w_norm.unsqueeze(-1) * B_c)
30 |
31 | else:
32 | a_mean = A.mean(axis=1, keepdim=True)
33 | b_mean = B.mean(axis=1, keepdim=True)
34 |
35 | A_c = A - a_mean
36 | B_c = B - b_mean
37 |
38 | # Covariance matrix
39 | H = A_c.transpose(1, 2) @ B_c
40 |
41 | if check_rank:
42 | if (torch.linalg.matrix_rank(H) == 1).sum() > 0:
43 | return None, None, False
44 |
45 | U, S, V = torch.svd(H)
46 | # Fixes orientation such that Det(R) = + 1
47 | Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device)
48 | Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2)))
49 | # Rotation matrix
50 | R = V @ Z @ U.transpose(1, 2)
51 | # Translation vector
52 | t = b_mean - a_mean @ R.transpose(1, 2)
53 |
54 | return R, t, True
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/mickey_extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from lib.models.MicKey.modules.DINO_modules.dinov2 import vit_large
4 | from lib.models.MicKey.modules.att_layers.transformer import Transformer_self_att
5 | from lib.models.MicKey.modules.utils.extractor_utils import desc_l2norm, BasicBlock
6 |
7 | class MicKey_Extractor(nn.Module):
8 | def __init__(self, cfg, dinov2_weights=None):
9 | super().__init__()
10 |
11 | # Define DINOv2 extractor
12 | self.dino_channels = cfg['DINOV2']['CHANNEL_DIM']
13 | self.dino_downfactor = cfg['DINOV2']['DOWN_FACTOR']
14 | if dinov2_weights is None:
15 | dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/"
16 | "dinov2_vitl14/dinov2_vitl14_pretrain.pth",
17 | map_location="cpu")
18 | vit_kwargs = dict(img_size= 518,
19 | patch_size= 14,
20 | init_values = 1.0,
21 | ffn_layer = "mlp",
22 | block_chunks = 0,
23 | )
24 |
25 | self.dinov2_vitl14 = vit_large(**vit_kwargs)
26 | self.dinov2_vitl14.load_state_dict(dinov2_weights)
27 | self.dinov2_vitl14.requires_grad_(False)
28 | self.dinov2_vitl14.eval()
29 |
30 | # Define whether DINOv2 runs on float16 or float32
31 | if cfg['DINOV2']['FLOAT16']:
32 | self.amp_dtype = torch.float16
33 | self.dinov2_vitl14.to(self.amp_dtype)
34 | else:
35 | self.amp_dtype = torch.float32
36 |
37 | # Define MicKey's heads
38 | self.depth_head = DeepResBlock_depth(cfg)
39 | self.det_offset = DeepResBlock_offset(cfg)
40 | self.dsc_head = DeepResBlock_desc(cfg)
41 | self.det_head = DeepResBlock_det(cfg)
42 |
43 | def forward(self, x):
44 |
45 | B, C, H, W = x.shape
46 | x = x[:, :, :self.dino_downfactor * (H//self.dino_downfactor), :self.dino_downfactor * (W//self.dino_downfactor)]
47 |
48 | with torch.no_grad():
49 | dinov2_features = self.dinov2_vitl14.forward_features(x.to(self.amp_dtype))
50 | dinov2_features = dinov2_features['x_norm_patchtokens'].permute(0, 2, 1).\
51 | reshape(B, self.dino_channels, H // self.dino_downfactor, W // self.dino_downfactor).float()
52 |
53 | scrs = self.det_head(dinov2_features)
54 | kpts = self.det_offset(dinov2_features)
55 | depths = self.depth_head(dinov2_features)
56 | dscs = self.dsc_head(dinov2_features)
57 |
58 | return kpts, depths, scrs, dscs
59 |
60 | def train(self, mode: bool = True):
61 | self.dsc_head.train(mode)
62 | self.depth_head.train(mode)
63 | self.det_offset.train(mode)
64 | self.det_head.train(mode)
65 |
66 |
67 | class DeepResBlock_det(torch.nn.Module):
68 | def __init__(self, config, padding_mode = 'zeros'):
69 | super().__init__()
70 |
71 | bn = config['KP_HEADS']['BN']
72 | in_channels = config['DINOV2']['CHANNEL_DIM']
73 | block_dims = config['KP_HEADS']['BLOCKS_DIM']
74 | add_posEnc = config['KP_HEADS']['POS_ENCODING']
75 |
76 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode)
77 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode)
78 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode)
79 | self.resblock4 = BasicBlock(block_dims[2], block_dims[3], stride=1, bn=bn, padding_mode=padding_mode)
80 |
81 | self.score = nn.Conv2d(block_dims[3], 1, kernel_size=1, stride=1, padding=0, bias=False)
82 |
83 | self.use_softmax = config['KP_HEADS']['USE_SOFTMAX']
84 | self.sigmoid = torch.nn.Sigmoid()
85 | self.logsigmoid = torch.nn.LogSigmoid()
86 | self.softmax = torch.nn.Softmax(dim=-1)
87 |
88 | # Allow more exploration with reinforce algorithm
89 | self.tmp_softmax = 100
90 |
91 | self.eps = nn.Parameter(torch.tensor(1e-16), requires_grad=False)
92 | self.offset_par1 = nn.Parameter(torch.tensor(0.5), requires_grad=False)
93 | self.offset_par2 = nn.Parameter(torch.tensor(2.), requires_grad=False)
94 | self.ones_kernel = nn.Parameter(torch.ones((1, 1, 3, 3)), requires_grad=False)
95 |
96 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc)
97 |
98 | def remove_borders(self, score_map: torch.Tensor, borders: int):
99 | '''
100 | It removes the borders of the image to avoid detections on the corners
101 | '''
102 | shape = score_map.shape
103 | mask = torch.ones_like(score_map)
104 |
105 | mask[:, :, 0:borders, :] = 0
106 | mask[:, :, :, 0:borders] = 0
107 | mask[:, :, shape[2] - borders:shape[2], :] = 0
108 | mask[:, :, :, shape[3] - borders:shape[3]] = 0
109 |
110 | return mask * score_map
111 |
112 | def remove_brd_and_softmax(self, scores, borders):
113 |
114 | B = scores.shape[0]
115 |
116 | scores = scores - (scores.view(B, -1).mean(-1).view(B, 1, 1, 1) + self.eps).detach()
117 | exp_scores = torch.exp(scores / self.tmp_softmax)
118 |
119 | # remove borders
120 | exp_scores = self.remove_borders(exp_scores, borders=borders)
121 |
122 | # apply softmax
123 | sum_scores = exp_scores.sum(-1).sum(-1).view(B, 1, 1, 1)
124 | return exp_scores / (sum_scores + self.eps)
125 |
126 | def forward(self, feature_volume):
127 |
128 | x = self.resblock1(feature_volume)
129 | x = self.resblock2(x)
130 | x = self.resblock3(x)
131 | x = self.att_layer(x)
132 | x = self.resblock4(x)
133 |
134 | # Predict xy scores
135 | scores = self.score(x)
136 |
137 | if self.use_softmax:
138 | scores = self.remove_brd_and_softmax(scores, 3)
139 | else:
140 | scores = self.remove_borders(self.sigmoid(scores), borders=3)
141 |
142 | return scores
143 |
144 |
145 | class DeepResBlock_offset(torch.nn.Module):
146 | def __init__(self, config, padding_mode = 'zeros'):
147 | super().__init__()
148 |
149 | bn = config['KP_HEADS']['BN']
150 | in_channels = config['DINOV2']['CHANNEL_DIM']
151 | block_dims = config['KP_HEADS']['BLOCKS_DIM']
152 | add_posEnc = config['KP_HEADS']['POS_ENCODING']
153 | self.sigmoid = torch.nn.Sigmoid()
154 |
155 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode)
156 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode)
157 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode)
158 | self.resblock4 = BasicBlock(block_dims[2], block_dims[3], stride=1, bn=bn, padding_mode=padding_mode)
159 |
160 | self.xy_offset = nn.Conv2d(block_dims[3], 2, kernel_size=1, stride=1, padding=0, bias=False)
161 |
162 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc)
163 |
164 | def forward(self, feature_volume):
165 |
166 | x = self.resblock1(feature_volume)
167 | x = self.resblock2(x)
168 | x = self.resblock3(x)
169 | x = self.att_layer(x)
170 | x = self.resblock4(x)
171 |
172 | # Predict xy offsets
173 | xy_offsets = self.xy_offset(x)
174 |
175 | # Offset goes from 0 to 1
176 | xy_offsets = self.sigmoid(xy_offsets)
177 |
178 | return xy_offsets
179 |
180 |
181 | class DeepResBlock_depth(torch.nn.Module):
182 | def __init__(self, config, padding_mode = 'zeros'):
183 | super().__init__()
184 |
185 | bn = config['KP_HEADS']['BN']
186 | in_channels = config['DINOV2']['CHANNEL_DIM']
187 | block_dims = config['KP_HEADS']['BLOCKS_DIM']
188 | add_posEnc = config['KP_HEADS']['POS_ENCODING']
189 |
190 | self.use_depth_sigmoid = config['KP_HEADS']['USE_DEPTHSIGMOID']
191 | self.max_depth = config['KP_HEADS']['MAX_DEPTH']
192 | self.sigmoid = torch.nn.Sigmoid()
193 |
194 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode)
195 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode)
196 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode)
197 | self.resblock4 = BasicBlock(block_dims[2], block_dims[3], stride=1, bn=bn, padding_mode=padding_mode)
198 |
199 | self.depth = nn.Conv2d(block_dims[3], 1, kernel_size=1, stride=1, padding=0, bias=False)
200 |
201 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc)
202 |
203 | def forward(self, feature_volume):
204 |
205 | x = self.resblock1(feature_volume)
206 | x = self.resblock2(x)
207 | x = self.resblock3(x)
208 | x = self.att_layer(x)
209 | x = self.resblock4(x)
210 |
211 | # Predict xy depths
212 | # depths = torch.clip(self.depth(x), min=1e-3, max=500)
213 | if self.use_depth_sigmoid:
214 | depths = self.max_depth * self.sigmoid(self.depth(x))
215 | else:
216 | depths = self.depth(x)
217 |
218 | return depths
219 |
220 |
221 | class DeepResBlock_desc(torch.nn.Module):
222 | def __init__(self, config, padding_mode = 'zeros'):
223 | super().__init__()
224 |
225 | bn = config['KP_HEADS']['BN']
226 | last_dim = config['DSC_HEAD']['LAST_DIM']
227 | in_channels = config['DINOV2']['CHANNEL_DIM']
228 | block_dims = config['KP_HEADS']['BLOCKS_DIM']
229 | add_posEnc = config['DSC_HEAD']['POS_ENCODING']
230 | self.norm_desc = config['DSC_HEAD']['NORM_DSC']
231 |
232 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode)
233 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode)
234 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode)
235 | self.resblock4 = BasicBlock(block_dims[2], last_dim, stride=1, bn=bn, padding_mode=padding_mode)
236 |
237 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc)
238 |
239 |
240 | def forward(self, feature_volume):
241 |
242 | x = self.resblock1(feature_volume)
243 | x = self.resblock2(x)
244 | x = self.resblock3(x)
245 | x = self.att_layer(x)
246 | x = self.resblock4(x, relu=False)
247 |
248 | if self.norm_desc:
249 | x = desc_l2norm(x)
250 |
251 | return x
252 |
253 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/utils/extractor_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def desc_l2norm(desc: torch.Tensor):
7 | '''descriptors with shape NxC or NxCxHxW'''
8 | eps_l2_norm = 1e-10
9 | desc = desc / desc.pow(2).sum(dim=1, keepdim=True).add(eps_l2_norm).pow(0.5)
10 | return desc
11 |
12 | class BasicBlock(nn.Module):
13 | '''Pre-activation version of the BasicBlock.'''
14 | expansion = 1
15 |
16 | def __init__(self, in_planes, planes, stride=1, bn=True, padding_mode='zeros'):
17 | super(BasicBlock, self).__init__()
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, padding_mode=padding_mode)
19 | self.bn1 = nn.BatchNorm2d(planes) if bn else nn.Identity()
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode)
21 | self.bn2 = nn.BatchNorm2d(planes) if bn else nn.Identity()
22 |
23 | if stride != 1 or in_planes != self.expansion*planes:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
26 | )
27 |
28 | def forward(self, x, relu=True):
29 | shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += shortcut
33 | if relu:
34 | out = F.relu(out)
35 | return out
36 |
--------------------------------------------------------------------------------
/lib/models/MicKey/modules/utils/feature_matcher.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def arange_like(x, dim: int):
6 | return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1
7 |
8 | class featureMatcher(nn.Module):
9 | def __init__(self, cfg):
10 | super().__init__()
11 |
12 | if cfg['TYPE'] == 'DualSoftmax':
13 | self.matching_mat = dualSoftmax(cfg['DUAL_SOFTMAX'])
14 | elif cfg['TYPE'] == 'Sinkhorn':
15 | self.matching_mat = sinkhorn(cfg['SINKHORN'])
16 | else:
17 | print('[ERROR]: feature matcher not recognized')
18 |
19 | def get_matches_list(self, scores, min_conf=0.0):
20 |
21 | # Supports batch_size = 1
22 |
23 | # Get the matches with score above "match_threshold".
24 | max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
25 | indices0, indices1 = max0.indices, max1.indices
26 | mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
27 | # zero = scores.new_tensor(0)
28 | zero = torch.tensor(0).to(scores.device).float()
29 | mscores0 = torch.where(mutual0, max0.values.exp(), zero)
30 | valid0 = mutual0 & (mscores0 > min_conf)
31 | # indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
32 | minus_one = torch.tensor(-1).to(scores.device)
33 | indices0 = torch.where(valid0, indices0, minus_one)
34 |
35 | valid = indices0 > -1
36 |
37 | idx0 = arange_like(indices0, 1)[valid[0]].unsqueeze(1)
38 | idx1 = indices0[valid].unsqueeze(1)
39 |
40 | matches = torch.concat([idx0, idx1], dim=1)
41 |
42 | batch_idx = torch.tile(torch.arange(1).view(1, 1), [1, len(matches)]).reshape(-1)
43 | scores_matches = scores[batch_idx, idx0[:, 0], idx1[:, 0]]
44 | _, idx_sorted = torch.sort(scores_matches, descending=True)
45 |
46 | return matches[idx_sorted]
47 |
48 | def forward(self, kpt0, dsc0, kpt1, dsc1):
49 |
50 | scores = self.matching_mat(dsc0, dsc1)
51 |
52 | return scores
53 |
54 | class dualSoftmax(nn.Module):
55 | def __init__(self, cfg):
56 | super().__init__()
57 |
58 | self.temperature = cfg['TEMPERATURE']
59 | self.use_dustbin = False
60 | if cfg['USE_DUSTBIN']:
61 | self.dustbin_score = nn.Parameter(torch.tensor(1.))
62 | self.use_dustbin = True
63 |
64 | def forward(self, dsc0, dsc1):
65 | scores = torch.matmul(dsc0.transpose(1, 2).contiguous(), dsc1) / self.temperature
66 |
67 | if self.use_dustbin:
68 | b, m, n = scores.shape
69 |
70 | bins0 = self.dustbin_score.expand(b, m, 1)
71 | bins1 = self.dustbin_score.expand(b, 1, n)
72 | alpha = self.dustbin_score.expand(b, 1, 1)
73 |
74 | couplings = torch.cat([torch.cat([scores, bins0], -1),
75 | torch.cat([bins1, alpha], -1)], 1)
76 |
77 | couplings = F.softmax(couplings, 1) * F.softmax(couplings, 2)
78 | scores = couplings[:, :-1, :-1]
79 |
80 | else:
81 | scores = F.softmax(scores, 1) * F.softmax(scores, 2)
82 |
83 | return scores
84 |
85 | class sinkhorn(nn.Module):
86 | def __init__(self, cfg):
87 | super().__init__()
88 |
89 | self.dustbin_score = nn.Parameter(torch.tensor(cfg['DUSTBIN_SCORE_INIT']))
90 | self.sinkhorn_iterations = cfg['NUM_IT']
91 | self.descriptor_dim = 128
92 |
93 | def log_sinkhorn_iterations(self, Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor,
94 | iters: int) -> torch.Tensor:
95 | """ Perform Sinkhorn Normalization in Log-space for stability"""
96 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu)
97 | for _ in range(iters):
98 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2)
99 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1)
100 | return Z + u.unsqueeze(2) + v.unsqueeze(1)
101 |
102 | def log_optimal_transport(self, scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor:
103 | """ Perform Differentiable Optimal Transport in Log-space for stability"""
104 | b, m, n = scores.shape
105 | # one = scores.new_tensor(1)
106 | one = torch.ones((), device=scores.device)
107 | ms, ns = (m * one).to(scores), (n * one).to(scores)
108 |
109 | bins0 = alpha.expand(b, m, 1)
110 | bins1 = alpha.expand(b, 1, n)
111 | alpha = alpha.expand(b, 1, 1)
112 |
113 | couplings = torch.cat([torch.cat([scores, bins0], -1),
114 | torch.cat([bins1, alpha], -1)], 1)
115 |
116 | norm = - (ms + ns).log()
117 | log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm])
118 | log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm])
119 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1)
120 |
121 | Z = self.log_sinkhorn_iterations(couplings, log_mu, log_nu, iters)
122 | Z = Z - norm # multiply probabilities by M+N
123 | return Z.exp()
124 |
125 | def forward(self, dsc0, dsc1, tmp):
126 |
127 | # Compute matching descriptor distance.
128 | scores = torch.einsum('bdn,bdm->bnm', dsc0, dsc1)
129 | scores = scores / self.descriptor_dim**.5
130 |
131 | # scores = torch.matmul(dsc0.transpose(1, 2).contiguous(), dsc1)
132 |
133 | scores = self.log_optimal_transport(
134 | scores, self.dustbin_score,
135 | iters=self.sinkhorn_iterations)
136 |
137 | return scores[:, :-1, :-1]
138 |
--------------------------------------------------------------------------------
/lib/models/builder.py:
--------------------------------------------------------------------------------
1 | import torch.cuda
2 |
3 | from lib.models.MicKey.compute_pose import MickeyRelativePose
4 |
5 | def build_model(cfg, checkpoint=''):
6 |
7 | if cfg.MODEL == 'MicKey':
8 |
9 | model = MickeyRelativePose(cfg)
10 |
11 | checkpoint_loaded = torch.load(checkpoint)
12 | model.on_load_checkpoint(checkpoint_loaded)
13 | model.load_state_dict(checkpoint_loaded['state_dict'])
14 |
15 | if torch.cuda.is_available():
16 | model = model.cuda()
17 | model.eval()
18 | return model
19 | else:
20 | raise NotImplementedError()
21 |
--------------------------------------------------------------------------------
/lib/utils/data.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def data_to_model_device(data, model):
4 | '''Move all tensors in data dictionary to the same device as model'''
5 |
6 | try:
7 | device = next(model.parameters()).device
8 | except:
9 | # in case the model has no parameters (baseline models)
10 | device = 'cpu'
11 |
12 | for k, v in data.items():
13 | if torch.is_tensor(v):
14 | data[k] = v.to(device)
15 |
16 | return data
17 |
--------------------------------------------------------------------------------
/lib/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc
2 |
3 | import torch
4 | import numpy as np
5 | from collections import defaultdict
6 | from lib.benchmarks.reprojection import get_grid_multipleheight
7 | from lib.models.MicKey.modules.utils.training_utils import project_2d
8 |
9 | # global variable, avoids creating it again
10 | eye_coords_glob = get_grid_multipleheight()
11 |
12 | def pose_error_torch(R, t, Tgt, reduce=None):
13 | """Compute angular, scale and euclidean error of translation vector (metric). Compute angular rotation error."""
14 |
15 | Rgt = Tgt[:, :3, :3] # [B, 3, 3]
16 | tgt = Tgt[:, :3, 3:].transpose(1, 2) # [B, 1, 3]
17 |
18 | scale_t = torch.linalg.norm(t, dim=-1)
19 | scale_tgt = torch.linalg.norm(tgt, dim=-1)
20 |
21 | cosine = (t @ tgt.transpose(1, 2)).squeeze(-1) / (scale_t * scale_tgt + 1e-9)
22 | cosine = torch.clip(cosine, -1.0, 1.0) # handle numerical errors
23 | t_ang_err = torch.rad2deg(torch.acos(cosine))
24 | t_ang_err = torch.minimum(t_ang_err, 180 - t_ang_err)
25 |
26 | t_scale_err = scale_t / scale_tgt
27 | t_scale_err_sym = torch.maximum(scale_t / scale_tgt, scale_tgt / scale_t)
28 | t_euclidean_err = torch.linalg.norm(t - tgt, dim=-1)
29 |
30 | residual = R.transpose(1, 2) @ Rgt
31 | trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1)
32 | cosine = (trace - 1) / 2
33 | cosine = torch.clip(cosine, -1., 1.) # handle numerical errors
34 | R_err = torch.rad2deg(torch.acos(cosine))
35 |
36 | if reduce is None:
37 | def fn(x): return x
38 | elif reduce == 'mean':
39 | fn = torch.mean
40 | elif reduce == 'median':
41 | fn = torch.median
42 |
43 | t_ang_err = fn(t_ang_err)
44 | t_scale_err = fn(t_scale_err)
45 | t_euclidean_err = fn(t_euclidean_err)
46 | R_err = fn(R_err)
47 |
48 | errors = {'t_err_ang': t_ang_err,
49 | 't_err_scale': t_scale_err,
50 | 't_err_scale_sym': t_scale_err_sym,
51 | 't_err_euc': t_euclidean_err,
52 | 'R_err': R_err}
53 | return errors
54 |
55 |
56 | def vcre_loss(R, t, Tgt, K0, H=720):
57 | """Compute Virtual Correspondences Reprojection Error in torch (with batches)."""
58 |
59 | B = R.shape[0]
60 | Rgt = Tgt[:, :3, :3] # [B, 3, 3]
61 | tgt = Tgt[:, :3, 3:].transpose(1, 2) # [B, 1, 3]
62 |
63 | eye_coords = torch.from_numpy(eye_coords_glob).unsqueeze(0)[:, :, :3].to(R.device).float()
64 | eye_coords = torch.tile(eye_coords, [B, 1, 1])
65 |
66 | # obtain ground-truth position of projected points
67 | uv_gt = project_2d(eye_coords, K0)
68 |
69 | # Avoid breaking gradients due to inplace operation
70 | eye_coord_tmp = (R @ eye_coords.transpose(2, 1) + t.transpose(2, 1))
71 | eyes_residual = (Rgt.transpose(2, 1) @ eye_coord_tmp -1 * Rgt.transpose(2, 1) @ tgt.transpose(2, 1)).transpose(2, 1)
72 |
73 | uv_pred = project_2d(eyes_residual, K0)
74 |
75 | uv_gt = torch.clip(uv_gt, 0, H)
76 | uv_pred = torch.clip(uv_pred, 0, H)
77 |
78 | repr_err = ((((uv_gt - uv_pred) ** 2.).sum(-1) + 1e-6) ** 0.5).mean(-1).view(B, 1)
79 |
80 | return repr_err
81 |
82 |
83 | def vcre_torch(R, t, Tgt, K0, reduce=None, H=720, W=540):
84 | """Compute Virtual Correspondences Reprojection Error in torch (with batches)."""
85 |
86 | B = R.shape[0]
87 | Rgt = Tgt[:, :3, :3] # [B, 3, 3]
88 | tgt = Tgt[:, :3, 3:].transpose(1, 2) # [B, 1, 3]
89 |
90 | eye_coords = torch.from_numpy(eye_coords_glob).unsqueeze(0).to(R.device).float()
91 | eye_coords = torch.tile(eye_coords, [B, 1, 1])
92 |
93 | # obtain ground-truth position of projected points
94 | uv_gt = project_2d(eye_coords[:, :, :3], K0)
95 |
96 | # residual transformation
97 | cam2w_est = torch.tile(torch.eye(4).view(1, 4, 4), [B, 1, 1]).to(R.device).float()
98 | cam2w_est[:, :3, :3] = R
99 | cam2w_est[:, :3, -1] = t[:, 0]
100 |
101 | cam2w_gt = torch.tile(torch.eye(4).view(1, 4, 4), [B, 1, 1]).to(R.device).float()
102 | cam2w_gt[:, :3, :3] = Rgt
103 | cam2w_gt[:, :3, -1] = tgt[:, 0]
104 |
105 | # residual reprojection
106 | eyes_residual = (torch.linalg.inv(cam2w_gt) @ cam2w_est @ eye_coords.transpose(2, 1)).transpose(2, 1)
107 | uv_pred = project_2d(eyes_residual[:, :, :3], K0)
108 |
109 | uv_gt[:, :, 0], uv_pred[:, :, 0] = torch.clip(uv_gt[:, :, 0], 0, W), torch.clip(uv_pred[:, :, 0], 0, W)
110 | uv_gt[:, :, 1], uv_pred[:, :, 1] = torch.clip(uv_gt[:, :, 1], 0, H), torch.clip(uv_pred[:, :, 1], 0, H)
111 |
112 | repr_err = ((((uv_gt - uv_pred) ** 2.).sum(-1) + 1e-6) ** 0.5).mean(-1).view(B, 1)
113 |
114 | if reduce is None:
115 | def fn(x): return x
116 | elif reduce == 'mean':
117 | fn = torch.mean
118 | elif reduce == 'median':
119 | fn = torch.median
120 |
121 | repr_err = fn(repr_err)
122 |
123 | errors = {'repr_err': repr_err}
124 |
125 | return errors
126 |
127 |
128 |
129 | def error_auc(errors, thresholds):
130 | """
131 | Args:
132 | errors (list): [N,]
133 | thresholds (list)
134 | """
135 | errors = np.nan_to_num(errors, nan=float('inf')) # convert nans to inf
136 | errors = [0] + sorted(list(errors))
137 | recall = list(np.linspace(0, 1, len(errors)))
138 |
139 | aucs = []
140 | for thr in thresholds:
141 | last_index = np.searchsorted(errors, thr)
142 | y = recall[:last_index] + [recall[last_index-1]]
143 | x = errors[:last_index] + [thr]
144 | aucs.append(np.trapz(y, x) / thr)
145 |
146 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
147 |
148 |
149 | def ecdf(x):
150 | """Get Empirical Cumulative Distribution Function (ECDF) given samples x [N,]"""
151 | cd = np.linspace(0, 1, x.shape[0])
152 | v = np.sort(x)
153 | return v, cd
154 |
155 |
156 | def print_auc_table(agg_metrics):
157 | pose_error = np.maximum(agg_metrics['R_err'], agg_metrics['t_err_ang'])
158 | auc_pose = error_auc(pose_error, (5, 10, 20))
159 | print('Pose error AUC @ 5/10/20deg: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_pose.values()))
160 |
161 | auc_rotation = error_auc(agg_metrics['R_err'], (5, 10, 20))
162 | print('Rotation error AUC @ 5/10/20deg: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_rotation.values()))
163 |
164 | auc_translation_ang = error_auc(agg_metrics['t_err_ang'], (5, 10, 20))
165 | print(
166 | 'Translation angular error AUC @ 5/10/20deg: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_translation_ang.values()))
167 |
168 | auc_translation_euc = error_auc(agg_metrics['t_err_euc'], (0.1, 0.5, 1))
169 | print(
170 | 'Translation Euclidean error AUC @ 0.1/0.5/1m: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_translation_euc.values()))
171 |
172 |
173 | def precision(agg_metrics, rot_threshold, trans_threshold):
174 | '''Provides ratio of samples with rotation error < rot_threshold AND translation error < trans_threshold'''
175 | mask_rot = agg_metrics['R_err'] <= rot_threshold
176 | mask_trans = agg_metrics['t_err_euc'] <= trans_threshold
177 | recall = (mask_rot * mask_trans).mean()
178 | return recall
179 |
180 |
181 | def A_metrics(t_scale_err_sym):
182 | """Returns A1/A2/A3 metrics of translation vector norm given the "symmetric" scale error
183 | where
184 | t_scale_err_sym = torch.maximum((t_norm_gt / t_norm_pred), (t_norm_pred / t_norm_gt))
185 | """
186 |
187 | if not torch.is_tensor(t_scale_err_sym):
188 | t_scale_err_sym = torch.from_numpy(t_scale_err_sym)
189 |
190 | thresh = t_scale_err_sym
191 | a1 = (thresh < 1.25).float().mean()
192 | a2 = (thresh < 1.25 ** 2).float().mean()
193 | a3 = (thresh < 1.25 ** 3).float().mean()
194 | return a1, a2, a3
195 |
196 |
197 | class MetricsAccumulator:
198 | """Accumulates metrics and aggregates them when requested"""
199 |
200 | def __init__(self):
201 | self.data = defaultdict(list)
202 |
203 | def accumulate(self, data):
204 | for key, value in data.items():
205 | self.data[key].append(value)
206 |
207 | def aggregate(self):
208 | res = dict()
209 | for key in self.data.keys():
210 | res[key] = torch.cat(self.data[key]).view(-1).cpu().numpy()
211 | return res
212 |
--------------------------------------------------------------------------------
/resources/environment.yml:
--------------------------------------------------------------------------------
1 | name: mickey
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - python=3.8.17
7 | - pip=23.2.1
8 | - pip:
9 | - einops==0.6.1
10 | - lazy-loader==0.3
11 | - lightning-utilities==0.9.0
12 | - matplotlib==3.7.2
13 | - numpy==1.24.4
14 | - omegaconf==2.3.0
15 | - open3d==0.17.0
16 | - opencv-python==4.8.0.74
17 | - protobuf==4.23.4
18 | - pytorch-lightning==2.0.6
19 | - tensorboard==2.13.0
20 | - tensorboard-data-server==0.7.1
21 | - timm==0.6.7
22 | - torch==2.0.1
23 | - torchmetrics==1.0.2
24 | - torchvision==0.15.2
25 | - tqdm==4.65.1
26 | - transforms3d==0.4.1
27 | - xformers==0.0.20
28 | - yacs==0.1.8
29 | - pillow==10.0.0
30 | - scipy==1.10.1
31 | - pyrender==0.1.45
32 | - trimesh==4.0.2
33 |
--------------------------------------------------------------------------------
/resources/teaser_mickey.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nianticlabs/mickey/2391be8a35491e7b43481c069f5dab65030839b9/resources/teaser_mickey.png
--------------------------------------------------------------------------------
/submission.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | from collections import defaultdict
4 | from dataclasses import dataclass
5 | from zipfile import ZipFile
6 |
7 | import torch
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from config.default import cfg
12 | from lib.datasets.datamodules import DataModule
13 | from lib.models.builder import build_model
14 | from lib.utils.data import data_to_model_device
15 | from transforms3d.quaternions import mat2quat
16 |
17 | @dataclass
18 | class Pose:
19 | image_name: str
20 | q: np.ndarray
21 | t: np.ndarray
22 | inliers: float
23 |
24 | def __str__(self) -> str:
25 | formatter = {'float': lambda v: f'{v:.6f}'}
26 | max_line_width = 1000
27 | q_str = np.array2string(self.q, formatter=formatter, max_line_width=max_line_width)[1:-1]
28 | t_str = np.array2string(self.t, formatter=formatter, max_line_width=max_line_width)[1:-1]
29 | return f'{self.image_name} {q_str} {t_str} {self.inliers}'
30 |
31 |
32 | def predict(loader, model):
33 | results_dict = defaultdict(list)
34 |
35 | for data in tqdm(loader):
36 |
37 | # run inference
38 | data = data_to_model_device(data, model)
39 | with torch.no_grad():
40 | R_batched, t_batched = model(data)
41 |
42 | for i_batch in range(len(data['scene_id'])):
43 | R = R_batched[i_batch].unsqueeze(0).detach().cpu().numpy()
44 | t = t_batched[i_batch].reshape(-1).detach().cpu().numpy()
45 | inliers = data['inliers'][i_batch].item()
46 |
47 | scene = data['scene_id'][i_batch]
48 | query_img = data['pair_names'][1][i_batch]
49 |
50 | # ignore frames without poses (e.g. not enough feature matches)
51 | if np.isnan(R).any() or np.isnan(t).any() or np.isinf(t).any():
52 | continue
53 |
54 | # populate results_dict
55 | estimated_pose = Pose(image_name=query_img,
56 | q=mat2quat(R).reshape(-1),
57 | t=t.reshape(-1),
58 | inliers=inliers)
59 | results_dict[scene].append(estimated_pose)
60 |
61 | return results_dict
62 |
63 |
64 | def save_submission(results_dict: dict, output_path: Path):
65 | with ZipFile(output_path, 'w') as zip:
66 | for scene, poses in results_dict.items():
67 | poses_str = '\n'.join((str(pose) for pose in poses))
68 | zip.writestr(f'pose_{scene}.txt', poses_str.encode('utf-8'))
69 |
70 |
71 | def eval(args):
72 | # Load configs
73 | cfg.merge_from_file('config/datasets/mapfree.yaml')
74 | cfg.merge_from_file(args.config)
75 |
76 | # Create dataloader
77 | if args.split == 'test':
78 | cfg.TRAINING.BATCH_SIZE = 8
79 | cfg.TRAINING.NUM_WORKERS = 8
80 | dataloader = DataModule(cfg, drop_last_val=False).test_dataloader()
81 | elif args.split == 'val':
82 | cfg.TRAINING.BATCH_SIZE = 12
83 | cfg.TRAINING.NUM_WORKERS = 8
84 | dataloader = DataModule(cfg, drop_last_val=False).val_dataloader()
85 | else:
86 | raise NotImplemented(f'Invalid split: {args.split}')
87 |
88 | # Create model
89 | model = build_model(cfg, args.checkpoint)
90 |
91 | # Get predictions from model
92 | results_dict = predict(dataloader, model)
93 |
94 | # Save predictions to txt per scene within zip
95 | args.output_root.mkdir(parents=True, exist_ok=True)
96 | save_submission(results_dict, args.output_root / 'submission.zip')
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--config', help='path to config file')
101 | parser.add_argument('--checkpoint',
102 | help='path to model checkpoint (models with learned parameters)', default='')
103 | parser.add_argument('--output_root', '-o', type=Path, default=Path('results/'))
104 | parser.add_argument('--split', choices=('val', 'test'), default='test',
105 | help='Dataset split to use for evaluation. Choose from test or val. Default: test')
106 | args = parser.parse_args()
107 | eval(args)
108 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | # do this before importing numpy! (doing it right up here in case numpy is dependency of e.g. json)
4 | os.environ["MKL_NUM_THREADS"] = "1" # noqa: E402
5 | os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa: E402
6 | os.environ["OMP_NUM_THREADS"] = "1" # noqa: E402
7 | os.environ["OPENBLAS_NUM_THREADS"] = "1" # noqa: E402
8 |
9 | import pytorch_lightning as pl
10 | import torch
11 | from pytorch_lightning.loggers import TensorBoardLogger
12 |
13 | from config.default import cfg
14 | from lib.datasets.datamodules import DataModuleTraining
15 | from lib.models.MicKey.model import MicKeyTrainingModel
16 | from lib.models.MicKey.modules.utils.training_utils import create_exp_name, create_result_dir
17 | import random
18 | import shutil
19 |
20 | def train_model(args):
21 |
22 | cfg.merge_from_file(args.dataset_config)
23 | cfg.merge_from_file(args.config)
24 |
25 | exp_name = create_exp_name(args.experiment, cfg)
26 | print('Start training of ' + exp_name)
27 |
28 | cfg.DATASET.SEED = random.randint(0, 1000000)
29 |
30 | model = MicKeyTrainingModel(cfg)
31 |
32 | checkpoint_vcre_callback = pl.callbacks.ModelCheckpoint(
33 | filename='{epoch}-best_vcre',
34 | save_last=True,
35 | save_top_k=1,
36 | verbose=True,
37 | monitor='val_vcre/auc_vcre',
38 | mode='max'
39 | )
40 |
41 | checkpoint_pose_callback = pl.callbacks.ModelCheckpoint(
42 | filename='{epoch}-best_pose',
43 | save_last=True,
44 | save_top_k=1,
45 | verbose=True,
46 | monitor='val_AUC_pose/auc_pose',
47 | mode='max'
48 | )
49 |
50 | epochend_callback = pl.callbacks.ModelCheckpoint(
51 | filename='e{epoch}-last',
52 | save_top_k=1,
53 | every_n_epochs=1,
54 | save_on_train_epoch_end=True
55 | )
56 |
57 | lr_monitoring_callback = pl.callbacks.LearningRateMonitor(logging_interval='step')
58 | logger = TensorBoardLogger(save_dir=args.path_weights, name=exp_name)
59 |
60 | trainer = pl.Trainer(devices=cfg.TRAINING.NUM_GPUS,
61 | log_every_n_steps=cfg.TRAINING.LOG_INTERVAL,
62 | val_check_interval=cfg.TRAINING.VAL_INTERVAL,
63 | limit_val_batches=cfg.TRAINING.VAL_BATCHES,
64 | max_epochs=cfg.TRAINING.EPOCHS,
65 | logger=logger,
66 | callbacks=[checkpoint_pose_callback, lr_monitoring_callback, epochend_callback, checkpoint_vcre_callback],
67 | num_sanity_val_steps=0,
68 | gradient_clip_val=cfg.TRAINING.GRAD_CLIP)
69 |
70 | datamodule_end = DataModuleTraining(cfg)
71 | print('Training with {:.2f}/{:.2f} image overlap'.format(cfg.DATASET.MIN_OVERLAP_SCORE, cfg.DATASET.MAX_OVERLAP_SCORE))
72 |
73 | create_result_dir(logger.log_dir + '/config.yaml')
74 | shutil.copyfile(args.config, logger.log_dir + '/config.yaml')
75 |
76 | if args.resume:
77 | ckpt_path = args.resume
78 | else:
79 | ckpt_path = None
80 |
81 | trainer.fit(model, datamodule_end, ckpt_path=ckpt_path)
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--config', help='path to config file', default='config/MicKey/curriculum_learning.yaml')
86 | parser.add_argument('--dataset_config', help='path to dataset config file', default='config/datasets/mapfree.yaml')
87 | parser.add_argument('--experiment', help='experiment name', default='MicKey_default')
88 | parser.add_argument('--path_weights', help='path to the directory to save the weights', default='weights/')
89 | parser.add_argument('--resume', help='resume from checkpoint path', default=None)
90 | args = parser.parse_args()
91 | train_model(args)
--------------------------------------------------------------------------------