├── LICENSE ├── README.md ├── benchmark ├── config.py ├── mapfree.py ├── metrics.py ├── reprojection.py ├── test_metrics.py └── utils.py ├── config ├── MicKey │ ├── curriculum_learning.yaml │ ├── curriculum_learning_warm_up.yaml │ ├── overlap_score.yaml │ └── overlap_score_warm_up.yaml ├── datasets │ └── mapfree.yaml └── default.py ├── data └── toy_example │ ├── im0.jpg │ ├── im1.jpg │ └── intrinsics.txt ├── demo_inference.py ├── lib ├── benchmarks │ ├── reprojection.py │ └── utils.py ├── datasets │ ├── datamodules.py │ ├── mapfree.py │ ├── sampler.py │ └── utils.py ├── models │ ├── MicKey │ │ ├── compute_pose.py │ │ ├── model.py │ │ └── modules │ │ │ ├── DINO_modules │ │ │ ├── dinov2.py │ │ │ └── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── attention.py │ │ │ │ ├── block.py │ │ │ │ ├── dino_head.py │ │ │ │ ├── drop_path.py │ │ │ │ ├── layer_scale.py │ │ │ │ ├── mlp.py │ │ │ │ ├── patch_embed.py │ │ │ │ └── swiglu_ffn.py │ │ │ ├── att_layers │ │ │ ├── attention.py │ │ │ ├── transformer.py │ │ │ └── transformer_utils.py │ │ │ ├── compute_correspondences.py │ │ │ ├── loss │ │ │ ├── loss_class.py │ │ │ ├── loss_utils.py │ │ │ └── solvers.py │ │ │ ├── mickey_extractor.py │ │ │ └── utils │ │ │ ├── extractor_utils.py │ │ │ ├── feature_matcher.py │ │ │ ├── probabilisticProcrustes.py │ │ │ └── training_utils.py │ └── builder.py └── utils │ ├── data.py │ ├── metrics.py │ └── visualization.py ├── resources ├── environment.yml └── teaser_mickey.png ├── submission.py └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright © Niantic, Inc. 2024. Patent Pending. 2 | 3 | All rights reserved. 4 | 5 | 6 | 7 | ======================================================================================= 8 | 9 | 10 | 11 | This Software is licensed under the terms of the following Matching 2D Images in 3D: 12 | Metric Relative Pose from Metric Correspondences license which allows for non-commercial 13 | use only. For any other use of the software not covered by the terms of this license, 14 | please contact partnerships@nianticlabs.com 15 | 16 | 17 | 18 | ======================================================================================= 19 | 20 | 21 | 22 | Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences License 23 | 24 | 25 | This Agreement is made by and between the Licensor and the Licensee as 26 | defined and identified below. 27 | 28 | 29 | 1. Definitions. 30 | 31 | In this Agreement (“the Agreement”) the following words shall have the 32 | following meanings: 33 | 34 | "Authors" shall mean A. Barroso-Laguna, S. Munukutla, V. Prisacariu, E. Brachmann 35 | "Licensee" Shall mean the person or organization agreeing to use the 36 | Software in accordance with these terms and conditions. 37 | "Licensor" shall mean Niantic Inc., a company organized and existing under 38 | the laws of Delaware, whose principal place of business is at 1 Ferry Building, 39 | Suite 200, San Francisco, 94111. 40 | "Software" shall mean the Matching 2D Images in 3D: Metric Relative Pose 41 | from Metric Correspondences Software uploaded by Licensor to the GitHub 42 | repository at https://github.com/nianticlabs/mickey 43 | on April 10th 2024 in source code or object code form and any 44 | accompanying documentation as well as any modifications or additions uploaded 45 | to the same GitHub repository by Licensor. 46 | 47 | 48 | 2. License. 49 | 50 | 2.1 The Licensor has all necessary rights to grant a license under: (i) 51 | copyright and rights in the nature of copyright subsisting in the Software; and 52 | (ii) certain patent rights resulting from a patent application(s) filed by the 53 | Licensor in the United States and/or other jurisdictions in connection with the 54 | Software. The Licensor grants the Licensee for the duration of this Agreement, 55 | a free of charge, non-sublicenseable, non-exclusive, non-transferable copyright 56 | and patent license (in consequence of said patent application(s)) to use the 57 | Software for non-commercial purpose only, including teaching and research at 58 | educational institutions and research at not-for-profit research institutions 59 | in accordance with the provisions of this Agreement. Non-commercial use 60 | expressly excludes any profit-making or commercial activities, including without 61 | limitation sale, license, manufacture or development of commercial products, use in 62 | commercially-sponsored research, use at a laboratory or other facility owned or 63 | controlled (whether in whole or in part) by a commercial entity, provision of 64 | consulting service, use for or on behalf of any commercial entity, use in 65 | research where a commercial party obtains rights to research results or any 66 | other benefit, and use of the code in any models, model weights or code 67 | resulting from such procedure in any commercial product. Notwithstanding the 68 | foregoing restrictions, you can use this code for publishing comparison results 69 | for academic papers, including retraining on your own data. Any use of the 70 | Software for any purpose other than pursuant to the license grant set forth 71 | above shall automatically terminate this License. 72 | 73 | 74 | 2.2 The Licensee is permitted to make modifications to the Software 75 | provided that any distribution of such modifications is in accordance with 76 | Clause 3. 77 | 78 | 2.3 Except as expressly permitted by this Agreement and save to the 79 | extent and in the circumstances expressly required to be permitted by law, the 80 | Licensee is not permitted to rent, lease, sell, offer to sell, or loan the 81 | Software or its associated documentation. 82 | 83 | 84 | 3. Redistribution and modifications 85 | 86 | 3.1 The Licensee may reproduce and distribute copies of the Software, with 87 | or without modifications, in source format only and only to this same GitHub 88 | repository , and provided that any and every distribution is accompanied by an 89 | unmodified copy of this License and that the following copyright notice is 90 | always displayed in an obvious manner: Copyright © Niantic, Inc. 2018. All 91 | rights reserved. 92 | 93 | 94 | 3.2 In the case where the Software has been modified, any distribution must 95 | include prominent notices indicating which files have been changed. 96 | 97 | 3.3 The Licensee shall cause any work that it distributes or publishes, 98 | that in whole or in part contains or is derived from the Software or any part 99 | thereof (“Work based on the Software”), to be licensed as a whole at no charge 100 | to all third parties entitled to a license to the Software under the terms of 101 | this License and on the same terms provided in this License. 102 | 103 | 104 | 4. Duration. 105 | 106 | This Agreement is effective until the Licensee terminates it by destroying 107 | the Software, any Work based on the Software, and its documentation together 108 | with all copies. It will also terminate automatically if the Licensee fails to 109 | abide by its terms. Upon automatic termination the Licensee agrees to destroy 110 | all copies of the Software, Work based on the Software, and its documentation. 111 | 112 | 113 | 5. Disclaimer of Warranties. 114 | 115 | The Software is provided as is. To the maximum extent permitted by law, 116 | Licensor provides no warranties or conditions of any kind, either express or 117 | implied, including without limitation, any warranties or condition of title, 118 | non-infringement or fitness for a particular purpose. 119 | 120 | 121 | 6. LIMITATION OF LIABILITY. 122 | 123 | IN NO EVENT SHALL THE LICENSOR AND/OR AUTHORS BE LIABLE FOR ANY DIRECT, 124 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY OR CONSEQUENTIAL DAMAGES (INCLUDING 125 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 126 | DATA OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 127 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE 128 | OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 129 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 130 | 131 | 132 | 7. Indemnity. 133 | 134 | The Licensee shall indemnify the Licensor and/or Authors against all third 135 | party claims that may be asserted against or suffered by the Licensor and/or 136 | Authors and which relate to use of the Software by the Licensee. 137 | 138 | 139 | 8. Intellectual Property. 140 | 141 | 8.1 As between the Licensee and Licensor, copyright and all other 142 | intellectual property rights subsisting in or in connection with the Software 143 | and supporting information shall remain at all times the property of the 144 | Licensor. The Licensee shall acquire no rights in any such material except as 145 | expressly provided in this Agreement. 146 | 147 | 8.2 No permission is granted to use the trademarks or product names of the 148 | Licensor except as required for reasonable and customary use in describing the 149 | origin of the Software and for the purposes of abiding by the terms of Clause 150 | 3.1. 151 | 152 | 8.3 The Licensee shall promptly notify the Licensor of any improvement or 153 | new use of the Software (“Improvements”) in sufficient detail for Licensor to 154 | evaluate the Improvements. The Licensee hereby grants the Licensor and its 155 | affiliates a non-exclusive, fully paid-up, royalty-free, irrevocable and 156 | perpetual license to all Improvements for non-commercial academic research and 157 | teaching purposes upon creation of such improvements. 158 | 159 | 8.4 The Licensee grants an exclusive first option to the Licensor to be 160 | exercised by the Licensor within three (3) years of the date of notification of 161 | an Improvement under Clause 8.3 to use any the Improvement for commercial 162 | purposes on terms to be negotiated and agreed by Licensee and Licensor in good 163 | faith within a period of six (6) months from the date of exercise of the said 164 | option (including without limitation any royalty share in net income from such 165 | commercialization payable to the Licensee, as the case may be). 166 | 167 | 168 | 9. Acknowledgements. 169 | 170 | The Licensee shall acknowledge the Authors and use of the Software in the 171 | publication of any work that uses, or results that are achieved through, the 172 | use of the Software. The following citation shall be included in the 173 | acknowledgement: “Matching 2D Images in 3D: Metric Relative Pose 174 | from Metric Correspondences", by A. Barroso-Laguna, S. Munukutla, 175 | V. Prisacariu, E. Brachmann, CVPR 2024. 176 | 177 | 178 | 10. Governing Law. 179 | 180 | This Agreement shall be governed by, construed and interpreted in 181 | accordance with English law and the parties submit to the exclusive 182 | jurisdiction of the English courts. 183 | 184 | 185 | 11. Termination. 186 | 187 | Upon termination of this Agreement, the licenses granted hereunder will 188 | terminate and Sections 5, 6, 7, 8, 9, 10 and 11 shall survive any termination 189 | of this Agreement. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences

3 |

4 | Axel Barroso-Laguna 5 | · 6 | Sowmya Munukutla 7 | · 8 | Victor Adrian Prisacariu 9 | · 10 | Eric Brachmann 11 |

12 |

CVPR 2024 (Oral)

13 |

Project Page | Paper | arXiv | Supplemental

14 |
15 | 16 | This is the reference implementation of the paper **"Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences"** presented at **CVPR 2024**. 17 | 18 | The paper introduces **M**etr**ic Key**points (MicKey), a feature detection pipeline that regresses keypoint positions in camera space. 19 | MicKey presents a differentiable approach to establish metric correspondences via descriptor matching. From the metric correspondences, MicKey recovers metric relative poses. 20 | MicKey is trained in an end-to-end fashion using differentiable pose optimization and requires only image pairs and their ground truth relative poses for supervision. 21 | 22 |

23 | teaser 24 |

25 | 26 | ## Setup 27 | 28 | Assuming a fresh [Anaconda](https://www.anaconda.com/download/) distribution, you can install dependencies with: 29 | ```shell 30 | conda env create -f resources/environment.yml 31 | conda activate mickey 32 | ``` 33 | We ran our experiments with PyTorch 2.0.1, CUDA 11.6, Python 3.8.17 and Debian GNU/Linux 11. 34 | 35 | ## Evaluating MicKey 36 | MicKey aims at addressing the problem of instant Augmented Reality (AR) introduced in the [Map-free benchmark](https://research.nianticlabs.com/mapfree-reloc-benchmark). 37 | In the Map-free set up, instead of building 3D maps from hundreds of images and scale calibrations, they propose to use only one photo of a scene as the map. 38 | The Map-free benchmark then evaluates how accurate is the estimated metric relative pose between the reference image (the map) 39 | and the query image (the user). 40 | 41 | ### Download Map-free dataset 42 | You can find the Map-free dataset in [their project page](https://research.nianticlabs.com/mapfree-reloc-benchmark/dataset). 43 | Extract the test.zip file into `data/mapfree`. Optionally, if you want to train MicKey, also download train and val zip files. 44 | 45 | ### Pre-trained Models 46 | We provide two [MicKey models](https://storage.googleapis.com/niantic-lon-static/research/mickey/assets/mickey_weights.zip). 47 | * _mickey.ckpt_: These are the default weights for MicKey, without using the overlapping scores provides in Map-free dataset and following the curriculum learning strategy described in the paper. 48 | * _mickey_sc.ckpt_: These are the weights when training MicKey using the min and max overlapping scores defined in Map-free. 49 | 50 | Extract mickey_weights.zip into `weights/`. In the zip file, we also provide the default configuration needed to run the evaluation. 51 | 52 | ### Run the submission script 53 | Similar to Map-free code base, we provide a [submission script](submission.py) to generate submission files: 54 | 55 | ```shell 56 | python submission.py --config path/to/config --checkpoint path/to/checkpoint --o results/your_method 57 | ``` 58 | The resulting file `results/your_method/submission.zip` can be uploaded to the Map-free [online benchmark website](https://research.nianticlabs.com/mapfree-reloc-benchmark) and compared against existing methods in the [leaderboard](https://research.nianticlabs.com/mapfree-reloc-benchmark/leaderboard). 59 | 60 | ### Run the local evaluation 61 | The Map-free benchmark does not provide ground-truth poses for the test set. But we can still evaluate our method locally on the validation set. 62 | ```shell 63 | python submission.py --config path/to/config --checkpoint path/to/checkpoint --o results/your_method --split val 64 | ``` 65 | and evaluate it as: 66 | ```shell 67 | python -m benchmark.mapfree --submission_path results/your_method/submission.zip --split val 68 | ``` 69 | 70 | ### Download MicKey correspondences and depth files 71 | We provide the depth maps and correspondences computed by MicKey. 72 | - Download [MicKey depth maps](https://storage.googleapis.com/niantic-lon-static/research/map-free-reloc/assets/mickey_depths.tar.gz). 73 | - Download [MicKey correspondences](https://storage.googleapis.com/niantic-lon-static/research/map-free-reloc/assets/mickey_correspondences.zip). 74 | - Extract the contents of both files to `data/mapfree` 75 | 76 | Refer to the [Map-free benchmark](https://github.com/nianticlabs/map-free-reloc/tree/main?tab=readme-ov-file#feature-matching--scale-from-depth-baselines) to learn how to load precomputed correspondes and depth maps in their feature matching pipeline. 77 | 78 | ### Running MicKey in custom images 79 | We provide a [demo script](demo_inference.py) to run the relative pose estimation pipeline on custom image pairs. 80 | As an example, we store in `data/toy_example` two images with their respective intrinsics. 81 | The script computes their metric relative pose and saves the corresponding depth and keypoint score maps in the image folder. 82 | Run the demo script as: 83 | ```shell 84 | python demo_inference.py --im_path_ref data/toy_example/im0.jpg \ 85 | --im_path_dst data/toy_example/im1.jpg \ 86 | --intrinsics data/toy_example/intrinsics.txt \ 87 | --checkpoint path/to/checkpoint \ 88 | --config path/to/config 89 | ``` 90 | 91 | To generate the 3D assets as in [MicKey's webpage](https://nianticlabs.github.io/mickey/), you can turn on the 92 | `--generate_3D_vis` flag. This will generate a rendered image with the input images, their computed 3D camera positions, 93 | and the set of 3D point inliers. 94 | 95 | ## Training MicKey 96 | Besides the test scripts, we also provide the training code to train MicKey. 97 | 98 | We provide two default configurations in `config/MicKey/`: 99 | * _curriculum_learning.yaml_: This configuration follows the curriculum learning approach detailed in the paper. 100 | It hence does not use any image overlapping information but only relative ground truth poses during training. 101 | * _overlap_score.yaml_: This configuration relies on the image overlapping information to only choose solvable image pairs during training. 102 | 103 | Besides the two default configurations, we also provide a configuration to speed up their training. 104 | These configurations use low-resolution images and do not add the null hypothesis (refer to Section 3.1.4 for details). 105 | We recommend initializing MicKey with these configurations and then 106 | fine-tuning the network with the default ones (which use high-resolution images and the null hypothesis). 107 | They can be found under `config/MicKey/`: 108 | * _curriculum_learning_warm_up.yaml_ 109 | * _overlap_score_warm_up.yaml_ 110 | 111 | To train MicKey default model, use: 112 | ```shell 113 | python train.py --config config/MicKey/curriculum_learning.yaml \ 114 | --dataset_config config/datasets/mapfree.yaml \ 115 | --experiment experiment_name \ 116 | --path_weights path/to/checkpoint/folder 117 | ``` 118 | Resume training from a checkpoint by adding `--resume {path_to_checkpoint}`. 119 | 120 | The top models, according to the validation loss, the VCRE metric, and the pose AUC score, are saved during training. 121 | Tensorboard results and checkpoints are saved into the folder `dir/to/weights/experiment_name`. 122 | 123 | Note that by default, the configuration is set to use 4 GPUs. 124 | You can reduce the expected number of GPUs in the config file (e.g., _NUM_GPUS: 1_). 125 | 126 | ## Changelog 127 | - 19 September 2024: Added vectorized RANSAC and warm up configurations. 128 | - 13 August 2024: Added visualization code. 129 | - 7 June 2024: Added precomputed depth maps and keypoint correspondences. 130 | 131 | ## BibTeX 132 | If you use this code in your research, please consider citing our paper: 133 | 134 | ```bibtex 135 | @inproceedings{barroso2024mickey, 136 | title={Matching 2D Images in 3D: Metric Relative Pose from Metric Correspondences}, 137 | author={Barroso-Laguna, Axel and Munukutla, Sowmya and Prisacariu, Victor and Brachmann, Eric}, 138 | booktitle={CVPR}, 139 | year={2024} 140 | } 141 | ``` 142 | 143 | ## License 144 | Copyright © Niantic, Inc. 2024. Patent Pending. All rights reserved. This code is for non-commercial use. Please see the [license](LICENSE) file for terms. 145 | 146 | ## Acknowledgements 147 | We use parts of code from different repositories. We thank the authors and maintainers of the following repositories. 148 | - [Map-free](https://research.nianticlabs.com/mapfree-reloc-benchmark) 149 | - [ACE](https://github.com/nianticlabs/ace) 150 | - [ACE0](https://github.com/nianticlabs/acezero) 151 | - [DUSt3R](https://github.com/naver/dust3r) 152 | - [RoMa](https://github.com/Parskatt/RoMa) 153 | - [DINOv2](https://github.com/facebookresearch/dinov2) 154 | - [LoFTR](https://github.com/zju3dv/LoFTR) 155 | - [DPT](https://github.com/isl-org/DPT) 156 | - [ExtremeRotation](https://github.com/RuojinCai/ExtremeRotation_code) 157 | 158 | -------------------------------------------------------------------------------- /benchmark/config.py: -------------------------------------------------------------------------------- 1 | # translation and rotation thresholds [meters, degrees] 2 | # used to compute Precision and AUC considering Pose Error 3 | t_threshold = 0.25 4 | R_threshold = 5 5 | 6 | # reprojection (VCRE) threshold [pixels] 7 | # used to compute Precision and AUC considering VCRE 8 | vcre_threshold = 90 9 | -------------------------------------------------------------------------------- /benchmark/mapfree.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | from pathlib import Path 4 | from zipfile import ZipFile 5 | from io import TextIOWrapper 6 | import json 7 | import logging 8 | import numpy as np 9 | 10 | from benchmark.utils import load_poses, subsample_poses, load_K, precision_recall 11 | from benchmark.metrics import MetricManager, Inputs 12 | import benchmark.config as config 13 | from config.default import cfg 14 | 15 | def plot_perfect_curve(P): 16 | total_bins = 1000 17 | prec_values = [] 18 | ratio_values = [] 19 | for i in range(total_bins): 20 | ratio_tmp = i/total_bins 21 | value = min(1, P / ratio_tmp) 22 | prec_values.append(value) 23 | ratio_values.append(ratio_tmp) 24 | return prec_values, ratio_values 25 | 26 | def compute_scene_metrics(dataset_path: Path, submission_zip: ZipFile, scene: str): 27 | metric_manager = MetricManager() 28 | 29 | # load intrinsics and poses 30 | try: 31 | K, W, H = load_K(dataset_path / scene / 'intrinsics.txt') 32 | with (dataset_path / scene / 'poses.txt').open('r', encoding='utf-8') as gt_poses_file: 33 | gt_poses = load_poses(gt_poses_file, load_confidence=False) 34 | except FileNotFoundError as e: 35 | logging.error(f'Could not find ground-truth dataset files: {e}') 36 | raise 37 | else: 38 | logging.info( 39 | f'Loaded ground-truth intrinsics and poses for scene {scene}') 40 | 41 | # try to load estimated poses from submission 42 | try: 43 | with submission_zip.open(f'pose_{scene}.txt') as estimated_poses_file: 44 | estimated_poses_file_wrapper = TextIOWrapper( 45 | estimated_poses_file, encoding='utf-8') 46 | estimated_poses = load_poses( 47 | estimated_poses_file_wrapper, load_confidence=True) 48 | except KeyError as e: 49 | logging.warning( 50 | f'Submission does not have estimates for scene {scene}.') 51 | return dict(), len(gt_poses) 52 | except UnicodeDecodeError as e: 53 | logging.error('Unsupported file encoding: please use UTF-8') 54 | raise 55 | else: 56 | logging.info(f'Loaded estimated poses for scene {scene}') 57 | 58 | # The val/test set is subsampled by a factor of 5 59 | gt_poses = subsample_poses(gt_poses, subsample=5) 60 | 61 | # failures encode how many frames did not have an estimate 62 | # e.g. user/method did not provide an estimate for that frame 63 | # it's different from when an estimate is provided with low confidence! 64 | failures = 0 65 | 66 | # Results encoded as dict 67 | # key: metric name; value: list of values (one per frame). 68 | # e.g. results['t_err'] = [1.2, 0.3, 0.5, ...] 69 | results = defaultdict(list) 70 | 71 | # compute metrics per frame 72 | for frame_num, (q_gt, t_gt, _) in gt_poses.items(): 73 | if frame_num not in estimated_poses: 74 | failures += 1 75 | continue 76 | 77 | q_est, t_est, confidence = estimated_poses[frame_num] 78 | inputs = Inputs(q_gt=q_gt, t_gt=t_gt, q_est=q_est, t_est=t_est, 79 | confidence=confidence, K=K[frame_num], W=W, H=H) 80 | metric_manager(inputs, results) 81 | 82 | return results, failures 83 | 84 | 85 | def aggregate_results(all_results, all_failures): 86 | # aggregate metrics 87 | median_metrics = defaultdict(list) 88 | all_metrics = defaultdict(list) 89 | for scene_results in all_results.values(): 90 | for metric, values in scene_results.items(): 91 | median_metrics[metric].append(np.median(values)) 92 | all_metrics[metric].extend(values) 93 | all_metrics = {k: np.array(v) for k, v in all_metrics.items()} 94 | assert all([v.ndim == 1 for v in all_metrics.values()] 95 | ), 'invalid metrics shape' 96 | 97 | # compute avg median metrics 98 | avg_median_metrics = {metric: np.mean( 99 | values) for metric, values in median_metrics.items()} 100 | 101 | # compute precision/AUC for pose error and reprojection errors 102 | accepted_poses = (all_metrics['trans_err'] < config.t_threshold) * \ 103 | (all_metrics['rot_err'] < config.R_threshold) 104 | accepted_vcre = all_metrics['reproj_err'] < config.vcre_threshold 105 | total_samples = len(next(iter(all_metrics.values()))) + all_failures 106 | 107 | prec_pose = np.sum(accepted_poses) / total_samples 108 | prec_vcre = np.sum(accepted_vcre) / total_samples 109 | 110 | # compute AUC for pose and VCRE 111 | pose_prec_values, pose_recall_values, auc_pose = precision_recall( 112 | inliers=all_metrics['confidence'], tp=accepted_poses, failures=all_failures) 113 | vcre_prec_values, vcre_recall_values, auc_vcre = precision_recall( 114 | inliers=all_metrics['confidence'], tp=accepted_vcre, failures=all_failures) 115 | 116 | curves_data = {} 117 | curves_data['vcre_prec_values'], curves_data['vcre_recall_values'] = vcre_prec_values, vcre_recall_values 118 | curves_data['pose_prec_values'], curves_data['pose_recall_values'] = pose_prec_values, pose_recall_values 119 | 120 | # output metrics 121 | output_metrics = dict() 122 | output_metrics['Average Median Translation Error'] = avg_median_metrics['trans_err'] 123 | output_metrics['Average Median Rotation Error'] = avg_median_metrics['rot_err'] 124 | output_metrics['Average Median Reprojection Error'] = avg_median_metrics['reproj_err'] 125 | output_metrics[f'Precision @ Pose Error < ({config.t_threshold*100}cm, {config.R_threshold}deg)'] = prec_pose 126 | output_metrics[f'AUC @ Pose Error < ({config.t_threshold*100}cm, {config.R_threshold}deg)'] = auc_pose 127 | output_metrics[f'Precision @ VCRE < {config.vcre_threshold}px'] = prec_vcre 128 | output_metrics[f'AUC @ VCRE < {config.vcre_threshold}px'] = auc_vcre 129 | output_metrics[f'Estimates for % of frames'] = len(all_metrics['trans_err']) / total_samples 130 | return output_metrics, curves_data 131 | 132 | 133 | def count_unexpected_scenes(scenes: tuple, submission_zip: ZipFile): 134 | submission_scenes = [fname[5:-4] 135 | for fname in submission_zip.namelist() if fname.startswith("pose_")] 136 | return len(set(submission_scenes) - set(scenes)) 137 | 138 | def main(args): 139 | dataset_path = args.dataset_path / args.split 140 | scenes = tuple(f.name for f in dataset_path.iterdir() if f.is_dir()) 141 | 142 | try: 143 | submission_zip = ZipFile(args.submission_path, 'r') 144 | except FileNotFoundError as e: 145 | logging.error(f'Could not find ZIP file in path {args.submission_path}') 146 | return 147 | 148 | all_results = dict() 149 | all_failures = 0 150 | for scene in scenes: 151 | metrics, failures = compute_scene_metrics( 152 | dataset_path, submission_zip, scene) 153 | all_results[scene] = metrics 154 | all_failures += failures 155 | 156 | if all_failures > 0: 157 | logging.warning( 158 | f'Submission is missing pose estimates for {all_failures} frames') 159 | 160 | unexpected_scene_count = count_unexpected_scenes(scenes, submission_zip) 161 | if unexpected_scene_count > 0: 162 | logging.warning( 163 | f'Submission contains estimates for {unexpected_scene_count} scenes outside the {args.split} set') 164 | 165 | if all((len(metrics) == 0 for metrics in all_results.values())): 166 | logging.error( 167 | f'Submission does not have any valid pose estimates') 168 | return 169 | 170 | output_metrics, curves_data = aggregate_results(all_results, all_failures) 171 | output_json = json.dumps(output_metrics, indent=2) 172 | print(output_json) 173 | 174 | 175 | if __name__ == '__main__': 176 | parser = argparse.ArgumentParser( 177 | 'eval', description='Evaluate submissions for the MapFree dataset benchmark') 178 | parser.add_argument('--submission_path', type=Path, default='', 179 | help='Path to the submission ZIP file') 180 | parser.add_argument('--split', choices=('val', 'test'), default='test', 181 | help='Dataset split to use for evaluation. Default: test') 182 | parser.add_argument('--log', choices=('warning', 'info', 'error'), 183 | default='warning', help='Logging level. Default: warning') 184 | parser.add_argument('--dataset_path', type=Path, default=None, 185 | help='Path to the dataset folder') 186 | 187 | args = parser.parse_args() 188 | 189 | if args.dataset_path is None: 190 | cfg.merge_from_file('config/datasets/mapfree.yaml') 191 | args.dataset_path = Path(cfg.DATASET.DATA_ROOT) 192 | 193 | logging.basicConfig(level=args.log.upper()) 194 | try: 195 | main(args) 196 | except Exception: 197 | logging.error("Unexpected behaviour. Exiting.") 198 | 199 | -------------------------------------------------------------------------------- /benchmark/metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable 3 | 4 | import numpy as np 5 | 6 | from benchmark.reprojection import reprojection_error 7 | from benchmark.utils import VARIANTS_ANGLE_SIN, quat_angle_error 8 | 9 | 10 | @dataclass 11 | class Inputs: 12 | q_gt: np.array 13 | t_gt: np.array 14 | q_est: np.array 15 | t_est: np.array 16 | confidence: float 17 | K: np.array 18 | W: int 19 | H: int 20 | 21 | def __post_init__(self): 22 | assert self.q_gt.shape == (4,), 'invalid gt quaternion shape' 23 | assert self.t_gt.shape == (3,), 'invalid gt translation shape' 24 | assert self.q_est.shape == (4,), 'invalid estimated quaternion shape' 25 | assert self.t_est.shape == (3,), 'invalid estimated translation shape' 26 | assert self.confidence >= 0, 'confidence must be non negative' 27 | assert self.K.shape == (3, 3), 'invalid K shape' 28 | assert self.W > 0, 'invalid image width' 29 | assert self.H > 0, 'invalid image height' 30 | 31 | 32 | class MyDict(dict): 33 | def register(self, fn) -> Callable: 34 | """Registers a function within dict(fn_name -> fn_ref). 35 | This is used to evaluate all registered metrics in MetricManager.__call__()""" 36 | self[fn.__name__] = fn 37 | return fn 38 | 39 | 40 | class MetricManager: 41 | _metrics = MyDict() 42 | 43 | def __call__(self, inputs: Inputs, results: dict) -> None: 44 | for metric, metric_fn in self._metrics.items(): 45 | results[metric].append(metric_fn(inputs)) 46 | 47 | @staticmethod 48 | @_metrics.register 49 | def trans_err(inputs: Inputs) -> np.float64: 50 | return np.linalg.norm(inputs.t_est - inputs.t_gt) 51 | 52 | @staticmethod 53 | @_metrics.register 54 | def rot_err(inputs: Inputs, variant: str = VARIANTS_ANGLE_SIN) -> np.float64: 55 | return quat_angle_error(label=inputs.q_est, pred=inputs.q_gt, variant=variant)[0, 0] 56 | 57 | @staticmethod 58 | @_metrics.register 59 | def reproj_err(inputs: Inputs) -> float: 60 | return reprojection_error( 61 | q_est=inputs.q_est, t_est=inputs.t_est, q_gt=inputs.q_gt, t_gt=inputs.t_gt, K=inputs.K, 62 | W=inputs.W, H=inputs.H) 63 | 64 | @staticmethod 65 | @_metrics.register 66 | def confidence(inputs: Inputs) -> float: 67 | return inputs.confidence 68 | -------------------------------------------------------------------------------- /benchmark/reprojection.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from transforms3d.quaternions import quat2mat 5 | 6 | 7 | def project(pts: np.ndarray, K: np.ndarray, img_size: List[int] or Tuple[int] = None) -> np.ndarray: 8 | """Projects 3D points to image plane. 9 | 10 | Args: 11 | - pts [N, 3/4]: points in camera coordinates (homogeneous or non-homogeneous) 12 | - K [3, 3]: intrinsic matrix 13 | - img_size (width, height): optional, clamp projection to image borders 14 | Outputs: 15 | - uv [N, 2]: coordinates of projected points 16 | """ 17 | 18 | assert len(pts.shape) == 2, 'incorrect number of dimensions' 19 | assert pts.shape[1] in [3, 4], 'invalid dimension size' 20 | assert K.shape == (3, 3), 'incorrect intrinsic shape' 21 | 22 | uv_h = (K @ pts[:, :3].T).T 23 | uv = uv_h[:, :2] / uv_h[:, -1:] 24 | 25 | if img_size is not None: 26 | uv[:, 0] = np.clip(uv[:, 0], 0, img_size[0]) 27 | uv[:, 1] = np.clip(uv[:, 1], 0, img_size[1]) 28 | 29 | return uv 30 | 31 | 32 | def get_grid_multipleheight() -> np.ndarray: 33 | # create grid of points 34 | ar_grid_step = 0.3 35 | ar_grid_num_x = 7 36 | ar_grid_num_y = 4 37 | ar_grid_num_z = 7 38 | ar_grid_z_offset = 1.8 39 | ar_grid_y_offset = 0 40 | 41 | ar_grid_x_pos = np.arange(0, ar_grid_num_x)-(ar_grid_num_x-1)/2 42 | ar_grid_x_pos *= ar_grid_step 43 | 44 | ar_grid_y_pos = np.arange(0, ar_grid_num_y)-(ar_grid_num_y-1)/2 45 | ar_grid_y_pos *= ar_grid_step 46 | ar_grid_y_pos += ar_grid_y_offset 47 | 48 | ar_grid_z_pos = np.arange(0, ar_grid_num_z).astype(float) 49 | ar_grid_z_pos *= ar_grid_step 50 | ar_grid_z_pos += ar_grid_z_offset 51 | 52 | xx, yy, zz = np.meshgrid(ar_grid_x_pos, ar_grid_y_pos, ar_grid_z_pos) 53 | ones = np.ones(xx.shape[0]*xx.shape[1]*xx.shape[2]) 54 | eye_coords = np.concatenate([c.reshape(-1, 1) 55 | for c in (xx, yy, zz, ones)], axis=-1) 56 | return eye_coords 57 | 58 | 59 | # global variable, avoids creating it again 60 | eye_coords_glob = get_grid_multipleheight() 61 | 62 | 63 | def reprojection_error( 64 | q_est: np.ndarray, t_est: np.ndarray, q_gt: np.ndarray, t_gt: np.ndarray, K: np.ndarray, 65 | W: int, H: int) -> float: 66 | eye_coords = eye_coords_glob 67 | 68 | # obtain ground-truth position of projected points 69 | uv_gt = project(eye_coords, K, (W, H)) 70 | 71 | # residual transformation 72 | cam2w_est = np.eye(4) 73 | cam2w_est[:3, :3] = quat2mat(q_est) 74 | cam2w_est[:3, -1] = t_est 75 | cam2w_gt = np.eye(4) 76 | cam2w_gt[:3, :3] = quat2mat(q_gt) 77 | cam2w_gt[:3, -1] = t_gt 78 | 79 | # residual reprojection 80 | eyes_residual = (np.linalg.inv(cam2w_est) @ cam2w_gt @ eye_coords.T).T 81 | uv_pred = project(eyes_residual, K, (W, H)) 82 | 83 | # get reprojection error 84 | repr_err = np.linalg.norm(uv_gt - uv_pred, ord=2, axis=1) 85 | mean_repr_err = float(repr_err.mean().item()) 86 | return mean_repr_err 87 | -------------------------------------------------------------------------------- /benchmark/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from transforms3d.euler import euler2quat 4 | from transforms3d.quaternions import axangle2quat, qmult, quat2mat, rotate_vector 5 | 6 | from benchmark.metrics import Inputs, MetricManager 7 | from benchmark.reprojection import project 8 | from benchmark.utils import VARIANTS_ANGLE_COS, VARIANTS_ANGLE_SIN 9 | 10 | 11 | def createInput(q_gt=None, t_gt=None, q_est=None, t_est=None, confidence=None, K=None, W=None, H=None): 12 | q_gt = np.zeros(4) if q_gt is None else q_gt 13 | t_gt = np.zeros(3) if t_gt is None else t_gt 14 | q_est = np.zeros(4) if q_est is None else q_est 15 | t_est = np.zeros(3) if t_est is None else t_est 16 | confidence = 0. if confidence is None else confidence 17 | K = np.eye(3) if K is None else K 18 | H = 1 if H is None else H 19 | W = 1 if W is None else W 20 | return Inputs(q_gt=q_gt, t_gt=t_gt, q_est=q_est, t_est=t_est, confidence=confidence, K=K, W=W, H=H) 21 | 22 | 23 | def randomQuat(): 24 | angles = np.random.uniform(0, 2*np.pi, 3) 25 | q = euler2quat(*angles) 26 | return q 27 | 28 | 29 | class TestMetrics: 30 | @pytest.mark.parametrize('run_number', range(50)) 31 | def test_t_err_tinvariance(self, run_number: int) -> None: 32 | """Computes the translation error given an initial translation and displacement of this 33 | translation. The translation error must be equal to the norm of the displacement.""" 34 | mean, var = 5, 10 35 | t0 = np.random.normal(mean, var, (3,)) 36 | displacement = np.random.normal(mean, var, (3,)) 37 | 38 | i = createInput(t_gt=t0, t_est=t0+displacement) 39 | trans_err = MetricManager.trans_err(i) 40 | assert np.isclose(trans_err, np.linalg.norm(displacement)) 41 | 42 | @pytest.mark.parametrize('run_number', range(50)) 43 | def test_trans_err_rinvariance(self, run_number: int) -> None: 44 | """Computes the translation error given estimated and gt vectors. 45 | The translation error must be the same for a rotated version of those vectors 46 | (same random rotation)""" 47 | mean, var = 5, 10 48 | t0 = np.random.normal(mean, var, (3,)) 49 | t1 = np.random.normal(mean, var, (3,)) 50 | q = randomQuat() 51 | 52 | i = createInput(t_gt=t0, t_est=t1) 53 | trans_err = MetricManager.trans_err(i) 54 | 55 | ir = createInput(t_gt=rotate_vector(t0, q), t_est=rotate_vector(t1, q)) 56 | trans_err_r = MetricManager.trans_err(ir) 57 | 58 | assert np.isclose(trans_err, trans_err_r) 59 | 60 | @pytest.mark.parametrize('run_number', range(50)) 61 | @pytest.mark.parametrize('dtype', (np.float64, np.float32)) 62 | def test_rot_err_raxis(self, run_number: int, dtype: type) -> None: 63 | """Test rotation error for rotations around a random axis. 64 | 65 | Note: We create GT as high precision, and only downcast when calling rot_err. 66 | """ 67 | q = randomQuat().astype(np.float64) 68 | 69 | axis = np.random.uniform(low=-1, high=1, size=3).astype(np.float64) 70 | angle = np.float64(np.random.uniform(low=-np.pi, high=np.pi)) 71 | qres = axangle2quat(vector=axis, theta=angle, is_normalized=False).astype(np.float64) 72 | 73 | i = createInput(q_gt=q.astype(dtype), q_est=qmult(q, qres).astype(dtype)) 74 | rot_err = MetricManager.rot_err(i) 75 | assert isinstance(rot_err, np.float64) 76 | rot_err_expected = np.abs(np.degrees(angle)) 77 | # if we add up errors, we want them to be positive 78 | assert 0. <= rot_err 79 | rtol = 1.e-5 # numpy default 80 | atol = 1.e-8 # numpy default 81 | if isinstance(dtype, np.float32): 82 | atol = 1.e-7 # 1/50 test might fail at 1.e-8 83 | assert np.isclose(rot_err, rot_err_expected, rtol=rtol, atol=atol) 84 | 85 | @pytest.mark.parametrize('run_number', range(50)) 86 | def test_r_err_mat(self, run_number: int) -> None: 87 | q0 = randomQuat() 88 | q1 = randomQuat() 89 | 90 | i = createInput(q_gt=q0, q_est=q1) 91 | rot_err = MetricManager.rot_err(i) 92 | 93 | R0 = quat2mat(q0) 94 | R1 = quat2mat(q1) 95 | Rres = R1 @ R0.T 96 | theta = (np.trace(Rres) - 1)/2 97 | theta = np.clip(theta, -1, 1) 98 | angle = np.degrees(np.arccos(theta)) 99 | 100 | assert np.isclose(angle, rot_err) 101 | 102 | def test_reproj_error_identity(self): 103 | """Test that reprojection error is zero if poses match""" 104 | q = randomQuat() 105 | t = np.random.normal(0, 10, (3,)) 106 | i = createInput(q_gt=q, t_gt=t, q_est=q, t_est=t) 107 | 108 | reproj_err = MetricManager.reproj_err(i) 109 | assert np.isclose(reproj_err, 0) 110 | 111 | @pytest.mark.parametrize('run_number', range(10)) 112 | @pytest.mark.parametrize('variant', (VARIANTS_ANGLE_SIN,)) 113 | @pytest.mark.parametrize('dtype', (np.float64,)) 114 | def test_r_err_small(self, run_number: int, variant: str, dtype: type) -> None: 115 | """Test rotation error for small angle differences. 116 | 117 | Note: We create GT as high precision, and only downcast when calling rot_err. 118 | """ 119 | scales_failed = [] 120 | for scale in np.logspace(start=-1, stop=-9, num=9, base=10, dtype=dtype): 121 | q = randomQuat().astype(np.float64) 122 | angle = np.float64(np.random.uniform(low=-np.pi, high=np.pi)) * scale 123 | assert isinstance(angle, np.float64) 124 | axis = np.random.uniform(low=-1., high=1., size=3).astype(np.float64) 125 | assert axis.dtype == np.float64 126 | qres = axangle2quat(vector=axis, theta=angle, is_normalized=False).astype(np.float64) 127 | assert qres.dtype == np.float64 128 | 129 | i = createInput(q_gt=q.astype(dtype), q_est=qmult(q, qres).astype(dtype)) 130 | 131 | # We expect the error to always be np.float64 for highest acc. 132 | rot_err = MetricManager.rot_err(i, variant=variant) 133 | assert isinstance(rot_err, np.float64) 134 | rot_err_expected = np.abs(np.degrees(angle)) 135 | assert isinstance(rot_err_expected, type(rot_err)) 136 | 137 | # if we add up errors, we want them to be positive 138 | assert 0. <= rot_err 139 | 140 | # check accuracy for one magnitude higher tolerance than the angle 141 | tol = 0.1 * scale 142 | # need to be more permissive for lower precision 143 | if dtype == np.float32: 144 | tol = 1.e3 * scale 145 | 146 | # cast to dtype for checking 147 | rot_err = rot_err.astype(dtype) 148 | rot_err_expected = rot_err_expected.astype(dtype) 149 | 150 | if variant == VARIANTS_ANGLE_SIN: 151 | assert np.isclose(rot_err, rot_err_expected, rtol=tol, atol=tol) 152 | elif variant == VARIANTS_ANGLE_COS: 153 | if not np.isclose(rot_err, rot_err_expected, rtol=tol, atol=tol): 154 | print(f"[variant '{variant}'] raises an error for\n" 155 | f"\trot_err: {rot_err}" 156 | f"\trot_err_expected: {rot_err_expected}" 157 | f"\trtol: {tol}" 158 | f"\tatol: {tol}") 159 | scales_failed.append(scale) 160 | if len(scales_failed): 161 | pytest.fail(f"Variant {variant} failed at scales {scales_failed}") 162 | 163 | 164 | def test_projection() -> None: 165 | xyz = np.array(((10, 20, 30), (10, 30, 50), (-20, -15, 5), 166 | (-20, -50, 10)), dtype=np.float32) 167 | K = np.eye(3) 168 | 169 | uv = np.array(((1/3, 2/3), (1/5, 3/5), (-4, -3), 170 | (-2, -5)), dtype=np.float32) 171 | assert np.allclose(uv, project(xyz, K)) 172 | 173 | uv = np.array(((1/3, 2/3), (1/5, 3/5), (0, 0), (0, 0)), dtype=np.float32) 174 | assert np.allclose(uv, project(xyz, K, img_size=(5, 5))) 175 | -------------------------------------------------------------------------------- /benchmark/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import typing 3 | import logging 4 | 5 | import numpy as np 6 | from transforms3d.quaternions import qinverse, rotate_vector, qmult 7 | 8 | VARIANTS_ANGLE_SIN = 'sin' 9 | VARIANTS_ANGLE_COS = 'cos' 10 | 11 | 12 | def convert_world2cam_to_cam2world(q, t): 13 | qinv = qinverse(q) 14 | tinv = -rotate_vector(t, qinv) 15 | return qinv, tinv 16 | 17 | 18 | def load_poses(file: typing.IO, load_confidence: bool = False): 19 | """Load poses from text file and converts them to cam2world convention (t is the camera center in world coordinates) 20 | 21 | The text file encodes world2cam poses with the format: 22 | imgpath qw qx qy qz tx ty tz [confidence] 23 | where qw qx qy qz is the quaternion encoding rotation, 24 | and tx ty tz is the translation vector, 25 | and confidence is a float encoding confidence, for estimated poses 26 | """ 27 | 28 | expected_parts = 9 if load_confidence else 8 29 | 30 | poses = dict() 31 | for line_number, line in enumerate(file.readlines()): 32 | parts = tuple(line.strip().split(' ')) 33 | 34 | # if 'tensor' in parts[-1]: 35 | # print('ERROR: confidence is a tensor') 36 | # parts = list(parts) 37 | # parts[-1] = parts[-1].split('[')[-1].split(']')[0] 38 | if len(parts) != expected_parts: 39 | logging.warning( 40 | f'Invalid number of fields in file {file.name} line {line_number}.' 41 | f' Expected {expected_parts}, received {len(parts)}. Ignoring line.') 42 | continue 43 | 44 | try: 45 | name = parts[0] 46 | if '#' in name: 47 | logging.info(f'Ignoring comment line in {file.name} line {line_number}') 48 | continue 49 | frame_num = int(name[-9:-4]) 50 | except ValueError: 51 | logging.warning( 52 | f'Invalid frame number in file {file.name} line {line_number}.' 53 | f' Expected formatting "seq1/frame_00000.jpg". Ignoring line.') 54 | continue 55 | 56 | try: 57 | parts_float = tuple(map(float, parts[1:])) 58 | if any(np.isnan(v) or np.isinf(v) for v in parts_float): 59 | raise ValueError() 60 | qw, qx, qy, qz, tx, ty, tz = parts_float[:7] 61 | confidence = parts_float[7] if load_confidence else None 62 | except ValueError: 63 | logging.warning( 64 | f'Error parsing pose in file {file.name} line {line_number}. Ignoring line.') 65 | continue 66 | 67 | q = np.array((qw, qx, qy, qz), dtype=np.float64) 68 | t = np.array((tx, ty, tz), dtype=np.float64) 69 | 70 | if np.isclose(np.linalg.norm(q), 0): 71 | logging.warning( 72 | f'Error parsing pose in file {file.name} line {line_number}. ' 73 | 'Quaternion must have non-zero norm. Ignoring line.') 74 | continue 75 | 76 | q, t = convert_world2cam_to_cam2world(q, t) 77 | poses[frame_num] = (q, t, confidence) 78 | return poses 79 | 80 | 81 | def subsample_poses(poses: dict, subsample: int = 1): 82 | return {k: v for i, (k, v) in enumerate(poses.items()) if i % subsample == 0} 83 | 84 | 85 | def load_K(file_path: Path): 86 | K = dict() 87 | with file_path.open('r', encoding='utf-8') as f: 88 | for line in f.readlines(): 89 | if '#' in line: 90 | continue 91 | line = line.strip().split(' ') 92 | 93 | frame_num = int(line[0][-9:-4]) 94 | fx, fy, cx, cy, W, H = map(float, line[1:]) 95 | K[frame_num] = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) 96 | return K, W, H 97 | 98 | 99 | def quat_angle_error(label, pred, variant=VARIANTS_ANGLE_SIN) -> np.ndarray: 100 | assert label.shape == (4,) 101 | assert pred.shape == (4,) 102 | assert variant in (VARIANTS_ANGLE_SIN, VARIANTS_ANGLE_COS), \ 103 | f"Need variant to be in ({VARIANTS_ANGLE_SIN}, {VARIANTS_ANGLE_COS})" 104 | 105 | if len(label.shape) == 1: 106 | label = np.expand_dims(label, axis=0) 107 | if len(label.shape) != 2 or label.shape[0] != 1 or label.shape[1] != 4: 108 | raise RuntimeError(f"Unexpected shape of label: {label.shape}, expected: (1, 4)") 109 | 110 | if len(pred.shape) == 1: 111 | pred = np.expand_dims(pred, axis=0) 112 | if len(pred.shape) != 2 or pred.shape[0] != 1 or pred.shape[1] != 4: 113 | raise RuntimeError(f"Unexpected shape of pred: {pred.shape}, expected: (1, 4)") 114 | 115 | label = label.astype(np.float64) 116 | pred = pred.astype(np.float64) 117 | 118 | q1 = pred / np.linalg.norm(pred, axis=1, keepdims=True) 119 | q2 = label / np.linalg.norm(label, axis=1, keepdims=True) 120 | if variant == VARIANTS_ANGLE_COS: 121 | d = np.abs(np.sum(np.multiply(q1, q2), axis=1, keepdims=True)) 122 | d = np.clip(d, a_min=-1, a_max=1) 123 | angle = 2. * np.degrees(np.arccos(d)) 124 | elif variant == VARIANTS_ANGLE_SIN: 125 | if q1.shape[0] != 1 or q2.shape[0] != 1: 126 | raise NotImplementedError(f"Multiple angles is todo") 127 | # https://www.researchgate.net/post/How_do_I_calculate_the_smallest_angle_between_two_quaternions/5d6ed4a84f3a3e1ed3656616/citation/download 128 | sine = qmult(q1[0], qinverse(q2[0])) # note: takes first element in 2D array 129 | # 114.59 = 2. * 180. / pi 130 | angle = np.arcsin(np.linalg.norm(sine[1:], keepdims=True)) * 114.59155902616465 131 | angle = np.expand_dims(angle, axis=0) 132 | 133 | return angle.astype(np.float64) 134 | 135 | 136 | def precision_recall(inliers, tp, failures): 137 | """ 138 | Computes Precision/Recall plot for a set of poses given inliers (confidence) and wether the 139 | estimated pose error (whatever it may be) is within a threshold. 140 | Each point in the plot is obtained by choosing a threshold for inliers (i.e. inlier_thr). 141 | Recall measures how many images have inliers >= inlier_thr 142 | Precision measures how many images that have inliers >= inlier_thr have 143 | estimated pose error <= pose_threshold (measured by counting tps) 144 | Where pose_threshold is (trans_thr[m], rot_thr[deg]) 145 | 146 | Inputs: 147 | - inliers [N] 148 | - terr [N] 149 | - rerr [N] 150 | - failures (int) 151 | - pose_threshold (tuple float) 152 | Output 153 | - precision [N] 154 | - recall [N] 155 | - average_precision (scalar) 156 | """ 157 | 158 | assert len(inliers) == len(tp), 'unequal shapes' 159 | 160 | # sort by inliers (descending order) 161 | inliers = np.array(inliers) 162 | sort_idx = np.argsort(inliers)[::-1] 163 | inliers = inliers[sort_idx] 164 | tp = np.array(tp).reshape(-1)[sort_idx] 165 | 166 | # get idxs where inliers change (avoid tied up values) 167 | distinct_value_indices = np.where(np.diff(inliers))[0] 168 | threshold_idxs = np.r_[distinct_value_indices, inliers.size - 1] 169 | 170 | # compute prec/recall 171 | N = inliers.shape[0] 172 | rec = np.arange(N, dtype=np.float32) + 1 173 | cum_tp = np.cumsum(tp) 174 | prec = cum_tp[threshold_idxs] / rec[threshold_idxs] 175 | rec = rec[threshold_idxs] / (float(N) + float(failures)) 176 | 177 | # invert order and ensures (prec=1, rec=0) point 178 | last_ind = rec.searchsorted(rec[-1]) 179 | sl = slice(last_ind, None, -1) 180 | prec = np.r_[prec[sl], 1] 181 | rec = np.r_[rec[sl], 0] 182 | 183 | # compute average precision (AUC) as the weighted average of precisions 184 | average_precision = np.abs(np.sum(np.diff(rec) * np.array(prec)[:-1])) 185 | 186 | return prec, rec, average_precision 187 | -------------------------------------------------------------------------------- /config/MicKey/curriculum_learning.yaml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 'MicKey' 3 | DEBUG: False 4 | MICKEY: 5 | DINOV2: 6 | DOWN_FACTOR: 14 7 | CHANNEL_DIM: 1024 8 | FLOAT16: True 9 | 10 | KP_HEADS: 11 | BLOCKS_DIM: [512, 256, 128, 64] 12 | BN: True 13 | USE_SOFTMAX: True 14 | USE_DEPTHSIGMOID: False 15 | MAX_DEPTH: 60 16 | POS_ENCODING: True 17 | 18 | DSC_HEAD: 19 | LAST_DIM: 128 20 | BLOCKS_DIM: [512, 256, 128] 21 | BN: True 22 | NORM_DSC: True 23 | POS_ENCODING: True 24 | 25 | FEATURE_MATCHER: 26 | TYPE: 'DualSoftmax' 27 | DUAL_SOFTMAX: 28 | TEMPERATURE: 0.1 29 | USE_DUSTBIN: True 30 | SINKHORN: 31 | NUM_IT: 10 32 | DUSTBIN_SCORE_INIT: 1. 33 | USE_TRANSFORMER: False 34 | 35 | TRAINING: 36 | NUM_GPUS: 4 37 | BATCH_SIZE: 8 # BS for each dataloader (in every GPU) 38 | NUM_WORKERS: 8 39 | SAMPLER: 'scene_balance' 40 | N_SAMPLES_SCENE: 100 41 | SAMPLE_WITH_REPLACEMENT: True 42 | LR: 1e-4 43 | LOG_INTERVAL: 50 44 | VAL_INTERVAL: 0.5 45 | VAL_BATCHES: 100 46 | EPOCHS: 100 47 | 48 | DATASET: 49 | HEIGHT: 720 50 | WIDTH: 540 51 | 52 | MIN_OVERLAP_SCORE: 0.0 # [train only] discard data with overlap_score < min_overlap_score 53 | MAX_OVERLAP_SCORE: 1.0 # [train only] discard data with overlap_score < min_overlap_score 54 | 55 | LOSS_CLASS: 56 | 57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR 58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values. 59 | 60 | POSE_ERR: 61 | MAX_LOSS_VALUE: 1.5 62 | MAX_LOSS_SOFTVALUE: 0.8 63 | VCRE: 64 | MAX_LOSS_VALUE: 90 65 | MAX_LOSS_SOFTVALUE: 0.8 66 | 67 | GENERATE_HYPOTHESES: 68 | SCORE_TEMPERATURE: 20 69 | IT_MATCHES: 20 70 | IT_RANSAC: 20 71 | INLIER_3D_TH: 0.3 72 | INLIER_REF_TH: 0.15 73 | NUM_REF_STEPS: 4 74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability 75 | 76 | NULL_HYPOTHESIS: 77 | ADD_NULL_HYPOTHESIS: True 78 | TH_OUTLIERS: 0.35 79 | 80 | CURRICULUM_LEARNING: 81 | TRAIN_CURRICULUM: True # It indicates if MicKey should be trained with curriculum learning 82 | TRAIN_WITH_TOPK: True # It indicates if MicKey should be trained only with top image pairs 83 | TOPK_INIT: 30 84 | TOPK: 80 85 | 86 | SAMPLER: 87 | NUM_SAMPLES_MATCHES: 512 88 | 89 | PROCRUSTES: 90 | IT_MATCHES: 20 91 | IT_RANSAC: 100 92 | NUM_SAMPLED_MATCHES: 2048 93 | NUM_CORR_3D_3D: 3 94 | NUM_REFINEMENTS: 4 95 | TH_INLIER: 0.15 96 | TH_SOFT_INLIER: 0.3 97 | 98 | -------------------------------------------------------------------------------- /config/MicKey/curriculum_learning_warm_up.yaml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 'MicKey' 3 | DEBUG: False 4 | MICKEY: 5 | DINOV2: 6 | DOWN_FACTOR: 14 7 | CHANNEL_DIM: 1024 8 | FLOAT16: True 9 | 10 | KP_HEADS: 11 | BLOCKS_DIM: [512, 256, 128, 64] 12 | BN: True 13 | USE_SOFTMAX: True 14 | USE_DEPTHSIGMOID: False 15 | MAX_DEPTH: 60 16 | POS_ENCODING: True 17 | 18 | DSC_HEAD: 19 | LAST_DIM: 128 20 | BLOCKS_DIM: [512, 256, 128] 21 | BN: True 22 | NORM_DSC: True 23 | POS_ENCODING: True 24 | 25 | FEATURE_MATCHER: 26 | TYPE: 'DualSoftmax' 27 | DUAL_SOFTMAX: 28 | TEMPERATURE: 0.1 29 | USE_DUSTBIN: True 30 | SINKHORN: 31 | NUM_IT: 10 32 | DUSTBIN_SCORE_INIT: 1. 33 | USE_TRANSFORMER: False 34 | 35 | TRAINING: 36 | NUM_GPUS: 4 37 | BATCH_SIZE: 24 # BS for each dataloader (in every GPU) 38 | NUM_WORKERS: 8 39 | SAMPLER: 'scene_balance' 40 | N_SAMPLES_SCENE: 100 41 | SAMPLE_WITH_REPLACEMENT: True 42 | LR: 1e-4 43 | LOG_INTERVAL: 50 44 | VAL_INTERVAL: 0.5 45 | VAL_BATCHES: 100 46 | EPOCHS: 100 47 | 48 | DATASET: 49 | HEIGHT: 480 50 | WIDTH: 360 51 | 52 | MIN_OVERLAP_SCORE: 0.0 # [train only] discard data with overlap_score < min_overlap_score 53 | MAX_OVERLAP_SCORE: 1.0 # [train only] discard data with overlap_score < min_overlap_score 54 | 55 | LOSS_CLASS: 56 | 57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR 58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values. 59 | 60 | POSE_ERR: 61 | MAX_LOSS_VALUE: 1.5 62 | MAX_LOSS_SOFTVALUE: 0.8 63 | VCRE: 64 | MAX_LOSS_VALUE: 90 65 | MAX_LOSS_SOFTVALUE: 0.8 66 | 67 | GENERATE_HYPOTHESES: 68 | SCORE_TEMPERATURE: 20 69 | IT_MATCHES: 20 70 | IT_RANSAC: 20 71 | INLIER_3D_TH: 0.3 72 | INLIER_REF_TH: 0.15 73 | NUM_REF_STEPS: 4 74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability 75 | 76 | NULL_HYPOTHESIS: 77 | ADD_NULL_HYPOTHESIS: False 78 | TH_OUTLIERS: 0.35 79 | 80 | CURRICULUM_LEARNING: 81 | TRAIN_CURRICULUM: True # It indicates if MicKey should be trained with curriculum learning 82 | TRAIN_WITH_TOPK: True # It indicates if MicKey should be trained only with top image pairs 83 | TOPK_INIT: 30 84 | TOPK: 80 85 | 86 | SAMPLER: 87 | NUM_SAMPLES_MATCHES: 64 88 | 89 | PROCRUSTES: 90 | IT_MATCHES: 20 91 | IT_RANSAC: 100 92 | NUM_SAMPLED_MATCHES: 2048 93 | NUM_CORR_3D_3D: 3 94 | NUM_REFINEMENTS: 4 95 | TH_INLIER: 0.15 96 | TH_SOFT_INLIER: 0.3 97 | 98 | -------------------------------------------------------------------------------- /config/MicKey/overlap_score.yaml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 'MicKey' 3 | DEBUG: False 4 | MICKEY: 5 | DINOV2: 6 | DOWN_FACTOR: 14 7 | CHANNEL_DIM: 1024 8 | FLOAT16: True 9 | 10 | KP_HEADS: 11 | BLOCKS_DIM: [512, 256, 128, 64] 12 | BN: True 13 | USE_SOFTMAX: True 14 | USE_DEPTHSIGMOID: False 15 | MAX_DEPTH: 60 16 | POS_ENCODING: True 17 | 18 | DSC_HEAD: 19 | LAST_DIM: 128 20 | BLOCKS_DIM: [512, 256, 128] 21 | BN: True 22 | NORM_DSC: True 23 | POS_ENCODING: True 24 | 25 | FEATURE_MATCHER: 26 | TYPE: 'DualSoftmax' 27 | DUAL_SOFTMAX: 28 | TEMPERATURE: 0.1 29 | USE_DUSTBIN: True 30 | SINKHORN: 31 | NUM_IT: 10 32 | DUSTBIN_SCORE_INIT: 1. 33 | USE_TRANSFORMER: False 34 | 35 | TRAINING: 36 | NUM_GPUS: 4 37 | BATCH_SIZE: 8 # BS for each dataloader (in every GPU) 38 | NUM_WORKERS: 8 39 | SAMPLER: 'scene_balance' 40 | N_SAMPLES_SCENE: 100 41 | SAMPLE_WITH_REPLACEMENT: True 42 | LR: 1e-4 43 | LOG_INTERVAL: 50 44 | VAL_INTERVAL: 0.5 45 | VAL_BATCHES: 100 46 | EPOCHS: 100 47 | 48 | DATASET: 49 | HEIGHT: 720 50 | WIDTH: 540 51 | 52 | MIN_OVERLAP_SCORE: 0.4 # [train only] discard data with overlap_score < min_overlap_score 53 | MAX_OVERLAP_SCORE: 0.8 # [train only] discard data with overlap_score < min_overlap_score 54 | 55 | LOSS_CLASS: 56 | 57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR 58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values. 59 | 60 | POSE_ERR: 61 | MAX_LOSS_VALUE: 1.5 62 | MAX_LOSS_SOFTVALUE: 0.8 63 | VCRE: 64 | MAX_LOSS_VALUE: 90 65 | MAX_LOSS_SOFTVALUE: 0.8 66 | 67 | GENERATE_HYPOTHESES: 68 | SCORE_TEMPERATURE: 20 69 | IT_MATCHES: 20 70 | IT_RANSAC: 20 71 | INLIER_3D_TH: 0.3 72 | INLIER_REF_TH: 0.15 73 | NUM_REF_STEPS: 4 74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability 75 | 76 | NULL_HYPOTHESIS: 77 | ADD_NULL_HYPOTHESIS: True 78 | TH_OUTLIERS: 0.35 79 | 80 | CURRICULUM_LEARNING: 81 | TRAIN_CURRICULUM: False # It indicates if MicKey should be trained with curriculum learning 82 | TRAIN_WITH_TOPK: False # It indicates if MicKey should be trained only with top image pairs 83 | TOPK_INIT: 30 84 | TOPK: 80 85 | 86 | SAMPLER: 87 | NUM_SAMPLES_MATCHES: 512 88 | 89 | PROCRUSTES: 90 | IT_MATCHES: 20 91 | IT_RANSAC: 100 92 | NUM_SAMPLED_MATCHES: 2048 93 | NUM_CORR_3D_3D: 3 94 | NUM_REFINEMENTS: 4 95 | TH_INLIER: 0.15 96 | TH_SOFT_INLIER: 0.3 -------------------------------------------------------------------------------- /config/MicKey/overlap_score_warm_up.yaml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 'MicKey' 3 | DEBUG: False 4 | MICKEY: 5 | DINOV2: 6 | DOWN_FACTOR: 14 7 | CHANNEL_DIM: 1024 8 | FLOAT16: True 9 | 10 | KP_HEADS: 11 | BLOCKS_DIM: [512, 256, 128, 64] 12 | BN: True 13 | USE_SOFTMAX: True 14 | USE_DEPTHSIGMOID: False 15 | MAX_DEPTH: 60 16 | POS_ENCODING: True 17 | 18 | DSC_HEAD: 19 | LAST_DIM: 128 20 | BLOCKS_DIM: [512, 256, 128] 21 | BN: True 22 | NORM_DSC: True 23 | POS_ENCODING: True 24 | 25 | FEATURE_MATCHER: 26 | TYPE: 'DualSoftmax' 27 | DUAL_SOFTMAX: 28 | TEMPERATURE: 0.1 29 | USE_DUSTBIN: True 30 | SINKHORN: 31 | NUM_IT: 10 32 | DUSTBIN_SCORE_INIT: 1. 33 | USE_TRANSFORMER: False 34 | 35 | TRAINING: 36 | NUM_GPUS: 4 37 | BATCH_SIZE: 24 # BS for each dataloader (in every GPU) 38 | NUM_WORKERS: 8 39 | SAMPLER: 'scene_balance' 40 | N_SAMPLES_SCENE: 100 41 | SAMPLE_WITH_REPLACEMENT: True 42 | LR: 1e-4 43 | LOG_INTERVAL: 50 44 | VAL_INTERVAL: 0.5 45 | VAL_BATCHES: 100 46 | EPOCHS: 100 47 | 48 | DATASET: 49 | HEIGHT: 480 50 | WIDTH: 360 51 | 52 | MIN_OVERLAP_SCORE: 0.4 # [train only] discard data with overlap_score < min_overlap_score 53 | MAX_OVERLAP_SCORE: 0.8 # [train only] discard data with overlap_score < min_overlap_score 54 | 55 | LOSS_CLASS: 56 | 57 | LOSS_FUNCTION: "VCRE" # VCRE or POSE_ERR 58 | SOFT_CLIPPING: True # It indicates if it soft-clips the loss values. 59 | 60 | POSE_ERR: 61 | MAX_LOSS_VALUE: 1.5 62 | MAX_LOSS_SOFTVALUE: 0.8 63 | VCRE: 64 | MAX_LOSS_VALUE: 90 65 | MAX_LOSS_SOFTVALUE: 0.8 66 | 67 | GENERATE_HYPOTHESES: 68 | SCORE_TEMPERATURE: 20 69 | IT_MATCHES: 20 70 | IT_RANSAC: 20 71 | INLIER_3D_TH: 0.3 72 | INLIER_REF_TH: 0.15 73 | NUM_REF_STEPS: 4 74 | NUM_CORR_3d3d: 8 # Bigger number of 3d-3d correspondences helps stability 75 | 76 | NULL_HYPOTHESIS: 77 | ADD_NULL_HYPOTHESIS: False 78 | TH_OUTLIERS: 0.35 79 | 80 | CURRICULUM_LEARNING: 81 | TRAIN_CURRICULUM: False # It indicates if MicKey should be trained with curriculum learning 82 | TRAIN_WITH_TOPK: False # It indicates if MicKey should be trained only with top image pairs 83 | TOPK_INIT: 30 84 | TOPK: 80 85 | 86 | SAMPLER: 87 | NUM_SAMPLES_MATCHES: 64 88 | 89 | PROCRUSTES: 90 | IT_MATCHES: 20 91 | IT_RANSAC: 100 92 | NUM_SAMPLED_MATCHES: 2048 93 | NUM_CORR_3D_3D: 3 94 | NUM_REFINEMENTS: 4 95 | TH_INLIER: 0.15 96 | TH_SOFT_INLIER: 0.3 -------------------------------------------------------------------------------- /config/datasets/mapfree.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | DATA_SOURCE: 'MapFree' 3 | DATA_ROOT: 'data/' 4 | SCENES: None # should be a list [] or None. If none, use all scenes. 5 | AUGMENTATION_TYPE: None 6 | HEIGHT: 720 7 | WIDTH: 540 8 | MIN_OVERLAP_SCORE: 0.2 # [train only] discard data with overlap_score < min_overlap_score 9 | MAX_OVERLAP_SCORE: 0.7 # [train only] discard data with overlap_score < min_overlap_score 10 | SEED: 66 -------------------------------------------------------------------------------- /config/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _CN = CN() 4 | 5 | ############## Model ############## 6 | _CN.MODEL = None # options: ['MicKey'] 7 | _CN.DEBUG = False 8 | 9 | # MicKey configuration 10 | _CN.MICKEY = CN() 11 | 12 | _CN.MICKEY.DINOV2 = CN() 13 | _CN.MICKEY.DINOV2.DOWN_FACTOR = None 14 | _CN.MICKEY.DINOV2.CHANNEL_DIM = None 15 | _CN.MICKEY.DINOV2.FLOAT16 = None 16 | 17 | _CN.MICKEY.KP_HEADS = CN() 18 | _CN.MICKEY.KP_HEADS.BLOCKS_DIM = None 19 | _CN.MICKEY.KP_HEADS.BN = None 20 | _CN.MICKEY.KP_HEADS.USE_SOFTMAX = None 21 | _CN.MICKEY.KP_HEADS.USE_DEPTHSIGMOID = None 22 | _CN.MICKEY.KP_HEADS.MAX_DEPTH = None 23 | _CN.MICKEY.KP_HEADS.POS_ENCODING = None 24 | 25 | _CN.MICKEY.DSC_HEAD = CN() 26 | _CN.MICKEY.DSC_HEAD.LAST_DIM = None 27 | _CN.MICKEY.DSC_HEAD.BLOCKS_DIM = None 28 | _CN.MICKEY.DSC_HEAD.BN = None 29 | _CN.MICKEY.DSC_HEAD.NORM_DSC = None 30 | _CN.MICKEY.DSC_HEAD.POS_ENCODING = None 31 | 32 | 33 | _CN.FEATURE_MATCHER = CN() 34 | _CN.FEATURE_MATCHER.TYPE = None 35 | _CN.FEATURE_MATCHER.DUAL_SOFTMAX = CN() 36 | _CN.FEATURE_MATCHER.DUAL_SOFTMAX.TEMPERATURE = None 37 | _CN.FEATURE_MATCHER.DUAL_SOFTMAX.USE_DUSTBIN = None 38 | _CN.FEATURE_MATCHER.SINKHORN = CN() 39 | _CN.FEATURE_MATCHER.SINKHORN.NUM_IT = None 40 | _CN.FEATURE_MATCHER.SINKHORN.DUSTBIN_SCORE_INIT = None 41 | _CN.FEATURE_MATCHER.USE_TRANSFORMER = None 42 | _CN.FEATURE_MATCHER.TOP_KEYPOINTS = False 43 | 44 | # LOSS_CLASS 45 | _CN.LOSS_CLASS = CN() 46 | _CN.LOSS_CLASS.LOSS_FUNCTION = None 47 | _CN.LOSS_CLASS.SOFT_CLIPPING = None 48 | 49 | _CN.LOSS_CLASS.POSE_ERR = CN() 50 | _CN.LOSS_CLASS.POSE_ERR.MAX_LOSS_VALUE = None 51 | _CN.LOSS_CLASS.POSE_ERR.MAX_LOSS_SOFTVALUE = None 52 | 53 | _CN.LOSS_CLASS.VCRE = CN() 54 | _CN.LOSS_CLASS.VCRE.MAX_LOSS_VALUE = None 55 | _CN.LOSS_CLASS.VCRE.MAX_LOSS_SOFTVALUE = None 56 | 57 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES = CN() 58 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.SCORE_TEMPERATURE = None 59 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.IT_MATCHES = None 60 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.IT_RANSAC = None 61 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.INLIER_3D_TH = None 62 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.INLIER_REF_TH = None 63 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.NUM_REF_STEPS = None 64 | _CN.LOSS_CLASS.GENERATE_HYPOTHESES.NUM_CORR_3d3d = None 65 | 66 | _CN.LOSS_CLASS.CURRICULUM_LEARNING = CN() 67 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_CURRICULUM = None 68 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_WITH_TOPK = None 69 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TOPK_INIT = None 70 | _CN.LOSS_CLASS.CURRICULUM_LEARNING.TOPK = None 71 | 72 | _CN.LOSS_CLASS.NULL_HYPOTHESIS = CN() 73 | _CN.LOSS_CLASS.NULL_HYPOTHESIS.ADD_NULL_HYPOTHESIS = None 74 | _CN.LOSS_CLASS.NULL_HYPOTHESIS.TH_OUTLIERS = None 75 | 76 | _CN.LOSS_CLASS.SAMPLER = CN() 77 | _CN.LOSS_CLASS.SAMPLER.NUM_SAMPLES_MATCHES = None 78 | 79 | 80 | # Procrustes RANSAC options 81 | _CN.PROCRUSTES = CN() 82 | _CN.PROCRUSTES.IT_MATCHES = None 83 | _CN.PROCRUSTES.IT_RANSAC = None 84 | _CN.PROCRUSTES.NUM_SAMPLED_MATCHES = None 85 | _CN.PROCRUSTES.NUM_CORR_3D_3D = None 86 | _CN.PROCRUSTES.NUM_REFINEMENTS = None 87 | _CN.PROCRUSTES.TH_INLIER = None 88 | _CN.PROCRUSTES.TH_SOFT_INLIER = None 89 | 90 | 91 | 92 | 93 | # Training Procrustes RANSAC options 94 | _CN.PROCRUSTES_TRAINING = CN() 95 | _CN.PROCRUSTES_TRAINING.MAX_CORR_DIST = None 96 | _CN.PROCRUSTES_TRAINING.REFINE = False #refine pose with ICP 97 | 98 | 99 | ############## Dataset ############## 100 | _CN.DATASET = CN() 101 | # 1. data config 102 | _CN.DATASET.DATA_SOURCE = None # options: ['ScanNet', '7Scenes', 'MapFree'] 103 | _CN.DATASET.SCENES = None # scenes to use (for 7Scenes/MapFree); should be a list []; If none, use all scenes. 104 | _CN.DATASET.DATA_ROOT = None # path to dataset folder 105 | _CN.DATASET.SEED = None # SEED for dataset generation 106 | _CN.DATASET.NPZ_ROOT = None # path to npz files containing pairs of frame indices per sample 107 | _CN.DATASET.MIN_OVERLAP_SCORE = None # discard data with overlap_score < min_overlap_score 108 | _CN.DATASET.MAX_OVERLAP_SCORE = None # discard data with overlap_score > max_overlap_score 109 | _CN.DATASET.CONSECUTIVE_PAIRS = None # options: [None, 'colorjitter'] 110 | _CN.DATASET.FRAME_RATE = None # options: [None, 'colorjitter'] 111 | _CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'colorjitter'] 112 | _CN.DATASET.BLACK_WHITE = False # if true, transform images to black & white 113 | _CN.DATASET.PAIRS_TXT = CN() # Path to text file defining the train/val/test pairs (7Scenes) 114 | _CN.DATASET.PAIRS_TXT.TRAIN = None 115 | _CN.DATASET.PAIRS_TXT.VAL = None 116 | _CN.DATASET.PAIRS_TXT.TEST = None 117 | _CN.DATASET.PAIRS_TXT.ONE_NN = False # If true, keeps only reference image w/ highest similarity to each query 118 | _CN.DATASET.HEIGHT = None 119 | _CN.DATASET.WIDTH = None 120 | 121 | ############# TRAINING ############# 122 | _CN.TRAINING = CN() 123 | # Data Loader settings 124 | _CN.TRAINING.BATCH_SIZE = None 125 | _CN.TRAINING.NUM_WORKERS = None 126 | _CN.TRAINING.NUM_GPUS = None 127 | _CN.TRAINING.SAMPLER = None # options: ['random', 'scene_balance'] 128 | _CN.TRAINING.N_SAMPLES_SCENE = None # if 'scene_balance' sampler, the number of samples to get per scene 129 | _CN.TRAINING.SAMPLE_WITH_REPLACEMENT = None # if 'scene_balance' sampler, whether to sample with replacement 130 | 131 | # Training settings 132 | _CN.TRAINING.LR = None 133 | _CN.TRAINING.LR_STEP_INTERVAL = None 134 | _CN.TRAINING.LR_STEP_GAMMA = None # multiplicative factor of LR every LR_STEP_ITERATIONS 135 | _CN.TRAINING.VAL_INTERVAL = None 136 | _CN.TRAINING.VAL_BATCHES = None 137 | _CN.TRAINING.LOG_INTERVAL = None 138 | _CN.TRAINING.EPOCHS = None 139 | _CN.TRAINING.GRAD_CLIP = 0. # Indicates the L2 norm at which to clip the gradient. Disabled if 0 140 | 141 | cfg = _CN -------------------------------------------------------------------------------- /data/toy_example/im0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nianticlabs/mickey/2391be8a35491e7b43481c069f5dab65030839b9/data/toy_example/im0.jpg -------------------------------------------------------------------------------- /data/toy_example/im1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nianticlabs/mickey/2391be8a35491e7b43481c069f5dab65030839b9/data/toy_example/im1.jpg -------------------------------------------------------------------------------- /data/toy_example/intrinsics.txt: -------------------------------------------------------------------------------- 1 | im0.jpg 549.7018 549.7018 268.6665 351.8357 540 720 2 | im1.jpg 549.0616 549.0616 268.8559 351.8485 540 720 -------------------------------------------------------------------------------- /demo_inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from lib.models.builder import build_model 4 | from lib.datasets.utils import correct_intrinsic_scale 5 | from lib.utils.visualization import prepare_score_map, colorize_depth, get_render, create_point_cloud_from_inliers 6 | from config.default import cfg 7 | import numpy as np 8 | from pathlib import Path 9 | import cv2 10 | 11 | 12 | def read_color_image(path, resize): 13 | """ 14 | Args: 15 | resize (tuple): align image to depthmap, in (w, h). 16 | Returns: 17 | image (torch.tensor): (3, h, w) 18 | """ 19 | # read and resize image 20 | cv_type = cv2.IMREAD_COLOR 21 | image = cv2.imread(str(path), cv_type) 22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 23 | if resize is not None: 24 | image = cv2.resize(image, resize) 25 | 26 | # (h, w, 3) -> (3, h, w) and normalized 27 | image = torch.from_numpy(image).float().permute(2, 0, 1) / 255 28 | 29 | return image.unsqueeze(0) 30 | 31 | def read_intrinsics(path_intrinsics, resize): 32 | Ks = {} 33 | with Path(path_intrinsics).open('r') as f: 34 | for line in f.readlines(): 35 | if '#' in line: 36 | continue 37 | 38 | line = line.strip().split(' ') 39 | img_name = line[0] 40 | fx, fy, cx, cy, W, H = map(float, line[1:]) 41 | 42 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) 43 | if resize is not None: 44 | K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H).numpy() 45 | Ks[img_name] = K 46 | return Ks 47 | 48 | def generate_3d_vis(data, n_matches, root_dir, batch_id, 49 | use_3d_color_coded=True, color_src_frame=[223, 71, 28], color_dst_frame=[83, 154, 218], 50 | add_dst_lines=True, add_ref_lines=True, add_ref_pts=True, add_points=True, 51 | size_box=0.03, size_2d=0.015, cam_size=1., max_conf_th=0.8, add_confidence=True, 52 | angle_y=0, angle_x=-25, cam_offset_x=0.1, cam_offset_y=0.0, cam_offset_z=-2): 53 | 54 | print('Generating 3D visualization image...') 55 | 56 | # Generate point cloud from inlier matches 57 | point_cloud = create_point_cloud_from_inliers(data, batch_id, use_3d_color_coded) 58 | 59 | # Prepare the data: 60 | data_np = {} 61 | data_np['K_color0'] = data['K_color0'].detach().cpu().numpy() 62 | data_np['K_color1'] = data['K_color1'].detach().cpu().numpy() 63 | data_np['image0'] = 255 * data['image0'].permute(0, 2, 3, 1).detach().cpu().numpy() 64 | data_np['image1'] = 255 * data['image1'].permute(0, 2, 3, 1).detach().cpu().numpy() 65 | 66 | R = data['R'][batch_id][np.newaxis].detach().cpu().numpy() 67 | t = data['t'][batch_id].detach().cpu().numpy().reshape(-1) 68 | P = np.eye(4) 69 | P[:3, :3] = R 70 | P[:3, 3] = t 71 | 72 | # Render the image with camera and 3D points 73 | frame = get_render(P, data, batch_id, point_cloud, color_src_frame, color_dst_frame, 74 | angle_y, angle_x, cam_offset_x, cam_offset_y, cam_offset_z, cam_size, size_box, size_2d, 75 | add_ref_lines, add_dst_lines, add_ref_pts, add_points, n_matches, max_conf_th, add_confidence) 76 | 77 | cv2.imwrite(root_dir + '/3d_vis.png', cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)) 78 | 79 | def run_demo_inference(args): 80 | 81 | # Select device 82 | use_cuda = torch.cuda.is_available() 83 | device = torch.device('cuda:0' if use_cuda else 'cpu') 84 | 85 | print('Preparing data...') 86 | 87 | # Prepare config file 88 | cfg.merge_from_file(args.config) 89 | 90 | # Prepare the model 91 | model = build_model(cfg, checkpoint=args.checkpoint) 92 | 93 | # Load demo images 94 | im0 = read_color_image(args.im_path_ref, args.resize).to(device) 95 | im1 = read_color_image(args.im_path_dst, args.resize).to(device) 96 | 97 | # Load intrinsics 98 | K = read_intrinsics(args.intrinsics, args.resize) 99 | 100 | # Prepare data for MicKey 101 | batch_id = 0 102 | im0_name = args.im_path_ref.split('/')[-1] 103 | im1_name = args.im_path_dst.split('/')[-1] 104 | data = {} 105 | data['image0'] = im0 106 | data['image1'] = im1 107 | data['K_color0'] = torch.from_numpy(K[im0_name]).unsqueeze(0).to(device) 108 | data['K_color1'] = torch.from_numpy(K[im1_name]).unsqueeze(0).to(device) 109 | 110 | # Run inference 111 | print('Running MicKey relative pose estimation...') 112 | model(data, return_inliers=args.generate_3D_vis) 113 | 114 | # Pose, inliers and score are stored in: 115 | # data['R'] = R 116 | # data['t'] = t 117 | # data['inliers'] = inliers 118 | 119 | print('Saving depth and score maps in image directory ...') 120 | depth0_map = colorize_depth(data['depth0_map'][batch_id], invalid_mask=(data['depth0_map'][batch_id] < 0.001).cpu()[0]) 121 | depth1_map = colorize_depth(data['depth1_map'][batch_id], invalid_mask=(data['depth1_map'][batch_id] < 0.001).cpu()[0]) 122 | score0_map = prepare_score_map(data['scr0'][batch_id], data['image0'][batch_id], temperature=0.5) 123 | score1_map = prepare_score_map(data['scr1'][batch_id], data['image1'][batch_id], temperature=0.5) 124 | 125 | ext_im0 = args.im_path_ref.split('.')[-1] 126 | ext_im1 = args.im_path_dst.split('.')[-1] 127 | 128 | cv2.imwrite(args.im_path_ref.replace(ext_im0, 'score.jpg'), score0_map) 129 | cv2.imwrite(args.im_path_dst.replace(ext_im1, 'score.jpg'), score1_map) 130 | 131 | cv2.imwrite(args.im_path_ref.replace(ext_im0, 'depth.jpg'), depth0_map) 132 | cv2.imwrite(args.im_path_dst.replace(ext_im1, 'depth.jpg'), depth1_map) 133 | 134 | if args.generate_3D_vis: 135 | # We use the maximum possible number of inliers to draw the confidence 136 | n_matches = model.e2e_Procrustes.num_samples_matches 137 | dir_name = '/'.join(args.im_path_ref.split('/')[:-1]) 138 | generate_3d_vis(data, n_matches, dir_name, batch_id) 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--im_path_ref', help='path to reference image', default='data/toy_example/im0.jpg') 143 | parser.add_argument('--im_path_dst', help='path to destination image', default='data/toy_example/im1.jpg') 144 | parser.add_argument('--intrinsics', help='path to intrinsics file', default='data/toy_example/intrinsics.txt') 145 | parser.add_argument('--resize', nargs=2, type=int, help='resize applied to the image and intrinsics (w, h)', default=None) 146 | parser.add_argument('--config', help='path to config file', default='weights/mickey_weights/config.yaml') 147 | parser.add_argument('--checkpoint', help='path to model checkpoint', 148 | default='weights/mickey_weights/mickey.ckpt') 149 | parser.add_argument('--generate_3D_vis', help='Set to True to generate a 3D visualisation of the output poses', 150 | default=False) 151 | args = parser.parse_args() 152 | 153 | run_demo_inference(args) 154 | 155 | -------------------------------------------------------------------------------- /lib/benchmarks/reprojection.py: -------------------------------------------------------------------------------- 1 | # Code from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | 3 | from typing import List, Tuple 4 | import numpy as np 5 | from transforms3d.quaternions import quat2mat 6 | 7 | def project(pts: np.ndarray, K: np.ndarray, img_size: List[int] or Tuple[int] = None) -> np.ndarray: 8 | """Projects 3D points to image plane. 9 | 10 | Args: 11 | - pts [N, 3/4]: points in camera coordinates (homogeneous or non-homogeneous) 12 | - K [3, 3]: intrinsic matrix 13 | - img_size (width, height): optional, clamp projection to image borders 14 | Outputs: 15 | - uv [N, 2]: coordinates of projected points 16 | """ 17 | 18 | assert len(pts.shape) == 2, 'incorrect number of dimensions' 19 | assert pts.shape[1] in [3, 4], 'invalid dimension size' 20 | assert K.shape == (3, 3), 'incorrect intrinsic shape' 21 | 22 | uv_h = (K @ pts[:, :3].T).T 23 | uv = uv_h[:, :2] / uv_h[:, -1:] 24 | 25 | if img_size is not None: 26 | uv[:, 0] = np.clip(uv[:, 0], 0, img_size[0]) 27 | uv[:, 1] = np.clip(uv[:, 1], 0, img_size[1]) 28 | 29 | return uv 30 | 31 | 32 | def get_grid_multipleheight() -> np.ndarray: 33 | # create grid of points 34 | ar_grid_step = 0.3 35 | ar_grid_num_x = 7 36 | ar_grid_num_y = 4 37 | ar_grid_num_z = 7 38 | ar_grid_z_offset = 1.8 39 | ar_grid_y_offset = 0 40 | 41 | ar_grid_x_pos = np.arange(0, ar_grid_num_x)-(ar_grid_num_x-1)/2 42 | ar_grid_x_pos *= ar_grid_step 43 | 44 | ar_grid_y_pos = np.arange(0, ar_grid_num_y)-(ar_grid_num_y-1)/2 45 | ar_grid_y_pos *= ar_grid_step 46 | ar_grid_y_pos += ar_grid_y_offset 47 | 48 | ar_grid_z_pos = np.arange(0, ar_grid_num_z).astype(float) 49 | ar_grid_z_pos *= ar_grid_step 50 | ar_grid_z_pos += ar_grid_z_offset 51 | 52 | xx, yy, zz = np.meshgrid(ar_grid_x_pos, ar_grid_y_pos, ar_grid_z_pos) 53 | ones = np.ones(xx.shape[0]*xx.shape[1]*xx.shape[2]) 54 | eye_coords = np.concatenate([c.reshape(-1, 1) 55 | for c in (xx, yy, zz, ones)], axis=-1) 56 | return eye_coords 57 | 58 | 59 | # global variable, avoids creating it again 60 | eye_coords_glob = get_grid_multipleheight() 61 | 62 | 63 | def reprojection_error( 64 | q_est: np.ndarray, t_est: np.ndarray, q_gt: np.ndarray, t_gt: np.ndarray, K: np.ndarray, 65 | W: int, H: int) -> float: 66 | eye_coords = eye_coords_glob 67 | 68 | # obtain ground-truth position of projected points 69 | uv_gt = project(eye_coords, K, (W, H)) 70 | 71 | # residual transformation 72 | cam2w_est = np.eye(4) 73 | cam2w_est[:3, :3] = quat2mat(q_est) 74 | cam2w_est[:3, -1] = t_est 75 | cam2w_gt = np.eye(4) 76 | cam2w_gt[:3, :3] = quat2mat(q_gt) 77 | cam2w_gt[:3, -1] = t_gt 78 | 79 | # residual reprojection 80 | eyes_residual = (np.linalg.inv(cam2w_est) @ cam2w_gt @ eye_coords.T).T 81 | uv_pred = project(eyes_residual, K, (W, H)) 82 | 83 | # get reprojection error 84 | repr_err = np.linalg.norm(uv_gt - uv_pred, ord=2, axis=1) 85 | mean_repr_err = float(repr_err.mean().item()) 86 | return mean_repr_err 87 | -------------------------------------------------------------------------------- /lib/benchmarks/utils.py: -------------------------------------------------------------------------------- 1 | # Code from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | 3 | from pathlib import Path 4 | import typing 5 | import logging 6 | 7 | import numpy as np 8 | from transforms3d.quaternions import qinverse, rotate_vector, qmult 9 | 10 | VARIANTS_ANGLE_SIN = 'sin' 11 | VARIANTS_ANGLE_COS = 'cos' 12 | 13 | 14 | def convert_world2cam_to_cam2world(q, t): 15 | qinv = qinverse(q) 16 | tinv = -rotate_vector(t, qinv) 17 | return qinv, tinv 18 | 19 | 20 | def load_poses(file: typing.IO, load_confidence: bool = False): 21 | """Load poses from text file and converts them to cam2world convention (t is the camera center in world coordinates) 22 | 23 | The text file encodes world2cam poses with the format: 24 | imgpath qw qx qy qz tx ty tz [confidence] 25 | where qw qx qy qz is the quaternion encoding rotation, 26 | and tx ty tz is the translation vector, 27 | and confidence is a float encoding confidence, for estimated poses 28 | """ 29 | 30 | expected_parts = 9 if load_confidence else 8 31 | 32 | poses = dict() 33 | for line_number, line in enumerate(file.readlines()): 34 | parts = tuple(line.strip().split(' ')) 35 | 36 | # if 'tensor' in parts[-1]: 37 | # print('ERROR: confidence is a tensor') 38 | # parts = list(parts) 39 | # parts[-1] = parts[-1].split('[')[-1].split(']')[0] 40 | if len(parts) != expected_parts: 41 | logging.warning( 42 | f'Invalid number of fields in file {file.name} line {line_number}.' 43 | f' Expected {expected_parts}, received {len(parts)}. Ignoring line.') 44 | continue 45 | 46 | try: 47 | name = parts[0] 48 | if '#' in name: 49 | logging.info(f'Ignoring comment line in {file.name} line {line_number}') 50 | continue 51 | frame_num = int(name[-9:-4]) 52 | except ValueError: 53 | logging.warning( 54 | f'Invalid frame number in file {file.name} line {line_number}.' 55 | f' Expected formatting "seq1/frame_00000.jpg". Ignoring line.') 56 | continue 57 | 58 | try: 59 | parts_float = tuple(map(float, parts[1:])) 60 | if any(np.isnan(v) or np.isinf(v) for v in parts_float): 61 | raise ValueError() 62 | qw, qx, qy, qz, tx, ty, tz = parts_float[:7] 63 | confidence = parts_float[7] if load_confidence else None 64 | except ValueError: 65 | logging.warning( 66 | f'Error parsing pose in file {file.name} line {line_number}. Ignoring line.') 67 | continue 68 | 69 | q = np.array((qw, qx, qy, qz), dtype=np.float64) 70 | t = np.array((tx, ty, tz), dtype=np.float64) 71 | 72 | if np.isclose(np.linalg.norm(q), 0): 73 | logging.warning( 74 | f'Error parsing pose in file {file.name} line {line_number}. ' 75 | 'Quaternion must have non-zero norm. Ignoring line.') 76 | continue 77 | 78 | q, t = convert_world2cam_to_cam2world(q, t) 79 | poses[frame_num] = (q, t, confidence) 80 | return poses 81 | 82 | 83 | def subsample_poses(poses: dict, subsample: int = 1): 84 | return {k: v for i, (k, v) in enumerate(poses.items()) if i % subsample == 0} 85 | 86 | 87 | def load_K(file_path: Path): 88 | K = dict() 89 | with file_path.open('r', encoding='utf-8') as f: 90 | for line in f.readlines(): 91 | if '#' in line: 92 | continue 93 | line = line.strip().split(' ') 94 | 95 | frame_num = int(line[0][-9:-4]) 96 | fx, fy, cx, cy, W, H = map(float, line[1:]) 97 | K[frame_num] = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) 98 | return K, W, H 99 | 100 | 101 | def quat_angle_error(label, pred, variant=VARIANTS_ANGLE_SIN) -> np.ndarray: 102 | assert label.shape == (4,) 103 | assert pred.shape == (4,) 104 | assert variant in (VARIANTS_ANGLE_SIN, VARIANTS_ANGLE_COS), \ 105 | f"Need variant to be in ({VARIANTS_ANGLE_SIN}, {VARIANTS_ANGLE_COS})" 106 | 107 | if len(label.shape) == 1: 108 | label = np.expand_dims(label, axis=0) 109 | if len(label.shape) != 2 or label.shape[0] != 1 or label.shape[1] != 4: 110 | raise RuntimeError(f"Unexpected shape of label: {label.shape}, expected: (1, 4)") 111 | 112 | if len(pred.shape) == 1: 113 | pred = np.expand_dims(pred, axis=0) 114 | if len(pred.shape) != 2 or pred.shape[0] != 1 or pred.shape[1] != 4: 115 | raise RuntimeError(f"Unexpected shape of pred: {pred.shape}, expected: (1, 4)") 116 | 117 | label = label.astype(np.float64) 118 | pred = pred.astype(np.float64) 119 | 120 | q1 = pred / np.linalg.norm(pred, axis=1, keepdims=True) 121 | q2 = label / np.linalg.norm(label, axis=1, keepdims=True) 122 | if variant == VARIANTS_ANGLE_COS: 123 | d = np.abs(np.sum(np.multiply(q1, q2), axis=1, keepdims=True)) 124 | d = np.clip(d, a_min=-1, a_max=1) 125 | angle = 2. * np.degrees(np.arccos(d)) 126 | elif variant == VARIANTS_ANGLE_SIN: 127 | if q1.shape[0] != 1 or q2.shape[0] != 1: 128 | raise NotImplementedError(f"Multiple angles is todo") 129 | # https://www.researchgate.net/post/How_do_I_calculate_the_smallest_angle_between_two_quaternions/5d6ed4a84f3a3e1ed3656616/citation/download 130 | sine = qmult(q1[0], qinverse(q2[0])) # note: takes first element in 2D array 131 | # 114.59 = 2. * 180. / pi 132 | angle = np.arcsin(np.linalg.norm(sine[1:], keepdims=True)) * 114.59155902616465 133 | angle = np.expand_dims(angle, axis=0) 134 | 135 | return angle.astype(np.float64) 136 | 137 | 138 | def precision_recall(inliers, tp, failures): 139 | """ 140 | Computes Precision/Recall plot for a set of poses given inliers (confidence) and wether the 141 | estimated pose error (whatever it may be) is within a threshold. 142 | Each point in the plot is obtained by choosing a threshold for inliers (i.e. inlier_thr). 143 | Recall measures how many images have inliers >= inlier_thr 144 | Precision measures how many images that have inliers >= inlier_thr have 145 | estimated pose error <= pose_threshold (measured by counting tps) 146 | Where pose_threshold is (trans_thr[m], rot_thr[deg]) 147 | 148 | Inputs: 149 | - inliers [N] 150 | - terr [N] 151 | - rerr [N] 152 | - failures (int) 153 | - pose_threshold (tuple float) 154 | Output 155 | - precision [N] 156 | - recall [N] 157 | - average_precision (scalar) 158 | """ 159 | 160 | assert len(inliers) == len(tp), 'unequal shapes' 161 | 162 | # sort by inliers (descending order) 163 | inliers = np.array(inliers) 164 | sort_idx = np.argsort(inliers)[::-1] 165 | inliers = inliers[sort_idx] 166 | tp = np.array(tp).reshape(-1)[sort_idx] 167 | 168 | # get idxs where inliers change (avoid tied up values) 169 | distinct_value_indices = np.where(np.diff(inliers))[0] 170 | threshold_idxs = np.r_[distinct_value_indices, inliers.size - 1] 171 | 172 | # compute prec/recall 173 | N = inliers.shape[0] 174 | rec = np.arange(N, dtype=np.float32) + 1 175 | cum_tp = np.cumsum(tp) 176 | prec = cum_tp[threshold_idxs] / rec[threshold_idxs] 177 | rec = rec[threshold_idxs] / (float(N) + float(failures)) 178 | 179 | # invert order and ensures (prec=1, rec=0) point 180 | last_ind = rec.searchsorted(rec[-1]) 181 | sl = slice(last_ind, None, -1) 182 | prec = np.r_[prec[sl], 1] 183 | rec = np.r_[rec[sl], 0] 184 | 185 | # compute average precision (AUC) as the weighted average of precisions 186 | average_precision = np.abs(np.sum(np.diff(rec) * np.array(prec)[:-1])) 187 | 188 | return prec, rec, average_precision 189 | -------------------------------------------------------------------------------- /lib/datasets/datamodules.py: -------------------------------------------------------------------------------- 1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | 3 | import torch.utils as utils 4 | from torchvision.transforms import ColorJitter, Grayscale 5 | import pytorch_lightning as pl 6 | 7 | from torch.utils.data import DataLoader 8 | from lib.datasets.sampler import RandomConcatSampler 9 | from lib.datasets.mapfree import MapFreeDataset 10 | 11 | 12 | class DataModule(pl.LightningDataModule): 13 | def __init__(self, cfg, drop_last_val=True): 14 | super().__init__() 15 | self.cfg = cfg 16 | self.drop_last_val = drop_last_val 17 | 18 | datasets = {'MapFree': MapFreeDataset} 19 | 20 | assert cfg.DATASET.DATA_SOURCE in datasets.keys(), 'invalid DATA_SOURCE, this dataset is not implemented' 21 | self.dataset_type = datasets[cfg.DATASET.DATA_SOURCE] 22 | 23 | def get_sampler(self, dataset, reset_epoch=False): 24 | if self.cfg.TRAINING.SAMPLER == 'scene_balance': 25 | sampler = RandomConcatSampler(dataset, 26 | self.cfg.TRAINING.N_SAMPLES_SCENE, 27 | self.cfg.TRAINING.SAMPLE_WITH_REPLACEMENT, 28 | shuffle=True, 29 | reset_on_iter=reset_epoch 30 | ) 31 | else: 32 | sampler = None 33 | return sampler 34 | 35 | def train_dataloader(self): 36 | transforms = ColorJitter() if self.cfg.DATASET.AUGMENTATION_TYPE == 'colorjitter' else None 37 | transforms = Grayscale( 38 | num_output_channels=3) if self.cfg.DATASET.BLACK_WHITE else transforms 39 | 40 | dataset = self.dataset_type(self.cfg, 'train', transforms=transforms) 41 | sampler = self.get_sampler(dataset) 42 | 43 | dataloader = utils.data.DataLoader(dataset, 44 | batch_size=self.cfg.TRAINING.BATCH_SIZE, 45 | num_workers=self.cfg.TRAINING.NUM_WORKERS, 46 | sampler=sampler 47 | ) 48 | return dataloader 49 | 50 | def val_dataloader(self): 51 | dataset = self.dataset_type(self.cfg, 'val') 52 | dataloader = utils.data.DataLoader(dataset, 53 | batch_size=self.cfg.TRAINING.BATCH_SIZE, 54 | num_workers=self.cfg.TRAINING.NUM_WORKERS, 55 | sampler=None, 56 | drop_last=self.drop_last_val 57 | ) 58 | return dataloader 59 | 60 | def test_dataloader(self): 61 | dataset = self.dataset_type(self.cfg, 'test') 62 | dataloader = utils.data.DataLoader(dataset, 63 | batch_size=self.cfg.TRAINING.BATCH_SIZE, 64 | num_workers=self.cfg.TRAINING.NUM_WORKERS, 65 | shuffle=False, 66 | drop_last=self.drop_last_val) 67 | return dataloader 68 | 69 | 70 | class DataModuleTraining(pl.LightningDataModule): 71 | def __init__(self, cfg): 72 | super().__init__() 73 | self.cfg = cfg 74 | self.seed = cfg.DATASET.SEED 75 | 76 | datasets = {'MapFree': MapFreeDataset} 77 | 78 | assert cfg.DATASET.DATA_SOURCE in datasets.keys(), 'invalid DATA_SOURCE, this dataset is not implemented' 79 | self.dataset_type = datasets[cfg.DATASET.DATA_SOURCE] 80 | 81 | def get_sampler(self, dataset, reset_epoch=False, seed=66): 82 | if self.cfg.TRAINING.SAMPLER == 'scene_balance': 83 | sampler = RandomConcatSampler(dataset, 84 | self.cfg.TRAINING.N_SAMPLES_SCENE, 85 | self.cfg.TRAINING.SAMPLE_WITH_REPLACEMENT, 86 | shuffle=True, 87 | reset_on_iter=reset_epoch, 88 | seed=seed) 89 | else: 90 | sampler = None 91 | return sampler 92 | 93 | def train_dataloader(self): 94 | transforms = ColorJitter() if self.cfg.DATASET.AUGMENTATION_TYPE == 'colorjitter' else None 95 | transforms = Grayscale( 96 | num_output_channels=3) if self.cfg.DATASET.BLACK_WHITE else transforms 97 | 98 | dataset = self.dataset_type(self.cfg, 'train', transforms=transforms) 99 | sampler = self.get_sampler(dataset, seed=self.seed) 100 | dataloader = DataLoader(dataset, 101 | batch_size=self.cfg.TRAINING.BATCH_SIZE, 102 | num_workers=self.cfg.TRAINING.NUM_WORKERS, 103 | sampler=sampler) 104 | return dataloader 105 | 106 | def val_dataloader(self): 107 | dataset = self.dataset_type(self.cfg, 'val') 108 | sampler = self.get_sampler(dataset, reset_epoch=True) 109 | # sampler = None 110 | dataloader = DataLoader(dataset, 111 | batch_size=self.cfg.TRAINING.BATCH_SIZE, 112 | num_workers=self.cfg.TRAINING.NUM_WORKERS, 113 | sampler=sampler, 114 | drop_last=True) 115 | return dataloader 116 | 117 | def test_dataloader(self): 118 | dataset = self.dataset_type(self.cfg, 'test') 119 | dataloader = DataLoader(dataset, 120 | batch_size=self.cfg.TRAINING.BATCH_SIZE, 121 | num_workers=self.cfg.TRAINING.NUM_WORKERS, 122 | shuffle=False) 123 | return dataloader -------------------------------------------------------------------------------- /lib/datasets/mapfree.py: -------------------------------------------------------------------------------- 1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | 3 | from pathlib import Path 4 | import torch 5 | import torch.utils.data as data 6 | import numpy as np 7 | from transforms3d.quaternions import qinverse, qmult, rotate_vector, quat2mat 8 | from lib.datasets.utils import read_color_image, read_depth_image, correct_intrinsic_scale 9 | 10 | class MapFreeScene(data.Dataset): 11 | def __init__( 12 | self, scene_root, resize, sample_factor=1, overlap_limits=None, transforms=None, 13 | test_scene=False): 14 | super().__init__() 15 | 16 | self.scene_root = Path(scene_root) 17 | self.resize = resize 18 | self.sample_factor = sample_factor 19 | self.transforms = transforms 20 | self.test_scene = test_scene 21 | 22 | # load absolute poses 23 | self.poses = self.read_poses(self.scene_root) 24 | 25 | # read intrinsics 26 | self.K, self.K_ori = self.read_intrinsics(self.scene_root, resize) 27 | 28 | # load pairs 29 | self.pairs = self.load_pairs(self.scene_root, overlap_limits, self.sample_factor) 30 | 31 | @staticmethod 32 | def read_intrinsics(scene_root: Path, resize=None): 33 | Ks = {} 34 | K_ori = {} 35 | with (scene_root / 'intrinsics.txt').open('r') as f: 36 | for line in f.readlines(): 37 | if '#' in line: 38 | continue 39 | 40 | line = line.strip().split(' ') 41 | img_name = line[0] 42 | fx, fy, cx, cy, W, H = map(float, line[1:]) 43 | 44 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) 45 | K_ori[img_name] = K 46 | if resize is not None: 47 | K = correct_intrinsic_scale(K, resize[0] / W, resize[1] / H) 48 | Ks[img_name] = K 49 | return Ks, K_ori 50 | 51 | @staticmethod 52 | def read_poses(scene_root: Path): 53 | """ 54 | Returns a dictionary that maps: img_path -> (q, t) where 55 | np.array q = (qw, qx qy qz) quaternion encoding rotation matrix; 56 | np.array t = (tx ty tz) translation vector; 57 | (q, t) encodes absolute pose (world-to-camera), i.e. X_c = R(q) X_W + t 58 | """ 59 | poses = {} 60 | with (scene_root / 'poses.txt').open('r') as f: 61 | for line in f.readlines(): 62 | if '#' in line: 63 | continue 64 | 65 | line = line.strip().split(' ') 66 | img_name = line[0] 67 | qt = np.array(list(map(float, line[1:]))) 68 | poses[img_name] = (qt[:4], qt[4:]) 69 | return poses 70 | 71 | def load_pairs(self, scene_root: Path, overlap_limits: tuple = None, sample_factor: int = 1): 72 | """ 73 | For training scenes, filter pairs of frames based on overlap (pre-computed in overlaps.npz) 74 | For test/val scenes, pairs are formed between keyframe and every other sample_factor query frames. 75 | If sample_factor == 1, all query frames are used. Note: sample_factor applicable only to test/val 76 | Returns: 77 | pairs: nd.array [Npairs, 4], where each column represents seaA, imA, seqB, imB, respectively 78 | """ 79 | pairs = self.load_pairs_overlap(scene_root, overlap_limits, sample_factor) 80 | 81 | return pairs 82 | 83 | def load_pairs_overlap(self, scene_root: Path, overlap_limits: tuple = None, sample_factor: int = 1): 84 | overlaps_path = scene_root / 'overlaps.npz' 85 | 86 | if overlaps_path.exists(): 87 | f = np.load(overlaps_path, allow_pickle=True) 88 | idxs, overlaps = f['idxs'], f['overlaps'] 89 | if overlap_limits is not None: 90 | min_overlap, max_overlap = overlap_limits 91 | mask = (overlaps > min_overlap) * (overlaps < max_overlap) 92 | idxs = idxs[mask] 93 | return idxs.copy() 94 | else: 95 | idxs = np.zeros((len(self.poses) - 1, 4), dtype=np.uint16) 96 | idxs[:, 2] = 1 97 | idxs[:, 3] = np.array([int(fn[-9:-4]) 98 | for fn in self.poses.keys() if 'seq0' not in fn], dtype=np.uint16) 99 | return idxs[::sample_factor] 100 | 101 | def get_pair_path(self, pair): 102 | seqA, imgA, seqB, imgB = pair 103 | return (f'seq{seqA}/frame_{imgA:05}.jpg', f'seq{seqB}/frame_{imgB:05}.jpg') 104 | 105 | def __len__(self): 106 | return len(self.pairs) 107 | 108 | def __getitem__(self, index): 109 | # image paths (relative to scene_root) 110 | im1_path, im2_path = self.get_pair_path(self.pairs[index]) 111 | 112 | # load color images 113 | image1 = read_color_image(self.scene_root / im1_path, 114 | self.resize, augment_fn=self.transforms) 115 | image2 = read_color_image(self.scene_root / im2_path, 116 | self.resize, augment_fn=self.transforms) 117 | 118 | # get absolute pose of im0 and im1 119 | if self.test_scene: 120 | t1, t2, c1, c2 = np.zeros([3]), np.zeros([3]), np.zeros([3]), np.zeros([3]) 121 | q1, q2 = np.zeros([4]), np.zeros([4]) 122 | T = np.zeros([4, 4]) 123 | else: 124 | # quaternion and translation vector that transforms World-to-Cam 125 | q1, t1 = self.poses[im1_path] 126 | # quaternion and translation vector that transforms World-to-Cam 127 | q2, t2 = self.poses[im2_path] 128 | c1 = rotate_vector(-t1, qinverse(q1)) # center of camera 1 in world coordinates) 129 | c2 = rotate_vector(-t2, qinverse(q2)) # center of camera 2 in world coordinates) 130 | 131 | # get 4 x 4 relative pose transformation matrix (from im1 to im2) 132 | # for val set, q1,t1 is the identity pose, so the relative pose matches the absolute pose 133 | q12 = qmult(q2, qinverse(q1)) 134 | t12 = t2 - rotate_vector(t1, q12) 135 | T = np.eye(4, dtype=np.float32) 136 | T[:3, :3] = quat2mat(q12) 137 | T[:3, -1] = t12 138 | 139 | T = torch.from_numpy(T) 140 | 141 | data = { 142 | 'image0': image1, # (3, h, w) 143 | 'image1': image2, 144 | 'T_0to1': T, # (4, 4) # relative pose 145 | 'abs_q_0': q1, 146 | 'abs_c_0': c1, 147 | 'abs_q_1': q2, 148 | 'abs_c_1': c2, 149 | 'K_color0': self.K[im1_path], # (3, 3) 150 | 'Kori_color0': self.K_ori[im1_path], # (3, 3) 151 | 'K_color1': self.K[im2_path], # (3, 3) 152 | 'Kori_color1': self.K_ori[im2_path], # (3, 3) 153 | 'dataset_name': 'Mapfree', 154 | 'scene_id': self.scene_root.stem, 155 | 'scene_root': str(self.scene_root), 156 | 'pair_id': index*self.sample_factor, 157 | 'pair_names': (im1_path, im2_path), 158 | } 159 | 160 | return data 161 | 162 | 163 | class MapFreeDataset(data.ConcatDataset): 164 | def __init__(self, cfg, mode, transforms=None): 165 | assert mode in ['train', 'val', 'test'], 'Invalid dataset mode' 166 | 167 | data_root = Path(cfg.DATASET.DATA_ROOT) / mode 168 | resize = (cfg.DATASET.WIDTH, cfg.DATASET.HEIGHT) 169 | 170 | if mode=='test': 171 | test_scene = True 172 | else: 173 | test_scene = False 174 | 175 | overlap_limits = (cfg.DATASET.MIN_OVERLAP_SCORE, cfg.DATASET.MAX_OVERLAP_SCORE) 176 | sample_factor = {'train': 1, 'val': 5, 'test': 5}[mode] 177 | 178 | scenes = cfg.DATASET.SCENES 179 | if scenes is None: 180 | # Locate all scenes of the current dataset 181 | scenes = [s.name for s in data_root.iterdir() if s.is_dir()] 182 | 183 | if cfg.DEBUG: 184 | if mode=='train': 185 | scenes = scenes[:30] 186 | elif mode=='val': 187 | scenes = scenes[:10] 188 | 189 | # Init dataset objects for each scene 190 | data_srcs = [ 191 | MapFreeScene( 192 | data_root / scene, resize, sample_factor, overlap_limits, transforms, 193 | test_scene) for scene in scenes] 194 | super().__init__(data_srcs) 195 | -------------------------------------------------------------------------------- /lib/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | # From https://github.com/zju3dv/LoFTR/blob/261baf641cb9ada07dd9746e420ada7fe8a03152/src/datasets/sampler.py 3 | import torch 4 | from torch.utils.data import Sampler, ConcatDataset 5 | 6 | class RandomConcatSampler(Sampler): 7 | """ Random sampler for ConcatDataset. At each epoch, `n_samples_per_subset` samples will be draw from each subset 8 | in the ConcatDataset. If `subset_replacement` is ``True``, sampling within each subset will be done with replacement. 9 | However, it is impossible to sample data without replacement between epochs, unless bulding a stateful sampler lived along the entire training phase. 10 | 11 | For current implementation, the randomness of sampling is ensured no matter the sampler is recreated across epochs or not and call `torch.manual_seed()` or not. 12 | Args: 13 | shuffle (bool): shuffle the random sampled indices across all sub-datsets. 14 | repeat (int): repeatedly use the sampled indices multiple times for training. 15 | [arXiv:1902.05509, arXiv:1901.09335] 16 | NOTE: Don't re-initialize the sampler between epochs (will lead to repeated samples) 17 | NOTE: This sampler behaves differently with DistributedSampler. 18 | It assume the dataset is splitted across ranks instead of replicated. 19 | TODO: Add a `set_epoch()` method to fullfill sampling without replacement across epochs. 20 | ref: https://github.com/PyTorchLightning/pytorch-lightning/blob/e9846dd758cfb1500eb9dba2d86f6912eb487587/pytorch_lightning/trainer/training_loop.py#L373 21 | """ 22 | 23 | def __init__(self, 24 | data_source: ConcatDataset, 25 | n_samples_per_subset: int, 26 | subset_replacement: bool = True, 27 | shuffle: bool = True, 28 | repeat: int = 1, 29 | seed: int = 66, 30 | reset_on_iter: bool = False): 31 | if not isinstance(data_source, ConcatDataset): 32 | raise TypeError("data_source should be torch.utils.data.ConcatDataset") 33 | 34 | self.data_source = data_source 35 | self.n_subset = len(self.data_source.datasets) 36 | self.n_samples_per_subset = n_samples_per_subset 37 | self.n_samples = self.n_subset * self.n_samples_per_subset * repeat 38 | self.subset_replacement = subset_replacement 39 | self.repeat = repeat 40 | self.shuffle = shuffle 41 | self.seed = seed 42 | self.reset_on_iter = reset_on_iter # If true, recreate random seed to that samples are the same every epoch 43 | self.generator = torch.manual_seed(self.seed) 44 | assert self.repeat >= 1 45 | 46 | def __len__(self): 47 | return self.n_samples 48 | 49 | def __iter__(self): 50 | if self.reset_on_iter: 51 | self.generator = torch.manual_seed(self.seed) 52 | 53 | indices = [] 54 | # sample from each sub-dataset 55 | for d_idx in range(self.n_subset): 56 | low = 0 if d_idx == 0 else self.data_source.cumulative_sizes[d_idx - 1] 57 | high = self.data_source.cumulative_sizes[d_idx] 58 | if self.subset_replacement: 59 | rand_tensor = torch.randint(low, high, (self.n_samples_per_subset,), 60 | generator=self.generator, dtype=torch.int64) 61 | else: # sample without replacement 62 | len_subset = len(self.data_source.datasets[d_idx]) 63 | rand_tensor = torch.randperm(len_subset, generator=self.generator) + low 64 | if len_subset >= self.n_samples_per_subset: 65 | rand_tensor = rand_tensor[:self.n_samples_per_subset] 66 | else: # padding with replacement 67 | rand_tensor_replacement = torch.randint( 68 | low, high, (self.n_samples_per_subset - len_subset,), 69 | generator=self.generator, dtype=torch.int64) 70 | rand_tensor = torch.cat([rand_tensor, rand_tensor_replacement]) 71 | indices.append(rand_tensor) 72 | indices = torch.cat(indices) 73 | if self.shuffle: # shuffle the sampled dataset (from multiple subsets) 74 | rand_tensor = torch.randperm(len(indices), generator=self.generator) 75 | indices = indices[rand_tensor] 76 | 77 | # repeat the sampled indices (can be used for RepeatAugmentation or pure RepeatSampling) 78 | if self.repeat > 1: 79 | repeat_indices = [indices.clone() for _ in range(self.repeat - 1)] 80 | if self.shuffle: 81 | def _choice(x): return x[torch.randperm(len(x), generator=self.generator)] 82 | repeat_indices = map(_choice, repeat_indices) 83 | indices = torch.cat([indices, *repeat_indices], 0) 84 | 85 | assert indices.shape[0] == self.n_samples 86 | return iter(indices.tolist()) 87 | -------------------------------------------------------------------------------- /lib/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from numpy.linalg import inv 7 | 8 | def imread(path, augment_fn=None): 9 | cv_type = cv2.IMREAD_COLOR 10 | image = cv2.imread(str(path), cv_type) 11 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 12 | 13 | if augment_fn is not None: 14 | image = cv2.imread(str(path), cv2.IMREAD_COLOR) 15 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 16 | image = augment_fn(image) 17 | image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 18 | return image # (h, w, 3) 19 | 20 | 21 | def get_resized_wh(w, h, resize=None): 22 | if resize is not None: # resize the longer edge 23 | scale = resize / max(h, w) 24 | w_new, h_new = int(round(w * scale)), int(round(h * scale)) 25 | else: 26 | w_new, h_new = w, h 27 | return w_new, h_new 28 | 29 | 30 | def get_divisible_wh(w, h, df=None): 31 | if df is not None: 32 | w_new, h_new = map(lambda x: int(x // df * df), [w, h]) 33 | else: 34 | w_new, h_new = w, h 35 | return w_new, h_new 36 | 37 | 38 | def pad_bottom_right(inp, pad_size, df, ret_mask=False, border=3): 39 | assert isinstance(pad_size, int) and pad_size >= max( 40 | inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" 41 | mask = None 42 | if inp.ndim == 2: 43 | padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) 44 | padded[:inp.shape[0], :inp.shape[1]] = inp 45 | if ret_mask: 46 | mask = np.zeros((pad_size, pad_size), dtype=bool) 47 | mask[:inp.shape[0], :inp.shape[1]] = True 48 | elif inp.ndim == 3: 49 | padded = np.zeros((pad_size, pad_size, inp.shape[2]), dtype=inp.dtype) 50 | padded[:inp.shape[0], :inp.shape[1], :] = inp 51 | if ret_mask: 52 | 53 | mask = np.zeros((1, pad_size//df, pad_size//df)) 54 | mask[:, :inp.shape[0]//df-border, :inp.shape[1]//df-border] = 1 55 | 56 | else: 57 | raise NotImplementedError() 58 | return padded, mask 59 | 60 | 61 | def read_color_image(path, resize=(640, 480), augment_fn=None): 62 | """ 63 | Args: 64 | resize (tuple): align image to depthmap, in (w, h). 65 | augment_fn (callable, optional): augments images with pre-defined visual effects 66 | Returns: 67 | image (torch.tensor): (3, h, w) 68 | """ 69 | # read and resize image 70 | image = imread(path, None) 71 | image = cv2.resize(image, resize) 72 | 73 | # (h, w, 3) -> (3, h, w) and normalized 74 | image = torch.from_numpy(image).float().permute(2, 0, 1) / 255 75 | if augment_fn: 76 | image = augment_fn(image) 77 | return image 78 | 79 | 80 | def read_depth_image(path): 81 | depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) 82 | depth = depth / 1000 83 | depth = torch.from_numpy(depth).float() # (h, w) 84 | return depth 85 | 86 | def correct_intrinsic_scale(K, scale_x, scale_y): 87 | '''Given an intrinsic matrix (3x3) and two scale factors, returns the new intrinsic matrix corresponding to 88 | the new coordinates x' = scale_x * x; y' = scale_y * y 89 | Source: https://dsp.stackexchange.com/questions/6055/how-does-resizing-an-image-affect-the-intrinsic-camera-matrix 90 | ''' 91 | 92 | transform = torch.eye(3) 93 | transform[0, 0] = scale_x 94 | transform[0, 2] = scale_x / 2 - 0.5 95 | transform[1, 1] = scale_y 96 | transform[1, 2] = scale_y / 2 - 0.5 97 | Kprime = transform @ K 98 | 99 | return Kprime 100 | 101 | def define_sampling_grid(im_size, feats_downsample=4, step=1): 102 | """ 103 | Auxiliary function to generate the sampling grid from the feature map 104 | Args: 105 | im_size: original image size that goes into the network 106 | feats_downsample: rescaling factor that happens within the architecture due to downsampling steps 107 | Output: 108 | indexes_mat: dense grid sampling indexes, size: (im_size/feats_downsample, im_size/feats_downsample) 109 | """ 110 | 111 | feats_size = int(im_size/feats_downsample) 112 | grid_size = int(im_size/feats_downsample/step) 113 | 114 | indexes = np.asarray(range(0, feats_size, step))[:grid_size] 115 | indexes_x = indexes.reshape((1, len(indexes), 1)) 116 | indexes_y = indexes.reshape((len(indexes), 1, 1)) 117 | 118 | indexes_x = np.tile(indexes_x, [len(indexes), 1, 1]) 119 | indexes_y = np.tile(indexes_y, [1, len(indexes), 1]) 120 | 121 | indexes_mat = np.concatenate([indexes_x, indexes_y], axis=-1) 122 | indexes_mat = indexes_mat.reshape((grid_size*grid_size, 2)) 123 | 124 | return indexes_mat 125 | 126 | 127 | -------------------------------------------------------------------------------- /lib/models/MicKey/compute_pose.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | from lib.models.MicKey.modules.compute_correspondences import ComputeCorrespondences 4 | from lib.models.MicKey.modules.utils.probabilisticProcrustes import e2eProbabilisticProcrustesSolver 5 | 6 | class MickeyRelativePose(pl.LightningModule): 7 | # Compute the metric relative pose between two input images, with given intrinsics (for the pose solver). 8 | 9 | def __init__(self, cfg): 10 | super().__init__() 11 | 12 | # Define MicKey architecture and matching module: 13 | self.compute_matches = ComputeCorrespondences(cfg) 14 | 15 | # Metric solver 16 | self.e2e_Procrustes = e2eProbabilisticProcrustesSolver(cfg) 17 | 18 | self.is_eval_model(True) 19 | 20 | def forward(self, data, return_inliers=False): 21 | 22 | self.compute_matches(data) 23 | data['final_scores'] = data['scores'] * data['kp_scores'] 24 | 25 | if return_inliers: 26 | # Returns inliers list: 27 | R, t, inliers, inliers_list = self.e2e_Procrustes.estimate_pose_vectorized(data, return_inliers=True) 28 | data['inliers_list'] = inliers_list 29 | else: 30 | # If the inlier list is not needed: 31 | R, t, inliers = self.e2e_Procrustes.estimate_pose_vectorized(data, return_inliers=False) 32 | 33 | data['R'] = R 34 | data['t'] = t 35 | data['inliers'] = inliers 36 | 37 | return R, t 38 | 39 | def on_load_checkpoint(self, checkpoint): 40 | # This function avoids loading DINOv2 which are not sotred in Mickey's checkpoint. 41 | # This saves memory during training, since DINOv2 is frozen and not updated there is no need to store 42 | # the weights in every checkpoint. 43 | 44 | # Recover DINOv2 features from pretrained weights. 45 | for param_tensor in self.compute_matches.state_dict(): 46 | if 'dinov2'in param_tensor: 47 | checkpoint['state_dict']['compute_matches.'+param_tensor] = \ 48 | self.compute_matches.state_dict()[param_tensor] 49 | 50 | def is_eval_model(self, is_eval): 51 | if is_eval: 52 | self.compute_matches.extractor.depth_head.eval() 53 | self.compute_matches.extractor.det_offset.eval() 54 | self.compute_matches.extractor.dsc_head.eval() 55 | self.compute_matches.extractor.det_head.eval() 56 | else: 57 | self.compute_matches.extractor.depth_head.train() 58 | self.compute_matches.extractor.det_offset.train() 59 | self.compute_matches.extractor.dsc_head.train() 60 | self.compute_matches.extractor.det_head.train() 61 | -------------------------------------------------------------------------------- /lib/models/MicKey/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | from lib.models.MicKey.modules.loss.loss_class import MetricPoseLoss 5 | from lib.models.MicKey.modules.compute_correspondences import ComputeCorrespondences 6 | from lib.models.MicKey.modules.utils.training_utils import log_image_matches, debug_reward_matches_log, vis_inliers 7 | from lib.models.MicKey.modules.utils.probabilisticProcrustes import e2eProbabilisticProcrustesSolver 8 | 9 | from lib.utils.metrics import pose_error_torch, vcre_torch 10 | from lib.benchmarks.utils import precision_recall 11 | 12 | class MicKeyTrainingModel(pl.LightningModule): 13 | def __init__(self, cfg): 14 | super().__init__() 15 | 16 | # Store MicKey's configuration 17 | self.cfg = cfg 18 | 19 | # Define MicKey architecture and matching module: 20 | self.compute_matches = ComputeCorrespondences(cfg) 21 | self.is_eval_model(False) 22 | 23 | # Loss function class 24 | self.loss_fn = MetricPoseLoss(cfg) 25 | 26 | # Metric solvers 27 | self.e2e_Procrustes = e2eProbabilisticProcrustesSolver(cfg) 28 | 29 | # Logger parameters 30 | self.counter_batch = 0 31 | self.log_store_ims = True 32 | self.log_max_ims = 5 33 | self.log_im_counter_train = 0 34 | self.log_im_counter_val = 0 35 | self.log_interval = cfg.TRAINING.LOG_INTERVAL 36 | 37 | # Define curriculum learning parameters: 38 | self.curriculum_learning = cfg.LOSS_CLASS.CURRICULUM_LEARNING.TRAIN_CURRICULUM 39 | self.topK = cfg.LOSS_CLASS.CURRICULUM_LEARNING.TOPK_INIT 40 | self.topK_max = cfg.LOSS_CLASS.CURRICULUM_LEARNING.TOPK 41 | 42 | # Lightning configurations 43 | self.automatic_optimization = False # This property activates manual optimization. 44 | self.multi_gpu = True 45 | self.validation_step_outputs = [] 46 | # torch.autograd.set_detect_anomaly(True) 47 | 48 | def forward(self, data): 49 | self.compute_matches(data) 50 | 51 | def training_step(self, batch, batch_idx): 52 | 53 | self(batch) 54 | self.prepare_batch_for_loss(batch, batch_idx) 55 | 56 | avg_loss, outputs, probs_grad, num_its = self.loss_fn(batch) 57 | 58 | training_step_ok = self.backward_step(batch, outputs, probs_grad, avg_loss, num_its) 59 | self.tensorboard_log_step(batch, avg_loss, outputs, probs_grad, training_step_ok) 60 | 61 | def on_train_epoch_end(self): 62 | if self.curriculum_learning: 63 | self.topK = min(self.topK_max, self.topK + 5) 64 | self.loss_fn.topK = self.topK 65 | 66 | def validation_step(self, batch, batch_idx): 67 | 68 | self.is_eval_model(True) 69 | self(batch) 70 | self.prepare_batch_for_loss(batch, batch_idx) 71 | 72 | # validation metrics 73 | avg_loss, outputs, probs_grad, num_its = self.loss_fn(batch) 74 | outputs['loss'] = avg_loss 75 | 76 | # Metric pose evaluation 77 | R_ours, t_m_ours, inliers_ours = self.e2e_Procrustes.estimate_pose_vectorized(batch) 78 | outputs_metric_ours = pose_error_torch(R_ours, t_m_ours, batch['T_0to1'], reduce=None) 79 | outputs['metric_ours_t_err_ang'] = outputs_metric_ours['t_err_ang'] 80 | outputs['metric_ours_t_err_euc'] = outputs_metric_ours['t_err_euc'] 81 | outputs['metric_ours_R_err'] = outputs_metric_ours['R_err'] 82 | outputs['metric_inliers'] = inliers_ours 83 | 84 | outputs_vcre_ours = vcre_torch(R_ours, t_m_ours, batch['T_0to1'], batch['Kori_color0'], reduce=None) 85 | outputs['metric_ours_vcre'] = outputs_vcre_ours['repr_err'] 86 | 87 | self.validation_step_outputs.append(outputs) 88 | 89 | return outputs 90 | 91 | def backward_step(self, batch, outputs, probs_grad, avg_loss, num_its): 92 | opt = self.optimizers() 93 | 94 | # update model 95 | opt.zero_grad() 96 | 97 | if num_its == 0: 98 | print('No valid hypotheses were generated') 99 | return False 100 | 101 | # Generate gradients for learning keypoint offsets 102 | avg_loss.backward() 103 | 104 | invalid_probs = torch.isnan(probs_grad[0]).any() 105 | invalid_kps0 = (torch.isnan(outputs['kps0'].grad).any() or torch.isinf(outputs['kps0'].grad).any()) 106 | invalid_kps1 = (torch.isnan(outputs['kps1'].grad).any() or torch.isinf(outputs['kps1'].grad).any()) 107 | invalid_depth0 = (torch.isnan(outputs['depth0'].grad).any() or torch.isinf(outputs['depth0'].grad).any()) 108 | invalid_depth1 = (torch.isnan(outputs['depth1'].grad).any() or torch.isinf(outputs['depth1'].grad).any()) 109 | 110 | if invalid_probs: 111 | print('Found NaN/Inf in probs!') 112 | return False 113 | 114 | if invalid_depth0 or invalid_depth1: 115 | print('Found NaN/Inf in depth0/depth1 gradients!') 116 | return False 117 | 118 | if batch['kps0'].requires_grad: 119 | 120 | if invalid_kps0 or invalid_kps1: 121 | print('Found NaN/Inf in kps0/kps1 gradients!') 122 | return False 123 | 124 | torch.autograd.backward((torch.log(batch['final_scores'] + 1e-16), 125 | batch['kps0'], batch['kps1'], batch['depth_kp0'], batch['depth_kp1']), 126 | (probs_grad[0], outputs['kps0'].grad, outputs['kps1'].grad, 127 | outputs['depth0'].grad, outputs['depth1'].grad)) 128 | elif batch['depth_kp0'].requires_grad: 129 | torch.autograd.backward((torch.log(batch['final_scores'] + 1e-16), 130 | batch['depth_kp0'], batch['depth_kp1']), 131 | (probs_grad[0], outputs['depth0'].grad, outputs['depth1'].grad)) 132 | else: 133 | torch.autograd.backward((torch.log(batch['final_scores'] + 1e-16)), 134 | (probs_grad[0])) 135 | 136 | # add gradient clipping after backward to avoid gradient exploding 137 | torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=5) 138 | 139 | # check if the gradients of the training parameters contain nan values 140 | nans = sum([torch.isnan(param.grad).any() for param in list(self.parameters()) if param.grad is not None]) 141 | if nans != 0: 142 | print("parameter gradients includes {} nan values".format(nans)) 143 | return False 144 | 145 | opt.step() 146 | 147 | return True 148 | 149 | def tensorboard_log_step(self, batch, avg_loss, outputs, probs_grad, training_step_ok): 150 | 151 | self.log('train/loss', avg_loss.detach()) 152 | self.log('train/loss_rot', outputs['avg_loss_rot'].detach()) 153 | self.log('train/loss_trans', outputs['avg_loss_trans'].detach()) 154 | 155 | if self.log_store_ims: 156 | if self.counter_batch % self.log_interval == 0: 157 | self.counter_batch = 0 158 | 159 | # If training with curriculum learning, not all image pairs have valid gradients 160 | # ensure selecting one image pair for logging that has reward information 161 | batch_id = torch.where(outputs['mask_topk'] == 1.)[0][0].item() 162 | 163 | # Metric pose evaluation 164 | R_ours, t_m_ours, inliers_ours, inliers_list_ours = self.e2e_Procrustes.estimate_pose_vectorized(batch, return_inliers=True) 165 | outputs_metric_ours = pose_error_torch(R_ours, t_m_ours, batch['T_0to1'], reduce=None) 166 | self.log('train_metric_pose/ours_t_err_ang', outputs_metric_ours['t_err_ang'].mean().detach()) 167 | self.log('train_metric_pose/ours_t_err_euc', outputs_metric_ours['t_err_euc'].mean().detach()) 168 | self.log('train_metric_pose/ours_R_err', outputs_metric_ours['R_err'].mean().detach()) 169 | 170 | outputs_vcre_ours = vcre_torch(R_ours, t_m_ours, batch['T_0to1'], batch['Kori_color0'], reduce=None) 171 | self.log('train_vcre/repr_err', outputs_vcre_ours['repr_err'].mean().detach()) 172 | 173 | im_inliers = vis_inliers(inliers_list_ours, batch, batch_i=batch_id) 174 | 175 | im_matches, sc_map0, sc_map1, depth_map0, depth_map1 = log_image_matches(self.compute_matches.matcher, 176 | batch, train_depth=True, 177 | batch_i=batch_id) 178 | 179 | tensorboard = self.logger.experiment 180 | tensorboard.add_image('training_matching/best_inliers', im_inliers, global_step=self.log_im_counter_train) 181 | tensorboard.add_image('training_matching/best_matches_desc', im_matches, global_step=self.log_im_counter_train) 182 | tensorboard.add_image('training_scores/map0', sc_map0, global_step=self.log_im_counter_train) 183 | tensorboard.add_image('training_scores/map1', sc_map1, global_step=self.log_im_counter_train) 184 | tensorboard.add_image('training_depth/map0', depth_map0[0], global_step=self.log_im_counter_train) 185 | tensorboard.add_image('training_depth/map1', depth_map1[0], global_step=self.log_im_counter_train) 186 | if training_step_ok: 187 | try: 188 | im_rewards, rew_kp0, rew_kp1 = debug_reward_matches_log(batch, probs_grad, batch_i=batch_id) 189 | tensorboard.add_image('training_rewards/pair0', im_rewards, global_step=self.log_im_counter_train) 190 | except ValueError: 191 | print('[WARNING]: Failed to log reward image. Selected image is not in topK image pairs. ') 192 | 193 | self.log_im_counter_train += 1 194 | 195 | torch.cuda.empty_cache() 196 | self.counter_batch += 1 197 | 198 | def prepare_batch_for_loss(self, batch, batch_idx): 199 | 200 | batch['batch_idx'] = batch_idx 201 | batch['final_scores'] = batch['scores'] * batch['kp_scores'] 202 | 203 | return batch 204 | 205 | def on_validation_epoch_end(self): 206 | 207 | # aggregates metrics/losses from all validation steps 208 | aggregated = {} 209 | for key in self.validation_step_outputs[0].keys(): 210 | aggregated[key] = torch.stack([x[key] for x in self.validation_step_outputs]) 211 | 212 | # compute stats 213 | mean_R_loss = aggregated['avg_loss_rot'].mean() 214 | mean_t_loss = aggregated['avg_loss_trans'].mean() 215 | mean_loss = aggregated['loss'].mean() 216 | 217 | # Metric stats: 218 | metric_ours_t_err_ang = aggregated['metric_ours_t_err_ang'].mean() 219 | metric_ours_t_err_euc = aggregated['metric_ours_t_err_euc'].mean() 220 | metric_ours_R_err = aggregated['metric_ours_R_err'].mean() 221 | 222 | metric_ours_vcre = aggregated['metric_ours_vcre'].mean() 223 | 224 | # compute precision/AUC for pose error 225 | t_threshold = 0.25 226 | R_threshold = 5 227 | accepted_poses_ours = (aggregated['metric_ours_t_err_euc'].view(-1) < t_threshold) * \ 228 | (aggregated['metric_ours_R_err'].view(-1) < R_threshold) 229 | 230 | inliers = aggregated['metric_inliers'].view(-1).detach().cpu().numpy() 231 | 232 | prec_pose_ours = accepted_poses_ours.sum()/len(accepted_poses_ours) 233 | 234 | _, _, auc_pose = precision_recall(inliers=inliers, tp=accepted_poses_ours.detach().cpu().numpy(), failures=0) 235 | 236 | # compute precision/AUC for pose error 237 | t_threshold = 0.5 238 | R_threshold = 10 239 | accepted_poses_ours = (aggregated['metric_ours_t_err_euc'].view(-1) < t_threshold) * \ 240 | (aggregated['metric_ours_R_err'].view(-1) < R_threshold) 241 | 242 | inliers = aggregated['metric_inliers'].view(-1).detach().cpu().numpy() 243 | 244 | prec_pose_ours_10 = accepted_poses_ours.sum() / len(accepted_poses_ours) 245 | 246 | _, _, auc_pose_10 = precision_recall(inliers=inliers, tp=accepted_poses_ours.detach().cpu().numpy(), failures=0) 247 | 248 | 249 | # compute precision/AUC for reprojection errors 250 | px_threshold = 90 251 | accepted_vcre_ours = aggregated['metric_ours_vcre'].view(-1) < px_threshold 252 | 253 | prec_vcre_ours = accepted_vcre_ours.sum()/len(accepted_vcre_ours) 254 | 255 | _, _, auc_vcre = precision_recall(inliers=inliers, tp=accepted_vcre_ours.detach().cpu().numpy(), failures=0) 256 | 257 | # log stats 258 | self.log('val_loss/loss_R', mean_R_loss, sync_dist=self.multi_gpu) 259 | self.log('val_loss/loss_t', mean_t_loss, sync_dist=self.multi_gpu) 260 | self.log('val_loss/loss', mean_loss, sync_dist=self.multi_gpu) 261 | 262 | self.log('val_metric_pose/ours_t_err_ang', metric_ours_t_err_ang, sync_dist=self.multi_gpu) 263 | self.log('val_metric_pose/ours_t_err_euc', metric_ours_t_err_euc, sync_dist=self.multi_gpu) 264 | self.log('val_metric_pose/ours_R_err', metric_ours_R_err, sync_dist=self.multi_gpu) 265 | 266 | self.log('val_vcre/auc_vcre', auc_vcre, sync_dist=self.multi_gpu) 267 | self.log('val_vcre/prec_vcre_ours', prec_vcre_ours, sync_dist=self.multi_gpu) 268 | self.log('val_vcre/metric_ours_vcre', metric_ours_vcre, sync_dist=self.multi_gpu) 269 | 270 | self.log('val_AUC_pose/prec_pose_ours', prec_pose_ours, sync_dist=self.multi_gpu) 271 | self.log('val_AUC_pose/auc_pose', torch.tensor(auc_pose), sync_dist=self.multi_gpu) 272 | 273 | self.log('val_AUC_pose/prec_pose_ours_10', prec_pose_ours_10, sync_dist=self.multi_gpu) 274 | self.log('val_AUC_pose/auc_pose_10', torch.tensor(auc_pose_10), sync_dist=self.multi_gpu) 275 | 276 | self.validation_step_outputs.clear() # free memory 277 | 278 | self.is_eval_model(False) 279 | 280 | return mean_loss 281 | 282 | def configure_optimizers(self): 283 | tcfg = self.cfg.TRAINING 284 | opt = torch.optim.Adam(self.parameters(), lr=tcfg.LR, eps=1e-6) 285 | if tcfg.LR_STEP_INTERVAL: 286 | scheduler = torch.optim.lr_scheduler.StepLR( 287 | opt, tcfg.LR_STEP_INTERVAL, tcfg.LR_STEP_GAMMA) 288 | return {'optimizer': opt, 'lr_scheduler': {'scheduler': scheduler, 'interval': 'step'}} 289 | return opt 290 | 291 | def on_save_checkpoint(self, checkpoint): 292 | # As DINOv2 is pre-trained (and no finetuned, avoid saving its weights (it should help the memory). 293 | dinov2_keys = [] 294 | for key in checkpoint['state_dict'].keys(): 295 | if 'dinov2' in key: 296 | dinov2_keys.append(key) 297 | for key in dinov2_keys: 298 | del checkpoint['state_dict'][key] 299 | 300 | def on_load_checkpoint(self, checkpoint): 301 | 302 | # Recover DINOv2 features from pretrained weights. 303 | for param_tensor in self.compute_matches.state_dict(): 304 | if 'dinov2'in param_tensor: 305 | checkpoint['state_dict']['compute_matches.'+param_tensor] = \ 306 | self.compute_matches.state_dict()[param_tensor] 307 | 308 | def is_eval_model(self, is_eval): 309 | if is_eval: 310 | self.compute_matches.extractor.depth_head.eval() 311 | self.compute_matches.extractor.det_offset.eval() 312 | self.compute_matches.extractor.dsc_head.eval() 313 | self.compute_matches.extractor.det_head.eval() 314 | else: 315 | self.compute_matches.extractor.depth_head.train() 316 | self.compute_matches.extractor.det_offset.train() 317 | self.compute_matches.extractor.dsc_head.train() 318 | self.compute_matches.extractor.det_head.train() 319 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/dinov2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | from functools import partial 12 | import math 13 | import logging 14 | from typing import Sequence, Tuple, Union, Callable 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.utils.checkpoint 19 | from torch.nn.init import trunc_normal_ 20 | 21 | from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block 22 | 23 | 24 | def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: 25 | if not depth_first and include_root: 26 | fn(module=module, name=name) 27 | for child_name, child_module in module.named_children(): 28 | child_name = ".".join((name, child_name)) if name else child_name 29 | named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) 30 | if depth_first and include_root: 31 | fn(module=module, name=name) 32 | return module 33 | 34 | 35 | class BlockChunk(nn.ModuleList): 36 | def forward(self, x): 37 | for b in self: 38 | x = b(x) 39 | return x 40 | 41 | 42 | class DinoVisionTransformer(nn.Module): 43 | def __init__( 44 | self, 45 | img_size=224, 46 | patch_size=16, 47 | in_chans=3, 48 | embed_dim=768, 49 | depth=12, 50 | num_heads=12, 51 | mlp_ratio=4.0, 52 | qkv_bias=True, 53 | ffn_bias=True, 54 | proj_bias=True, 55 | drop_path_rate=0.0, 56 | drop_path_uniform=False, 57 | init_values=None, # for layerscale: None or 0 => no layerscale 58 | embed_layer=PatchEmbed, 59 | act_layer=nn.GELU, 60 | block_fn=Block, 61 | ffn_layer="mlp", 62 | block_chunks=1, 63 | ): 64 | """ 65 | Args: 66 | img_size (int, tuple): input image size 67 | patch_size (int, tuple): patch size 68 | in_chans (int): number of input channels 69 | embed_dim (int): embedding dimension 70 | depth (int): depth of transformer 71 | num_heads (int): number of attention heads 72 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 73 | qkv_bias (bool): enable bias for qkv if True 74 | proj_bias (bool): enable bias for proj in attn if True 75 | ffn_bias (bool): enable bias for ffn if True 76 | drop_path_rate (float): stochastic depth rate 77 | drop_path_uniform (bool): apply uniform drop rate across blocks 78 | weight_init (str): weight init scheme 79 | init_values (float): layer-scale init values 80 | embed_layer (nn.Module): patch embedding layer 81 | act_layer (nn.Module): MLP activation layer 82 | block_fn (nn.Module): transformer block class 83 | ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" 84 | block_chunks: (int) split block sequence into block_chunks units for FSDP wrap 85 | """ 86 | super().__init__() 87 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 88 | 89 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 90 | self.num_tokens = 1 91 | self.n_blocks = depth 92 | self.num_heads = num_heads 93 | self.patch_size = patch_size 94 | 95 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 96 | num_patches = self.patch_embed.num_patches 97 | 98 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 99 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 100 | 101 | if drop_path_uniform is True: 102 | dpr = [drop_path_rate] * depth 103 | else: 104 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 105 | 106 | if ffn_layer == "mlp": 107 | ffn_layer = Mlp 108 | elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": 109 | ffn_layer = SwiGLUFFNFused 110 | elif ffn_layer == "identity": 111 | 112 | def f(*args, **kwargs): 113 | return nn.Identity() 114 | 115 | ffn_layer = f 116 | else: 117 | raise NotImplementedError 118 | 119 | blocks_list = [ 120 | block_fn( 121 | dim=embed_dim, 122 | num_heads=num_heads, 123 | mlp_ratio=mlp_ratio, 124 | qkv_bias=qkv_bias, 125 | proj_bias=proj_bias, 126 | ffn_bias=ffn_bias, 127 | drop_path=dpr[i], 128 | norm_layer=norm_layer, 129 | act_layer=act_layer, 130 | ffn_layer=ffn_layer, 131 | init_values=init_values, 132 | ) 133 | for i in range(depth) 134 | ] 135 | if block_chunks > 0: 136 | self.chunked_blocks = True 137 | chunked_blocks = [] 138 | chunksize = depth // block_chunks 139 | for i in range(0, depth, chunksize): 140 | # this is to keep the block index consistent if we chunk the block list 141 | chunked_blocks.append([nn.Identity()] * i + blocks_list[i: i + chunksize]) 142 | self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) 143 | else: 144 | self.chunked_blocks = False 145 | self.blocks = nn.ModuleList(blocks_list) 146 | 147 | self.norm = norm_layer(embed_dim) 148 | self.head = nn.Identity() 149 | 150 | self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) 151 | 152 | self.init_weights() 153 | for param in self.parameters(): 154 | param.requires_grad = False 155 | 156 | @property 157 | def device(self): 158 | return self.cls_token.device 159 | 160 | def init_weights(self): 161 | trunc_normal_(self.pos_embed, std=0.02) 162 | nn.init.normal_(self.cls_token, std=1e-6) 163 | named_apply(init_weights_vit_timm, self) 164 | 165 | def interpolate_pos_encoding(self, x, w, h): 166 | previous_dtype = x.dtype 167 | npatch = x.shape[1] - 1 168 | N = self.pos_embed.shape[1] - 1 169 | if npatch == N and w == h: 170 | return self.pos_embed 171 | pos_embed = self.pos_embed.float() 172 | class_pos_embed = pos_embed[:, 0] 173 | patch_pos_embed = pos_embed[:, 1:] 174 | dim = x.shape[-1] 175 | w0 = w // self.patch_size 176 | h0 = h // self.patch_size 177 | # we add a small number to avoid floating point error in the interpolation 178 | # see discussion at https://github.com/facebookresearch/dino/issues/8 179 | w0, h0 = w0 + 0.1, h0 + 0.1 180 | 181 | patch_pos_embed = nn.functional.interpolate( 182 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), 183 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 184 | mode="bicubic", 185 | ) 186 | 187 | assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] 188 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) 189 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) 190 | 191 | def prepare_tokens_with_masks(self, x, masks=None): 192 | B, nc, w, h = x.shape 193 | x = self.patch_embed(x) 194 | if masks is not None: 195 | x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) 196 | 197 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 198 | x = x + self.interpolate_pos_encoding(x, w, h) 199 | 200 | return x 201 | 202 | def forward_features_list(self, x_list, masks_list): 203 | x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] 204 | for blk in self.blocks: 205 | x = blk(x) 206 | 207 | all_x = x 208 | output = [] 209 | for x, masks in zip(all_x, masks_list): 210 | x_norm = self.norm(x) 211 | output.append( 212 | { 213 | "x_norm_clstoken": x_norm[:, 0], 214 | "x_norm_patchtokens": x_norm[:, 1:], 215 | "x_prenorm": x, 216 | "masks": masks, 217 | } 218 | ) 219 | return output 220 | 221 | def forward_features(self, x, masks=None): 222 | if isinstance(x, list): 223 | return self.forward_features_list(x, masks) 224 | 225 | x = self.prepare_tokens_with_masks(x, masks) 226 | 227 | for blk in self.blocks: 228 | x = blk(x) 229 | 230 | x_norm = self.norm(x) 231 | return { 232 | "x_norm_clstoken": x_norm[:, 0], 233 | "x_norm_patchtokens": x_norm[:, 1:], 234 | "x_prenorm": x, 235 | "masks": masks, 236 | } 237 | 238 | def _get_intermediate_layers_not_chunked(self, x, n=1): 239 | x = self.prepare_tokens_with_masks(x) 240 | # If n is an int, take the n last blocks. If it's a list, take them 241 | output, total_block_len = [], len(self.blocks) 242 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 243 | for i, blk in enumerate(self.blocks): 244 | x = blk(x) 245 | if i in blocks_to_take: 246 | output.append(x) 247 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 248 | return output 249 | 250 | def _get_intermediate_layers_chunked(self, x, n=1): 251 | x = self.prepare_tokens_with_masks(x) 252 | output, i, total_block_len = [], 0, len(self.blocks[-1]) 253 | # If n is an int, take the n last blocks. If it's a list, take them 254 | blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n 255 | for block_chunk in self.blocks: 256 | for blk in block_chunk[i:]: # Passing the nn.Identity() 257 | x = blk(x) 258 | if i in blocks_to_take: 259 | output.append(x) 260 | i += 1 261 | assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" 262 | return output 263 | 264 | def get_intermediate_layers( 265 | self, 266 | x: torch.Tensor, 267 | n: Union[int, Sequence] = 1, # Layers or n last layers to take 268 | reshape: bool = False, 269 | return_class_token: bool = False, 270 | norm=True, 271 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 272 | if self.chunked_blocks: 273 | outputs = self._get_intermediate_layers_chunked(x, n) 274 | else: 275 | outputs = self._get_intermediate_layers_not_chunked(x, n) 276 | if norm: 277 | outputs = [self.norm(out) for out in outputs] 278 | class_tokens = [out[:, 0] for out in outputs] 279 | outputs = [out[:, 1:] for out in outputs] 280 | if reshape: 281 | B, _, w, h = x.shape 282 | outputs = [ 283 | out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() 284 | for out in outputs 285 | ] 286 | if return_class_token: 287 | return tuple(zip(outputs, class_tokens)) 288 | return tuple(outputs) 289 | 290 | def forward(self, *args, is_training=False, **kwargs): 291 | ret = self.forward_features(*args, **kwargs) 292 | if is_training: 293 | return ret 294 | else: 295 | return self.head(ret["x_norm_clstoken"]) 296 | 297 | 298 | def init_weights_vit_timm(module: nn.Module, name: str = ""): 299 | """ViT weight initialization, original timm impl (for reproducibility)""" 300 | if isinstance(module, nn.Linear): 301 | trunc_normal_(module.weight, std=0.02) 302 | if module.bias is not None: 303 | nn.init.zeros_(module.bias) 304 | 305 | 306 | def vit_small(patch_size=16, **kwargs): 307 | model = DinoVisionTransformer( 308 | patch_size=patch_size, 309 | embed_dim=384, 310 | depth=12, 311 | num_heads=6, 312 | mlp_ratio=4, 313 | block_fn=partial(Block, attn_class=MemEffAttention), 314 | **kwargs, 315 | ) 316 | return model 317 | 318 | 319 | def vit_base(patch_size=16, **kwargs): 320 | model = DinoVisionTransformer( 321 | patch_size=patch_size, 322 | embed_dim=768, 323 | depth=12, 324 | num_heads=12, 325 | mlp_ratio=4, 326 | block_fn=partial(Block, attn_class=MemEffAttention), 327 | **kwargs, 328 | ) 329 | return model 330 | 331 | 332 | def vit_large(patch_size=16, **kwargs): 333 | model = DinoVisionTransformer( 334 | patch_size=patch_size, 335 | embed_dim=1024, 336 | depth=24, 337 | num_heads=16, 338 | mlp_ratio=4, 339 | block_fn=partial(Block, attn_class=MemEffAttention), 340 | **kwargs, 341 | ) 342 | return model 343 | 344 | 345 | def vit_giant2(patch_size=16, **kwargs): 346 | """ 347 | Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 348 | """ 349 | model = DinoVisionTransformer( 350 | patch_size=patch_size, 351 | embed_dim=1536, 352 | depth=40, 353 | num_heads=24, 354 | mlp_ratio=4, 355 | block_fn=partial(Block, attn_class=MemEffAttention), 356 | **kwargs, 357 | ) 358 | return model -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .dino_head import DINOHead 8 | from .mlp import Mlp 9 | from .patch_embed import PatchEmbed 10 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 11 | from .block import NestedTensorBlock 12 | from .attention import MemEffAttention 13 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 10 | 11 | import logging 12 | 13 | from torch import Tensor 14 | from torch import nn 15 | 16 | 17 | logger = logging.getLogger("dinov2") 18 | 19 | 20 | try: 21 | from xformers.ops import memory_efficient_attention, unbind, fmha 22 | 23 | XFORMERS_AVAILABLE = True 24 | except ImportError: 25 | logger.warning("xFormers not available") 26 | XFORMERS_AVAILABLE = False 27 | 28 | 29 | class Attention(nn.Module): 30 | def __init__( 31 | self, 32 | dim: int, 33 | num_heads: int = 8, 34 | qkv_bias: bool = False, 35 | proj_bias: bool = True, 36 | attn_drop: float = 0.0, 37 | proj_drop: float = 0.0, 38 | ) -> None: 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim**-0.5 43 | 44 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 45 | self.attn_drop = nn.Dropout(attn_drop) 46 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 47 | self.proj_drop = nn.Dropout(proj_drop) 48 | 49 | def forward(self, x: Tensor) -> Tensor: 50 | B, N, C = x.shape 51 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 52 | 53 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 54 | attn = q @ k.transpose(-2, -1) 55 | 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class MemEffAttention(Attention): 66 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 67 | if not XFORMERS_AVAILABLE: 68 | assert attn_bias is None, "xFormers is required for nested tensors usage" 69 | return super().forward(x) 70 | 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 73 | 74 | q, k, v = unbind(qkv, 2) 75 | 76 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 77 | x = x.reshape([B, N, C]) 78 | 79 | x = self.proj(x) 80 | x = self.proj_drop(x) 81 | return x 82 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | import logging 12 | from typing import Callable, List, Any, Tuple, Dict 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | 17 | from .attention import Attention, MemEffAttention 18 | from .drop_path import DropPath 19 | from .layer_scale import LayerScale 20 | from .mlp import Mlp 21 | 22 | 23 | logger = logging.getLogger("dinov2") 24 | 25 | 26 | try: 27 | from xformers.ops import fmha 28 | from xformers.ops import scaled_index_add, index_select_cat 29 | 30 | XFORMERS_AVAILABLE = True 31 | except ImportError: 32 | logger.warning("xFormers not available") 33 | XFORMERS_AVAILABLE = False 34 | 35 | 36 | class Block(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int, 41 | mlp_ratio: float = 4.0, 42 | qkv_bias: bool = False, 43 | proj_bias: bool = True, 44 | ffn_bias: bool = True, 45 | drop: float = 0.0, 46 | attn_drop: float = 0.0, 47 | init_values=None, 48 | drop_path: float = 0.0, 49 | act_layer: Callable[..., nn.Module] = nn.GELU, 50 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 51 | attn_class: Callable[..., nn.Module] = Attention, 52 | ffn_layer: Callable[..., nn.Module] = Mlp, 53 | ) -> None: 54 | super().__init__() 55 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 56 | self.norm1 = norm_layer(dim) 57 | self.attn = attn_class( 58 | dim, 59 | num_heads=num_heads, 60 | qkv_bias=qkv_bias, 61 | proj_bias=proj_bias, 62 | attn_drop=attn_drop, 63 | proj_drop=drop, 64 | ) 65 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 66 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 67 | 68 | self.norm2 = norm_layer(dim) 69 | mlp_hidden_dim = int(dim * mlp_ratio) 70 | self.mlp = ffn_layer( 71 | in_features=dim, 72 | hidden_features=mlp_hidden_dim, 73 | act_layer=act_layer, 74 | drop=drop, 75 | bias=ffn_bias, 76 | ) 77 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 78 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 79 | 80 | self.sample_drop_ratio = drop_path 81 | 82 | def forward(self, x: Tensor) -> Tensor: 83 | def attn_residual_func(x: Tensor) -> Tensor: 84 | return self.ls1(self.attn(self.norm1(x))) 85 | 86 | def ffn_residual_func(x: Tensor) -> Tensor: 87 | return self.ls2(self.mlp(self.norm2(x))) 88 | 89 | if self.training and self.sample_drop_ratio > 0.1: 90 | # the overhead is compensated only for a drop path rate larger than 0.1 91 | x = drop_add_residual_stochastic_depth( 92 | x, 93 | residual_func=attn_residual_func, 94 | sample_drop_ratio=self.sample_drop_ratio, 95 | ) 96 | x = drop_add_residual_stochastic_depth( 97 | x, 98 | residual_func=ffn_residual_func, 99 | sample_drop_ratio=self.sample_drop_ratio, 100 | ) 101 | elif self.training and self.sample_drop_ratio > 0.0: 102 | x = x + self.drop_path1(attn_residual_func(x)) 103 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 104 | else: 105 | x = x + attn_residual_func(x) 106 | x = x + ffn_residual_func(x) 107 | return x 108 | 109 | 110 | def drop_add_residual_stochastic_depth( 111 | x: Tensor, 112 | residual_func: Callable[[Tensor], Tensor], 113 | sample_drop_ratio: float = 0.0, 114 | ) -> Tensor: 115 | # 1) extract subset using permutation 116 | b, n, d = x.shape 117 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 118 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 119 | x_subset = x[brange] 120 | 121 | # 2) apply residual_func to get residual 122 | residual = residual_func(x_subset) 123 | 124 | x_flat = x.flatten(1) 125 | residual = residual.flatten(1) 126 | 127 | residual_scale_factor = b / sample_subset_size 128 | 129 | # 3) add the residual 130 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 131 | return x_plus_residual.view_as(x) 132 | 133 | 134 | def get_branges_scales(x, sample_drop_ratio=0.0): 135 | b, n, d = x.shape 136 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 137 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 138 | residual_scale_factor = b / sample_subset_size 139 | return brange, residual_scale_factor 140 | 141 | 142 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 143 | if scaling_vector is None: 144 | x_flat = x.flatten(1) 145 | residual = residual.flatten(1) 146 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 147 | else: 148 | x_plus_residual = scaled_index_add( 149 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 150 | ) 151 | return x_plus_residual 152 | 153 | 154 | attn_bias_cache: Dict[Tuple, Any] = {} 155 | 156 | 157 | def get_attn_bias_and_cat(x_list, branges=None): 158 | """ 159 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 160 | """ 161 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 162 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 163 | if all_shapes not in attn_bias_cache.keys(): 164 | seqlens = [] 165 | for b, x in zip(batch_sizes, x_list): 166 | for _ in range(b): 167 | seqlens.append(x.shape[1]) 168 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 169 | attn_bias._batch_sizes = batch_sizes 170 | attn_bias_cache[all_shapes] = attn_bias 171 | 172 | if branges is not None: 173 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 174 | else: 175 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 176 | cat_tensors = torch.cat(tensors_bs1, dim=1) 177 | 178 | return attn_bias_cache[all_shapes], cat_tensors 179 | 180 | 181 | def drop_add_residual_stochastic_depth_list( 182 | x_list: List[Tensor], 183 | residual_func: Callable[[Tensor, Any], Tensor], 184 | sample_drop_ratio: float = 0.0, 185 | scaling_vector=None, 186 | ) -> Tensor: 187 | # 1) generate random set of indices for dropping samples in the batch 188 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 189 | branges = [s[0] for s in branges_scales] 190 | residual_scale_factors = [s[1] for s in branges_scales] 191 | 192 | # 2) get attention bias and index+concat the tensors 193 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 194 | 195 | # 3) apply residual_func to get residual, and split the result 196 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 197 | 198 | outputs = [] 199 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 200 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 201 | return outputs 202 | 203 | 204 | class NestedTensorBlock(Block): 205 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 206 | """ 207 | x_list contains a list of tensors to nest together and run 208 | """ 209 | assert isinstance(self.attn, MemEffAttention) 210 | 211 | if self.training and self.sample_drop_ratio > 0.0: 212 | 213 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 214 | return self.attn(self.norm1(x), attn_bias=attn_bias) 215 | 216 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 217 | return self.mlp(self.norm2(x)) 218 | 219 | x_list = drop_add_residual_stochastic_depth_list( 220 | x_list, 221 | residual_func=attn_residual_func, 222 | sample_drop_ratio=self.sample_drop_ratio, 223 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 224 | ) 225 | x_list = drop_add_residual_stochastic_depth_list( 226 | x_list, 227 | residual_func=ffn_residual_func, 228 | sample_drop_ratio=self.sample_drop_ratio, 229 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 230 | ) 231 | return x_list 232 | else: 233 | 234 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 235 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 236 | 237 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 238 | return self.ls2(self.mlp(self.norm2(x))) 239 | 240 | attn_bias, x = get_attn_bias_and_cat(x_list) 241 | x = x + attn_residual_func(x, attn_bias=attn_bias) 242 | x = x + ffn_residual_func(x) 243 | return attn_bias.split(x) 244 | 245 | def forward(self, x_or_x_list): 246 | if isinstance(x_or_x_list, Tensor): 247 | return super().forward(x_or_x_list) 248 | elif isinstance(x_or_x_list, list): 249 | assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" 250 | return self.forward_nested(x_or_x_list) 251 | else: 252 | raise AssertionError 253 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn.init import trunc_normal_ 10 | from torch.nn.utils import weight_norm 11 | 12 | 13 | class DINOHead(nn.Module): 14 | def __init__( 15 | self, 16 | in_dim, 17 | out_dim, 18 | use_bn=False, 19 | nlayers=3, 20 | hidden_dim=2048, 21 | bottleneck_dim=256, 22 | mlp_bias=True, 23 | ): 24 | super().__init__() 25 | nlayers = max(nlayers, 1) 26 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 27 | self.apply(self._init_weights) 28 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 29 | self.last_layer.weight_g.data.fill_(1) 30 | 31 | def _init_weights(self, m): 32 | if isinstance(m, nn.Linear): 33 | trunc_normal_(m.weight, std=0.02) 34 | if isinstance(m, nn.Linear) and m.bias is not None: 35 | nn.init.constant_(m.bias, 0) 36 | 37 | def forward(self, x): 38 | x = self.mlp(x) 39 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 40 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 41 | x = self.last_layer(x) 42 | return x 43 | 44 | 45 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 46 | if nlayers == 1: 47 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 48 | else: 49 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 50 | if use_bn: 51 | layers.append(nn.BatchNorm1d(hidden_dim)) 52 | layers.append(nn.GELU()) 53 | for _ in range(nlayers - 2): 54 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 55 | if use_bn: 56 | layers.append(nn.BatchNorm1d(hidden_dim)) 57 | layers.append(nn.GELU()) 58 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 59 | return nn.Sequential(*layers) 60 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 10 | 11 | 12 | from torch import nn 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 16 | if drop_prob == 0.0 or not training: 17 | return x 18 | keep_prob = 1 - drop_prob 19 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 20 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 21 | if keep_prob > 0.0: 22 | random_tensor.div_(keep_prob) 23 | output = x * random_tensor 24 | return output 25 | 26 | 27 | class DropPath(nn.Module): 28 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 29 | 30 | def __init__(self, drop_prob=None): 31 | super(DropPath, self).__init__() 32 | self.drop_prob = drop_prob 33 | 34 | def forward(self, x): 35 | return drop_path(x, self.drop_prob, self.training) 36 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 8 | 9 | from typing import Union 10 | 11 | import torch 12 | from torch import Tensor 13 | from torch import nn 14 | 15 | 16 | class LayerScale(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | init_values: Union[float, Tensor] = 1e-5, 21 | inplace: bool = False, 22 | ) -> None: 23 | super().__init__() 24 | self.inplace = inplace 25 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 29 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 10 | 11 | 12 | from typing import Callable, Optional 13 | 14 | from torch import Tensor, nn 15 | 16 | 17 | class Mlp(nn.Module): 18 | def __init__( 19 | self, 20 | in_features: int, 21 | hidden_features: Optional[int] = None, 22 | out_features: Optional[int] = None, 23 | act_layer: Callable[..., nn.Module] = nn.GELU, 24 | drop: float = 0.0, 25 | bias: bool = True, 26 | ) -> None: 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x: Tensor) -> Tensor: 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # References: 8 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 9 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 10 | 11 | from typing import Callable, Optional, Tuple, Union 12 | 13 | from torch import Tensor 14 | import torch.nn as nn 15 | 16 | 17 | def make_2tuple(x): 18 | if isinstance(x, tuple): 19 | assert len(x) == 2 20 | return x 21 | 22 | assert isinstance(x, int) 23 | return (x, x) 24 | 25 | 26 | class PatchEmbed(nn.Module): 27 | """ 28 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 29 | 30 | Args: 31 | img_size: Image size. 32 | patch_size: Patch token size. 33 | in_chans: Number of input image channels. 34 | embed_dim: Number of linear projection output channels. 35 | norm_layer: Normalization layer. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | img_size: Union[int, Tuple[int, int]] = 224, 41 | patch_size: Union[int, Tuple[int, int]] = 16, 42 | in_chans: int = 3, 43 | embed_dim: int = 768, 44 | norm_layer: Optional[Callable] = None, 45 | flatten_embedding: bool = True, 46 | ) -> None: 47 | super().__init__() 48 | 49 | image_HW = make_2tuple(img_size) 50 | patch_HW = make_2tuple(patch_size) 51 | patch_grid_size = ( 52 | image_HW[0] // patch_HW[0], 53 | image_HW[1] // patch_HW[1], 54 | ) 55 | 56 | self.img_size = image_HW 57 | self.patch_size = patch_HW 58 | self.patches_resolution = patch_grid_size 59 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 60 | 61 | self.in_chans = in_chans 62 | self.embed_dim = embed_dim 63 | 64 | self.flatten_embedding = flatten_embedding 65 | 66 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 67 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 68 | 69 | def forward(self, x: Tensor) -> Tensor: 70 | _, _, H, W = x.shape 71 | patch_H, patch_W = self.patch_size 72 | 73 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 74 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 75 | 76 | x = self.proj(x) # B C H W 77 | H, W = x.size(2), x.size(3) 78 | x = x.flatten(2).transpose(1, 2) # B HW C 79 | x = self.norm(x) 80 | if not self.flatten_embedding: 81 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 82 | return x 83 | 84 | def flops(self) -> float: 85 | Ho, Wo = self.patches_resolution 86 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 87 | if self.norm is not None: 88 | flops += Ho * Wo * self.embed_dim 89 | return flops 90 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/DINO_modules/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Optional 8 | 9 | from torch import Tensor, nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | act_layer: Callable[..., nn.Module] = None, 20 | drop: float = 0.0, 21 | bias: bool = True, 22 | ) -> None: 23 | super().__init__() 24 | out_features = out_features or in_features 25 | hidden_features = hidden_features or in_features 26 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 27 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | x12 = self.w12(x) 31 | x1, x2 = x12.chunk(2, dim=-1) 32 | hidden = F.silu(x1) * x2 33 | return self.w3(hidden) 34 | 35 | 36 | try: 37 | from xformers.ops import SwiGLU 38 | 39 | XFORMERS_AVAILABLE = True 40 | except ImportError: 41 | SwiGLU = SwiGLUFFN 42 | XFORMERS_AVAILABLE = False 43 | 44 | 45 | class SwiGLUFFNFused(SwiGLU): 46 | def __init__( 47 | self, 48 | in_features: int, 49 | hidden_features: Optional[int] = None, 50 | out_features: Optional[int] = None, 51 | act_layer: Callable[..., nn.Module] = None, 52 | drop: float = 0.0, 53 | bias: bool = True, 54 | ) -> None: 55 | out_features = out_features or in_features 56 | hidden_features = hidden_features or in_features 57 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 58 | super().__init__( 59 | in_features=in_features, 60 | hidden_features=hidden_features, 61 | out_features=out_features, 62 | bias=bias, 63 | ) 64 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/att_layers/attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some functions are borrowed from LoFTR: Detector-Free Local 3 | Feature Matching with Transformers (https://github.com/zju3dv/LoFTR) and modified here. 4 | If using this code, please consider citing LoFTR. 5 | """ 6 | 7 | import torch 8 | from torch.nn import Module, Dropout 9 | import torch.nn.functional as F 10 | 11 | def elu_feature_map(x): 12 | return torch.nn.functional.elu(x) + 1 13 | 14 | class Attention(Module): 15 | def __init__(self, eps=1e-6, use_dropout=False, attention='', attention_dropout=0.1): 16 | super().__init__() 17 | self.feature_map = elu_feature_map 18 | self.eps = eps 19 | self.use_dropout = use_dropout 20 | self.dropout = Dropout(attention_dropout) 21 | self.attention = attention 22 | 23 | def forward_full(self, queries, keys, values): 24 | """ Multi-head scaled dot-product attention, a.k.a full attention. 25 | Args: 26 | queries: [N, L, H, D] 27 | keys: [N, S, H, D] 28 | values: [N, S, H, D] 29 | Returns: 30 | queried_values: (N, L, H, D) 31 | """ 32 | 33 | # Compute the unnormalized attention 34 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) 35 | 36 | # Compute the attention and the weighted average 37 | softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) 38 | A = torch.softmax(softmax_temp * QK, dim=2) 39 | if self.use_dropout: 40 | A = self.dropout(A) 41 | 42 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) 43 | 44 | return queried_values.contiguous() 45 | 46 | def forward_linear(self, queries, keys, values): 47 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 48 | Args: 49 | queries: [N, L, H, D] 50 | keys: [N, S, H, D] 51 | values: [N, S, H, D] 52 | Returns: 53 | queried_values: (N, L, H, D) 54 | """ 55 | Q = self.feature_map(queries) 56 | K = self.feature_map(keys) 57 | 58 | v_length = values.size(1) 59 | values = values / v_length # prevent fp16 overflow 60 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 61 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 62 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 63 | 64 | return queried_values.contiguous() 65 | 66 | def forward_flash(self, q, k, v): 67 | args = [x.half().contiguous() for x in [q, k, v]] 68 | v = F.scaled_dot_product_attention(*args).to(q.dtype) 69 | return v.contiguous() 70 | 71 | def forward(self, queries, keys, values): 72 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 73 | Args: 74 | queries: [N, L, H, D] 75 | keys: [N, S, H, D] 76 | values: [N, S, H, D] 77 | Returns: 78 | x: queried values (N, L, H, D) 79 | """ 80 | if self.attention == 'linear': 81 | x = self.forward_linear(queries, keys, values) 82 | elif self.attention == 'flash': 83 | x = self.forward_flash(queries, keys, values) 84 | else: 85 | x = self.forward_full(queries, keys, values) 86 | return x 87 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/att_layers/transformer.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from einops.einops import rearrange 7 | from lib.models.MicKey.modules.att_layers.transformer_utils import EncoderLayer 8 | import torch.nn.functional as F 9 | 10 | class PositionEncodingSine(nn.Module): 11 | """ 12 | This is a sinusoidal position encoding that generalized to 2-dimensional images 13 | """ 14 | 15 | def __init__(self, d_model, max_shape=(256, 256)): 16 | """ 17 | Args: 18 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 19 | temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), 20 | the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact 21 | on the final performance. For now, we keep both impls for backward compatability. 22 | We will remove the buggy impl after re-training all variants of our released models. 23 | """ 24 | super().__init__() 25 | 26 | pe = torch.zeros((d_model, *max_shape)) 27 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 28 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 29 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) 30 | div_term = div_term[:, None, None] # [C//4, 1, 1] 31 | pe[0::4, :, :] = torch.sin(x_position * div_term) 32 | pe[1::4, :, :] = torch.cos(x_position * div_term) 33 | pe[2::4, :, :] = torch.sin(y_position * div_term) 34 | pe[3::4, :, :] = torch.cos(y_position * div_term) 35 | 36 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] 37 | 38 | def forward(self, x): 39 | """ 40 | Args: 41 | x: [N, C, H, W] 42 | """ 43 | return x + self.pe[:, :, :x.size(2), :x.size(3)] 44 | 45 | class Transformer_self_att(nn.Module): 46 | """This class implement self attention transformer module. 47 | Arguments: 48 | d_model: Feature dimension after feature extractor (default: 1024d). 49 | aggregator_conf: Configuration dictionary containing the parameters for the transformer module. 50 | """ 51 | 52 | def __init__(self, d_model, num_layers, add_posEnc=False): 53 | super(Transformer_self_att, self).__init__() 54 | 55 | # Define the transformer parameters 56 | self.d_model = d_model 57 | 58 | # TODO: Expose parameters to config file 59 | layer_names = ['self'] * num_layers 60 | attention = 'linear' 61 | self.nheads = 8 62 | self.layer_names = layer_names 63 | encoder_layer = EncoderLayer(d_model, self.nheads, attention) 64 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) 65 | self._reset_parameters() 66 | self.add_posEnc = add_posEnc 67 | self.posEnc = PositionEncodingSine(d_model) 68 | 69 | def _reset_parameters(self): 70 | for p in self.parameters(): 71 | if p.dim() > 1: 72 | nn.init.xavier_uniform_(p) 73 | 74 | 75 | def forward(self, feats): 76 | """ 77 | Runs the common self and cross-attention module. 78 | Args: 79 | feats_a: Features from image A (source) ([N, d_model, im_size/down_factor, im_size/down_factor]). 80 | feats_b: Features from image B (destination) ([N, d_model, im_size/down_factor, im_size/down_factor]). 81 | Output: 82 | feats_a: Self and cross-attended features corresponding to image A (source) 83 | ([N, d_model, im_size/down_factor, im_size/down_factor]) 84 | feats_b: Self and cross-attended features corresponding to image B (destination) 85 | ([N, d_model, im_size/down_factor, im_size/down_factor]). 86 | """ 87 | 88 | assert self.d_model == feats.size(1), "The feature size and transformer must be equal" 89 | 90 | b, c, h, w = feats.size() 91 | 92 | if self.add_posEnc: 93 | feats = self.posEnc(feats) 94 | 95 | feats = rearrange(feats, 'n c h w -> n (h w) c') 96 | 97 | # Apply linear self attention to feats 98 | for layer, name in zip(self.layers, self.layer_names): 99 | feats = layer(feats, feats) 100 | 101 | feats = feats.transpose(2, 1).reshape((b, c, h, w)) 102 | 103 | return feats 104 | 105 | class Transformer_att(nn.Module): 106 | """This class implement self attention transformer module. 107 | Arguments: 108 | d_model: Feature dimension after feature extractor (default: 1024d). 109 | aggregator_conf: Configuration dictionary containing the parameters for the transformer module. 110 | """ 111 | 112 | def __init__(self, d_model, num_layers, add_posEnc=False): 113 | super(Transformer_att, self).__init__() 114 | 115 | # Define the transformer parameters 116 | self.d_model = d_model 117 | 118 | # TODO: Expose parameters to config file 119 | layer_names = ['self', 'cross'] * num_layers 120 | attention = 'linear' 121 | self.nheads = 8 122 | self.layer_names = layer_names 123 | encoder_layer = EncoderLayer(d_model, self.nheads, attention) 124 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) 125 | self._reset_parameters() 126 | self.add_posEnc = add_posEnc 127 | self.posEnc = PositionEncodingSine(d_model) 128 | 129 | def _reset_parameters(self): 130 | for p in self.parameters(): 131 | if p.dim() > 1: 132 | nn.init.xavier_uniform_(p) 133 | 134 | 135 | def forward(self, feats0, feats1): 136 | """ 137 | Runs the common self and cross-attention module. 138 | Args: 139 | feats_a: Features from image A (source) ([N, d_model, im_size/down_factor, im_size/down_factor]). 140 | feats_b: Features from image B (destination) ([N, d_model, im_size/down_factor, im_size/down_factor]). 141 | Output: 142 | feats_a: Self and cross-attended features corresponding to image A (source) 143 | ([N, d_model, im_size/down_factor, im_size/down_factor]) 144 | feats_b: Self and cross-attended features corresponding to image B (destination) 145 | ([N, d_model, im_size/down_factor, im_size/down_factor]). 146 | """ 147 | 148 | assert self.d_model == feats0.size(1), "The feature size and transformer must be equal" 149 | 150 | b, c, h, w = feats0.size() 151 | 152 | if self.add_posEnc: 153 | feats0 = self.posEnc(feats0) 154 | feats1 = self.posEnc(feats1) 155 | 156 | feats0 = rearrange(feats0, 'n c h w -> n (h w) c') 157 | feats1 = rearrange(feats1, 'n c h w -> n (h w) c') 158 | 159 | # Apply linear self attention to feats 160 | for layer, name in zip(self.layers, self.layer_names): 161 | if name == 'self': 162 | feats0 = layer(feats0, feats0) 163 | feats1 = layer(feats1, feats1) 164 | elif name == 'cross': 165 | feats0, feats1 = layer(feats0, feats1), layer(feats1, feats0) 166 | else: 167 | raise KeyError 168 | 169 | feats0 = feats0.transpose(2, 1).reshape((b, c, h, w)) 170 | feats1 = feats1.transpose(2, 1).reshape((b, c, h, w)) 171 | 172 | return feats0, feats1 173 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/att_layers/transformer_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from lib.models.MicKey.modules.att_layers.attention import Attention 5 | 6 | class EncoderLayer(nn.Module): 7 | """ 8 | Transformer encoder layer containing the linear self and cross-attention, and the epipolar attention. 9 | Arguments: 10 | d_model: Feature dimension of the input feature maps (default: 128d). 11 | nhead: Number of heads in the multi-head attention. 12 | attention: Type of attention for the common transformer block. Options: linear, full. 13 | """ 14 | def __init__(self, d_model, nhead, attention='linear'): 15 | super(EncoderLayer, self).__init__() 16 | 17 | # Transformer encoder layer parameters 18 | self.dim = d_model // nhead 19 | self.nhead = nhead 20 | 21 | # multi-head attention definition 22 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 23 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 24 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 25 | # full_att = False if attention == 'linear' else True 26 | self.attention = Attention(attention=attention) 27 | self.merge = nn.Linear(d_model, d_model, bias=False) 28 | 29 | # feed-forward network 30 | self.mlp = nn.Sequential( 31 | nn.Linear(d_model*2, d_model*2, bias=False), 32 | nn.ReLU(True), 33 | nn.Linear(d_model*2, d_model, bias=False), 34 | ) 35 | 36 | # norm and dropout 37 | self.norm1 = nn.LayerNorm(d_model) 38 | self.norm2 = nn.LayerNorm(d_model) 39 | 40 | def forward(self, x, source): 41 | """ 42 | Args: 43 | x (torch.Tensor): [N, L, C] (L = im_size/down_factor ** 2) 44 | source (torch.Tensor): [N, S, C] 45 | if is_epi_att: 46 | S = (im_size/down_factor/step_grid) ** 2 * sampling_dim 47 | else: 48 | S = im_size/down_factor ** 2 49 | is_epi_att (bool): Indicates whether it applies epipolar cross-attention 50 | """ 51 | bs = x.size(0) 52 | query, key, value = x, source, source 53 | 54 | # multi-head attention 55 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 56 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 57 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) 58 | message = self.attention(query, key, value) # [N, L, (H, D)] 59 | message = self.merge(message.view(bs, -1, self.nhead * self.dim)) # [N, L, C] 60 | message = self.norm1(message) 61 | 62 | # feed-forward network 63 | message = self.mlp(torch.cat([x, message], dim=2)) 64 | message = self.norm2(message) 65 | 66 | return x + message 67 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/compute_correspondences.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.models.MicKey.modules.mickey_extractor import MicKey_Extractor 4 | from lib.models.MicKey.modules.utils.feature_matcher import featureMatcher 5 | 6 | class ComputeCorrespondences(nn.Module): 7 | def __init__(self, cfg): 8 | super().__init__() 9 | 10 | # Feature extractor 11 | self.extractor = MicKey_Extractor(cfg['MICKEY']) 12 | 13 | self.dsc_dim = cfg['MICKEY']['DSC_HEAD']['LAST_DIM'] 14 | 15 | # Feature matcher 16 | self.matcher = featureMatcher(cfg['FEATURE_MATCHER']) 17 | 18 | self.down_factor = cfg['MICKEY']['DINOV2']['DOWN_FACTOR'] 19 | 20 | def get_abs_kpts_coordinates(self, kpts): 21 | 22 | B, C, H, W = kpts.shape 23 | 24 | # Compute offset for every kp grid 25 | x_abs_pos = torch.arange(W).view(1, 1, W).tile([B, H, 1]).to(kpts.device) 26 | y_abs_pos = torch.arange(H).view(1, H, 1).tile([B, 1, W]).to(kpts.device) 27 | abs_pos = torch.concat([x_abs_pos.unsqueeze(1), y_abs_pos.unsqueeze(1)], dim=1) 28 | 29 | kpts_abs_pos = (kpts + abs_pos) * self.down_factor 30 | 31 | return kpts_abs_pos 32 | 33 | def prepare_kpts_dsc(self, kpt, depth, scr, dsc): 34 | 35 | B, _, H, W = kpt.shape 36 | num_kpts = (H * W) 37 | 38 | kpt = kpt.view(B, 2, num_kpts) 39 | depth = depth.view(B, 1, num_kpts) 40 | scr = scr.view(B, 1, num_kpts) 41 | dsc = dsc.view(B, self.dsc_dim, num_kpts) 42 | 43 | return kpt, depth, scr, dsc 44 | 45 | # Independent method to only combine matching and keypoint scores during training 46 | def kp_matrix_scores(self, sc0, sc1): 47 | 48 | # matrix with "probability" of sampling a correspondence based on keypoint scores only 49 | scores = torch.matmul(sc0.transpose(2, 1).contiguous(), sc1) 50 | return scores 51 | 52 | def forward(self, data): 53 | 54 | # Compute detection and descriptor maps 55 | im0 = data['image0'] 56 | im1 = data['image1'] 57 | 58 | # Extract independently features from im0 and im1 59 | kps0, depth0, scr0, dsc0 = self.extractor(im0) 60 | kps1, depth1, scr1, dsc1 = self.extractor(im1) 61 | 62 | kps0 = self.get_abs_kpts_coordinates(kps0) 63 | kps1 = self.get_abs_kpts_coordinates(kps1) 64 | 65 | # Log shape for logging purposes 66 | _, _, H_kp0, W_kp0 = kps0.shape 67 | _, _, H_kp1, W_kp1 = kps1.shape 68 | data['kps0_shape'] = [H_kp0, W_kp0] 69 | data['kps1_shape'] = [H_kp1, W_kp1] 70 | data['depth0_map'] = depth0 71 | data['depth1_map'] = depth1 72 | data['down_factor'] = self.down_factor 73 | 74 | # Reshape kpts and descriptors to [B, num_kpts, dim] 75 | kps0, depth0, scr0, dsc0 = self.prepare_kpts_dsc(kps0, depth0, scr0, dsc0) 76 | kps1, depth1, scr1, dsc1 = self.prepare_kpts_dsc(kps1, depth1, scr1, dsc1) 77 | 78 | # get correspondences 79 | scores = self.matcher(kps0, dsc0, kps1, dsc1) 80 | 81 | data['kps0'] = kps0 82 | data['depth_kp0'] = depth0 83 | data['scr0'] = scr0 84 | data['kps1'] = kps1 85 | data['depth_kp1'] = depth1 86 | data['scr1'] = scr1 87 | data['scores'] = scores 88 | data['dsc0'] = dsc0 89 | data['dsc1'] = dsc1 90 | data['kp_scores'] = self.kp_matrix_scores(scr0, scr1) 91 | 92 | return kps0, dsc0, kps1, dsc1 93 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/loss/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from lib.utils.metrics import vcre_loss 4 | 5 | def compute_angular_error(R, t, Rgt_i, tgt_i): 6 | loss_rot, rot_err = rot_angle_loss(R, Rgt_i) 7 | loss_trans, t_err = trans_ang_loss(t, tgt_i) 8 | 9 | max_loss, _ = torch.max(torch.cat((loss_rot, loss_trans), dim=-1), dim=-1) 10 | return max_loss, loss_rot, loss_trans 11 | 12 | def compute_angular_error_weighted(R, t, Rgt_i, tgt_i, weights_t): 13 | loss_rot, rot_err = rot_angle_loss(R, Rgt_i) 14 | loss_trans, t_err = trans_ang_loss(t, tgt_i) 15 | 16 | max_loss, _ = torch.max(torch.cat((loss_rot, loss_trans * weights_t), dim=-1), dim=-1) 17 | return max_loss, loss_rot, loss_trans 18 | 19 | def ess_sq_euclidean_error(E, Egt): 20 | 21 | B = E.shape[0] 22 | E_norm = E/E[:, 2, 2].view(B, 1, 1) 23 | Egt_norm = Egt/Egt[:, 2, 2].view(B, 1, 1) 24 | return torch.pow(E_norm-Egt_norm, 2).view(B, -1).sum(1) 25 | 26 | def compute_pose_loss(R, t, Rgt_i, tgt_i, K0=None, K1=None, soft_clipping=True): 27 | loss_rot, rot_err = rot_angle_loss(R, Rgt_i) 28 | loss_trans = trans_l1_loss(t, tgt_i) 29 | 30 | if soft_clipping: 31 | loss_trans_soft = torch.tanh(loss_trans/0.9) # xm ~ ? 32 | loss_rot_soft = torch.tanh(loss_rot/0.9) # xrads=xdeg ~ ? 33 | 34 | loss = loss_rot_soft + loss_trans_soft 35 | else: 36 | loss = loss_rot + loss_trans 37 | 38 | return loss, loss_rot, loss_trans 39 | 40 | def compute_vcre_loss(R, t, Rgt, tgt, K0, K1, soft_clipping=True): 41 | 42 | B = R.shape[0] 43 | Tgt = torch.zeros((B, 4, 4)).float().to(R.device) 44 | Tgt[:, :3, :3] = Rgt 45 | Tgt[:, :3, 3:] = tgt.transpose(2, 1) 46 | 47 | # Inv pose: 48 | R_inv = R.transpose(2, 1) 49 | t_inv = (-1 * R_inv @ t.transpose(2, 1)).transpose(2, 1) 50 | Tgt_inv = torch.zeros((B, 4, 4)).float().to(R.device) 51 | Rgt_inv = Rgt.transpose(2, 1) 52 | tgt_inv = (-1 * Rgt_inv @ tgt.transpose(2, 1)).transpose(2, 1) 53 | Tgt_inv[:, :3, :3] = Rgt_inv 54 | Tgt_inv[:, :3, 3:] = tgt_inv.transpose(2, 1) 55 | 56 | loss0 = vcre_loss(R, t, Tgt, K0) 57 | loss1 = vcre_loss(R_inv, t_inv, Tgt_inv, K1) 58 | 59 | loss = (loss1 + loss0) / 2. 60 | if soft_clipping: 61 | loss = torch.tanh(loss / 80) 62 | 63 | loss_rot, rot_err = rot_angle_loss(R, Rgt) 64 | loss_trans = trans_l1_loss(t, tgt) 65 | 66 | return loss, loss_rot, loss_trans 67 | 68 | def trans_ang_loss(t, tgt): 69 | """Computes L1 loss for translation vector ANGULAR error 70 | Input: 71 | t - estimated translation vector [B, 1, 3] 72 | tgt - ground-truth translation vector [B, 1, 3] 73 | Output: translation_loss 74 | """ 75 | 76 | scale_t = torch.linalg.norm(t, dim=-1) 77 | scale_tgt = torch.linalg.norm(tgt, dim=-1) 78 | 79 | cosine = (t @ tgt.transpose(1, 2)).squeeze(-1) / (scale_t * scale_tgt + 1e-6) 80 | cosine = torch.clip(cosine, -0.99999, 0.99999) # handle numerical errors and NaNs 81 | t_ang_err = torch.acos(cosine) 82 | t_ang_err = torch.minimum(t_ang_err, math.pi - t_ang_err) 83 | return torch.abs(t_ang_err - torch.zeros_like(t_ang_err)), t_ang_err 84 | 85 | def trans_l1_loss(t, tgt): 86 | """Computes L1 loss for translation vector 87 | Input: 88 | t - estimated translation vector [B, 1, 3] 89 | tgt - ground-truth translation vector [B, 1, 3] 90 | Output: translation_loss 91 | """ 92 | 93 | return torch.abs(t - tgt).sum(-1) 94 | 95 | def rot_angle_loss(R, Rgt): 96 | """ 97 | Computes rotation loss using L1 error of residual rotation angle [radians] 98 | Input: 99 | R - estimated rotation matrix [B, 3, 3] 100 | Rgt - groundtruth rotation matrix [B, 3, 3] 101 | Output: rotation_loss 102 | """ 103 | 104 | residual = R.transpose(1, 2) @ Rgt 105 | trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1) 106 | cosine = (trace - 1) / 2 107 | cosine = torch.clip(cosine, -0.99999, 0.99999) # handle numerical errors and NaNs 108 | R_err = torch.acos(cosine) 109 | loss = torch.abs(R_err - torch.zeros_like(R_err)).unsqueeze(-1) 110 | return loss, R_err 111 | 112 | 113 | def to_homogeneous_torch_batched(u_xys: torch.Tensor): 114 | batch_size, _, num_pts = u_xys.shape 115 | ones = torch.ones((batch_size, 1, num_pts)).float().to(u_xys.device) 116 | u_xyhs = torch.concat([u_xys, ones], dim=1) 117 | return u_xyhs 118 | 119 | 120 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/loss/solvers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def weighted_procrustes(A, B, w=None, use_weights=True, use_mask=False, eps=1e-16, check_rank=True): 4 | """ 5 | X: torch tensor B x N x 3 6 | Y: torch tensor B x N x 3 7 | w: torch tensor B x N 8 | """ 9 | # https://ieeexplore.ieee.org/document/88573 10 | # https://github.com/chrischoy/DeepGlobalRegistration/blob/master/core/registration.py#L160 11 | # Refer to Mapfree procrustes 12 | assert len(A) == len(B) 13 | if use_weights: 14 | W1 = torch.abs(w).sum(1, keepdim=True) 15 | w_norm = (w / (W1 + eps)).unsqueeze(-1) 16 | a_mean = (w_norm * A).sum(1, keepdim=True) 17 | b_mean = (w_norm * B).sum(1, keepdim=True) 18 | # print('Possible ERROR:') 19 | # print('check this: Sxy = (Y - muy).t().mm(w_norm * (X - mux)).cpu().double(). Repo DeepGlobalRegistration') 20 | 21 | A_c = A - a_mean 22 | B_c = B - b_mean 23 | 24 | if use_mask: 25 | # Covariance matrix 26 | H = A_c.transpose(1, 2) @ (w.unsqueeze(-1) * B_c) 27 | else: 28 | # Covariance matrix 29 | H = A_c.transpose(1, 2) @ (w_norm.unsqueeze(-1) * B_c) 30 | 31 | else: 32 | a_mean = A.mean(axis=1, keepdim=True) 33 | b_mean = B.mean(axis=1, keepdim=True) 34 | 35 | A_c = A - a_mean 36 | B_c = B - b_mean 37 | 38 | # Covariance matrix 39 | H = A_c.transpose(1, 2) @ B_c 40 | 41 | if check_rank: 42 | if (torch.linalg.matrix_rank(H) == 1).sum() > 0: 43 | return None, None, False 44 | 45 | U, S, V = torch.svd(H) 46 | # Fixes orientation such that Det(R) = + 1 47 | Z = torch.eye(3).unsqueeze(0).repeat(A.shape[0], 1, 1).to(A.device) 48 | Z[:, -1, -1] = torch.sign(torch.linalg.det(U @ V.transpose(1, 2))) 49 | # Rotation matrix 50 | R = V @ Z @ U.transpose(1, 2) 51 | # Translation vector 52 | t = b_mean - a_mean @ R.transpose(1, 2) 53 | 54 | return R, t, True -------------------------------------------------------------------------------- /lib/models/MicKey/modules/mickey_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from lib.models.MicKey.modules.DINO_modules.dinov2 import vit_large 4 | from lib.models.MicKey.modules.att_layers.transformer import Transformer_self_att 5 | from lib.models.MicKey.modules.utils.extractor_utils import desc_l2norm, BasicBlock 6 | 7 | class MicKey_Extractor(nn.Module): 8 | def __init__(self, cfg, dinov2_weights=None): 9 | super().__init__() 10 | 11 | # Define DINOv2 extractor 12 | self.dino_channels = cfg['DINOV2']['CHANNEL_DIM'] 13 | self.dino_downfactor = cfg['DINOV2']['DOWN_FACTOR'] 14 | if dinov2_weights is None: 15 | dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/" 16 | "dinov2_vitl14/dinov2_vitl14_pretrain.pth", 17 | map_location="cpu") 18 | vit_kwargs = dict(img_size= 518, 19 | patch_size= 14, 20 | init_values = 1.0, 21 | ffn_layer = "mlp", 22 | block_chunks = 0, 23 | ) 24 | 25 | self.dinov2_vitl14 = vit_large(**vit_kwargs) 26 | self.dinov2_vitl14.load_state_dict(dinov2_weights) 27 | self.dinov2_vitl14.requires_grad_(False) 28 | self.dinov2_vitl14.eval() 29 | 30 | # Define whether DINOv2 runs on float16 or float32 31 | if cfg['DINOV2']['FLOAT16']: 32 | self.amp_dtype = torch.float16 33 | self.dinov2_vitl14.to(self.amp_dtype) 34 | else: 35 | self.amp_dtype = torch.float32 36 | 37 | # Define MicKey's heads 38 | self.depth_head = DeepResBlock_depth(cfg) 39 | self.det_offset = DeepResBlock_offset(cfg) 40 | self.dsc_head = DeepResBlock_desc(cfg) 41 | self.det_head = DeepResBlock_det(cfg) 42 | 43 | def forward(self, x): 44 | 45 | B, C, H, W = x.shape 46 | x = x[:, :, :self.dino_downfactor * (H//self.dino_downfactor), :self.dino_downfactor * (W//self.dino_downfactor)] 47 | 48 | with torch.no_grad(): 49 | dinov2_features = self.dinov2_vitl14.forward_features(x.to(self.amp_dtype)) 50 | dinov2_features = dinov2_features['x_norm_patchtokens'].permute(0, 2, 1).\ 51 | reshape(B, self.dino_channels, H // self.dino_downfactor, W // self.dino_downfactor).float() 52 | 53 | scrs = self.det_head(dinov2_features) 54 | kpts = self.det_offset(dinov2_features) 55 | depths = self.depth_head(dinov2_features) 56 | dscs = self.dsc_head(dinov2_features) 57 | 58 | return kpts, depths, scrs, dscs 59 | 60 | def train(self, mode: bool = True): 61 | self.dsc_head.train(mode) 62 | self.depth_head.train(mode) 63 | self.det_offset.train(mode) 64 | self.det_head.train(mode) 65 | 66 | 67 | class DeepResBlock_det(torch.nn.Module): 68 | def __init__(self, config, padding_mode = 'zeros'): 69 | super().__init__() 70 | 71 | bn = config['KP_HEADS']['BN'] 72 | in_channels = config['DINOV2']['CHANNEL_DIM'] 73 | block_dims = config['KP_HEADS']['BLOCKS_DIM'] 74 | add_posEnc = config['KP_HEADS']['POS_ENCODING'] 75 | 76 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode) 77 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode) 78 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode) 79 | self.resblock4 = BasicBlock(block_dims[2], block_dims[3], stride=1, bn=bn, padding_mode=padding_mode) 80 | 81 | self.score = nn.Conv2d(block_dims[3], 1, kernel_size=1, stride=1, padding=0, bias=False) 82 | 83 | self.use_softmax = config['KP_HEADS']['USE_SOFTMAX'] 84 | self.sigmoid = torch.nn.Sigmoid() 85 | self.logsigmoid = torch.nn.LogSigmoid() 86 | self.softmax = torch.nn.Softmax(dim=-1) 87 | 88 | # Allow more exploration with reinforce algorithm 89 | self.tmp_softmax = 100 90 | 91 | self.eps = nn.Parameter(torch.tensor(1e-16), requires_grad=False) 92 | self.offset_par1 = nn.Parameter(torch.tensor(0.5), requires_grad=False) 93 | self.offset_par2 = nn.Parameter(torch.tensor(2.), requires_grad=False) 94 | self.ones_kernel = nn.Parameter(torch.ones((1, 1, 3, 3)), requires_grad=False) 95 | 96 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc) 97 | 98 | def remove_borders(self, score_map: torch.Tensor, borders: int): 99 | ''' 100 | It removes the borders of the image to avoid detections on the corners 101 | ''' 102 | shape = score_map.shape 103 | mask = torch.ones_like(score_map) 104 | 105 | mask[:, :, 0:borders, :] = 0 106 | mask[:, :, :, 0:borders] = 0 107 | mask[:, :, shape[2] - borders:shape[2], :] = 0 108 | mask[:, :, :, shape[3] - borders:shape[3]] = 0 109 | 110 | return mask * score_map 111 | 112 | def remove_brd_and_softmax(self, scores, borders): 113 | 114 | B = scores.shape[0] 115 | 116 | scores = scores - (scores.view(B, -1).mean(-1).view(B, 1, 1, 1) + self.eps).detach() 117 | exp_scores = torch.exp(scores / self.tmp_softmax) 118 | 119 | # remove borders 120 | exp_scores = self.remove_borders(exp_scores, borders=borders) 121 | 122 | # apply softmax 123 | sum_scores = exp_scores.sum(-1).sum(-1).view(B, 1, 1, 1) 124 | return exp_scores / (sum_scores + self.eps) 125 | 126 | def forward(self, feature_volume): 127 | 128 | x = self.resblock1(feature_volume) 129 | x = self.resblock2(x) 130 | x = self.resblock3(x) 131 | x = self.att_layer(x) 132 | x = self.resblock4(x) 133 | 134 | # Predict xy scores 135 | scores = self.score(x) 136 | 137 | if self.use_softmax: 138 | scores = self.remove_brd_and_softmax(scores, 3) 139 | else: 140 | scores = self.remove_borders(self.sigmoid(scores), borders=3) 141 | 142 | return scores 143 | 144 | 145 | class DeepResBlock_offset(torch.nn.Module): 146 | def __init__(self, config, padding_mode = 'zeros'): 147 | super().__init__() 148 | 149 | bn = config['KP_HEADS']['BN'] 150 | in_channels = config['DINOV2']['CHANNEL_DIM'] 151 | block_dims = config['KP_HEADS']['BLOCKS_DIM'] 152 | add_posEnc = config['KP_HEADS']['POS_ENCODING'] 153 | self.sigmoid = torch.nn.Sigmoid() 154 | 155 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode) 156 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode) 157 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode) 158 | self.resblock4 = BasicBlock(block_dims[2], block_dims[3], stride=1, bn=bn, padding_mode=padding_mode) 159 | 160 | self.xy_offset = nn.Conv2d(block_dims[3], 2, kernel_size=1, stride=1, padding=0, bias=False) 161 | 162 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc) 163 | 164 | def forward(self, feature_volume): 165 | 166 | x = self.resblock1(feature_volume) 167 | x = self.resblock2(x) 168 | x = self.resblock3(x) 169 | x = self.att_layer(x) 170 | x = self.resblock4(x) 171 | 172 | # Predict xy offsets 173 | xy_offsets = self.xy_offset(x) 174 | 175 | # Offset goes from 0 to 1 176 | xy_offsets = self.sigmoid(xy_offsets) 177 | 178 | return xy_offsets 179 | 180 | 181 | class DeepResBlock_depth(torch.nn.Module): 182 | def __init__(self, config, padding_mode = 'zeros'): 183 | super().__init__() 184 | 185 | bn = config['KP_HEADS']['BN'] 186 | in_channels = config['DINOV2']['CHANNEL_DIM'] 187 | block_dims = config['KP_HEADS']['BLOCKS_DIM'] 188 | add_posEnc = config['KP_HEADS']['POS_ENCODING'] 189 | 190 | self.use_depth_sigmoid = config['KP_HEADS']['USE_DEPTHSIGMOID'] 191 | self.max_depth = config['KP_HEADS']['MAX_DEPTH'] 192 | self.sigmoid = torch.nn.Sigmoid() 193 | 194 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode) 195 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode) 196 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode) 197 | self.resblock4 = BasicBlock(block_dims[2], block_dims[3], stride=1, bn=bn, padding_mode=padding_mode) 198 | 199 | self.depth = nn.Conv2d(block_dims[3], 1, kernel_size=1, stride=1, padding=0, bias=False) 200 | 201 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc) 202 | 203 | def forward(self, feature_volume): 204 | 205 | x = self.resblock1(feature_volume) 206 | x = self.resblock2(x) 207 | x = self.resblock3(x) 208 | x = self.att_layer(x) 209 | x = self.resblock4(x) 210 | 211 | # Predict xy depths 212 | # depths = torch.clip(self.depth(x), min=1e-3, max=500) 213 | if self.use_depth_sigmoid: 214 | depths = self.max_depth * self.sigmoid(self.depth(x)) 215 | else: 216 | depths = self.depth(x) 217 | 218 | return depths 219 | 220 | 221 | class DeepResBlock_desc(torch.nn.Module): 222 | def __init__(self, config, padding_mode = 'zeros'): 223 | super().__init__() 224 | 225 | bn = config['KP_HEADS']['BN'] 226 | last_dim = config['DSC_HEAD']['LAST_DIM'] 227 | in_channels = config['DINOV2']['CHANNEL_DIM'] 228 | block_dims = config['KP_HEADS']['BLOCKS_DIM'] 229 | add_posEnc = config['DSC_HEAD']['POS_ENCODING'] 230 | self.norm_desc = config['DSC_HEAD']['NORM_DSC'] 231 | 232 | self.resblock1 = BasicBlock(in_channels, block_dims[0], stride=1, bn=bn, padding_mode=padding_mode) 233 | self.resblock2 = BasicBlock(block_dims[0], block_dims[1], stride=1, bn=bn, padding_mode=padding_mode) 234 | self.resblock3 = BasicBlock(block_dims[1], block_dims[2], stride=1, bn=bn, padding_mode=padding_mode) 235 | self.resblock4 = BasicBlock(block_dims[2], last_dim, stride=1, bn=bn, padding_mode=padding_mode) 236 | 237 | self.att_layer = Transformer_self_att(d_model=128, num_layers=3, add_posEnc=add_posEnc) 238 | 239 | 240 | def forward(self, feature_volume): 241 | 242 | x = self.resblock1(feature_volume) 243 | x = self.resblock2(x) 244 | x = self.resblock3(x) 245 | x = self.att_layer(x) 246 | x = self.resblock4(x, relu=False) 247 | 248 | if self.norm_desc: 249 | x = desc_l2norm(x) 250 | 251 | return x 252 | 253 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/utils/extractor_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def desc_l2norm(desc: torch.Tensor): 7 | '''descriptors with shape NxC or NxCxHxW''' 8 | eps_l2_norm = 1e-10 9 | desc = desc / desc.pow(2).sum(dim=1, keepdim=True).add(eps_l2_norm).pow(0.5) 10 | return desc 11 | 12 | class BasicBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1, bn=True, padding_mode='zeros'): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, padding_mode=padding_mode) 19 | self.bn1 = nn.BatchNorm2d(planes) if bn else nn.Identity() 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode) 21 | self.bn2 = nn.BatchNorm2d(planes) if bn else nn.Identity() 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x, relu=True): 29 | shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += shortcut 33 | if relu: 34 | out = F.relu(out) 35 | return out 36 | -------------------------------------------------------------------------------- /lib/models/MicKey/modules/utils/feature_matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def arange_like(x, dim: int): 6 | return x.new_ones(x.shape[dim]).cumsum(0) - 1 # traceable in 1.1 7 | 8 | class featureMatcher(nn.Module): 9 | def __init__(self, cfg): 10 | super().__init__() 11 | 12 | if cfg['TYPE'] == 'DualSoftmax': 13 | self.matching_mat = dualSoftmax(cfg['DUAL_SOFTMAX']) 14 | elif cfg['TYPE'] == 'Sinkhorn': 15 | self.matching_mat = sinkhorn(cfg['SINKHORN']) 16 | else: 17 | print('[ERROR]: feature matcher not recognized') 18 | 19 | def get_matches_list(self, scores, min_conf=0.0): 20 | 21 | # Supports batch_size = 1 22 | 23 | # Get the matches with score above "match_threshold". 24 | max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1) 25 | indices0, indices1 = max0.indices, max1.indices 26 | mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0) 27 | # zero = scores.new_tensor(0) 28 | zero = torch.tensor(0).to(scores.device).float() 29 | mscores0 = torch.where(mutual0, max0.values.exp(), zero) 30 | valid0 = mutual0 & (mscores0 > min_conf) 31 | # indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) 32 | minus_one = torch.tensor(-1).to(scores.device) 33 | indices0 = torch.where(valid0, indices0, minus_one) 34 | 35 | valid = indices0 > -1 36 | 37 | idx0 = arange_like(indices0, 1)[valid[0]].unsqueeze(1) 38 | idx1 = indices0[valid].unsqueeze(1) 39 | 40 | matches = torch.concat([idx0, idx1], dim=1) 41 | 42 | batch_idx = torch.tile(torch.arange(1).view(1, 1), [1, len(matches)]).reshape(-1) 43 | scores_matches = scores[batch_idx, idx0[:, 0], idx1[:, 0]] 44 | _, idx_sorted = torch.sort(scores_matches, descending=True) 45 | 46 | return matches[idx_sorted] 47 | 48 | def forward(self, kpt0, dsc0, kpt1, dsc1): 49 | 50 | scores = self.matching_mat(dsc0, dsc1) 51 | 52 | return scores 53 | 54 | class dualSoftmax(nn.Module): 55 | def __init__(self, cfg): 56 | super().__init__() 57 | 58 | self.temperature = cfg['TEMPERATURE'] 59 | self.use_dustbin = False 60 | if cfg['USE_DUSTBIN']: 61 | self.dustbin_score = nn.Parameter(torch.tensor(1.)) 62 | self.use_dustbin = True 63 | 64 | def forward(self, dsc0, dsc1): 65 | scores = torch.matmul(dsc0.transpose(1, 2).contiguous(), dsc1) / self.temperature 66 | 67 | if self.use_dustbin: 68 | b, m, n = scores.shape 69 | 70 | bins0 = self.dustbin_score.expand(b, m, 1) 71 | bins1 = self.dustbin_score.expand(b, 1, n) 72 | alpha = self.dustbin_score.expand(b, 1, 1) 73 | 74 | couplings = torch.cat([torch.cat([scores, bins0], -1), 75 | torch.cat([bins1, alpha], -1)], 1) 76 | 77 | couplings = F.softmax(couplings, 1) * F.softmax(couplings, 2) 78 | scores = couplings[:, :-1, :-1] 79 | 80 | else: 81 | scores = F.softmax(scores, 1) * F.softmax(scores, 2) 82 | 83 | return scores 84 | 85 | class sinkhorn(nn.Module): 86 | def __init__(self, cfg): 87 | super().__init__() 88 | 89 | self.dustbin_score = nn.Parameter(torch.tensor(cfg['DUSTBIN_SCORE_INIT'])) 90 | self.sinkhorn_iterations = cfg['NUM_IT'] 91 | self.descriptor_dim = 128 92 | 93 | def log_sinkhorn_iterations(self, Z: torch.Tensor, log_mu: torch.Tensor, log_nu: torch.Tensor, 94 | iters: int) -> torch.Tensor: 95 | """ Perform Sinkhorn Normalization in Log-space for stability""" 96 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 97 | for _ in range(iters): 98 | u = log_mu - torch.logsumexp(Z + v.unsqueeze(1), dim=2) 99 | v = log_nu - torch.logsumexp(Z + u.unsqueeze(2), dim=1) 100 | return Z + u.unsqueeze(2) + v.unsqueeze(1) 101 | 102 | def log_optimal_transport(self, scores: torch.Tensor, alpha: torch.Tensor, iters: int) -> torch.Tensor: 103 | """ Perform Differentiable Optimal Transport in Log-space for stability""" 104 | b, m, n = scores.shape 105 | # one = scores.new_tensor(1) 106 | one = torch.ones((), device=scores.device) 107 | ms, ns = (m * one).to(scores), (n * one).to(scores) 108 | 109 | bins0 = alpha.expand(b, m, 1) 110 | bins1 = alpha.expand(b, 1, n) 111 | alpha = alpha.expand(b, 1, 1) 112 | 113 | couplings = torch.cat([torch.cat([scores, bins0], -1), 114 | torch.cat([bins1, alpha], -1)], 1) 115 | 116 | norm = - (ms + ns).log() 117 | log_mu = torch.cat([norm.expand(m), ns.log()[None] + norm]) 118 | log_nu = torch.cat([norm.expand(n), ms.log()[None] + norm]) 119 | log_mu, log_nu = log_mu[None].expand(b, -1), log_nu[None].expand(b, -1) 120 | 121 | Z = self.log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) 122 | Z = Z - norm # multiply probabilities by M+N 123 | return Z.exp() 124 | 125 | def forward(self, dsc0, dsc1, tmp): 126 | 127 | # Compute matching descriptor distance. 128 | scores = torch.einsum('bdn,bdm->bnm', dsc0, dsc1) 129 | scores = scores / self.descriptor_dim**.5 130 | 131 | # scores = torch.matmul(dsc0.transpose(1, 2).contiguous(), dsc1) 132 | 133 | scores = self.log_optimal_transport( 134 | scores, self.dustbin_score, 135 | iters=self.sinkhorn_iterations) 136 | 137 | return scores[:, :-1, :-1] 138 | -------------------------------------------------------------------------------- /lib/models/builder.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | 3 | from lib.models.MicKey.compute_pose import MickeyRelativePose 4 | 5 | def build_model(cfg, checkpoint=''): 6 | 7 | if cfg.MODEL == 'MicKey': 8 | 9 | model = MickeyRelativePose(cfg) 10 | 11 | checkpoint_loaded = torch.load(checkpoint) 12 | model.on_load_checkpoint(checkpoint_loaded) 13 | model.load_state_dict(checkpoint_loaded['state_dict']) 14 | 15 | if torch.cuda.is_available(): 16 | model = model.cuda() 17 | model.eval() 18 | return model 19 | else: 20 | raise NotImplementedError() 21 | -------------------------------------------------------------------------------- /lib/utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def data_to_model_device(data, model): 4 | '''Move all tensors in data dictionary to the same device as model''' 5 | 6 | try: 7 | device = next(model.parameters()).device 8 | except: 9 | # in case the model has no parameters (baseline models) 10 | device = 'cpu' 11 | 12 | for k, v in data.items(): 13 | if torch.is_tensor(v): 14 | data[k] = v.to(device) 15 | 16 | return data 17 | -------------------------------------------------------------------------------- /lib/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Code adapted from Map-free benchmark: https://github.com/nianticlabs/map-free-reloc 2 | 3 | import torch 4 | import numpy as np 5 | from collections import defaultdict 6 | from lib.benchmarks.reprojection import get_grid_multipleheight 7 | from lib.models.MicKey.modules.utils.training_utils import project_2d 8 | 9 | # global variable, avoids creating it again 10 | eye_coords_glob = get_grid_multipleheight() 11 | 12 | def pose_error_torch(R, t, Tgt, reduce=None): 13 | """Compute angular, scale and euclidean error of translation vector (metric). Compute angular rotation error.""" 14 | 15 | Rgt = Tgt[:, :3, :3] # [B, 3, 3] 16 | tgt = Tgt[:, :3, 3:].transpose(1, 2) # [B, 1, 3] 17 | 18 | scale_t = torch.linalg.norm(t, dim=-1) 19 | scale_tgt = torch.linalg.norm(tgt, dim=-1) 20 | 21 | cosine = (t @ tgt.transpose(1, 2)).squeeze(-1) / (scale_t * scale_tgt + 1e-9) 22 | cosine = torch.clip(cosine, -1.0, 1.0) # handle numerical errors 23 | t_ang_err = torch.rad2deg(torch.acos(cosine)) 24 | t_ang_err = torch.minimum(t_ang_err, 180 - t_ang_err) 25 | 26 | t_scale_err = scale_t / scale_tgt 27 | t_scale_err_sym = torch.maximum(scale_t / scale_tgt, scale_tgt / scale_t) 28 | t_euclidean_err = torch.linalg.norm(t - tgt, dim=-1) 29 | 30 | residual = R.transpose(1, 2) @ Rgt 31 | trace = torch.diagonal(residual, dim1=-2, dim2=-1).sum(-1) 32 | cosine = (trace - 1) / 2 33 | cosine = torch.clip(cosine, -1., 1.) # handle numerical errors 34 | R_err = torch.rad2deg(torch.acos(cosine)) 35 | 36 | if reduce is None: 37 | def fn(x): return x 38 | elif reduce == 'mean': 39 | fn = torch.mean 40 | elif reduce == 'median': 41 | fn = torch.median 42 | 43 | t_ang_err = fn(t_ang_err) 44 | t_scale_err = fn(t_scale_err) 45 | t_euclidean_err = fn(t_euclidean_err) 46 | R_err = fn(R_err) 47 | 48 | errors = {'t_err_ang': t_ang_err, 49 | 't_err_scale': t_scale_err, 50 | 't_err_scale_sym': t_scale_err_sym, 51 | 't_err_euc': t_euclidean_err, 52 | 'R_err': R_err} 53 | return errors 54 | 55 | 56 | def vcre_loss(R, t, Tgt, K0, H=720): 57 | """Compute Virtual Correspondences Reprojection Error in torch (with batches).""" 58 | 59 | B = R.shape[0] 60 | Rgt = Tgt[:, :3, :3] # [B, 3, 3] 61 | tgt = Tgt[:, :3, 3:].transpose(1, 2) # [B, 1, 3] 62 | 63 | eye_coords = torch.from_numpy(eye_coords_glob).unsqueeze(0)[:, :, :3].to(R.device).float() 64 | eye_coords = torch.tile(eye_coords, [B, 1, 1]) 65 | 66 | # obtain ground-truth position of projected points 67 | uv_gt = project_2d(eye_coords, K0) 68 | 69 | # Avoid breaking gradients due to inplace operation 70 | eye_coord_tmp = (R @ eye_coords.transpose(2, 1) + t.transpose(2, 1)) 71 | eyes_residual = (Rgt.transpose(2, 1) @ eye_coord_tmp -1 * Rgt.transpose(2, 1) @ tgt.transpose(2, 1)).transpose(2, 1) 72 | 73 | uv_pred = project_2d(eyes_residual, K0) 74 | 75 | uv_gt = torch.clip(uv_gt, 0, H) 76 | uv_pred = torch.clip(uv_pred, 0, H) 77 | 78 | repr_err = ((((uv_gt - uv_pred) ** 2.).sum(-1) + 1e-6) ** 0.5).mean(-1).view(B, 1) 79 | 80 | return repr_err 81 | 82 | 83 | def vcre_torch(R, t, Tgt, K0, reduce=None, H=720, W=540): 84 | """Compute Virtual Correspondences Reprojection Error in torch (with batches).""" 85 | 86 | B = R.shape[0] 87 | Rgt = Tgt[:, :3, :3] # [B, 3, 3] 88 | tgt = Tgt[:, :3, 3:].transpose(1, 2) # [B, 1, 3] 89 | 90 | eye_coords = torch.from_numpy(eye_coords_glob).unsqueeze(0).to(R.device).float() 91 | eye_coords = torch.tile(eye_coords, [B, 1, 1]) 92 | 93 | # obtain ground-truth position of projected points 94 | uv_gt = project_2d(eye_coords[:, :, :3], K0) 95 | 96 | # residual transformation 97 | cam2w_est = torch.tile(torch.eye(4).view(1, 4, 4), [B, 1, 1]).to(R.device).float() 98 | cam2w_est[:, :3, :3] = R 99 | cam2w_est[:, :3, -1] = t[:, 0] 100 | 101 | cam2w_gt = torch.tile(torch.eye(4).view(1, 4, 4), [B, 1, 1]).to(R.device).float() 102 | cam2w_gt[:, :3, :3] = Rgt 103 | cam2w_gt[:, :3, -1] = tgt[:, 0] 104 | 105 | # residual reprojection 106 | eyes_residual = (torch.linalg.inv(cam2w_gt) @ cam2w_est @ eye_coords.transpose(2, 1)).transpose(2, 1) 107 | uv_pred = project_2d(eyes_residual[:, :, :3], K0) 108 | 109 | uv_gt[:, :, 0], uv_pred[:, :, 0] = torch.clip(uv_gt[:, :, 0], 0, W), torch.clip(uv_pred[:, :, 0], 0, W) 110 | uv_gt[:, :, 1], uv_pred[:, :, 1] = torch.clip(uv_gt[:, :, 1], 0, H), torch.clip(uv_pred[:, :, 1], 0, H) 111 | 112 | repr_err = ((((uv_gt - uv_pred) ** 2.).sum(-1) + 1e-6) ** 0.5).mean(-1).view(B, 1) 113 | 114 | if reduce is None: 115 | def fn(x): return x 116 | elif reduce == 'mean': 117 | fn = torch.mean 118 | elif reduce == 'median': 119 | fn = torch.median 120 | 121 | repr_err = fn(repr_err) 122 | 123 | errors = {'repr_err': repr_err} 124 | 125 | return errors 126 | 127 | 128 | 129 | def error_auc(errors, thresholds): 130 | """ 131 | Args: 132 | errors (list): [N,] 133 | thresholds (list) 134 | """ 135 | errors = np.nan_to_num(errors, nan=float('inf')) # convert nans to inf 136 | errors = [0] + sorted(list(errors)) 137 | recall = list(np.linspace(0, 1, len(errors))) 138 | 139 | aucs = [] 140 | for thr in thresholds: 141 | last_index = np.searchsorted(errors, thr) 142 | y = recall[:last_index] + [recall[last_index-1]] 143 | x = errors[:last_index] + [thr] 144 | aucs.append(np.trapz(y, x) / thr) 145 | 146 | return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)} 147 | 148 | 149 | def ecdf(x): 150 | """Get Empirical Cumulative Distribution Function (ECDF) given samples x [N,]""" 151 | cd = np.linspace(0, 1, x.shape[0]) 152 | v = np.sort(x) 153 | return v, cd 154 | 155 | 156 | def print_auc_table(agg_metrics): 157 | pose_error = np.maximum(agg_metrics['R_err'], agg_metrics['t_err_ang']) 158 | auc_pose = error_auc(pose_error, (5, 10, 20)) 159 | print('Pose error AUC @ 5/10/20deg: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_pose.values())) 160 | 161 | auc_rotation = error_auc(agg_metrics['R_err'], (5, 10, 20)) 162 | print('Rotation error AUC @ 5/10/20deg: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_rotation.values())) 163 | 164 | auc_translation_ang = error_auc(agg_metrics['t_err_ang'], (5, 10, 20)) 165 | print( 166 | 'Translation angular error AUC @ 5/10/20deg: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_translation_ang.values())) 167 | 168 | auc_translation_euc = error_auc(agg_metrics['t_err_euc'], (0.1, 0.5, 1)) 169 | print( 170 | 'Translation Euclidean error AUC @ 0.1/0.5/1m: {0:.3f}/{1:.3f}/{2:.3f}'.format(*auc_translation_euc.values())) 171 | 172 | 173 | def precision(agg_metrics, rot_threshold, trans_threshold): 174 | '''Provides ratio of samples with rotation error < rot_threshold AND translation error < trans_threshold''' 175 | mask_rot = agg_metrics['R_err'] <= rot_threshold 176 | mask_trans = agg_metrics['t_err_euc'] <= trans_threshold 177 | recall = (mask_rot * mask_trans).mean() 178 | return recall 179 | 180 | 181 | def A_metrics(t_scale_err_sym): 182 | """Returns A1/A2/A3 metrics of translation vector norm given the "symmetric" scale error 183 | where 184 | t_scale_err_sym = torch.maximum((t_norm_gt / t_norm_pred), (t_norm_pred / t_norm_gt)) 185 | """ 186 | 187 | if not torch.is_tensor(t_scale_err_sym): 188 | t_scale_err_sym = torch.from_numpy(t_scale_err_sym) 189 | 190 | thresh = t_scale_err_sym 191 | a1 = (thresh < 1.25).float().mean() 192 | a2 = (thresh < 1.25 ** 2).float().mean() 193 | a3 = (thresh < 1.25 ** 3).float().mean() 194 | return a1, a2, a3 195 | 196 | 197 | class MetricsAccumulator: 198 | """Accumulates metrics and aggregates them when requested""" 199 | 200 | def __init__(self): 201 | self.data = defaultdict(list) 202 | 203 | def accumulate(self, data): 204 | for key, value in data.items(): 205 | self.data[key].append(value) 206 | 207 | def aggregate(self): 208 | res = dict() 209 | for key in self.data.keys(): 210 | res[key] = torch.cat(self.data[key]).view(-1).cpu().numpy() 211 | return res 212 | -------------------------------------------------------------------------------- /resources/environment.yml: -------------------------------------------------------------------------------- 1 | name: mickey 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - python=3.8.17 7 | - pip=23.2.1 8 | - pip: 9 | - einops==0.6.1 10 | - lazy-loader==0.3 11 | - lightning-utilities==0.9.0 12 | - matplotlib==3.7.2 13 | - numpy==1.24.4 14 | - omegaconf==2.3.0 15 | - open3d==0.17.0 16 | - opencv-python==4.8.0.74 17 | - protobuf==4.23.4 18 | - pytorch-lightning==2.0.6 19 | - tensorboard==2.13.0 20 | - tensorboard-data-server==0.7.1 21 | - timm==0.6.7 22 | - torch==2.0.1 23 | - torchmetrics==1.0.2 24 | - torchvision==0.15.2 25 | - tqdm==4.65.1 26 | - transforms3d==0.4.1 27 | - xformers==0.0.20 28 | - yacs==0.1.8 29 | - pillow==10.0.0 30 | - scipy==1.10.1 31 | - pyrender==0.1.45 32 | - trimesh==4.0.2 33 | -------------------------------------------------------------------------------- /resources/teaser_mickey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nianticlabs/mickey/2391be8a35491e7b43481c069f5dab65030839b9/resources/teaser_mickey.png -------------------------------------------------------------------------------- /submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from collections import defaultdict 4 | from dataclasses import dataclass 5 | from zipfile import ZipFile 6 | 7 | import torch 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from config.default import cfg 12 | from lib.datasets.datamodules import DataModule 13 | from lib.models.builder import build_model 14 | from lib.utils.data import data_to_model_device 15 | from transforms3d.quaternions import mat2quat 16 | 17 | @dataclass 18 | class Pose: 19 | image_name: str 20 | q: np.ndarray 21 | t: np.ndarray 22 | inliers: float 23 | 24 | def __str__(self) -> str: 25 | formatter = {'float': lambda v: f'{v:.6f}'} 26 | max_line_width = 1000 27 | q_str = np.array2string(self.q, formatter=formatter, max_line_width=max_line_width)[1:-1] 28 | t_str = np.array2string(self.t, formatter=formatter, max_line_width=max_line_width)[1:-1] 29 | return f'{self.image_name} {q_str} {t_str} {self.inliers}' 30 | 31 | 32 | def predict(loader, model): 33 | results_dict = defaultdict(list) 34 | 35 | for data in tqdm(loader): 36 | 37 | # run inference 38 | data = data_to_model_device(data, model) 39 | with torch.no_grad(): 40 | R_batched, t_batched = model(data) 41 | 42 | for i_batch in range(len(data['scene_id'])): 43 | R = R_batched[i_batch].unsqueeze(0).detach().cpu().numpy() 44 | t = t_batched[i_batch].reshape(-1).detach().cpu().numpy() 45 | inliers = data['inliers'][i_batch].item() 46 | 47 | scene = data['scene_id'][i_batch] 48 | query_img = data['pair_names'][1][i_batch] 49 | 50 | # ignore frames without poses (e.g. not enough feature matches) 51 | if np.isnan(R).any() or np.isnan(t).any() or np.isinf(t).any(): 52 | continue 53 | 54 | # populate results_dict 55 | estimated_pose = Pose(image_name=query_img, 56 | q=mat2quat(R).reshape(-1), 57 | t=t.reshape(-1), 58 | inliers=inliers) 59 | results_dict[scene].append(estimated_pose) 60 | 61 | return results_dict 62 | 63 | 64 | def save_submission(results_dict: dict, output_path: Path): 65 | with ZipFile(output_path, 'w') as zip: 66 | for scene, poses in results_dict.items(): 67 | poses_str = '\n'.join((str(pose) for pose in poses)) 68 | zip.writestr(f'pose_{scene}.txt', poses_str.encode('utf-8')) 69 | 70 | 71 | def eval(args): 72 | # Load configs 73 | cfg.merge_from_file('config/datasets/mapfree.yaml') 74 | cfg.merge_from_file(args.config) 75 | 76 | # Create dataloader 77 | if args.split == 'test': 78 | cfg.TRAINING.BATCH_SIZE = 8 79 | cfg.TRAINING.NUM_WORKERS = 8 80 | dataloader = DataModule(cfg, drop_last_val=False).test_dataloader() 81 | elif args.split == 'val': 82 | cfg.TRAINING.BATCH_SIZE = 12 83 | cfg.TRAINING.NUM_WORKERS = 8 84 | dataloader = DataModule(cfg, drop_last_val=False).val_dataloader() 85 | else: 86 | raise NotImplemented(f'Invalid split: {args.split}') 87 | 88 | # Create model 89 | model = build_model(cfg, args.checkpoint) 90 | 91 | # Get predictions from model 92 | results_dict = predict(dataloader, model) 93 | 94 | # Save predictions to txt per scene within zip 95 | args.output_root.mkdir(parents=True, exist_ok=True) 96 | save_submission(results_dict, args.output_root / 'submission.zip') 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--config', help='path to config file') 101 | parser.add_argument('--checkpoint', 102 | help='path to model checkpoint (models with learned parameters)', default='') 103 | parser.add_argument('--output_root', '-o', type=Path, default=Path('results/')) 104 | parser.add_argument('--split', choices=('val', 'test'), default='test', 105 | help='Dataset split to use for evaluation. Choose from test or val. Default: test') 106 | args = parser.parse_args() 107 | eval(args) 108 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | # do this before importing numpy! (doing it right up here in case numpy is dependency of e.g. json) 4 | os.environ["MKL_NUM_THREADS"] = "1" # noqa: E402 5 | os.environ["NUMEXPR_NUM_THREADS"] = "1" # noqa: E402 6 | os.environ["OMP_NUM_THREADS"] = "1" # noqa: E402 7 | os.environ["OPENBLAS_NUM_THREADS"] = "1" # noqa: E402 8 | 9 | import pytorch_lightning as pl 10 | import torch 11 | from pytorch_lightning.loggers import TensorBoardLogger 12 | 13 | from config.default import cfg 14 | from lib.datasets.datamodules import DataModuleTraining 15 | from lib.models.MicKey.model import MicKeyTrainingModel 16 | from lib.models.MicKey.modules.utils.training_utils import create_exp_name, create_result_dir 17 | import random 18 | import shutil 19 | 20 | def train_model(args): 21 | 22 | cfg.merge_from_file(args.dataset_config) 23 | cfg.merge_from_file(args.config) 24 | 25 | exp_name = create_exp_name(args.experiment, cfg) 26 | print('Start training of ' + exp_name) 27 | 28 | cfg.DATASET.SEED = random.randint(0, 1000000) 29 | 30 | model = MicKeyTrainingModel(cfg) 31 | 32 | checkpoint_vcre_callback = pl.callbacks.ModelCheckpoint( 33 | filename='{epoch}-best_vcre', 34 | save_last=True, 35 | save_top_k=1, 36 | verbose=True, 37 | monitor='val_vcre/auc_vcre', 38 | mode='max' 39 | ) 40 | 41 | checkpoint_pose_callback = pl.callbacks.ModelCheckpoint( 42 | filename='{epoch}-best_pose', 43 | save_last=True, 44 | save_top_k=1, 45 | verbose=True, 46 | monitor='val_AUC_pose/auc_pose', 47 | mode='max' 48 | ) 49 | 50 | epochend_callback = pl.callbacks.ModelCheckpoint( 51 | filename='e{epoch}-last', 52 | save_top_k=1, 53 | every_n_epochs=1, 54 | save_on_train_epoch_end=True 55 | ) 56 | 57 | lr_monitoring_callback = pl.callbacks.LearningRateMonitor(logging_interval='step') 58 | logger = TensorBoardLogger(save_dir=args.path_weights, name=exp_name) 59 | 60 | trainer = pl.Trainer(devices=cfg.TRAINING.NUM_GPUS, 61 | log_every_n_steps=cfg.TRAINING.LOG_INTERVAL, 62 | val_check_interval=cfg.TRAINING.VAL_INTERVAL, 63 | limit_val_batches=cfg.TRAINING.VAL_BATCHES, 64 | max_epochs=cfg.TRAINING.EPOCHS, 65 | logger=logger, 66 | callbacks=[checkpoint_pose_callback, lr_monitoring_callback, epochend_callback, checkpoint_vcre_callback], 67 | num_sanity_val_steps=0, 68 | gradient_clip_val=cfg.TRAINING.GRAD_CLIP) 69 | 70 | datamodule_end = DataModuleTraining(cfg) 71 | print('Training with {:.2f}/{:.2f} image overlap'.format(cfg.DATASET.MIN_OVERLAP_SCORE, cfg.DATASET.MAX_OVERLAP_SCORE)) 72 | 73 | create_result_dir(logger.log_dir + '/config.yaml') 74 | shutil.copyfile(args.config, logger.log_dir + '/config.yaml') 75 | 76 | if args.resume: 77 | ckpt_path = args.resume 78 | else: 79 | ckpt_path = None 80 | 81 | trainer.fit(model, datamodule_end, ckpt_path=ckpt_path) 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--config', help='path to config file', default='config/MicKey/curriculum_learning.yaml') 86 | parser.add_argument('--dataset_config', help='path to dataset config file', default='config/datasets/mapfree.yaml') 87 | parser.add_argument('--experiment', help='experiment name', default='MicKey_default') 88 | parser.add_argument('--path_weights', help='path to the directory to save the weights', default='weights/') 89 | parser.add_argument('--resume', help='resume from checkpoint path', default=None) 90 | args = parser.parse_args() 91 | train_model(args) --------------------------------------------------------------------------------