├── .gitignore ├── LICENSE ├── README.md ├── assets ├── 1min-oral.png ├── comparison-3dmatch.png ├── demo_inputs.png ├── demo_outputs.png ├── dgr.gif ├── frontier.png └── results.npz ├── config.py ├── core ├── __init__.py ├── correspondence.py ├── deep_global_registration.py ├── knn.py ├── loss.py ├── metrics.py ├── registration.py └── trainer.py ├── dataloader ├── base_loader.py ├── data_loaders.py ├── inf_sampler.py ├── kitti_loader.py ├── split │ ├── test_3dmatch.txt │ ├── test_kitti.txt │ ├── test_modelnet40.txt │ ├── test_scan2cad.txt │ ├── train_3dmatch.txt │ ├── train_kitti.txt │ ├── train_modelnet40.txt │ ├── train_scan2cad.txt │ ├── val_3dmatch.txt │ ├── val_kitti.txt │ ├── val_modelnet40.txt │ └── val_scan2cad.txt ├── threedmatch_loader.py └── transforms.py ├── demo.py ├── model ├── __init__.py ├── common.py ├── pyramidnet.py ├── residual_block.py ├── resunet.py └── simpleunet.py ├── requirements.txt ├── scripts ├── analyze_stats.py ├── download_3dmatch.sh ├── test_3dmatch.py ├── test_kitti.py ├── train_3dmatch.sh └── train_kitti.sh ├── train.py └── util ├── __init__.py ├── file.py ├── integration.py ├── pointcloud.py └── timer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | .DS_Store 3 | __pycache__ 4 | *.swp 5 | *.swo 6 | *.orig 7 | .idea 8 | outputs/ 9 | *.pyc 10 | *.npy 11 | *.pdf 12 | util/*.sh 13 | checkpoints 14 | # *.npz 15 | 3dmatch/ 16 | tmp.txt 17 | output.ply 18 | bunny.ply 19 | test.sh 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chris Choy (chrischoy@ai.stanford.edu), Wei Dong (weidong@andrew.cmu.edu) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | Please cite the following papers if you use any part of the code. 25 | 26 | ``` 27 | @inproceedings{choy2020deep, 28 | title={Deep Global Registration}, 29 | author={Choy, Christopher and Dong, Wei and Koltun, Vladlen}, 30 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 31 | year={2020} 32 | } 33 | 34 | @inproceedings{choy2019fully, 35 | author = {Choy, Christopher and Park, Jaesik and Koltun, Vladlen}, 36 | title = {Fully Convolutional Geometric Features}, 37 | booktitle = {ICCV}, 38 | year = {2019}, 39 | } 40 | 41 | @inproceedings{choy20194d, 42 | title={4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks}, 43 | author={Choy, Christopher and Gwak, JunYoung and Savarese, Silvio}, 44 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 45 | year={2019} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Global Registration 2 | 3 | ## Introduction 4 | This repository contains python scripts for training and testing [Deep Global Registration, CVPR 2020 Oral](https://node1.chrischoy.org/data/publications/dgr/DGR.pdf). 5 | Deep Global Registration (DGR) proposes a differentiable framework for pairwise registration of real-world 3D scans. DGR consists of the following three modules: 6 | 7 | - a 6-dimensional convolutional network for correspondence confidence prediction 8 | - a differentiable Weighted Procrustes algorithm for closed-form pose estimation 9 | - a robust gradient-based SE(3) optimizer for pose refinement. 10 | 11 | For more details, please check out 12 | 13 | - [CVPR 2020 oral paper](https://node1.chrischoy.org/data/publications/dgr/DGR.pdf) 14 | - [1min oral video](https://youtu.be/stzgn6DkozA) 15 | - [Full CVPR oral presentation](https://youtu.be/Iy17wvo07BU) 16 | 17 | [![1min oral](assets/1min-oral.png)](https://youtu.be/stzgn6DkozA) 18 | 19 | 20 | ## Quick Pipleine Visualization 21 | | Indoor 3DMatch Registration | Outdoor KITTI Lidar Registration | 22 | |:---------------------------:|:---------------------------:| 23 | | ![](https://chrischoy.github.io/images/publication/dgr/text_100.gif) | ![](https://chrischoy.github.io/images/publication/dgr/kitti1_optimized.gif) | 24 | 25 | ## Related Works 26 | Recent end-to-end frameworks combine feature learning and pose optimization. PointNetLK combines PointNet global features with an iterative pose optimization method. Wang et al. in Deep Closest Point train graph neural network features by backpropagating through pose optimization. 27 | We further advance this line of work. In particular, our Weighted Procrustes method reduces the complexity of optimization from quadratic to linear and enables the use of dense correspondences for highly accurate registration of real-world scans. 28 | 29 | ## Deep Global Registration 30 | The first component is a 6-dimensional convolutional network that analyzes the geometry of 3D correspondences and estimates their accuracy. Please refer to [High-dim ConvNets, CVPR'20](https://github.com/chrischoy/HighDimConvNets) for more details. 31 | 32 | The second component we develop is a differentiable Weighted Procrustes solver. The Procrustes method provides a closed-form solution for rigid registration in SE(3). A differentiable version of the Procrustes method used for end-to-end registration passes gradients through coordinates, which requires O(N^2) time and memory for N keypoints. Instead, the Weighted Procrustes method passes gradients through the weights associated with correspondences rather than correspondence coordinates. 33 | The computational complexity of the Weighted Procrustes method is linear to the number of correspondences, allowing the registration pipeline to use dense correspondence sets rather than sparse keypoints. This substantially increases registration accuracy. 34 | 35 | Our third component is a robust optimization module that fine-tunes the alignment produced by the Weighted Procrustes solver and the failure detection module. 36 | This optimization module minimizes a differentiable loss via gradient descent on the continuous SE(3) representation space. The optimization is fast since it does not require neighbor search in the inner loop such as ICP. 37 | 38 | ## Configuration 39 | Our network is built on the [MinkowskiEngine](https://github.com/StanfordVL/MinkowskiEngine) and the system requirements are: 40 | 41 | - Ubuntu 14.04 or higher 42 | - CUDA 10.1.243 or higher 43 | - pytorch 1.5 or higher 44 | - python 3.6 or higher 45 | - GCC 7 46 | 47 | You can install the MinkowskiEngine and the python requirements on your system with: 48 | 49 | ```shell 50 | # Install MinkowskiEngine 51 | sudo apt install libopenblas-dev g++-7 52 | pip install torch 53 | export CXX=g++-7; pip install -U MinkowskiEngine --install-option="--blas=openblas" -v 54 | 55 | # Download and setup DeepGlobalRegistration 56 | git clone https://github.com/chrischoy/DeepGlobalRegistration.git 57 | cd DeepGlobalRegistration 58 | pip install -r requirements.txt 59 | ``` 60 | 61 | ## Demo 62 | You may register your own data with relevant pretrained DGR models. 3DMatch is suitable for indoor RGB-D scans; KITTI is for outdoor LiDAR scans. 63 | 64 | | Inlier Model | FCGF model | Dataset | Voxel Size | Feature Dimension | Performance | Link | 65 | |:------------:|:-----------:|:-------:|:-------------:|:-----------------:|:--------------------------:|:------:| 66 | | ResUNetBN2C | ResUNetBN2C | 3DMatch | 5cm (0.05) | 32 | TE: 7.34cm, RE: 2.43deg | [weights](http://node2.chrischoy.org/data/projects/DGR/ResUNetBN2C-feat32-3dmatch-v0.05.pth) | 67 | | ResUNetBN2C | ResUNetBN2C | KITTI | 30cm (0.3) | 32 | TE: 3.14cm, RE: 0.14deg | [weights](http://node2.chrischoy.org/data/projects/DGR/ResUNetBN2C-feat32-kitti-v0.3.pth) | 68 | 69 | 70 | ```shell 71 | python demo.py 72 | ``` 73 | 74 | | Input PointClouds | Output Prediction | 75 | |:---------------------------:|:---------------------------:| 76 | | ![](assets/demo_inputs.png) | ![](assets/demo_outputs.png) | 77 | 78 | 79 | ## Experiments 80 | | Comparison | Speed vs. Recall Pareto Frontier | 81 | | ------- | --------------- | 82 | | ![Comparison](assets/comparison-3dmatch.png) | ![Frontier](assets/frontier.png) | 83 | 84 | 85 | ## Training 86 | The entire network depends on pretrained [FCGF models](https://github.com/chrischoy/FCGF#model-zoo). Please download corresponding models before training. 87 | | Model | Normalized Feature | Dataset | Voxel Size | Feature Dimension | Link | 88 | |:-----------:|:-------------------:|:-------:|:-------------:|:-----------------:|:------:| 89 | | ResUNetBN2C | True | 3DMatch | 5cm (0.05) | 32 | [download](https://node1.chrischoy.org/data/publications/fcgf/2019-08-16_19-21-47.pth) | 90 | | ResUNetBN2C | True | KITTI | 30cm (0.3) | 32 | [download](https://node1.chrischoy.org/data/publications/fcgf/KITTI-v0.3-ResUNetBN2C-conv1-5-nout32.pth) | 91 | 92 | 93 | ### 3DMatch 94 | You may download preprocessed data and train via these commands: 95 | ```shell 96 | ./scripts/download_3dmatch.sh /path/to/3dmatch 97 | export THREED_MATCH_DIR=/path/to/3dmatch; FCGF_WEIGHTS=/path/to/fcgf_3dmatch.pth ./scripts/train_3dmatch.sh 98 | ``` 99 | 100 | ### KITTI 101 | Follow the instruction on [KITTI Odometry website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) to download the KITTI odometry train set. Then train with 102 | ```shell 103 | export KITTI_PATH=/path/to/kitti; FCGF_WEIGHTS=/path/to/fcgf_kitti.pth ./scripts/train_kitti.sh 104 | ``` 105 | 106 | ## Testing 107 | 3DMatch test set is different from train set and is available at the [download section](http://3dmatch.cs.princeton.edu/) of the official website. You may download and decompress these scenes to a new folder. 108 | 109 | To evaluate trained model on 3DMatch or KITTI, you may use 110 | ```shell 111 | python -m scripts.test_3dmatch --threed_match_dir /path/to/3dmatch_test/ --weights /path/to/dgr_3dmatch.pth 112 | ``` 113 | and 114 | ```shell 115 | python -m scripts.test_kitti --kitti_dir /path/to/kitti/ --weights /path/to/dgr_kitti.pth 116 | ``` 117 | 118 | ## Generate figures 119 | We also provide experimental results of 3DMatch comparisons in `results.npz`. To reproduce figures we presented in the paper, you may use 120 | ```shell 121 | python scripts/analyze_stats.py assets/results.npz 122 | ``` 123 | 124 | ## Citing our work 125 | Please cite the following papers if you use our code: 126 | 127 | ```latex 128 | @inproceedings{choy2020deep, 129 | title={Deep Global Registration}, 130 | author={Choy, Christopher and Dong, Wei and Koltun, Vladlen}, 131 | booktitle={CVPR}, 132 | year={2020} 133 | } 134 | 135 | @inproceedings{choy2019fully, 136 | title = {Fully Convolutional Geometric Features}, 137 | author = {Choy, Christopher and Park, Jaesik and Koltun, Vladlen}, 138 | booktitle = {ICCV}, 139 | year = {2019} 140 | } 141 | 142 | @inproceedings{choy20194d, 143 | title={4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks}, 144 | author={Choy, Christopher and Gwak, JunYoung and Savarese, Silvio}, 145 | booktitle={CVPR}, 146 | year={2019} 147 | } 148 | ``` 149 | 150 | ## Concurrent Works 151 | 152 | There have a number of 3D registration works published concurrently. 153 | 154 | - Gojcic et al., [Learning Multiview 3D Point Cloud Registration, CVPR'20](https://github.com/zgojcic/3D_multiview_reg) 155 | - Wang et al., [PRNet: Self-Supervised Learning for Partial-to-Partial Registration, NeurIPS'19](https://github.com/WangYueFt/prnet) 156 | - Yang et al., [TEASER: Fast and Certifiable Point Cloud Registration, arXiv'20](https://github.com/MIT-SPARK/TEASER-plusplus) 157 | -------------------------------------------------------------------------------- /assets/1min-oral.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/1min-oral.png -------------------------------------------------------------------------------- /assets/comparison-3dmatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/comparison-3dmatch.png -------------------------------------------------------------------------------- /assets/demo_inputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/demo_inputs.png -------------------------------------------------------------------------------- /assets/demo_outputs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/demo_outputs.png -------------------------------------------------------------------------------- /assets/dgr.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/dgr.gif -------------------------------------------------------------------------------- /assets/frontier.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/frontier.png -------------------------------------------------------------------------------- /assets/results.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/DeepGlobalRegistration/1f4871d4c616fa2d2dc6d04a4506bd7d6e593fb4/assets/results.npz -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import argparse 8 | 9 | arg_lists = [] 10 | parser = argparse.ArgumentParser() 11 | 12 | 13 | def add_argument_group(name): 14 | arg = parser.add_argument_group(name) 15 | arg_lists.append(arg) 16 | return arg 17 | 18 | 19 | def str2bool(v): 20 | return v.lower() in ('true', '1') 21 | 22 | 23 | # yapf: disable 24 | logging_arg = add_argument_group('Logging') 25 | logging_arg.add_argument('--out_dir', type=str, default='outputs') 26 | 27 | trainer_arg = add_argument_group('Trainer') 28 | trainer_arg.add_argument('--trainer', type=str, default='WeightedProcrustesTrainer') 29 | 30 | # Batch setting 31 | trainer_arg.add_argument('--batch_size', type=int, default=4) 32 | trainer_arg.add_argument('--val_batch_size', type=int, default=1) 33 | 34 | # Data loader configs 35 | trainer_arg.add_argument('--train_phase', type=str, default="train") 36 | trainer_arg.add_argument('--val_phase', type=str, default="val") 37 | trainer_arg.add_argument('--test_phase', type=str, default="test") 38 | 39 | # Data augmentation 40 | trainer_arg.add_argument('--use_random_scale', type=str2bool, default=False) 41 | trainer_arg.add_argument('--min_scale', type=float, default=0.8) 42 | trainer_arg.add_argument('--max_scale', type=float, default=1.2) 43 | 44 | trainer_arg.add_argument('--use_random_rotation', type=str2bool, default=True) 45 | trainer_arg.add_argument('--rotation_range', type=float, default=360) 46 | trainer_arg.add_argument( 47 | '--positive_pair_search_voxel_size_multiplier', type=float, default=1.5) 48 | 49 | trainer_arg.add_argument('--save_epoch_freq', type=int, default=1) 50 | trainer_arg.add_argument('--val_epoch_freq', type=int, default=1) 51 | 52 | trainer_arg.add_argument('--stat_freq', type=int, default=40, help='Frequency for writing stats to log') 53 | trainer_arg.add_argument('--test_valid', type=str2bool, default=True) 54 | trainer_arg.add_argument('--val_max_iter', type=int, default=400) 55 | 56 | 57 | trainer_arg.add_argument('--use_balanced_loss', type=str2bool, default=False) 58 | trainer_arg.add_argument('--inlier_direct_loss_weight', type=float, default=1.) 59 | trainer_arg.add_argument('--procrustes_loss_weight', type=float, default=1.) 60 | trainer_arg.add_argument('--trans_weight', type=float, default=1) 61 | 62 | trainer_arg.add_argument('--eval_registration', type=str2bool, default=True) 63 | trainer_arg.add_argument('--clip_weight_thresh', type=float, default=0.05, help='Weight threshold for detecting inliers') 64 | trainer_arg.add_argument('--best_val_metric', type=str, default='succ_rate') 65 | 66 | # Inlier detection trainer 67 | inlier_arg = add_argument_group('Inlier') 68 | inlier_arg.add_argument('--inlier_model', type=str, default='ResUNetBN2C') 69 | inlier_arg.add_argument('--inlier_feature_type', type=str, default='ones') 70 | inlier_arg.add_argument('--inlier_conv1_kernel_size', type=int, default=3) 71 | inlier_arg.add_argument('--inlier_knn', type=int, default=1) 72 | inlier_arg.add_argument('--knn_search_method', type=str, default='gpu') 73 | inlier_arg.add_argument('--inlier_use_direct_loss', type=str2bool, default=True) 74 | 75 | # Feature specific configurations 76 | feat_arg = add_argument_group('feat') 77 | feat_arg.add_argument('--feat_model', type=str, default='SimpleNetBN2C') 78 | feat_arg.add_argument('--feat_model_n_out', type=int, default=16, help='Feature dimension') 79 | feat_arg.add_argument('--feat_conv1_kernel_size', type=int, default=3) 80 | feat_arg.add_argument('--normalize_feature', type=str2bool, default=True) 81 | feat_arg.add_argument('--use_xyz_feature', type=str2bool, default=False) 82 | feat_arg.add_argument('--dist_type', type=str, default='L2') 83 | 84 | # Optimizer arguments 85 | opt_arg = add_argument_group('Optimizer') 86 | opt_arg.add_argument('--optimizer', type=str, default='SGD') 87 | opt_arg.add_argument('--max_epoch', type=int, default=100) 88 | opt_arg.add_argument('--lr', type=float, default=1e-1) 89 | opt_arg.add_argument('--momentum', type=float, default=0.8) 90 | opt_arg.add_argument('--sgd_momentum', type=float, default=0.9) 91 | opt_arg.add_argument('--sgd_dampening', type=float, default=0.1) 92 | opt_arg.add_argument('--adam_beta1', type=float, default=0.9) 93 | opt_arg.add_argument('--adam_beta2', type=float, default=0.999) 94 | opt_arg.add_argument('--weight_decay', type=float, default=1e-4) 95 | opt_arg.add_argument('--iter_size', type=int, default=1, help='accumulate gradient') 96 | opt_arg.add_argument('--bn_momentum', type=float, default=0.05) 97 | opt_arg.add_argument('--exp_gamma', type=float, default=0.99) 98 | opt_arg.add_argument('--scheduler', type=str, default='ExpLR') 99 | opt_arg.add_argument('--num_train_iter', type=int, default=-1, help='train N iter if positive') 100 | opt_arg.add_argument('--icp_cache_path', type=str, default="icp") 101 | 102 | # Misc 103 | misc_arg = add_argument_group('Misc') 104 | misc_arg.add_argument('--use_gpu', type=str2bool, default=True) 105 | misc_arg.add_argument('--weights', type=str, default=None) 106 | misc_arg.add_argument('--weights_dir', type=str, default=None) 107 | misc_arg.add_argument('--resume', type=str, default=None) 108 | misc_arg.add_argument('--resume_dir', type=str, default=None) 109 | misc_arg.add_argument('--train_num_workers', type=int, default=2) 110 | misc_arg.add_argument('--val_num_workers', type=int, default=1) 111 | misc_arg.add_argument('--test_num_workers', type=int, default=2) 112 | misc_arg.add_argument('--fast_validation', type=str2bool, default=False) 113 | misc_arg.add_argument('--nn_max_n', type=int, default=250, help='The maximum number of features to find nearest neighbors in batch') 114 | 115 | # Dataset specific configurations 116 | data_arg = add_argument_group('Data') 117 | data_arg.add_argument('--dataset', type=str, default='ThreeDMatchPairDataset03') 118 | data_arg.add_argument('--voxel_size', type=float, default=0.025) 119 | data_arg.add_argument('--threed_match_dir', type=str, default='.') 120 | data_arg.add_argument('--kitti_dir', type=str, default=None, help="Path to the KITTI odometry dataset. This path should contain /dataset/sequences.") 121 | data_arg.add_argument('--kitti_max_time_diff', type=int, default=3, help='max time difference between pairs (non inclusive)') 122 | data_arg.add_argument('--kitti_date', type=str, default='2011_09_26') 123 | 124 | # Evaluation 125 | eval_arg = add_argument_group('Data') 126 | eval_arg.add_argument('--hit_ratio_thresh', type=float, default=0.1) 127 | eval_arg.add_argument('--success_rte_thresh', type=float, default=0.3, help='Success if the RTE below this (m)') 128 | eval_arg.add_argument('--success_rre_thresh', type=float, default=15, help='Success if the RTE below this (degree)') 129 | eval_arg.add_argument('--test_random_crop', action='store_true') 130 | eval_arg.add_argument('--test_random_rotation', type=str2bool, default=False) 131 | 132 | # Demo 133 | demo_arg = add_argument_group('demo') 134 | demo_arg.add_argument('--pcd0', default="redkitchen_000.ply", type=str) 135 | demo_arg.add_argument('--pcd1', default="redkitchen_010.ply", type=str) 136 | # yapf: enable 137 | 138 | 139 | def get_config(): 140 | args = parser.parse_args() 141 | return args 142 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | -------------------------------------------------------------------------------- /core/correspondence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import copy 8 | import numpy as np 9 | 10 | import open3d as o3d 11 | import torch 12 | 13 | 14 | def _hash(arr, M=None): 15 | if isinstance(arr, np.ndarray): 16 | N, D = arr.shape 17 | else: 18 | N, D = len(arr[0]), len(arr) 19 | 20 | hash_vec = np.zeros(N, dtype=np.int64) 21 | for d in range(D): 22 | if isinstance(arr, np.ndarray): 23 | hash_vec += arr[:, d] * M**d 24 | else: 25 | hash_vec += arr[d] * M**d 26 | return hash_vec 27 | 28 | 29 | def find_correct_correspondence(pos_pairs, pred_pairs, hash_seed=None, len_batch=None): 30 | assert len(pos_pairs) == len(pred_pairs) 31 | if hash_seed is None: 32 | assert len(len_batch) == len(pos_pairs) 33 | 34 | corrects = [] 35 | for i, pos_pred in enumerate(zip(pos_pairs, pred_pairs)): 36 | pos_pair, pred_pair = pos_pred 37 | if isinstance(pos_pair, torch.Tensor): 38 | pos_pair = pos_pair.numpy() 39 | if isinstance(pred_pair, torch.Tensor): 40 | pred_pair = pred_pair.numpy() 41 | 42 | if hash_seed is None: 43 | N0, N1 = len_batch[i] 44 | _hash_seed = max(N0, N1) 45 | else: 46 | _hash_seed = hash_seed 47 | 48 | pos_keys = _hash(pos_pair, _hash_seed) 49 | pred_keys = _hash(pred_pair, _hash_seed) 50 | 51 | corrects.append(np.isin(pred_keys, pos_keys, assume_unique=False)) 52 | 53 | return np.hstack(corrects) 54 | 55 | -------------------------------------------------------------------------------- /core/deep_global_registration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import os 8 | import sys 9 | import math 10 | import logging 11 | import open3d as o3d 12 | import numpy as np 13 | import time 14 | import torch 15 | import copy 16 | import MinkowskiEngine as ME 17 | 18 | sys.path.append('.') 19 | from model import load_model 20 | 21 | from core.registration import GlobalRegistration 22 | from core.knn import find_knn_gpu 23 | 24 | from util.timer import Timer 25 | from util.pointcloud import make_open3d_point_cloud 26 | 27 | 28 | # Feature-based registrations in Open3D 29 | def registration_ransac_based_on_feature_matching(pcd0, pcd1, feats0, feats1, 30 | distance_threshold, num_iterations): 31 | assert feats0.shape[1] == feats1.shape[1] 32 | 33 | source_feat = o3d.registration.Feature() 34 | source_feat.resize(feats0.shape[1], len(feats0)) 35 | source_feat.data = feats0.astype('d').transpose() 36 | 37 | target_feat = o3d.registration.Feature() 38 | target_feat.resize(feats1.shape[1], len(feats1)) 39 | target_feat.data = feats1.astype('d').transpose() 40 | 41 | result = o3d.registration.registration_ransac_based_on_feature_matching( 42 | pcd0, pcd1, source_feat, target_feat, distance_threshold, 43 | o3d.registration.TransformationEstimationPointToPoint(False), 4, 44 | [o3d.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)], 45 | o3d.registration.RANSACConvergenceCriteria(num_iterations, 1000)) 46 | 47 | return result.transformation 48 | 49 | 50 | def registration_ransac_based_on_correspondence(pcd0, pcd1, idx0, idx1, 51 | distance_threshold, num_iterations): 52 | corres = np.stack((idx0, idx1), axis=1) 53 | corres = o3d.utility.Vector2iVector(corres) 54 | 55 | result = o3d.pipelines.registration.registration_ransac_based_on_correspondence( 56 | source = pcd0, 57 | target = pcd1, 58 | corres = corres, 59 | max_correspondence_distance = distance_threshold, 60 | estimation_method = o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 61 | ransac_n = 4, 62 | criteria = o3d.pipelines.registration.RANSACConvergenceCriteria(4000000, num_iterations)) 63 | 64 | return result.transformation 65 | 66 | 67 | class DeepGlobalRegistration: 68 | def __init__(self, config, device=torch.device('cuda')): 69 | # Basic config 70 | self.config = config 71 | self.clip_weight_thresh = self.config.clip_weight_thresh 72 | self.device = device 73 | 74 | # Safeguard 75 | self.safeguard_method = 'correspondence' # correspondence, feature_matching 76 | 77 | # Final tuning 78 | self.use_icp = True 79 | 80 | # Misc 81 | self.feat_timer = Timer() 82 | self.reg_timer = Timer() 83 | 84 | # Model config loading 85 | print("=> loading checkpoint '{}'".format(config.weights)) 86 | assert os.path.exists(config.weights) 87 | 88 | state = torch.load(config.weights) 89 | network_config = state['config'] 90 | self.network_config = network_config 91 | self.config.inlier_feature_type = network_config.inlier_feature_type 92 | self.voxel_size = network_config.voxel_size 93 | print(f'=> Setting voxel size to {self.voxel_size}') 94 | 95 | # FCGF network initialization 96 | num_feats = 1 97 | try: 98 | FCGFModel = load_model(network_config['feat_model']) 99 | self.fcgf_model = FCGFModel( 100 | num_feats, 101 | network_config['feat_model_n_out'], 102 | bn_momentum=network_config['bn_momentum'], 103 | conv1_kernel_size=network_config['feat_conv1_kernel_size'], 104 | normalize_feature=network_config['normalize_feature']) 105 | 106 | except KeyError: # legacy pretrained models 107 | FCGFModel = load_model(network_config['model']) 108 | self.fcgf_model = FCGFModel(num_feats, 109 | network_config['model_n_out'], 110 | bn_momentum=network_config['bn_momentum'], 111 | conv1_kernel_size=network_config['conv1_kernel_size'], 112 | normalize_feature=network_config['normalize_feature']) 113 | 114 | self.fcgf_model.load_state_dict(state['state_dict']) 115 | self.fcgf_model = self.fcgf_model.to(device) 116 | self.fcgf_model.eval() 117 | 118 | # Inlier network initialization 119 | num_feats = 6 if network_config.inlier_feature_type == 'coords' else 1 120 | InlierModel = load_model(network_config['inlier_model']) 121 | self.inlier_model = InlierModel( 122 | num_feats, 123 | 1, 124 | bn_momentum=network_config['bn_momentum'], 125 | conv1_kernel_size=network_config['inlier_conv1_kernel_size'], 126 | normalize_feature=False, 127 | D=6) 128 | 129 | self.inlier_model.load_state_dict(state['state_dict_inlier']) 130 | self.inlier_model = self.inlier_model.to(self.device) 131 | self.inlier_model.eval() 132 | print("=> loading finished") 133 | 134 | def preprocess(self, pcd): 135 | ''' 136 | Stage 0: preprocess raw input point cloud 137 | Input: raw point cloud 138 | Output: voxelized point cloud with 139 | - xyz: unique point cloud with one point per voxel 140 | - coords: coords after voxelization 141 | - feats: dummy feature placeholder for general sparse convolution 142 | ''' 143 | if isinstance(pcd, o3d.geometry.PointCloud): 144 | xyz = np.array(pcd.points) 145 | elif isinstance(pcd, np.ndarray): 146 | xyz = pcd 147 | else: 148 | raise Exception('Unrecognized pcd type') 149 | 150 | # Voxelization: 151 | # Maintain double type for xyz to improve numerical accuracy in quantization 152 | _, sel = ME.utils.sparse_quantize(xyz / self.voxel_size, return_index=True) 153 | npts = len(sel) 154 | 155 | xyz = torch.from_numpy(xyz[sel]).to(self.device) 156 | 157 | # ME standard batch coordinates 158 | coords = ME.utils.batched_coordinates([torch.floor(xyz / self.voxel_size).int()], device=self.device) 159 | feats = torch.ones(npts, 1) 160 | 161 | return xyz.float(), coords, feats 162 | 163 | def fcgf_feature_extraction(self, feats, coords): 164 | ''' 165 | Step 1: extract fast and accurate FCGF feature per point 166 | ''' 167 | sinput = ME.SparseTensor(feats, coordinates=coords, device=self.device) 168 | 169 | return self.fcgf_model(sinput).F 170 | 171 | def fcgf_feature_matching(self, feats0, feats1): 172 | ''' 173 | Step 2: coarsely match FCGF features to generate initial correspondences 174 | ''' 175 | nns = find_knn_gpu(feats0, 176 | feats1, 177 | nn_max_n=self.network_config.nn_max_n, 178 | knn=1, 179 | return_distance=False) 180 | corres_idx0 = torch.arange(len(nns)).long().squeeze().to(self.device) 181 | corres_idx1 = nns.long().squeeze() 182 | 183 | return corres_idx0, corres_idx1 184 | 185 | def inlier_feature_generation(self, xyz0, xyz1, coords0, coords1, fcgf_feats0, 186 | fcgf_feats1, corres_idx0, corres_idx1): 187 | ''' 188 | Step 3: generate features for inlier prediction 189 | ''' 190 | assert len(corres_idx0) == len(corres_idx1) 191 | 192 | feat_type = self.config.inlier_feature_type 193 | assert feat_type in ['ones', 'feats', 'coords'] 194 | 195 | corres_idx0 = corres_idx0.to(self.device) 196 | corres_idx1 = corres_idx1.to(self.device) 197 | 198 | if feat_type == 'ones': 199 | feat = torch.ones((len(corres_idx0), 1)).float() 200 | elif feat_type == 'feats': 201 | feat = torch.cat((fcgf_feats0[corres_idx0], fcgf_feats1[corres_idx1]), dim=1) 202 | elif feat_type == 'coords': 203 | feat = torch.cat((torch.cos(xyz0[corres_idx0]), torch.cos(xyz1[corres_idx1])), 204 | dim=1) 205 | else: # should never reach here 206 | raise TypeError('Undefined feature type') 207 | 208 | return feat 209 | 210 | def inlier_prediction(self, inlier_feats, coords): 211 | ''' 212 | Step 4: predict inlier likelihood 213 | ''' 214 | sinput = ME.SparseTensor(inlier_feats, coordinates=coords, device=self.device) 215 | soutput = self.inlier_model(sinput) 216 | 217 | return soutput.F 218 | 219 | def safeguard_registration(self, pcd0, pcd1, idx0, idx1, feats0, feats1, 220 | distance_threshold, num_iterations): 221 | if self.safeguard_method == 'correspondence': 222 | T = registration_ransac_based_on_correspondence(pcd0, 223 | pcd1, 224 | idx0.cpu().numpy(), 225 | idx1.cpu().numpy(), 226 | distance_threshold, 227 | num_iterations=num_iterations) 228 | elif self.safeguard_method == 'fcgf_feature_matching': 229 | T = registration_ransac_based_on_fcgf_feature_matching(pcd0, pcd1, 230 | feats0.cpu().numpy(), 231 | feats1.cpu().numpy(), 232 | distance_threshold, 233 | num_iterations) 234 | else: 235 | raise ValueError('Undefined') 236 | return T 237 | 238 | def register(self, xyz0, xyz1, inlier_thr=0.00): 239 | ''' 240 | Main algorithm of DeepGlobalRegistration 241 | ''' 242 | self.reg_timer.tic() 243 | with torch.no_grad(): 244 | # Step 0: voxelize and generate sparse input 245 | xyz0, coords0, feats0 = self.preprocess(xyz0) 246 | xyz1, coords1, feats1 = self.preprocess(xyz1) 247 | 248 | # Step 1: Feature extraction 249 | self.feat_timer.tic() 250 | fcgf_feats0 = self.fcgf_feature_extraction(feats0, coords0) 251 | fcgf_feats1 = self.fcgf_feature_extraction(feats1, coords1) 252 | self.feat_timer.toc() 253 | 254 | # Step 2: Coarse correspondences 255 | corres_idx0, corres_idx1 = self.fcgf_feature_matching(fcgf_feats0, fcgf_feats1) 256 | 257 | # Step 3: Inlier feature generation 258 | # coords[corres_idx0]: 1D temporal + 3D spatial coord 259 | # coords[corres_idx1, 1:]: 3D spatial coord 260 | # => 1D temporal + 6D spatial coord 261 | inlier_coords = torch.cat((coords0[corres_idx0], coords1[corres_idx1, 1:]), 262 | dim=1).int() 263 | inlier_feats = self.inlier_feature_generation(xyz0, xyz1, coords0, coords1, 264 | fcgf_feats0, fcgf_feats1, 265 | corres_idx0, corres_idx1) 266 | 267 | # Step 4: Inlier likelihood estimation and truncation 268 | logit = self.inlier_prediction(inlier_feats.contiguous(), coords=inlier_coords) 269 | weights = logit.sigmoid() 270 | if self.clip_weight_thresh > 0: 271 | weights[weights < self.clip_weight_thresh] = 0 272 | wsum = weights.sum().item() 273 | 274 | # Step 5: Registration. Note: torch's gradient may be required at this stage 275 | # > Case 0: Weighted Procrustes + Robust Refinement 276 | wsum_threshold = max(200, len(weights) * 0.05) 277 | sign = '>=' if wsum >= wsum_threshold else '<' 278 | print(f'=> Weighted sum {wsum:.2f} {sign} threshold {wsum_threshold}') 279 | 280 | T = np.identity(4) 281 | if wsum >= wsum_threshold: 282 | try: 283 | rot, trans, opt_output = GlobalRegistration(xyz0[corres_idx0], 284 | xyz1[corres_idx1], 285 | weights=weights.detach(), 286 | break_threshold_ratio=1e-4, 287 | quantization_size=2 * 288 | self.voxel_size, 289 | verbose=False) 290 | T[0:3, 0:3] = rot.detach().cpu().numpy() 291 | T[0:3, 3] = trans.detach().cpu().numpy() 292 | dgr_time = self.reg_timer.toc() 293 | print(f'=> DGR takes {dgr_time:.2} s') 294 | 295 | except RuntimeError: 296 | # Will directly go to Safeguard 297 | print('###############################################') 298 | print('# WARNING: SVD failed, weights sum: ', wsum) 299 | print('# Falling back to Safeguard') 300 | print('###############################################') 301 | 302 | else: 303 | # > Case 1: Safeguard RANSAC + (Optional) ICP 304 | pcd0 = make_open3d_point_cloud(xyz0) 305 | pcd1 = make_open3d_point_cloud(xyz1) 306 | T = self.safeguard_registration(pcd0, 307 | pcd1, 308 | corres_idx0, 309 | corres_idx1, 310 | feats0, 311 | feats1, 312 | 2 * self.voxel_size, 313 | num_iterations=80000) 314 | safeguard_time = self.reg_timer.toc() 315 | print(f'=> Safeguard takes {safeguard_time:.2} s') 316 | 317 | if self.use_icp: 318 | T = o3d.pipelines.registration.registration_icp( 319 | source=make_open3d_point_cloud(xyz0), 320 | target=make_open3d_point_cloud(xyz1), 321 | max_correspondence_distance=self.voxel_size * 2, 322 | init=T).transformation 323 | 324 | return T 325 | -------------------------------------------------------------------------------- /core/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | import numpy as np 9 | from scipy.spatial import cKDTree 10 | 11 | from core.metrics import pdist 12 | 13 | 14 | def find_knn_cpu(feat0, feat1, knn=1, return_distance=False): 15 | feat1tree = cKDTree(feat1) 16 | dists, nn_inds = feat1tree.query(feat0, k=knn, n_jobs=-1) 17 | if return_distance: 18 | return nn_inds, dists 19 | else: 20 | return nn_inds 21 | 22 | 23 | def find_knn_gpu(F0, F1, nn_max_n=-1, knn=1, return_distance=False): 24 | 25 | def knn_dist(f0, f1, knn=1, dist_type='L2'): 26 | knn_dists, knn_inds = [], [] 27 | with torch.no_grad(): 28 | dist = pdist(f0, f1, dist_type=dist_type) 29 | min_dist, ind = dist.min(dim=1, keepdim=True) 30 | 31 | knn_dists.append(min_dist) 32 | knn_inds.append(ind) 33 | 34 | if knn > 1: 35 | for k in range(knn - 1): 36 | NR, NC = dist.shape 37 | flat_ind = (torch.arange(NR) * NC).type_as(ind) + ind.squeeze() 38 | dist.view(-1)[flat_ind] = np.inf 39 | min_dist, ind = dist.min(dim=1, keepdim=True) 40 | 41 | knn_dists.append(min_dist) 42 | knn_inds.append(ind) 43 | 44 | min_dist = torch.cat(knn_dists, 1) 45 | ind = torch.cat(knn_inds, 1) 46 | 47 | return min_dist, ind 48 | 49 | # Too much memory if F0 or F1 large. Divide the F0 50 | if nn_max_n > 1: 51 | N = len(F0) 52 | C = int(np.ceil(N / nn_max_n)) 53 | stride = nn_max_n 54 | dists, inds = [], [] 55 | 56 | for i in range(C): 57 | with torch.no_grad(): 58 | dist, ind = knn_dist(F0[i * stride:(i + 1) * stride], F1, knn=knn, dist_type='L2') 59 | dists.append(dist) 60 | inds.append(ind) 61 | 62 | dists = torch.cat(dists) 63 | inds = torch.cat(inds) 64 | assert len(inds) == N 65 | 66 | else: 67 | dist = pdist(F0, F1, dist_type='SquareL2') 68 | min_dist, inds = dist.min(dim=1) 69 | dists = min_dist.detach().unsqueeze(1) #.cpu() 70 | # inds = inds.cpu() 71 | if return_distance: 72 | return inds, dists 73 | else: 74 | return inds 75 | 76 | 77 | def find_knn_batch(F0, 78 | F1, 79 | len_batch, 80 | return_distance=False, 81 | nn_max_n=-1, 82 | knn=1, 83 | search_method=None, 84 | concat_results=False): 85 | if search_method is None or search_method == 'gpu': 86 | return find_knn_gpu_batch( 87 | F0, 88 | F1, 89 | len_batch=len_batch, 90 | nn_max_n=nn_max_n, 91 | knn=knn, 92 | return_distance=return_distance, 93 | concat_results=concat_results) 94 | elif search_method == 'cpu': 95 | return find_knn_cpu_batch( 96 | F0, 97 | F1, 98 | len_batch=len_batch, 99 | knn=knn, 100 | return_distance=return_distance, 101 | concat_results=concat_results) 102 | else: 103 | raise ValueError(f'Search method {search_method} not defined') 104 | 105 | 106 | def find_knn_gpu_batch(F0, 107 | F1, 108 | len_batch, 109 | nn_max_n=-1, 110 | knn=1, 111 | return_distance=False, 112 | concat_results=False): 113 | dists, nns = [], [] 114 | start0, start1 = 0, 0 115 | for N0, N1 in len_batch: 116 | nn = find_knn_gpu( 117 | F0[start0:start0 + N0], 118 | F1[start1:start1 + N1], 119 | nn_max_n=nn_max_n, 120 | knn=knn, 121 | return_distance=return_distance) 122 | if return_distance: 123 | nn, dist = nn 124 | dists.append(dist) 125 | if concat_results: 126 | nns.append(nn + start1) 127 | else: 128 | nns.append(nn) 129 | start0 += N0 130 | start1 += N1 131 | 132 | if concat_results: 133 | nns = torch.cat(nns) 134 | if return_distance: 135 | dists = torch.cat(dists) 136 | 137 | if return_distance: 138 | return nns, dists 139 | else: 140 | return nns 141 | 142 | 143 | def find_knn_cpu_batch(F0, 144 | F1, 145 | len_batch, 146 | knn=1, 147 | return_distance=False, 148 | concat_results=False): 149 | if not isinstance(F0, np.ndarray): 150 | F0 = F0.detach().cpu().numpy() 151 | F1 = F1.detach().cpu().numpy() 152 | 153 | dists, nns = [], [] 154 | start0, start1 = 0, 0 155 | for N0, N1 in len_batch: 156 | nn = find_knn_cpu( 157 | F0[start0:start0 + N0], F1[start1:start1 + N1], return_distance=return_distance) 158 | if return_distance: 159 | nn, dist = nn 160 | dists.append(dist) 161 | if concat_results: 162 | nns.append(nn + start1) 163 | else: 164 | nns.append(nn + start1) 165 | start0 += N0 166 | start1 += N1 167 | 168 | if concat_results: 169 | nns = np.hstack(nns) 170 | if return_distance: 171 | dists = np.hstack(dists) 172 | 173 | if return_distance: 174 | return torch.from_numpy(nns), torch.from_numpy(dists) 175 | else: 176 | return torch.from_numpy(nns) 177 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | import torch.nn as nn 9 | 10 | import numpy as np 11 | 12 | 13 | class UnbalancedLoss(nn.Module): 14 | NUM_LABELS = 2 15 | 16 | def __init__(self): 17 | super().__init__() 18 | self.crit = nn.BCEWithLogitsLoss() 19 | 20 | def forward(self, logits, label): 21 | return self.crit(logits, label.to(torch.float)) 22 | 23 | 24 | class BalancedLoss(nn.Module): 25 | NUM_LABELS = 2 26 | 27 | def __init__(self): 28 | super().__init__() 29 | self.crit = nn.BCEWithLogitsLoss() 30 | 31 | def forward(self, logits, label): 32 | assert torch.all(label < self.NUM_LABELS) 33 | loss = torch.scalar_tensor(0.).to(logits) 34 | for i in range(self.NUM_LABELS): 35 | target_mask = label == i 36 | if torch.any(target_mask): 37 | loss += self.crit(logits[target_mask], label[target_mask].to( 38 | torch.float)) / self.NUM_LABELS 39 | return loss 40 | 41 | 42 | class HighDimSmoothL1Loss: 43 | 44 | def __init__(self, weights, quantization_size=1, eps=np.finfo(np.float32).eps): 45 | self.eps = eps 46 | self.quantization_size = quantization_size 47 | self.weights = weights 48 | if self.weights is not None: 49 | self.w1 = weights.sum() 50 | 51 | def __call__(self, X, Y): 52 | sq_dist = torch.sum(((X - Y) / self.quantization_size)**2, axis=1, keepdim=True) 53 | use_sq_half = 0.5 * (sq_dist < 1).float() 54 | 55 | loss = (0.5 - use_sq_half) * (torch.sqrt(sq_dist + self.eps) - 56 | 0.5) + use_sq_half * sq_dist 57 | 58 | if self.weights is None: 59 | return loss.mean() 60 | else: 61 | return (loss * self.weights).sum() / self.w1 62 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | import torch.functional as F 9 | 10 | 11 | def rotation_mat2angle(R): 12 | return torch.acos(torch.clamp((torch.trace(R) - 1) / 2, -0.9999, 0.9999)) 13 | 14 | 15 | def rotation_error(R1, R2): 16 | assert R1.shape == R2.shape 17 | return torch.acos(torch.clamp((torch.trace(torch.mm(R1.t(), R2)) - 1) / 2, -0.9999, 0.9999)) 18 | 19 | 20 | def translation_error(t1, t2): 21 | assert t1.shape == t2.shape 22 | return torch.sqrt(((t1 - t2)**2).sum()) 23 | 24 | 25 | def batch_rotation_error(rots1, rots2): 26 | r""" 27 | arccos( (tr(R_1^T R_2) - 1) / 2 ) 28 | rots1: B x 3 x 3 or B x 9 29 | rots1: B x 3 x 3 or B x 9 30 | """ 31 | assert len(rots1) == len(rots2) 32 | trace_r1Tr2 = (rots1.reshape(-1, 9) * rots2.reshape(-1, 9)).sum(1) 33 | side = (trace_r1Tr2 - 1) / 2 34 | return torch.acos(torch.clamp(side, min=-0.999, max=0.999)) 35 | 36 | 37 | def batch_translation_error(trans1, trans2): 38 | r""" 39 | trans1: B x 3 40 | trans2: B x 3 41 | """ 42 | assert len(trans1) == len(trans2) 43 | return torch.norm(trans1 - trans2, p=2, dim=1, keepdim=False) 44 | 45 | 46 | 47 | def eval_metrics(output, target): 48 | output = (F.sigmoid(output) > 0.5) 49 | target = target 50 | return torch.norm(output - target) 51 | 52 | 53 | def corr_dist(est, gth, xyz0, xyz1, weight=None, max_dist=1): 54 | xyz0_est = xyz0 @ est[:3, :3].t() + est[:3, 3] 55 | xyz0_gth = xyz0 @ gth[:3, :3].t() + gth[:3, 3] 56 | dists = torch.clamp(torch.sqrt(((xyz0_est - xyz0_gth).pow(2)).sum(1)), max=max_dist) 57 | if weight is not None: 58 | dists = weight * dists 59 | return dists.mean() 60 | 61 | 62 | def pdist(A, B, dist_type='L2'): 63 | if dist_type == 'L2': 64 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 65 | return torch.sqrt(D2 + 1e-7) 66 | elif dist_type == 'SquareL2': 67 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 68 | else: 69 | raise NotImplementedError('Not implemented') 70 | 71 | 72 | def get_loss_fn(loss): 73 | if loss == 'corr_dist': 74 | return corr_dist 75 | else: 76 | raise ValueError(f'Loss {loss}, not defined') 77 | -------------------------------------------------------------------------------- /core/registration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import numpy as np 8 | 9 | import torch 10 | import torch.optim as optim 11 | 12 | from core.knn import pdist 13 | from core.loss import HighDimSmoothL1Loss 14 | 15 | 16 | def ortho2rotation(poses): 17 | r""" 18 | poses: batch x 6 19 | """ 20 | def normalize_vector(v): 21 | r""" 22 | Batch x 3 23 | """ 24 | v_mag = torch.sqrt((v**2).sum(1, keepdim=True)) 25 | v_mag = torch.clamp(v_mag, min=1e-8) 26 | v = v / v_mag 27 | return v 28 | 29 | def cross_product(u, v): 30 | r""" 31 | u: batch x 3 32 | v: batch x 3 33 | """ 34 | i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] 35 | j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] 36 | k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] 37 | 38 | i = i[:, None] 39 | j = j[:, None] 40 | k = k[:, None] 41 | return torch.cat((i, j, k), 1) 42 | 43 | def proj_u2a(u, a): 44 | r""" 45 | u: batch x 3 46 | a: batch x 3 47 | """ 48 | inner_prod = (u * a).sum(1, keepdim=True) 49 | norm2 = (u**2).sum(1, keepdim=True) 50 | norm2 = torch.clamp(norm2, min=1e-8) 51 | factor = inner_prod / norm2 52 | return factor * u 53 | 54 | x_raw = poses[:, 0:3] 55 | y_raw = poses[:, 3:6] 56 | 57 | x = normalize_vector(x_raw) 58 | y = normalize_vector(y_raw - proj_u2a(x, y_raw)) 59 | z = cross_product(x, y) 60 | 61 | x = x[:, :, None] 62 | y = y[:, :, None] 63 | z = z[:, :, None] 64 | return torch.cat((x, y, z), 2) 65 | 66 | 67 | def argmin_se3_squared_dist(X, Y): 68 | """ 69 | X: torch tensor N x 3 70 | Y: torch tensor N x 3 71 | """ 72 | # https://ieeexplore.ieee.org/document/88573 73 | assert len(X) == len(Y) 74 | mux = X.mean(0, keepdim=True) 75 | muy = Y.mean(0, keepdim=True) 76 | 77 | Sxy = (Y - muy).t().mm(X - mux) / len(X) 78 | U, D, V = Sxy.svd() 79 | # svd = gesvd.GESVDFunction() 80 | # U, S, V = svd.apply(Sxy) 81 | # S[-1, -1] = U.det() * V.det() 82 | S = torch.eye(3) 83 | if U.det() * V.det() < 0: 84 | S[-1, -1] = -1 85 | 86 | R = U.mm(S.mm(V.t())) 87 | t = muy.squeeze() - R.mm(mux.t()).squeeze() 88 | return R, t 89 | 90 | 91 | def weighted_procrustes(X, Y, w, eps): 92 | """ 93 | X: torch tensor N x 3 94 | Y: torch tensor N x 3 95 | w: torch tensor N 96 | """ 97 | # https://ieeexplore.ieee.org/document/88573 98 | assert len(X) == len(Y) 99 | W1 = torch.abs(w).sum() 100 | w_norm = w / (W1 + eps) 101 | mux = (w_norm * X).sum(0, keepdim=True) 102 | muy = (w_norm * Y).sum(0, keepdim=True) 103 | 104 | # Use CPU for small arrays 105 | Sxy = (Y - muy).t().mm(w_norm * (X - mux)).cpu().double() 106 | U, D, V = Sxy.svd() 107 | S = torch.eye(3).double() 108 | if U.det() * V.det() < 0: 109 | S[-1, -1] = -1 110 | 111 | R = U.mm(S.mm(V.t())).float() 112 | t = (muy.cpu().squeeze() - R.mm(mux.cpu().t()).squeeze()).float() 113 | return R, t 114 | 115 | 116 | class Transformation(torch.nn.Module): 117 | def __init__(self, R_init=None, t_init=None): 118 | torch.nn.Module.__init__(self) 119 | rot_init = torch.rand(1, 6) 120 | trans_init = torch.zeros(1, 3) 121 | if R_init is not None: 122 | rot_init[0, :3] = R_init[:, 0] 123 | rot_init[0, 3:] = R_init[:, 1] 124 | if t_init is not None: 125 | trans_init[0] = t_init 126 | 127 | self.rot6d = torch.nn.Parameter(rot_init) 128 | self.trans = torch.nn.Parameter(trans_init) 129 | 130 | def forward(self, points): 131 | rot_mat = ortho2rotation(self.rot6d) 132 | return points @ rot_mat[0].t() + self.trans 133 | 134 | 135 | def GlobalRegistration(points, 136 | trans_points, 137 | weights=None, 138 | max_iter=1000, 139 | verbose=False, 140 | stat_freq=20, 141 | max_break_count=20, 142 | break_threshold_ratio=1e-5, 143 | loss_fn=None, 144 | quantization_size=1): 145 | if isinstance(points, np.ndarray): 146 | points = torch.from_numpy(points).float() 147 | 148 | if isinstance(trans_points, np.ndarray): 149 | trans_points = torch.from_numpy(trans_points).float() 150 | 151 | if loss_fn is None: 152 | if weights is not None: 153 | weights.requires_grad = False 154 | loss_fn = HighDimSmoothL1Loss(weights, quantization_size) 155 | 156 | if weights is None: 157 | # Get the initialization using https://ieeexplore.ieee.org/document/88573 158 | R, t = argmin_se3_squared_dist(points, trans_points) 159 | else: 160 | R, t = weighted_procrustes(points, trans_points, weights, loss_fn.eps) 161 | transformation = Transformation(R, t).to(points.device) 162 | 163 | optimizer = optim.Adam(transformation.parameters(), lr=1e-1) 164 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.999) 165 | loss_prev = loss_fn(transformation(points), trans_points).item() 166 | break_counter = 0 167 | 168 | # Transform points 169 | for i in range(max_iter): 170 | new_points = transformation(points) 171 | loss = loss_fn(new_points, trans_points) 172 | if loss.item() < 1e-7: 173 | break 174 | 175 | optimizer.zero_grad() 176 | loss.backward() 177 | optimizer.step() 178 | scheduler.step() 179 | if i % stat_freq == 0 and verbose: 180 | print(i, scheduler.get_lr(), loss.item()) 181 | 182 | if abs(loss_prev - loss.item()) < loss_prev * break_threshold_ratio: 183 | break_counter += 1 184 | if break_counter >= max_break_count: 185 | break 186 | 187 | loss_prev = loss.item() 188 | 189 | rot6d = transformation.rot6d.detach() 190 | trans = transformation.trans.detach() 191 | 192 | opt_result = {'iterations': i, 'loss': loss.item(), 'break_count': break_counter} 193 | 194 | return ortho2rotation(rot6d)[0], trans, opt_result 195 | -------------------------------------------------------------------------------- /core/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | # Written by Chris Choy 8 | # Distributed under MIT License 9 | import time 10 | import os 11 | import os.path as osp 12 | import gc 13 | import logging 14 | import numpy as np 15 | import json 16 | 17 | import torch 18 | import torch.optim as optim 19 | import torch.nn.functional as F 20 | from tensorboardX import SummaryWriter 21 | 22 | from model import load_model 23 | from core.knn import find_knn_batch 24 | from core.correspondence import find_correct_correspondence 25 | from core.loss import UnbalancedLoss, BalancedLoss 26 | from core.metrics import batch_rotation_error, batch_translation_error 27 | import core.registration as GlobalRegistration 28 | 29 | from util.timer import Timer, AverageMeter 30 | from util.file import ensure_dir 31 | 32 | import MinkowskiEngine as ME 33 | 34 | eps = np.finfo(float).eps 35 | np2th = torch.from_numpy 36 | 37 | 38 | class WeightedProcrustesTrainer: 39 | def __init__(self, config, data_loader, val_data_loader=None): 40 | # occupancy only for 3D Match dataset. For ScanNet, use RGB 3 channels. 41 | num_feats = 3 if config.use_xyz_feature else 1 42 | 43 | # Feature model initialization 44 | if config.use_gpu and not torch.cuda.is_available(): 45 | logging.warning('Warning: There\'s no CUDA support on this machine, ' 46 | 'training is performed on CPU.') 47 | raise ValueError('GPU not available, but cuda flag set') 48 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 49 | 50 | self.config = config 51 | 52 | # Training config 53 | self.max_epoch = config.max_epoch 54 | self.start_epoch = 1 55 | self.checkpoint_dir = config.out_dir 56 | 57 | self.data_loader = data_loader 58 | self.train_data_loader_iter = self.data_loader.__iter__() 59 | 60 | self.iter_size = config.iter_size 61 | self.batch_size = data_loader.batch_size 62 | 63 | # Validation config 64 | self.val_max_iter = config.val_max_iter 65 | self.val_epoch_freq = config.val_epoch_freq 66 | self.best_val_metric = config.best_val_metric 67 | self.best_val_epoch = -np.inf 68 | self.best_val = -np.inf 69 | 70 | self.val_data_loader = val_data_loader 71 | self.test_valid = True if self.val_data_loader is not None else False 72 | 73 | # Logging 74 | self.log_step = int(np.sqrt(self.config.batch_size)) 75 | self.writer = SummaryWriter(config.out_dir) 76 | 77 | # Model 78 | FeatModel = load_model(config.feat_model) 79 | InlierModel = load_model(config.inlier_model) 80 | 81 | num_feats = 6 if self.config.inlier_feature_type == 'coords' else 1 82 | self.feat_model = FeatModel(num_feats, 83 | config.feat_model_n_out, 84 | bn_momentum=config.bn_momentum, 85 | conv1_kernel_size=config.feat_conv1_kernel_size, 86 | normalize_feature=config.normalize_feature).to( 87 | self.device) 88 | logging.info(self.feat_model) 89 | 90 | self.inlier_model = InlierModel(num_feats, 91 | 1, 92 | bn_momentum=config.bn_momentum, 93 | conv1_kernel_size=config.inlier_conv1_kernel_size, 94 | normalize_feature=False, 95 | D=6).to(self.device) 96 | logging.info(self.inlier_model) 97 | 98 | # Loss and optimizer 99 | self.clip_weight_thresh = self.config.clip_weight_thresh 100 | if self.config.use_balanced_loss: 101 | self.crit = BalancedLoss() 102 | else: 103 | self.crit = UnbalancedLoss() 104 | 105 | self.optimizer = getattr(optim, config.optimizer)(self.inlier_model.parameters(), 106 | lr=config.lr, 107 | momentum=config.momentum, 108 | weight_decay=config.weight_decay) 109 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.exp_gamma) 110 | 111 | # Output preparation 112 | ensure_dir(self.checkpoint_dir) 113 | json.dump(config, 114 | open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'), 115 | indent=4, 116 | sort_keys=False) 117 | 118 | self._load_weights(config) 119 | 120 | def train(self): 121 | """ 122 | Major interface 123 | Full training logic: train, valid, and save 124 | """ 125 | # Baseline random feature performance 126 | if self.test_valid: 127 | val_dict = self._valid_epoch() 128 | for k, v in val_dict.items(): 129 | self.writer.add_scalar(f'val/{k}', v, 0) 130 | 131 | # Train and valid 132 | for epoch in range(self.start_epoch, self.max_epoch + 1): 133 | lr = self.scheduler.get_lr() 134 | logging.info(f" Epoch: {epoch}, LR: {lr}") 135 | self._train_epoch(epoch) 136 | self._save_checkpoint(epoch) 137 | self.scheduler.step() 138 | 139 | if self.test_valid and epoch % self.val_epoch_freq == 0: 140 | val_dict = self._valid_epoch() 141 | for k, v in val_dict.items(): 142 | self.writer.add_scalar(f'val/{k}', v, epoch) 143 | 144 | if self.best_val < val_dict[self.best_val_metric]: 145 | logging.info( 146 | f'Saving the best val model with {self.best_val_metric}: {val_dict[self.best_val_metric]}' 147 | ) 148 | self.best_val = val_dict[self.best_val_metric] 149 | self.best_val_epoch = epoch 150 | self._save_checkpoint(epoch, 'best_val_checkpoint') 151 | 152 | else: 153 | logging.info( 154 | f'Current best val model with {self.best_val_metric}: {self.best_val} at epoch {self.best_val_epoch}' 155 | ) 156 | 157 | def _train_epoch(self, epoch): 158 | gc.collect() 159 | 160 | # Fix the feature model and train the inlier model 161 | self.feat_model.eval() 162 | self.inlier_model.train() 163 | 164 | # Epoch starts from 1 165 | total_loss, total_num = 0, 0.0 166 | data_loader = self.data_loader 167 | iter_size = self.iter_size 168 | 169 | # Meters for statistics 170 | average_valid_meter = AverageMeter() 171 | loss_meter = AverageMeter() 172 | data_meter = AverageMeter() 173 | regist_succ_meter = AverageMeter() 174 | regist_rte_meter = AverageMeter() 175 | regist_rre_meter = AverageMeter() 176 | 177 | # Timers for profiling 178 | data_timer = Timer() 179 | nn_timer = Timer() 180 | inlier_timer = Timer() 181 | total_timer = Timer() 182 | 183 | if self.config.num_train_iter > 0: 184 | num_train_iter = self.config.num_train_iter 185 | else: 186 | num_train_iter = len(data_loader) // iter_size 187 | start_iter = (epoch - 1) * num_train_iter 188 | 189 | tp, fp, tn, fn = 0, 0, 0, 0 190 | 191 | # Iterate over batches 192 | for curr_iter in range(num_train_iter): 193 | self.optimizer.zero_grad() 194 | 195 | batch_loss, data_time = 0, 0 196 | total_timer.tic() 197 | 198 | for iter_idx in range(iter_size): 199 | data_timer.tic() 200 | input_dict = self.get_data(self.train_data_loader_iter) 201 | data_time += data_timer.toc(average=False) 202 | 203 | # Initial inlier prediction with FCGF and KNN matching 204 | reg_coords, reg_feats, pred_pairs, is_correct, feat_time, nn_time = self.generate_inlier_input( 205 | xyz0=input_dict['pcd0'], 206 | xyz1=input_dict['pcd1'], 207 | iC0=input_dict['sinput0_C'], 208 | iC1=input_dict['sinput1_C'], 209 | iF0=input_dict['sinput0_F'], 210 | iF1=input_dict['sinput1_F'], 211 | len_batch=input_dict['len_batch'], 212 | pos_pairs=input_dict['correspondences']) 213 | nn_timer.update(nn_time) 214 | 215 | # Inlier prediction with 6D ConvNet 216 | inlier_timer.tic() 217 | reg_sinput = ME.SparseTensor(reg_feats.contiguous(), 218 | coords=reg_coords.int()).to(self.device) 219 | reg_soutput = self.inlier_model(reg_sinput) 220 | inlier_timer.toc() 221 | 222 | logits = reg_soutput.F 223 | weights = logits.sigmoid() 224 | 225 | # Truncate weights too low 226 | # For training, inplace modification is prohibited for backward 227 | if self.clip_weight_thresh > 0: 228 | weights_tmp = torch.zeros_like(weights) 229 | valid_mask = weights > self.clip_weight_thresh 230 | weights_tmp[valid_mask] = weights[valid_mask] 231 | weights = weights_tmp 232 | 233 | # Weighted Procrustes 234 | pred_rots, pred_trans, ws = self.weighted_procrustes(xyz0s=input_dict['pcd0'], 235 | xyz1s=input_dict['pcd1'], 236 | pred_pairs=pred_pairs, 237 | weights=weights) 238 | 239 | # Get batch registration loss 240 | gt_rots, gt_trans = self.decompose_rotation_translation(input_dict['T_gt']) 241 | rot_error = batch_rotation_error(pred_rots, gt_rots) 242 | trans_error = batch_translation_error(pred_trans, gt_trans) 243 | individual_loss = rot_error + self.config.trans_weight * trans_error 244 | 245 | # Select batches with at least 10 valid correspondences 246 | valid_mask = ws > 10 247 | num_valid = valid_mask.sum().item() 248 | average_valid_meter.update(num_valid) 249 | 250 | # Registration loss against registration GT 251 | loss = self.config.procrustes_loss_weight * individual_loss[valid_mask].mean() 252 | if not np.isfinite(loss.item()): 253 | max_val = loss.item() 254 | logging.info('Loss is infinite, abort ') 255 | continue 256 | 257 | # Direct inlier loss against nearest neighbor searched GT 258 | target = torch.from_numpy(is_correct).squeeze() 259 | if self.config.inlier_use_direct_loss: 260 | inlier_loss = self.config.inlier_direct_loss_weight * self.crit( 261 | logits.cpu().squeeze(), target.to(torch.float)) / iter_size 262 | loss += inlier_loss 263 | 264 | loss.backward() 265 | 266 | # Update statistics before backprop 267 | with torch.no_grad(): 268 | regist_rre_meter.update(rot_error.squeeze() * 180 / np.pi) 269 | regist_rte_meter.update(trans_error.squeeze()) 270 | 271 | success = (trans_error.squeeze() < self.config.success_rte_thresh) * ( 272 | rot_error.squeeze() * 180 / np.pi < self.config.success_rre_thresh) 273 | regist_succ_meter.update(success.float()) 274 | 275 | batch_loss += loss.mean().item() 276 | 277 | neg_target = (~target).to(torch.bool) 278 | pred = logits > 0 # todo thresh 279 | pred_on_pos, pred_on_neg = pred[target], pred[neg_target] 280 | tp += pred_on_pos.sum().item() 281 | fp += pred_on_neg.sum().item() 282 | tn += (~pred_on_neg).sum().item() 283 | fn += (~pred_on_pos).sum().item() 284 | 285 | # Check gradient and avoid backprop of inf values 286 | max_grad = torch.abs(self.inlier_model.final.kernel.grad).max().cpu().item() 287 | 288 | # Backprop only if gradient is finite 289 | if not np.isfinite(max_grad): 290 | self.optimizer.zero_grad() 291 | logging.info(f'Clearing the NaN gradient at iter {curr_iter}') 292 | else: 293 | self.optimizer.step() 294 | 295 | gc.collect() 296 | 297 | torch.cuda.empty_cache() 298 | 299 | total_loss += batch_loss 300 | total_num += 1.0 301 | total_timer.toc() 302 | data_meter.update(data_time) 303 | loss_meter.update(batch_loss) 304 | 305 | # Output to logs 306 | if curr_iter % self.config.stat_freq == 0: 307 | precision = tp / (tp + fp + eps) 308 | recall = tp / (tp + fn + eps) 309 | f1 = 2 * (precision * recall) / (precision + recall + eps) 310 | tpr = tp / (tp + fn + eps) 311 | tnr = tn / (tn + fp + eps) 312 | balanced_accuracy = (tpr + tnr) / 2 313 | 314 | correspondence_accuracy = is_correct.sum() / len(is_correct) 315 | 316 | stat = { 317 | 'loss': loss_meter.avg, 318 | 'precision': precision, 319 | 'recall': recall, 320 | 'tpr': tpr, 321 | 'tnr': tnr, 322 | 'balanced_accuracy': balanced_accuracy, 323 | 'f1': f1, 324 | 'num_valid': average_valid_meter.avg, 325 | } 326 | 327 | for k, v in stat.items(): 328 | self.writer.add_scalar(f'train/{k}', v, start_iter + curr_iter) 329 | 330 | logging.info(' '.join([ 331 | f"Train Epoch: {epoch} [{curr_iter}/{num_train_iter}],", 332 | f"Current Loss: {loss_meter.avg:.3e},", 333 | f"Correspondence acc: {correspondence_accuracy:.3e}", 334 | f", Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f},", 335 | f"TPR: {tpr:.4f}, TNR: {tnr:.4f}, BAcc: {balanced_accuracy:.4f}", 336 | f"RTE: {regist_rte_meter.avg:.3e}, RRE: {regist_rre_meter.avg:.3e},", 337 | f"Succ rate: {regist_succ_meter.avg:3e}", 338 | f"Avg num valid: {average_valid_meter.avg:3e}", 339 | f"\tData time: {data_meter.avg:.4f}, Train time: {total_timer.avg - data_meter.avg:.4f},", 340 | f"NN search time: {nn_timer.avg:.3e}, Total time: {total_timer.avg:.4f}" 341 | ])) 342 | 343 | loss_meter.reset() 344 | regist_rte_meter.reset() 345 | regist_rre_meter.reset() 346 | regist_succ_meter.reset() 347 | average_valid_meter.reset() 348 | data_meter.reset() 349 | total_timer.reset() 350 | 351 | tp, fp, tn, fn = 0, 0, 0, 0 352 | 353 | def _valid_epoch(self): 354 | # Change the network to evaluation mode 355 | self.feat_model.eval() 356 | self.inlier_model.eval() 357 | self.val_data_loader.dataset.reset_seed(0) 358 | 359 | num_data = 0 360 | loss_meter = AverageMeter() 361 | hit_ratio_meter = AverageMeter() 362 | regist_succ_meter = AverageMeter() 363 | regist_rte_meter = AverageMeter() 364 | regist_rre_meter = AverageMeter() 365 | data_timer = Timer() 366 | feat_timer = Timer() 367 | inlier_timer = Timer() 368 | nn_timer = Timer() 369 | dgr_timer = Timer() 370 | 371 | tot_num_data = len(self.val_data_loader.dataset) 372 | if self.val_max_iter > 0: 373 | tot_num_data = min(self.val_max_iter, tot_num_data) 374 | tot_num_data = int(tot_num_data / self.val_data_loader.batch_size) 375 | data_loader_iter = self.val_data_loader.__iter__() 376 | 377 | tp, fp, tn, fn = 0, 0, 0, 0 378 | for batch_idx in range(tot_num_data): 379 | data_timer.tic() 380 | input_dict = self.get_data(data_loader_iter) 381 | data_timer.toc() 382 | 383 | reg_coords, reg_feats, pred_pairs, is_correct, feat_time, nn_time = self.generate_inlier_input( 384 | xyz0=input_dict['pcd0'], 385 | xyz1=input_dict['pcd1'], 386 | iC0=input_dict['sinput0_C'], 387 | iC1=input_dict['sinput1_C'], 388 | iF0=input_dict['sinput0_F'], 389 | iF1=input_dict['sinput1_F'], 390 | len_batch=input_dict['len_batch'], 391 | pos_pairs=input_dict['correspondences']) 392 | feat_timer.update(feat_time) 393 | nn_timer.update(nn_time) 394 | 395 | hit_ratio_meter.update(is_correct.sum().item() / len(is_correct)) 396 | 397 | inlier_timer.tic() 398 | reg_sinput = ME.SparseTensor(reg_feats.contiguous(), 399 | coords=reg_coords.int()).to(self.device) 400 | reg_soutput = self.inlier_model(reg_sinput) 401 | inlier_timer.toc() 402 | 403 | dgr_timer.tic() 404 | logits = reg_soutput.F 405 | weights = logits.sigmoid() 406 | 407 | if self.clip_weight_thresh > 0: 408 | weights[weights < self.clip_weight_thresh] = 0 409 | 410 | # Weighted Procrustes 411 | pred_rots, pred_trans, ws = self.weighted_procrustes(xyz0s=input_dict['pcd0'], 412 | xyz1s=input_dict['pcd1'], 413 | pred_pairs=pred_pairs, 414 | weights=weights) 415 | dgr_timer.toc() 416 | 417 | valid_mask = ws > 10 418 | gt_rots, gt_trans = self.decompose_rotation_translation(input_dict['T_gt']) 419 | rot_error = batch_rotation_error(pred_rots, gt_rots) * 180 / np.pi 420 | trans_error = batch_translation_error(pred_trans, gt_trans) 421 | 422 | regist_rre_meter.update(rot_error.squeeze()) 423 | regist_rte_meter.update(trans_error.squeeze()) 424 | 425 | # Compute success 426 | success = (trans_error < self.config.success_rte_thresh) * ( 427 | rot_error < self.config.success_rre_thresh) * valid_mask 428 | regist_succ_meter.update(success.float()) 429 | 430 | target = torch.from_numpy(is_correct).squeeze() 431 | neg_target = (~target).to(torch.bool) 432 | pred = weights > 0.5 # TODO thresh 433 | pred_on_pos, pred_on_neg = pred[target], pred[neg_target] 434 | tp += pred_on_pos.sum().item() 435 | fp += pred_on_neg.sum().item() 436 | tn += (~pred_on_neg).sum().item() 437 | fn += (~pred_on_pos).sum().item() 438 | 439 | num_data += 1 440 | torch.cuda.empty_cache() 441 | 442 | if batch_idx % self.config.stat_freq == 0: 443 | precision = tp / (tp + fp + eps) 444 | recall = tp / (tp + fn + eps) 445 | f1 = 2 * (precision * recall) / (precision + recall + eps) 446 | tpr = tp / (tp + fn + eps) 447 | tnr = tn / (tn + fp + eps) 448 | balanced_accuracy = (tpr + tnr) / 2 449 | logging.info(' '.join([ 450 | f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3e},", 451 | f"NN search time: {nn_timer.avg:.3e}", 452 | f"Feature Extraction Time: {feat_timer.avg:.3e}, Inlier Time: {inlier_timer.avg:.3e},", 453 | f"Loss: {loss_meter.avg:.4f}, Hit Ratio: {hit_ratio_meter.avg:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, ", 454 | f"TPR: {tpr:.4f}, TNR: {tnr:.4f}, BAcc: {balanced_accuracy:.4f}, ", 455 | f"DGR RTE: {regist_rte_meter.avg:.3e}, DGR RRE: {regist_rre_meter.avg:.3e}, DGR Time: {dgr_timer.avg:.3e}", 456 | f"DGR Succ rate: {regist_succ_meter.avg:3e}", 457 | ])) 458 | data_timer.reset() 459 | 460 | precision = tp / (tp + fp + eps) 461 | recall = tp / (tp + fn + eps) 462 | f1 = 2 * (precision * recall) / (precision + recall + eps) 463 | tpr = tp / (tp + fn + eps) 464 | tnr = tn / (tn + fp + eps) 465 | balanced_accuracy = (tpr + tnr) / 2 466 | 467 | logging.info(' '.join([ 468 | f"Feature Extraction Time: {feat_timer.avg:.3e}, NN search time: {nn_timer.avg:.3e}", 469 | f"Inlier Time: {inlier_timer.avg:.3e}, Final Loss: {loss_meter.avg}, ", 470 | f"Loss: {loss_meter.avg}, Hit Ratio: {hit_ratio_meter.avg:.4f}, Precision: {precision}, Recall: {recall}, F1: {f1}, ", 471 | f"TPR: {tpr}, TNR: {tnr}, BAcc: {balanced_accuracy}, ", 472 | f"RTE: {regist_rte_meter.avg:.3e}, RRE: {regist_rre_meter.avg:.3e}, DGR Time: {dgr_timer.avg:.3e}", 473 | f"DGR Succ rate: {regist_succ_meter.avg:3e}", 474 | ])) 475 | 476 | stat = { 477 | 'loss': loss_meter.avg, 478 | 'precision': precision, 479 | 'recall': recall, 480 | 'tpr': tpr, 481 | 'tnr': tnr, 482 | 'balanced_accuracy': balanced_accuracy, 483 | 'f1': f1, 484 | 'regist_rte': regist_rte_meter.avg, 485 | 'regist_rre': regist_rre_meter.avg, 486 | 'succ_rate': regist_succ_meter.avg 487 | } 488 | 489 | return stat 490 | 491 | def _load_weights(self, config): 492 | if config.resume is None and config.weights: 493 | logging.info("=> loading weights for inlier model '{}'".format(config.weights)) 494 | checkpoint = torch.load(config.weights) 495 | self.feat_model.load_state_dict(checkpoint['state_dict']) 496 | logging.info("=> Loaded base model weights from '{}'".format(config.weights)) 497 | if 'state_dict_inlier' in checkpoint: 498 | self.inlier_model.load_state_dict(checkpoint['state_dict_inlier']) 499 | logging.info("=> Loaded inlier weights from '{}'".format(config.weights)) 500 | else: 501 | logging.warn("Inlier weight not found in '{}'".format(config.weights)) 502 | 503 | if config.resume is not None: 504 | if osp.isfile(config.resume): 505 | logging.info("=> loading checkpoint '{}'".format(config.resume)) 506 | state = torch.load(config.resume) 507 | 508 | self.start_epoch = state['epoch'] 509 | self.feat_model.load_state_dict(state['state_dict']) 510 | self.feat_model = self.feat_model.to(self.device) 511 | self.scheduler.load_state_dict(state['scheduler']) 512 | self.optimizer.load_state_dict(state['optimizer']) 513 | 514 | if 'best_val' in state.keys(): 515 | self.best_val = state['best_val'] 516 | self.best_val_epoch = state['best_val_epoch'] 517 | self.best_val_metric = state['best_val_metric'] 518 | 519 | if 'state_dict_inlier' in state: 520 | self.inlier_model.load_state_dict(state['state_dict_inlier']) 521 | self.inlier_model = self.inlier_model.to(self.device) 522 | else: 523 | logging.warn("Inlier weights not found in '{}'".format(config.resume)) 524 | else: 525 | logging.warn("Inlier weights does not exist at '{}'".format(config.resume)) 526 | 527 | def _save_checkpoint(self, epoch, filename='checkpoint'): 528 | """ 529 | Saving checkpoints 530 | 531 | :param epoch: current epoch number 532 | :param log: logging information of the epoch 533 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 534 | """ 535 | print('_save_checkpoint from inlier_trainer') 536 | state = { 537 | 'epoch': epoch, 538 | 'state_dict': self.feat_model.state_dict(), 539 | 'state_dict_inlier': self.inlier_model.state_dict(), 540 | 'optimizer': self.optimizer.state_dict(), 541 | 'scheduler': self.scheduler.state_dict(), 542 | 'config': self.config, 543 | 'best_val': self.best_val, 544 | 'best_val_epoch': self.best_val_epoch, 545 | 'best_val_metric': self.best_val_metric 546 | } 547 | filename = os.path.join(self.checkpoint_dir, f'{filename}.pth') 548 | logging.info("Saving checkpoint: {} ...".format(filename)) 549 | torch.save(state, filename) 550 | 551 | def get_data(self, iterator): 552 | while True: 553 | try: 554 | input_data = iterator.next() 555 | except ValueError as e: 556 | logging.info('Skipping an empty batch') 557 | continue 558 | 559 | return input_data 560 | 561 | def decompose_by_length(self, tensor, reference_tensors): 562 | decomposed_tensors = [] 563 | start_ind = 0 564 | for r in reference_tensors: 565 | N = len(r) 566 | decomposed_tensors.append(tensor[start_ind:start_ind + N]) 567 | start_ind += N 568 | return decomposed_tensors 569 | 570 | def decompose_rotation_translation(self, Ts): 571 | Ts = Ts.float() 572 | Rs = Ts[:, :3, :3] 573 | ts = Ts[:, :3, 3] 574 | 575 | Rs.require_grad = False 576 | ts.require_grad = False 577 | 578 | return Rs, ts 579 | 580 | def weighted_procrustes(self, xyz0s, xyz1s, pred_pairs, weights): 581 | decomposed_weights = self.decompose_by_length(weights, pred_pairs) 582 | RT = [] 583 | ws = [] 584 | 585 | for xyz0, xyz1, pred_pair, w in zip(xyz0s, xyz1s, pred_pairs, decomposed_weights): 586 | xyz0.requires_grad = False 587 | xyz1.requires_grad = False 588 | ws.append(w.sum().item()) 589 | predT = GlobalRegistration.weighted_procrustes( 590 | X=xyz0[pred_pair[:, 0]].to(self.device), 591 | Y=xyz1[pred_pair[:, 1]].to(self.device), 592 | w=w, 593 | eps=np.finfo(np.float32).eps) 594 | RT.append(predT) 595 | 596 | Rs, ts = list(zip(*RT)) 597 | Rs = torch.stack(Rs, 0) 598 | ts = torch.stack(ts, 0) 599 | ws = torch.Tensor(ws) 600 | return Rs, ts, ws 601 | 602 | def generate_inlier_features(self, xyz0, xyz1, C0, C1, F0, F1, pair_ind0, pair_ind1): 603 | """ 604 | Assume that the indices 0 and indices 1 gives the pairs in the 605 | (downsampled) correspondences. 606 | """ 607 | assert len(pair_ind0) == len(pair_ind1) 608 | reg_feat_type = self.config.inlier_feature_type 609 | assert reg_feat_type in ['ones', 'coords', 'counts', 'feats'] 610 | 611 | # Move coordinates and indices to the device 612 | if 'coords' in reg_feat_type: 613 | C0 = C0.to(self.device) 614 | C1 = C1.to(self.device) 615 | 616 | # TODO: change it to append the features and then concat at last 617 | if reg_feat_type == 'ones': 618 | reg_feat = torch.ones((len(pair_ind0), 1)).to(torch.float32) 619 | elif reg_feat_type == 'feats': 620 | reg_feat = torch.cat((F0[pair_ind0], F1[pair_ind1]), dim=1) 621 | elif reg_feat_type == 'coords': 622 | reg_feat = torch.cat((torch.cos(torch.cat( 623 | xyz0, 0)[pair_ind0]), torch.cos(torch.cat(xyz1, 0)[pair_ind1])), 624 | dim=1) 625 | else: 626 | raise ValueError('Inlier feature type not defined') 627 | 628 | return reg_feat 629 | 630 | def generate_inlier_input(self, xyz0, xyz1, iC0, iC1, iF0, iF1, len_batch, pos_pairs): 631 | # pairs consist of (xyz1 index, xyz0 index) 632 | stime = time.time() 633 | sinput0 = ME.SparseTensor(iF0, coords=iC0).to(self.device) 634 | oF0 = self.feat_model(sinput0).F 635 | 636 | sinput1 = ME.SparseTensor(iF1, coords=iC1).to(self.device) 637 | oF1 = self.feat_model(sinput1).F 638 | feat_time = time.time() - stime 639 | 640 | stime = time.time() 641 | pred_pairs = self.find_pairs(oF0, oF1, len_batch) 642 | nn_time = time.time() - stime 643 | 644 | is_correct = find_correct_correspondence(pos_pairs, pred_pairs, len_batch=len_batch) 645 | 646 | cat_pred_pairs = [] 647 | start_inds = torch.zeros((1, 2)).long() 648 | for lens, pred_pair in zip(len_batch, pred_pairs): 649 | cat_pred_pairs.append(pred_pair + start_inds) 650 | start_inds += torch.LongTensor(lens) 651 | 652 | cat_pred_pairs = torch.cat(cat_pred_pairs, 0) 653 | pred_pair_inds0, pred_pair_inds1 = cat_pred_pairs.t() 654 | reg_coords = torch.cat((iC0[pred_pair_inds0], iC1[pred_pair_inds1, 1:]), 1) 655 | reg_feats = self.generate_inlier_features(xyz0, xyz1, iC0, iC1, oF0, oF1, 656 | pred_pair_inds0, pred_pair_inds1).float() 657 | 658 | return reg_coords, reg_feats, pred_pairs, is_correct, feat_time, nn_time 659 | 660 | def find_pairs(self, F0, F1, len_batch): 661 | nn_batch = find_knn_batch(F0, 662 | F1, 663 | len_batch, 664 | nn_max_n=self.config.nn_max_n, 665 | knn=self.config.inlier_knn, 666 | return_distance=False, 667 | search_method=self.config.knn_search_method) 668 | 669 | pred_pairs = [] 670 | for nns, lens in zip(nn_batch, len_batch): 671 | pred_pair_ind0, pred_pair_ind1 = torch.arange( 672 | len(nns)).long()[:, None], nns.long().cpu() 673 | nn_pairs = [] 674 | for j in range(nns.shape[1]): 675 | nn_pairs.append( 676 | torch.cat((pred_pair_ind0.cpu(), pred_pair_ind1[:, j].unsqueeze(1)), 1)) 677 | 678 | pred_pairs.append(torch.cat(nn_pairs, 0)) 679 | return pred_pairs 680 | -------------------------------------------------------------------------------- /dataloader/base_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | # 8 | # Written by Chris Choy 9 | # Distributed under MIT License 10 | import os 11 | import logging 12 | import random 13 | import torch 14 | import torch.utils.data 15 | import numpy as np 16 | 17 | import dataloader.transforms as t 18 | from dataloader.inf_sampler import InfSampler 19 | 20 | import MinkowskiEngine as ME 21 | import open3d as o3d 22 | 23 | 24 | class CollationFunctionFactory: 25 | def __init__(self, concat_correspondences=True, collation_type='default'): 26 | self.concat_correspondences = concat_correspondences 27 | if collation_type == 'default': 28 | self.collation_fn = self.collate_default 29 | elif collation_type == 'collate_pair': 30 | self.collation_fn = self.collate_pair_fn 31 | else: 32 | raise ValueError(f'collation_type {collation_type} not found') 33 | 34 | def __call__(self, list_data): 35 | return self.collation_fn(list_data) 36 | 37 | def collate_default(self, list_data): 38 | return list_data 39 | 40 | def collate_pair_fn(self, list_data): 41 | N = len(list_data) 42 | list_data = [data for data in list_data if data is not None] 43 | if N != len(list_data): 44 | logging.info(f"Retain {len(list_data)} from {N} data.") 45 | if len(list_data) == 0: 46 | raise ValueError('No data in the batch') 47 | 48 | xyz0, xyz1, coords0, coords1, feats0, feats1, matching_inds, trans, extra_packages = list( 49 | zip(*list_data)) 50 | matching_inds_batch, trans_batch, len_batch = [], [], [] 51 | 52 | coords_batch0 = ME.utils.batched_coordinates(coords0) 53 | coords_batch1 = ME.utils.batched_coordinates(coords1) 54 | trans_batch = torch.from_numpy(np.stack(trans)).float() 55 | 56 | curr_start_inds = torch.zeros((1, 2), dtype=torch.int32) 57 | for batch_id, _ in enumerate(coords0): 58 | # For scan2cad there will be empty matching_inds even after filtering 59 | # This check will skip these pairs while not affecting other datasets 60 | if (len(matching_inds[batch_id]) == 0): 61 | continue 62 | 63 | N0 = coords0[batch_id].shape[0] 64 | N1 = coords1[batch_id].shape[0] 65 | 66 | if self.concat_correspondences: 67 | matching_inds_batch.append( 68 | torch.IntTensor(matching_inds[batch_id]) + curr_start_inds) 69 | else: 70 | matching_inds_batch.append(torch.IntTensor(matching_inds[batch_id])) 71 | 72 | len_batch.append([N0, N1]) 73 | 74 | # Move the head 75 | curr_start_inds[0, 0] += N0 76 | curr_start_inds[0, 1] += N1 77 | 78 | # Concatenate all lists 79 | feats_batch0 = torch.cat(feats0, 0).float() 80 | feats_batch1 = torch.cat(feats1, 0).float() 81 | # xyz_batch0 = torch.cat(xyz0, 0).float() 82 | # xyz_batch1 = torch.cat(xyz1, 0).float() 83 | # trans_batch = torch.cat(trans_batch, 0).float() 84 | if self.concat_correspondences: 85 | matching_inds_batch = torch.cat(matching_inds_batch, 0).int() 86 | 87 | return { 88 | 'pcd0': xyz0, 89 | 'pcd1': xyz1, 90 | 'sinput0_C': coords_batch0, 91 | 'sinput0_F': feats_batch0, 92 | 'sinput1_C': coords_batch1, 93 | 'sinput1_F': feats_batch1, 94 | 'correspondences': matching_inds_batch, 95 | 'T_gt': trans_batch, 96 | 'len_batch': len_batch, 97 | 'extra_packages': extra_packages, 98 | } 99 | 100 | 101 | class PairDataset(torch.utils.data.Dataset): 102 | AUGMENT = None 103 | 104 | def __init__(self, 105 | phase, 106 | transform=None, 107 | random_rotation=True, 108 | random_scale=True, 109 | manual_seed=False, 110 | config=None): 111 | self.phase = phase 112 | self.files = [] 113 | self.data_objects = [] 114 | self.transform = transform 115 | self.voxel_size = config.voxel_size 116 | self.matching_search_voxel_size = \ 117 | config.voxel_size * config.positive_pair_search_voxel_size_multiplier 118 | 119 | self.random_scale = random_scale 120 | self.min_scale = config.min_scale 121 | self.max_scale = config.max_scale 122 | self.random_rotation = random_rotation 123 | self.rotation_range = config.rotation_range 124 | self.randg = np.random.RandomState() 125 | if manual_seed: 126 | self.reset_seed() 127 | 128 | def reset_seed(self, seed=0): 129 | logging.info(f"Resetting the data loader seed to {seed}") 130 | self.randg.seed(seed) 131 | 132 | def apply_transform(self, pts, trans): 133 | R = trans[:3, :3] 134 | T = trans[:3, 3] 135 | pts = pts @ R.T + T 136 | return pts 137 | 138 | def __len__(self): 139 | return len(self.files) 140 | -------------------------------------------------------------------------------- /dataloader/data_loaders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | from dataloader.threedmatch_loader import * 8 | from dataloader.kitti_loader import * 9 | 10 | ALL_DATASETS = [ 11 | ThreeDMatchPairDataset07, ThreeDMatchPairDataset05, ThreeDMatchPairDataset03, 12 | ThreeDMatchTrajectoryDataset, KITTIPairDataset, KITTINMPairDataset 13 | ] 14 | dataset_str_mapping = {d.__name__: d for d in ALL_DATASETS} 15 | 16 | 17 | def make_data_loader(config, phase, batch_size, num_workers=0, shuffle=None): 18 | assert phase in ['train', 'trainval', 'val', 'test'] 19 | if shuffle is None: 20 | shuffle = phase != 'test' 21 | 22 | if config.dataset not in dataset_str_mapping.keys(): 23 | logging.error(f'Dataset {config.dataset}, does not exists in ' + 24 | ', '.join(dataset_str_mapping.keys())) 25 | 26 | Dataset = dataset_str_mapping[config.dataset] 27 | 28 | use_random_scale = False 29 | use_random_rotation = False 30 | transforms = [] 31 | if phase in ['train', 'trainval']: 32 | use_random_rotation = config.use_random_rotation 33 | use_random_scale = config.use_random_scale 34 | transforms += [t.Jitter()] 35 | 36 | if phase in ['val', 'test']: 37 | use_random_rotation = config.test_random_rotation 38 | 39 | dset = Dataset(phase, 40 | transform=t.Compose(transforms), 41 | random_scale=use_random_scale, 42 | random_rotation=use_random_rotation, 43 | config=config) 44 | 45 | collation_fn = CollationFunctionFactory(concat_correspondences=False, 46 | collation_type='collate_pair') 47 | 48 | loader = torch.utils.data.DataLoader(dset, 49 | batch_size=batch_size, 50 | collate_fn=collation_fn, 51 | num_workers=num_workers, 52 | sampler=InfSampler(dset, shuffle)) 53 | 54 | return loader 55 | -------------------------------------------------------------------------------- /dataloader/inf_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class InfSampler(Sampler): 12 | """Samples elements randomly, without replacement. 13 | 14 | Arguments: 15 | data_source (Dataset): dataset to sample from 16 | """ 17 | 18 | def __init__(self, data_source, shuffle=False): 19 | self.data_source = data_source 20 | self.shuffle = shuffle 21 | self.reset_permutation() 22 | 23 | def reset_permutation(self): 24 | perm = len(self.data_source) 25 | if self.shuffle: 26 | perm = torch.randperm(perm) 27 | self._perm = perm.tolist() 28 | 29 | def __iter__(self): 30 | return self 31 | 32 | def __next__(self): 33 | if len(self._perm) == 0: 34 | self.reset_permutation() 35 | return self._perm.pop() 36 | 37 | def __len__(self): 38 | return len(self.data_source) 39 | -------------------------------------------------------------------------------- /dataloader/kitti_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import os 8 | import glob 9 | 10 | from dataloader.base_loader import * 11 | from dataloader.transforms import * 12 | from util.pointcloud import get_matching_indices, make_open3d_point_cloud 13 | 14 | kitti_cache = {} 15 | kitti_icp_cache = {} 16 | 17 | class KITTIPairDataset(PairDataset): 18 | AUGMENT = None 19 | DATA_FILES = { 20 | 'train': './dataloader/split/train_kitti.txt', 21 | 'val': './dataloader/split/val_kitti.txt', 22 | 'test': './dataloader/split/test_kitti.txt' 23 | } 24 | TEST_RANDOM_ROTATION = False 25 | IS_ODOMETRY = True 26 | 27 | def __init__(self, 28 | phase, 29 | transform=None, 30 | random_rotation=True, 31 | random_scale=True, 32 | manual_seed=False, 33 | config=None): 34 | # For evaluation, use the odometry dataset training following the 3DFeat eval method 35 | self.root = root = config.kitti_dir + '/dataset' 36 | random_rotation = self.TEST_RANDOM_ROTATION 37 | self.icp_path = config.icp_cache_path 38 | try: 39 | os.mkdir(self.icp_path) 40 | except OSError as error: 41 | pass 42 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 43 | manual_seed, config) 44 | 45 | logging.info(f"Loading the subset {phase} from {root}") 46 | # Use the kitti root 47 | self.max_time_diff = max_time_diff = config.kitti_max_time_diff 48 | 49 | subset_names = open(self.DATA_FILES[phase]).read().split() 50 | for dirname in subset_names: 51 | drive_id = int(dirname) 52 | inames = self.get_all_scan_ids(drive_id) 53 | for start_time in inames: 54 | for time_diff in range(2, max_time_diff): 55 | pair_time = time_diff + start_time 56 | if pair_time in inames: 57 | self.files.append((drive_id, start_time, pair_time)) 58 | 59 | def get_all_scan_ids(self, drive_id): 60 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id) 61 | assert len( 62 | fnames) > 0, f"Make sure that the path {self.root} has drive id: {drive_id}" 63 | inames = [int(os.path.split(fname)[-1][:-4]) for fname in fnames] 64 | return inames 65 | 66 | @property 67 | def velo2cam(self): 68 | try: 69 | velo2cam = self._velo2cam 70 | except AttributeError: 71 | R = np.array([ 72 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 73 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 74 | ]).reshape(3, 3) 75 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 76 | velo2cam = np.hstack([R, T]) 77 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 78 | return self._velo2cam 79 | 80 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False): 81 | data_path = self.root + '/poses/%02d.txt' % drive 82 | if data_path not in kitti_cache: 83 | kitti_cache[data_path] = np.genfromtxt(data_path) 84 | if return_all: 85 | return kitti_cache[data_path] 86 | else: 87 | return kitti_cache[data_path][indices] 88 | 89 | def odometry_to_positions(self, odometry): 90 | T_w_cam0 = odometry.reshape(3, 4) 91 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 92 | return T_w_cam0 93 | 94 | def rot3d(self, axis, angle): 95 | ei = np.ones(3, dtype='bool') 96 | ei[axis] = 0 97 | i = np.nonzero(ei)[0] 98 | m = np.eye(3) 99 | c, s = np.cos(angle), np.sin(angle) 100 | m[i[0], i[0]] = c 101 | m[i[0], i[1]] = -s 102 | m[i[1], i[0]] = s 103 | m[i[1], i[1]] = c 104 | return m 105 | 106 | def pos_transform(self, pos): 107 | x, y, z, rx, ry, rz, _ = pos[0] 108 | RT = np.eye(4) 109 | RT[:3, :3] = np.dot(np.dot(self.rot3d(0, rx), self.rot3d(1, ry)), self.rot3d(2, rz)) 110 | RT[:3, 3] = [x, y, z] 111 | return RT 112 | 113 | def get_position_transform(self, pos0, pos1, invert=False): 114 | T0 = self.pos_transform(pos0) 115 | T1 = self.pos_transform(pos1) 116 | return (np.dot(T1, np.linalg.inv(T0)).T if not invert else np.dot( 117 | np.linalg.inv(T1), T0).T) 118 | 119 | def _get_velodyne_fn(self, drive, t): 120 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t) 121 | return fname 122 | 123 | def __getitem__(self, idx): 124 | drive = self.files[idx][0] 125 | t0, t1 = self.files[idx][1], self.files[idx][2] 126 | all_odometry = self.get_video_odometry(drive, [t0, t1]) 127 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry] 128 | fname0 = self._get_velodyne_fn(drive, t0) 129 | fname1 = self._get_velodyne_fn(drive, t1) 130 | 131 | # XYZ and reflectance 132 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4) 133 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4) 134 | 135 | xyz0 = xyzr0[:, :3] 136 | xyz1 = xyzr1[:, :3] 137 | 138 | key = '%d_%d_%d' % (drive, t0, t1) 139 | filename = self.icp_path + '/' + key + '.npy' 140 | if key not in kitti_icp_cache: 141 | if not os.path.exists(filename): 142 | # work on the downsampled xyzs, 0.05m == 5cm 143 | sel0 = ME.utils.sparse_quantize(xyz0 / 0.05, return_index=True) 144 | sel1 = ME.utils.sparse_quantize(xyz1 / 0.05, return_index=True) 145 | 146 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T) 147 | @ np.linalg.inv(self.velo2cam)).T 148 | xyz0_t = self.apply_transform(xyz0[sel0], M) 149 | pcd0 = make_open3d_point_cloud(xyz0_t) 150 | pcd1 = make_open3d_point_cloud(xyz1[sel1]) 151 | reg = o3d.registration.registration_icp(pcd0, pcd1, 0.2, np.eye(4), 152 | o3d.registration.TransformationEstimationPointToPoint(), 153 | o3d.registration.ICPConvergenceCriteria(max_iteration=200)) 154 | pcd0.transform(reg.transformation) 155 | # pcd0.transform(M2) or self.apply_transform(xyz0, M2) 156 | M2 = M @ reg.transformation 157 | # o3d.draw_geometries([pcd0, pcd1]) 158 | # write to a file 159 | np.save(filename, M2) 160 | else: 161 | M2 = np.load(filename) 162 | kitti_icp_cache[key] = M2 163 | else: 164 | M2 = kitti_icp_cache[key] 165 | 166 | if self.random_rotation: 167 | T0 = sample_random_trans(xyz0, self.randg, np.pi / 4) 168 | T1 = sample_random_trans(xyz1, self.randg, np.pi / 4) 169 | trans = T1 @ M2 @ np.linalg.inv(T0) 170 | 171 | xyz0 = self.apply_transform(xyz0, T0) 172 | xyz1 = self.apply_transform(xyz1, T1) 173 | else: 174 | trans = M2 175 | 176 | matching_search_voxel_size = self.matching_search_voxel_size 177 | if self.random_scale and random.random() < 0.95: 178 | scale = self.min_scale + \ 179 | (self.max_scale - self.min_scale) * random.random() 180 | matching_search_voxel_size *= scale 181 | xyz0 = scale * xyz0 182 | xyz1 = scale * xyz1 183 | 184 | # Voxelization 185 | xyz0_th = torch.from_numpy(xyz0) 186 | xyz1_th = torch.from_numpy(xyz1) 187 | 188 | sel0 = ME.utils.sparse_quantize(xyz0_th / self.voxel_size, return_index=True) 189 | sel1 = ME.utils.sparse_quantize(xyz1_th / self.voxel_size, return_index=True) 190 | 191 | # Make point clouds using voxelized points 192 | pcd0 = make_open3d_point_cloud(xyz0[sel0]) 193 | pcd1 = make_open3d_point_cloud(xyz1[sel1]) 194 | 195 | # Get matches 196 | matches = get_matching_indices(pcd0, pcd1, trans, matching_search_voxel_size) 197 | if len(matches) < 1000: 198 | raise ValueError(f"Insufficient matches in {drive}, {t0}, {t1}") 199 | 200 | # Get features 201 | npts0 = len(sel0) 202 | npts1 = len(sel1) 203 | 204 | feats_train0, feats_train1 = [], [] 205 | 206 | unique_xyz0_th = xyz0_th[sel0] 207 | unique_xyz1_th = xyz1_th[sel1] 208 | 209 | feats_train0.append(torch.ones((npts0, 1))) 210 | feats_train1.append(torch.ones((npts1, 1))) 211 | 212 | feats0 = torch.cat(feats_train0, 1) 213 | feats1 = torch.cat(feats_train1, 1) 214 | 215 | coords0 = torch.floor(unique_xyz0_th / self.voxel_size) 216 | coords1 = torch.floor(unique_xyz1_th / self.voxel_size) 217 | 218 | if self.transform: 219 | coords0, feats0 = self.transform(coords0, feats0) 220 | coords1, feats1 = self.transform(coords1, feats1) 221 | 222 | extra_package = {'drive': drive, 't0': t0, 't1': t1} 223 | 224 | return (unique_xyz0_th.float(), 225 | unique_xyz1_th.float(), coords0.int(), coords1.int(), feats0.float(), 226 | feats1.float(), matches, trans, extra_package) 227 | 228 | 229 | class KITTINMPairDataset(KITTIPairDataset): 230 | r""" 231 | Generate KITTI pairs within N meter distance 232 | """ 233 | MIN_DIST = 10 234 | 235 | def __init__(self, 236 | phase, 237 | transform=None, 238 | random_rotation=True, 239 | random_scale=True, 240 | manual_seed=False, 241 | config=None): 242 | self.root = root = os.path.join(config.kitti_dir, 'dataset') 243 | self.icp_path = os.path.join(config.kitti_dir, config.icp_cache_path) 244 | try: 245 | os.mkdir(self.icp_path) 246 | except OSError as error: 247 | pass 248 | random_rotation = self.TEST_RANDOM_ROTATION 249 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 250 | manual_seed, config) 251 | 252 | logging.info(f"Loading the subset {phase} from {root}") 253 | 254 | subset_names = open(self.DATA_FILES[phase]).read().split() 255 | for dirname in subset_names: 256 | drive_id = int(dirname) 257 | fnames = glob.glob(root + '/sequences/%02d/velodyne/*.bin' % drive_id) 258 | assert len(fnames) > 0, f"Make sure that the path {root} has data {dirname}" 259 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 260 | 261 | all_odo = self.get_video_odometry(drive_id, return_all=True) 262 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 263 | Ts = all_pos[:, :3, 3] 264 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3))**2 265 | pdist = np.sqrt(pdist.sum(-1)) 266 | more_than_10 = pdist > self.MIN_DIST 267 | curr_time = inames[0] 268 | while curr_time in inames: 269 | # Find the min index 270 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0] 271 | if len(next_time) == 0: 272 | curr_time += 1 273 | else: 274 | # Follow https://github.com/yewzijian/3DFeatNet/blob/master/scripts_data_processing/kitti/process_kitti_data.m#L44 275 | next_time = next_time[0] + curr_time - 1 276 | 277 | if next_time in inames: 278 | self.files.append((drive_id, curr_time, next_time)) 279 | curr_time = next_time + 1 280 | 281 | # Remove problematic sequence 282 | for item in [ 283 | (8, 15, 58), 284 | ]: 285 | if item in self.files: 286 | self.files.pop(self.files.index(item)) 287 | -------------------------------------------------------------------------------- /dataloader/split/test_3dmatch.txt: -------------------------------------------------------------------------------- 1 | 7-scenes-redkitchen 2 | sun3d-home_at-home_at_scan1_2013_jan_1 3 | sun3d-home_md-home_md_scan9_2012_sep_30 4 | sun3d-hotel_uc-scan3 5 | sun3d-hotel_umd-maryland_hotel1 6 | sun3d-hotel_umd-maryland_hotel3 7 | sun3d-mit_76_studyroom-76-1studyroom2 8 | sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika -------------------------------------------------------------------------------- /dataloader/split/test_kitti.txt: -------------------------------------------------------------------------------- 1 | 8 2 | 9 3 | 10 4 | -------------------------------------------------------------------------------- /dataloader/split/test_scan2cad.txt: -------------------------------------------------------------------------------- 1 | full_annotations_clean_test.json 2 | -------------------------------------------------------------------------------- /dataloader/split/train_3dmatch.txt: -------------------------------------------------------------------------------- 1 | sun3d-brown_bm_1-brown_bm_1 2 | sun3d-brown_cogsci_1-brown_cogsci_1 3 | sun3d-brown_cs_2-brown_cs2 4 | sun3d-brown_cs_3-brown_cs3 5 | sun3d-harvard_c3-hv_c3_1 6 | sun3d-harvard_c5-hv_c5_1 7 | sun3d-harvard_c6-hv_c6_1 8 | sun3d-harvard_c8-hv_c8_3 9 | sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika 10 | sun3d-hotel_nips2012-nips_4 11 | sun3d-hotel_sf-scan1 12 | sun3d-mit_32_d507-d507_2 13 | sun3d-mit_46_ted_lab1-ted_lab_2 14 | sun3d-mit_76_417-76-417b 15 | sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika 16 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika 17 | 7-scenes-chess 18 | 7-scenes-fire 19 | 7-scenes-office 20 | 7-scenes-pumpkin 21 | 7-scenes-stairs 22 | rgbd-scenes-v2-scene_01 23 | rgbd-scenes-v2-scene_02 24 | rgbd-scenes-v2-scene_03 25 | rgbd-scenes-v2-scene_04 26 | rgbd-scenes-v2-scene_05 27 | rgbd-scenes-v2-scene_06 28 | rgbd-scenes-v2-scene_07 29 | rgbd-scenes-v2-scene_08 30 | rgbd-scenes-v2-scene_09 31 | rgbd-scenes-v2-scene_11 32 | rgbd-scenes-v2-scene_12 33 | rgbd-scenes-v2-scene_13 34 | rgbd-scenes-v2-scene_14 35 | bundlefusion-apt0 36 | bundlefusion-apt1 37 | bundlefusion-apt2 38 | bundlefusion-copyroom 39 | bundlefusion-office1 40 | bundlefusion-office2 41 | bundlefusion-office3 42 | analysis-by-synthesis-apt1-kitchen 43 | analysis-by-synthesis-apt1-living 44 | analysis-by-synthesis-apt2-bed 45 | analysis-by-synthesis-apt2-living 46 | analysis-by-synthesis-apt2-luke 47 | analysis-by-synthesis-office2-5a 48 | analysis-by-synthesis-office2-5b 49 | -------------------------------------------------------------------------------- /dataloader/split/train_kitti.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | -------------------------------------------------------------------------------- /dataloader/split/train_scan2cad.txt: -------------------------------------------------------------------------------- 1 | full_annotations_clean_train.json 2 | -------------------------------------------------------------------------------- /dataloader/split/val_3dmatch.txt: -------------------------------------------------------------------------------- 1 | sun3d-brown_bm_4-brown_bm_4 2 | sun3d-harvard_c11-hv_c11_2 3 | 7-scenes-heads 4 | rgbd-scenes-v2-scene_10 5 | bundlefusion-office0 6 | analysis-by-synthesis-apt2-kitchen 7 | -------------------------------------------------------------------------------- /dataloader/split/val_kitti.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 7 3 | -------------------------------------------------------------------------------- /dataloader/split/val_scan2cad.txt: -------------------------------------------------------------------------------- 1 | full_annotations_clean_val.json 2 | -------------------------------------------------------------------------------- /dataloader/threedmatch_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import glob 8 | 9 | from dataloader.base_loader import * 10 | from dataloader.transforms import * 11 | 12 | from util.pointcloud import get_matching_indices, make_open3d_point_cloud 13 | from util.file import read_trajectory 14 | 15 | 16 | class IndoorPairDataset(PairDataset): 17 | ''' 18 | Train dataset 19 | ''' 20 | OVERLAP_RATIO = None 21 | AUGMENT = None 22 | 23 | def __init__(self, 24 | phase, 25 | transform=None, 26 | random_rotation=True, 27 | random_scale=True, 28 | manual_seed=False, 29 | config=None): 30 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 31 | manual_seed, config) 32 | self.root = root = config.threed_match_dir 33 | self.use_xyz_feature = config.use_xyz_feature 34 | logging.info(f"Loading the subset {phase} from {root}") 35 | 36 | subset_names = open(self.DATA_FILES[phase]).read().split() 37 | for name in subset_names: 38 | fname = name + "*%.2f.txt" % self.OVERLAP_RATIO 39 | fnames_txt = glob.glob(root + "/" + fname) 40 | assert len(fnames_txt) > 0, f"Make sure that the path {root} has data {fname}" 41 | for fname_txt in fnames_txt: 42 | with open(fname_txt) as f: 43 | content = f.readlines() 44 | fnames = [x.strip().split() for x in content] 45 | for fname in fnames: 46 | self.files.append([fname[0], fname[1]]) 47 | 48 | def __getitem__(self, idx): 49 | file0 = os.path.join(self.root, self.files[idx][0]) 50 | file1 = os.path.join(self.root, self.files[idx][1]) 51 | data0 = np.load(file0) 52 | data1 = np.load(file1) 53 | xyz0 = data0["pcd"] 54 | xyz1 = data1["pcd"] 55 | matching_search_voxel_size = self.matching_search_voxel_size 56 | 57 | if self.random_scale and random.random() < 0.95: 58 | scale = self.min_scale + \ 59 | (self.max_scale - self.min_scale) * random.random() 60 | matching_search_voxel_size *= scale 61 | xyz0 = scale * xyz0 62 | xyz1 = scale * xyz1 63 | 64 | if self.random_rotation: 65 | T0 = sample_random_trans(xyz0, self.randg, self.rotation_range) 66 | T1 = sample_random_trans(xyz1, self.randg, self.rotation_range) 67 | trans = T1 @ np.linalg.inv(T0) 68 | 69 | xyz0 = self.apply_transform(xyz0, T0) 70 | xyz1 = self.apply_transform(xyz1, T1) 71 | else: 72 | trans = np.identity(4) 73 | 74 | # Voxelization 75 | xyz0_th = torch.from_numpy(xyz0) 76 | xyz1_th = torch.from_numpy(xyz1) 77 | 78 | sel0 = ME.utils.sparse_quantize(xyz0_th / self.voxel_size, return_index=True) 79 | sel1 = ME.utils.sparse_quantize(xyz1_th / self.voxel_size, return_index=True) 80 | 81 | # Make point clouds using voxelized points 82 | pcd0 = make_open3d_point_cloud(xyz0[sel0]) 83 | pcd1 = make_open3d_point_cloud(xyz1[sel1]) 84 | 85 | # Select features and points using the returned voxelized indices 86 | # 3DMatch color is not helpful 87 | # pcd0.colors = o3d.utility.Vector3dVector(color0[sel0]) 88 | # pcd1.colors = o3d.utility.Vector3dVector(color1[sel1]) 89 | 90 | # Get matches 91 | matches = get_matching_indices(pcd0, pcd1, trans, matching_search_voxel_size) 92 | 93 | # Get features 94 | npts0 = len(sel0) 95 | npts1 = len(sel1) 96 | 97 | feats_train0, feats_train1 = [], [] 98 | 99 | unique_xyz0_th = xyz0_th[sel0] 100 | unique_xyz1_th = xyz1_th[sel1] 101 | 102 | # xyz as feats 103 | if self.use_xyz_feature: 104 | feats_train0.append(unique_xyz0_th - unique_xyz0_th.mean(0)) 105 | feats_train1.append(unique_xyz1_th - unique_xyz1_th.mean(0)) 106 | else: 107 | feats_train0.append(torch.ones((npts0, 1))) 108 | feats_train1.append(torch.ones((npts1, 1))) 109 | 110 | feats0 = torch.cat(feats_train0, 1) 111 | feats1 = torch.cat(feats_train1, 1) 112 | 113 | coords0 = torch.floor(unique_xyz0_th / self.voxel_size) 114 | coords1 = torch.floor(unique_xyz1_th / self.voxel_size) 115 | 116 | if self.transform: 117 | coords0, feats0 = self.transform(coords0, feats0) 118 | coords1, feats1 = self.transform(coords1, feats1) 119 | 120 | extra_package = {'idx': idx, 'file0': file0, 'file1': file1} 121 | 122 | return (unique_xyz0_th.float(), 123 | unique_xyz1_th.float(), coords0.int(), coords1.int(), feats0.float(), 124 | feats1.float(), matches, trans, extra_package) 125 | 126 | 127 | class ThreeDMatchPairDataset03(IndoorPairDataset): 128 | OVERLAP_RATIO = 0.3 129 | DATA_FILES = { 130 | 'train': './dataloader/split/train_3dmatch.txt', 131 | 'val': './dataloader/split/val_3dmatch.txt', 132 | 'test': './dataloader/split/test_3dmatch.txt' 133 | } 134 | 135 | 136 | class ThreeDMatchPairDataset05(ThreeDMatchPairDataset03): 137 | OVERLAP_RATIO = 0.5 138 | 139 | 140 | class ThreeDMatchPairDataset07(ThreeDMatchPairDataset03): 141 | OVERLAP_RATIO = 0.7 142 | 143 | 144 | class ThreeDMatchTrajectoryDataset(PairDataset): 145 | ''' 146 | Test dataset 147 | ''' 148 | DATA_FILES = { 149 | 'train': './dataloader/split/train_3dmatch.txt', 150 | 'val': './dataloader/split/val_3dmatch.txt', 151 | 'test': './dataloader/split/test_3dmatch.txt' 152 | } 153 | 154 | def __init__(self, 155 | phase, 156 | transform=None, 157 | random_rotation=True, 158 | random_scale=True, 159 | manual_seed=False, 160 | scene_id=None, 161 | config=None, 162 | return_ply_names=False): 163 | 164 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 165 | manual_seed, config) 166 | 167 | self.root = config.threed_match_dir 168 | 169 | subset_names = open(self.DATA_FILES[phase]).read().split() 170 | if scene_id is not None: 171 | subset_names = [subset_names[scene_id]] 172 | for sname in subset_names: 173 | traj_file = os.path.join(self.root, sname + '-evaluation/gt.log') 174 | assert os.path.exists(traj_file) 175 | traj = read_trajectory(traj_file) 176 | for ctraj in traj: 177 | i = ctraj.metadata[0] 178 | j = ctraj.metadata[1] 179 | T_gt = ctraj.pose 180 | self.files.append((sname, i, j, T_gt)) 181 | 182 | self.return_ply_names = return_ply_names 183 | 184 | def __getitem__(self, pair_index): 185 | sname, i, j, T_gt = self.files[pair_index] 186 | ply_name0 = os.path.join(self.root, sname, f'cloud_bin_{i}.ply') 187 | ply_name1 = os.path.join(self.root, sname, f'cloud_bin_{j}.ply') 188 | 189 | if self.return_ply_names: 190 | return sname, ply_name0, ply_name1, T_gt 191 | 192 | pcd0 = o3d.io.read_point_cloud(ply_name0) 193 | pcd1 = o3d.io.read_point_cloud(ply_name1) 194 | pcd0 = np.asarray(pcd0.points) 195 | pcd1 = np.asarray(pcd1.points) 196 | return sname, pcd0, pcd1, T_gt 197 | -------------------------------------------------------------------------------- /dataloader/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | import numpy as np 9 | import random 10 | from scipy.linalg import expm, norm 11 | 12 | 13 | # Rotation matrix along axis with angle theta 14 | def M(axis, theta): 15 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 16 | 17 | 18 | def sample_random_trans(pcd, randg, rotation_range=360): 19 | T = np.eye(4) 20 | R = M(randg.rand(3) - 0.5, rotation_range * np.pi / 180.0 * (randg.rand(1) - 0.5)) 21 | T[:3, :3] = R 22 | T[:3, 3] = R.dot(-np.mean(pcd, axis=0)) 23 | return T 24 | 25 | 26 | class Compose: 27 | def __init__(self, transforms): 28 | self.transforms = transforms 29 | 30 | def __call__(self, coords, feats): 31 | for transform in self.transforms: 32 | coords, feats = transform(coords, feats) 33 | return coords, feats 34 | 35 | 36 | class Jitter: 37 | def __init__(self, mu=0, sigma=0.01): 38 | self.mu = mu 39 | self.sigma = sigma 40 | 41 | def __call__(self, coords, feats): 42 | if random.random() < 0.95: 43 | feats += self.sigma * torch.randn(feats.shape[0], feats.shape[1]) 44 | if self.mu != 0: 45 | feats += self.mu 46 | return coords, feats 47 | 48 | 49 | class ChromaticShift: 50 | def __init__(self, mu=0, sigma=0.1): 51 | self.mu = mu 52 | self.sigma = sigma 53 | 54 | def __call__(self, coords, feats): 55 | if random.random() < 0.95: 56 | feats[:, :3] += torch.randn(self.mu, self.sigma, (1, 3)) 57 | return coords, feats 58 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import os 8 | from urllib.request import urlretrieve 9 | 10 | import open3d as o3d 11 | from core.deep_global_registration import DeepGlobalRegistration 12 | from config import get_config 13 | 14 | BASE_URL = "http://node2.chrischoy.org/data/" 15 | DOWNLOAD_LIST = [ 16 | (BASE_URL + "datasets/registration/", "redkitchen_000.ply"), 17 | (BASE_URL + "datasets/registration/", "redkitchen_010.ply"), 18 | (BASE_URL + "projects/DGR/", "ResUNetBN2C-feat32-3dmatch-v0.05.pth") 19 | ] 20 | 21 | # Check if the weights and file exist and download 22 | if not os.path.isfile('redkitchen_000.ply'): 23 | print('Downloading weights and pointcloud files...') 24 | for f in DOWNLOAD_LIST: 25 | print(f"Downloading {f}") 26 | urlretrieve(f[0] + f[1], f[1]) 27 | 28 | if __name__ == '__main__': 29 | config = get_config() 30 | if config.weights is None: 31 | config.weights = DOWNLOAD_LIST[-1][-1] 32 | 33 | # preprocessing 34 | pcd0 = o3d.io.read_point_cloud(config.pcd0) 35 | pcd0.estimate_normals() 36 | pcd1 = o3d.io.read_point_cloud(config.pcd1) 37 | pcd1.estimate_normals() 38 | 39 | # registration 40 | dgr = DeepGlobalRegistration(config) 41 | T01 = dgr.register(pcd0, pcd1) 42 | 43 | o3d.visualization.draw_geometries([pcd0, pcd1]) 44 | 45 | pcd0.transform(T01) 46 | print(T01) 47 | 48 | o3d.visualization.draw_geometries([pcd0, pcd1]) 49 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import logging 8 | import model.simpleunet as simpleunets 9 | import model.resunet as resunets 10 | import model.pyramidnet as pyramids 11 | 12 | MODELS = [] 13 | 14 | 15 | def add_models(module): 16 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Net' in a or 'MLP' in a]) 17 | 18 | 19 | add_models(simpleunets) 20 | add_models(resunets) 21 | add_models(pyramids) 22 | 23 | 24 | def load_model(name): 25 | '''Creates and returns an instance of the model given its class name. 26 | ''' 27 | # Find the model class from its name 28 | all_models = MODELS 29 | mdict = {model.__name__: model for model in all_models} 30 | if name not in mdict: 31 | logging.info(f'Invalid model index. You put {name}. Options are:') 32 | # Display a list of valid model names 33 | for model in all_models: 34 | logging.info('\t* {}'.format(model.__name__)) 35 | return None 36 | NetClass = mdict[name] 37 | 38 | return NetClass 39 | -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch.nn as nn 8 | import MinkowskiEngine as ME 9 | 10 | 11 | def get_norm(norm_type, num_feats, bn_momentum=0.05, dimension=-1): 12 | if norm_type == 'BN': 13 | return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) 14 | elif norm_type == 'IN': 15 | return ME.MinkowskiInstanceNorm(num_feats) 16 | elif norm_type == 'INBN': 17 | return nn.Sequential( 18 | ME.MinkowskiInstanceNorm(num_feats), 19 | ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum)) 20 | else: 21 | raise ValueError(f'Type {norm_type}, not defined') 22 | 23 | 24 | def get_nonlinearity(non_type): 25 | if non_type == 'ReLU': 26 | return ME.MinkowskiReLU() 27 | elif non_type == 'ELU': 28 | # return ME.MinkowskiInstanceNorm(num_feats, dimension=dimension) 29 | return ME.MinkowskiELU() 30 | else: 31 | raise ValueError(f'Type {non_type}, not defined') 32 | -------------------------------------------------------------------------------- /model/pyramidnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | import torch.nn as nn 9 | import MinkowskiEngine as ME 10 | from model.common import get_norm, get_nonlinearity 11 | 12 | from model.residual_block import get_block, conv, conv_tr, conv_norm_non 13 | 14 | 15 | class PyramidModule(ME.MinkowskiNetwork): 16 | NONLINEARITY = 'ELU' 17 | NORM_TYPE = 'BN' 18 | REGION_TYPE = ME.RegionType.HYPER_CUBE 19 | 20 | def __init__(self, 21 | inc, 22 | outc, 23 | inner_inc, 24 | inner_outc, 25 | inner_module=None, 26 | depth=1, 27 | bn_momentum=0.05, 28 | dimension=-1): 29 | ME.MinkowskiNetwork.__init__(self, dimension) 30 | self.depth = depth 31 | 32 | self.conv = nn.Sequential( 33 | conv_norm_non( 34 | inc, 35 | inner_inc, 36 | 3, 37 | 2, 38 | dimension, 39 | region_type=self.REGION_TYPE, 40 | norm_type=self.NORM_TYPE, 41 | nonlinearity=self.NONLINEARITY), *[ 42 | get_block( 43 | self.NORM_TYPE, 44 | inner_inc, 45 | inner_inc, 46 | bn_momentum=bn_momentum, 47 | region_type=self.REGION_TYPE, 48 | dimension=dimension) for d in range(depth) 49 | ]) 50 | self.inner_module = inner_module 51 | self.convtr = nn.Sequential( 52 | conv_tr( 53 | in_channels=inner_outc, 54 | out_channels=inner_outc, 55 | kernel_size=3, 56 | stride=2, 57 | dilation=1, 58 | has_bias=False, 59 | region_type=self.REGION_TYPE, 60 | dimension=dimension), 61 | get_norm( 62 | self.NORM_TYPE, inner_outc, bn_momentum=bn_momentum, dimension=dimension), 63 | get_nonlinearity(self.NONLINEARITY)) 64 | 65 | self.cat_conv = conv_norm_non( 66 | inner_outc + inc, 67 | outc, 68 | 1, 69 | 1, 70 | dimension, 71 | norm_type=self.NORM_TYPE, 72 | nonlinearity=self.NONLINEARITY) 73 | 74 | def forward(self, x): 75 | y = self.conv(x) 76 | if self.inner_module: 77 | y = self.inner_module(y) 78 | y = self.convtr(y) 79 | y = ME.cat(x, y) 80 | return self.cat_conv(y) 81 | 82 | 83 | class PyramidModuleINBN(PyramidModule): 84 | NORM_TYPE = 'INBN' 85 | 86 | 87 | class PyramidNet(ME.MinkowskiNetwork): 88 | NORM_TYPE = 'BN' 89 | NONLINEARITY = 'ELU' 90 | PYRAMID_MODULE = PyramidModule 91 | CHANNELS = [32, 64, 128, 128] 92 | TR_CHANNELS = [64, 128, 128, 128] 93 | DEPTHS = [1, 1, 1, 1] 94 | # None b1, b2, b3, btr3, btr2 95 | # 1 2 3 -3 -2 -1 96 | REGION_TYPE = ME.RegionType.HYPER_CUBE 97 | 98 | # To use the model, must call initialize_coords before forward pass. 99 | # Once data is processed, call clear to reset the model before calling initialize_coords 100 | def __init__(self, 101 | in_channels=3, 102 | out_channels=32, 103 | bn_momentum=0.1, 104 | conv1_kernel_size=3, 105 | normalize_feature=False, 106 | D=3): 107 | ME.MinkowskiNetwork.__init__(self, D) 108 | self.conv1_kernel_size = conv1_kernel_size 109 | self.normalize_feature = normalize_feature 110 | 111 | self.initialize_network(in_channels, out_channels, bn_momentum, D) 112 | 113 | def initialize_network(self, in_channels, out_channels, bn_momentum, dimension): 114 | NORM_TYPE = self.NORM_TYPE 115 | NONLINEARITY = self.NONLINEARITY 116 | CHANNELS = self.CHANNELS 117 | TR_CHANNELS = self.TR_CHANNELS 118 | DEPTHS = self.DEPTHS 119 | REGION_TYPE = self.REGION_TYPE 120 | 121 | self.conv = conv_norm_non( 122 | in_channels, 123 | CHANNELS[0], 124 | kernel_size=self.conv1_kernel_size, 125 | stride=1, 126 | dimension=dimension, 127 | bn_momentum=bn_momentum, 128 | region_type=REGION_TYPE, 129 | norm_type=NORM_TYPE, 130 | nonlinearity=NONLINEARITY) 131 | 132 | pyramid = None 133 | for d in range(len(DEPTHS) - 1, 0, -1): 134 | pyramid = self.PYRAMID_MODULE( 135 | CHANNELS[d - 1], 136 | TR_CHANNELS[d - 1], 137 | CHANNELS[d], 138 | TR_CHANNELS[d], 139 | pyramid, 140 | DEPTHS[d], 141 | dimension=dimension) 142 | self.pyramid = pyramid 143 | self.final = nn.Sequential( 144 | conv_norm_non( 145 | TR_CHANNELS[0], 146 | TR_CHANNELS[0], 147 | kernel_size=3, 148 | stride=1, 149 | dimension=dimension), 150 | conv(TR_CHANNELS[0], out_channels, 1, 1, dimension=dimension)) 151 | 152 | def forward(self, x): 153 | out = self.conv(x) 154 | out = self.pyramid(out) 155 | out = self.final(out) 156 | 157 | if self.normalize_feature: 158 | return ME.SparseTensor( 159 | out.F / (torch.norm(out.F, p=2, dim=1, keepdim=True) + 1e-8), 160 | coords_key=out.coords_key, 161 | coords_manager=out.coords_man) 162 | else: 163 | return out 164 | 165 | 166 | class PyramidNet6(PyramidNet): 167 | CHANNELS = [32, 64, 128, 192, 256, 256] 168 | TR_CHANNELS = [64, 128, 192, 192, 256, 256] 169 | DEPTHS = [1, 1, 1, 1, 1, 1] 170 | 171 | 172 | class PyramidNet6NoBlock(PyramidNet6): 173 | DEPTHS = [0, 0, 0, 0, 0, 0] 174 | 175 | 176 | class PyramidNet6INBN(PyramidNet6): 177 | NORM_TYPE = 'INBN' 178 | PYRAMID_MODULE = PyramidModuleINBN 179 | 180 | 181 | class PyramidNet6INBNNoBlock(PyramidNet6INBN): 182 | NORM_TYPE = 'INBN' 183 | 184 | 185 | class PyramidNet8(PyramidNet): 186 | CHANNELS = [32, 64, 128, 128, 192, 192, 256, 256] 187 | TR_CHANNELS = [64, 128, 128, 192, 192, 192, 256, 256] 188 | DEPTHS = [1, 1, 1, 1, 1, 1, 1, 1] 189 | 190 | 191 | class PyramidNet8INBN(PyramidNet8): 192 | NORM_TYPE = 'INBN' 193 | PYRAMID_MODULE = PyramidModuleINBN 194 | -------------------------------------------------------------------------------- /model/residual_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch.nn as nn 8 | 9 | from model.common import get_norm, get_nonlinearity 10 | 11 | import MinkowskiEngine as ME 12 | import MinkowskiEngine.MinkowskiFunctional as MEF 13 | 14 | 15 | def conv(in_channels, 16 | out_channels, 17 | kernel_size=3, 18 | stride=1, 19 | dilation=1, 20 | has_bias=False, 21 | region_type=0, 22 | dimension=3): 23 | if not isinstance(region_type, ME.RegionType): 24 | if region_type == 0: 25 | region_type = ME.RegionType.HYPER_CUBE 26 | elif region_type == 1: 27 | region_type = ME.RegionType.HYPER_CROSS 28 | else: 29 | raise ValueError('Unsupported region type') 30 | 31 | kernel_generator = ME.KernelGenerator( 32 | kernel_size=kernel_size, 33 | stride=stride, 34 | dilation=dilation, 35 | region_type=region_type, 36 | dimension=dimension) 37 | 38 | return ME.MinkowskiConvolution( 39 | in_channels, 40 | out_channels, 41 | kernel_size=kernel_size, 42 | stride=stride, 43 | kernel_generator=kernel_generator, 44 | dimension=dimension) 45 | 46 | 47 | def conv_tr(in_channels, 48 | out_channels, 49 | kernel_size, 50 | stride=1, 51 | dilation=1, 52 | has_bias=False, 53 | region_type=ME.RegionType.HYPER_CUBE, 54 | dimension=-1): 55 | assert dimension > 0, 'Dimension must be a positive integer' 56 | kernel_generator = ME.KernelGenerator( 57 | kernel_size, 58 | stride, 59 | dilation, 60 | is_transpose=True, 61 | region_type=region_type, 62 | dimension=dimension) 63 | 64 | kernel_generator = ME.KernelGenerator( 65 | kernel_size, 66 | stride, 67 | dilation, 68 | is_transpose=True, 69 | region_type=region_type, 70 | dimension=dimension) 71 | 72 | return ME.MinkowskiConvolutionTranspose( 73 | in_channels=in_channels, 74 | out_channels=out_channels, 75 | kernel_size=kernel_size, 76 | stride=stride, 77 | dilation=dilation, 78 | bias=has_bias, 79 | kernel_generator=kernel_generator, 80 | dimension=dimension) 81 | 82 | 83 | class BasicBlockBase(nn.Module): 84 | expansion = 1 85 | NORM_TYPE = 'BN' 86 | 87 | def __init__(self, 88 | inplanes, 89 | planes, 90 | stride=1, 91 | dilation=1, 92 | downsample=None, 93 | bn_momentum=0.1, 94 | region_type=0, 95 | D=3): 96 | super(BasicBlockBase, self).__init__() 97 | 98 | self.conv1 = conv( 99 | inplanes, 100 | planes, 101 | kernel_size=3, 102 | stride=stride, 103 | dilation=dilation, 104 | region_type=region_type, 105 | dimension=D) 106 | self.norm1 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, dimension=D) 107 | self.conv2 = conv( 108 | planes, 109 | planes, 110 | kernel_size=3, 111 | stride=1, 112 | dilation=dilation, 113 | region_type=region_type, 114 | dimension=D) 115 | self.norm2 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, dimension=D) 116 | self.downsample = downsample 117 | 118 | def forward(self, x): 119 | residual = x 120 | 121 | out = self.conv1(x) 122 | out = self.norm1(out) 123 | out = MEF.relu(out) 124 | 125 | out = self.conv2(out) 126 | out = self.norm2(out) 127 | 128 | if self.downsample is not None: 129 | residual = self.downsample(x) 130 | 131 | out += residual 132 | out = MEF.relu(out) 133 | 134 | return out 135 | 136 | 137 | class BasicBlockBN(BasicBlockBase): 138 | NORM_TYPE = 'BN' 139 | 140 | 141 | class BasicBlockIN(BasicBlockBase): 142 | NORM_TYPE = 'IN' 143 | 144 | 145 | class BasicBlockINBN(BasicBlockBase): 146 | NORM_TYPE = 'INBN' 147 | 148 | 149 | def get_block(norm_type, 150 | inplanes, 151 | planes, 152 | stride=1, 153 | dilation=1, 154 | downsample=None, 155 | bn_momentum=0.1, 156 | region_type=0, 157 | dimension=3): 158 | if norm_type == 'BN': 159 | Block = BasicBlockBN 160 | elif norm_type == 'IN': 161 | Block = BasicBlockIN 162 | elif norm_type == 'INBN': 163 | Block = BasicBlockINBN 164 | else: 165 | raise ValueError(f'Type {norm_type}, not defined') 166 | 167 | return Block(inplanes, planes, stride, dilation, downsample, bn_momentum, region_type, 168 | dimension) 169 | 170 | 171 | def conv_norm_non(inc, 172 | outc, 173 | kernel_size, 174 | stride, 175 | dimension, 176 | bn_momentum=0.05, 177 | region_type=ME.RegionType.HYPER_CUBE, 178 | norm_type='BN', 179 | nonlinearity='ELU'): 180 | return nn.Sequential( 181 | conv( 182 | in_channels=inc, 183 | out_channels=outc, 184 | kernel_size=kernel_size, 185 | stride=stride, 186 | dilation=1, 187 | has_bias=False, 188 | region_type=region_type, 189 | dimension=dimension), 190 | get_norm(norm_type, outc, bn_momentum=bn_momentum, dimension=dimension), 191 | get_nonlinearity(nonlinearity)) 192 | -------------------------------------------------------------------------------- /model/simpleunet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import torch 8 | import MinkowskiEngine as ME 9 | import MinkowskiEngine.MinkowskiFunctional as MEF 10 | from model.common import get_norm 11 | 12 | 13 | class SimpleNet(ME.MinkowskiNetwork): 14 | NORM_TYPE = None 15 | CHANNELS = [None, 32, 64, 128] 16 | TR_CHANNELS = [None, 32, 32, 64] 17 | 18 | # To use the model, must call initialize_coords before forward pass. 19 | # Once data is processed, call clear to reset the model before calling initialize_coords 20 | def __init__(self, 21 | in_channels=3, 22 | out_channels=32, 23 | bn_momentum=0.1, 24 | conv1_kernel_size=3, 25 | normalize_feature=False, 26 | D=3): 27 | super(SimpleNet, self).__init__(D) 28 | NORM_TYPE = self.NORM_TYPE 29 | CHANNELS = self.CHANNELS 30 | TR_CHANNELS = self.TR_CHANNELS 31 | self.normalize_feature = normalize_feature 32 | self.conv1 = ME.MinkowskiConvolution( 33 | in_channels=in_channels, 34 | out_channels=CHANNELS[1], 35 | kernel_size=conv1_kernel_size, 36 | stride=1, 37 | dilation=1, 38 | has_bias=False, 39 | dimension=D) 40 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 41 | 42 | self.conv2 = ME.MinkowskiConvolution( 43 | in_channels=CHANNELS[1], 44 | out_channels=CHANNELS[2], 45 | kernel_size=3, 46 | stride=2, 47 | dilation=1, 48 | has_bias=False, 49 | dimension=D) 50 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 51 | 52 | self.conv3 = ME.MinkowskiConvolution( 53 | in_channels=CHANNELS[2], 54 | out_channels=CHANNELS[3], 55 | kernel_size=3, 56 | stride=2, 57 | dilation=1, 58 | has_bias=False, 59 | dimension=D) 60 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 61 | 62 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 63 | in_channels=CHANNELS[3], 64 | out_channels=TR_CHANNELS[3], 65 | kernel_size=3, 66 | stride=2, 67 | dilation=1, 68 | has_bias=False, 69 | dimension=D) 70 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 71 | 72 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 73 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 74 | out_channels=TR_CHANNELS[2], 75 | kernel_size=3, 76 | stride=2, 77 | dilation=1, 78 | has_bias=False, 79 | dimension=D) 80 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 81 | 82 | self.conv1_tr = ME.MinkowskiConvolution( 83 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 84 | out_channels=TR_CHANNELS[1], 85 | kernel_size=3, 86 | stride=1, 87 | dilation=1, 88 | has_bias=False, 89 | dimension=D) 90 | self.norm1_tr = get_norm(NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 91 | 92 | self.final = ME.MinkowskiConvolution( 93 | in_channels=TR_CHANNELS[1], 94 | out_channels=out_channels, 95 | kernel_size=1, 96 | stride=1, 97 | dilation=1, 98 | has_bias=True, 99 | dimension=D) 100 | 101 | def forward(self, x): 102 | out_s1 = self.conv1(x) 103 | out_s1 = self.norm1(out_s1) 104 | out = MEF.relu(out_s1) 105 | 106 | out_s2 = self.conv2(out) 107 | out_s2 = self.norm2(out_s2) 108 | out = MEF.relu(out_s2) 109 | 110 | out_s4 = self.conv3(out) 111 | out_s4 = self.norm3(out_s4) 112 | out = MEF.relu(out_s4) 113 | 114 | out = self.conv3_tr(out) 115 | out = self.norm3_tr(out) 116 | out_s2_tr = MEF.relu(out) 117 | 118 | out = ME.cat(out_s2_tr, out_s2) 119 | 120 | out = self.conv2_tr(out) 121 | out = self.norm2_tr(out) 122 | out_s1_tr = MEF.relu(out) 123 | 124 | out = ME.cat(out_s1_tr, out_s1) 125 | out = self.conv1_tr(out) 126 | out = self.norm1_tr(out) 127 | out = MEF.relu(out) 128 | 129 | out = self.final(out) 130 | 131 | if self.normalize_feature: 132 | return ME.SparseTensor( 133 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 134 | coords_key=out.coords_key, 135 | coords_manager=out.coords_man) 136 | else: 137 | return out 138 | 139 | 140 | class SimpleNetIN(SimpleNet): 141 | NORM_TYPE = 'IN' 142 | 143 | 144 | class SimpleNetBN(SimpleNet): 145 | NORM_TYPE = 'BN' 146 | 147 | 148 | class SimpleNetBNE(SimpleNetBN): 149 | CHANNELS = [None, 16, 32, 32] 150 | TR_CHANNELS = [None, 16, 16, 32] 151 | 152 | 153 | class SimpleNetINE(SimpleNetBNE): 154 | NORM_TYPE = 'IN' 155 | 156 | 157 | class SimpleNet2(ME.MinkowskiNetwork): 158 | NORM_TYPE = None 159 | CHANNELS = [None, 32, 64, 128, 256] 160 | TR_CHANNELS = [None, 32, 32, 64, 64] 161 | 162 | # To use the model, must call initialize_coords before forward pass. 163 | # Once data is processed, call clear to reset the model before calling initialize_coords 164 | def __init__(self, 165 | in_channels=3, 166 | out_channels=32, 167 | bn_momentum=0.1, 168 | conv1_kernel_size=3, 169 | normalize_feature=False, 170 | D=3): 171 | ME.MinkowskiNetwork.__init__(self, D) 172 | NORM_TYPE = self.NORM_TYPE 173 | CHANNELS = self.CHANNELS 174 | TR_CHANNELS = self.TR_CHANNELS 175 | self.normalize_feature = normalize_feature 176 | self.conv1 = ME.MinkowskiConvolution( 177 | in_channels=in_channels, 178 | out_channels=CHANNELS[1], 179 | kernel_size=conv1_kernel_size, 180 | stride=1, 181 | dilation=1, 182 | has_bias=False, 183 | dimension=D) 184 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 185 | 186 | self.conv2 = ME.MinkowskiConvolution( 187 | in_channels=CHANNELS[1], 188 | out_channels=CHANNELS[2], 189 | kernel_size=3, 190 | stride=2, 191 | dilation=1, 192 | has_bias=False, 193 | dimension=D) 194 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 195 | 196 | self.conv3 = ME.MinkowskiConvolution( 197 | in_channels=CHANNELS[2], 198 | out_channels=CHANNELS[3], 199 | kernel_size=3, 200 | stride=2, 201 | dilation=1, 202 | has_bias=False, 203 | dimension=D) 204 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 205 | 206 | self.conv4 = ME.MinkowskiConvolution( 207 | in_channels=CHANNELS[3], 208 | out_channels=CHANNELS[4], 209 | kernel_size=3, 210 | stride=2, 211 | dilation=1, 212 | has_bias=False, 213 | dimension=D) 214 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 215 | 216 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 217 | in_channels=CHANNELS[4], 218 | out_channels=TR_CHANNELS[4], 219 | kernel_size=3, 220 | stride=2, 221 | dilation=1, 222 | has_bias=False, 223 | dimension=D) 224 | self.norm4_tr = get_norm( 225 | NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 226 | 227 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 228 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 229 | out_channels=TR_CHANNELS[3], 230 | kernel_size=3, 231 | stride=2, 232 | dilation=1, 233 | has_bias=False, 234 | dimension=D) 235 | self.norm3_tr = get_norm( 236 | NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 237 | 238 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 239 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 240 | out_channels=TR_CHANNELS[2], 241 | kernel_size=3, 242 | stride=2, 243 | dilation=1, 244 | has_bias=False, 245 | dimension=D) 246 | self.norm2_tr = get_norm( 247 | NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 248 | 249 | self.conv1_tr = ME.MinkowskiConvolution( 250 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 251 | out_channels=TR_CHANNELS[1], 252 | kernel_size=3, 253 | stride=1, 254 | dilation=1, 255 | has_bias=False, 256 | dimension=D) 257 | self.norm1_tr = get_norm( 258 | NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 259 | 260 | self.final = ME.MinkowskiConvolution( 261 | in_channels=TR_CHANNELS[1], 262 | out_channels=out_channels, 263 | kernel_size=1, 264 | stride=1, 265 | dilation=1, 266 | has_bias=True, 267 | dimension=D) 268 | 269 | def forward(self, x): 270 | out_s1 = self.conv1(x) 271 | out_s1 = self.norm1(out_s1) 272 | out = MEF.relu(out_s1) 273 | 274 | out_s2 = self.conv2(out) 275 | out_s2 = self.norm2(out_s2) 276 | out = MEF.relu(out_s2) 277 | 278 | out_s4 = self.conv3(out) 279 | out_s4 = self.norm3(out_s4) 280 | out = MEF.relu(out_s4) 281 | 282 | out_s8 = self.conv4(out) 283 | out_s8 = self.norm4(out_s8) 284 | out = MEF.relu(out_s8) 285 | 286 | out = self.conv4_tr(out) 287 | out = self.norm4_tr(out) 288 | out_s4_tr = MEF.relu(out) 289 | 290 | out = ME.cat(out_s4_tr, out_s4) 291 | 292 | out = self.conv3_tr(out) 293 | out = self.norm3_tr(out) 294 | out_s2_tr = MEF.relu(out) 295 | 296 | out = ME.cat(out_s2_tr, out_s2) 297 | 298 | out = self.conv2_tr(out) 299 | out = self.norm2_tr(out) 300 | out_s1_tr = MEF.relu(out) 301 | 302 | out = ME.cat(out_s1_tr, out_s1) 303 | out = self.conv1_tr(out) 304 | out = self.norm1_tr(out) 305 | out = MEF.relu(out) 306 | 307 | out = self.final(out) 308 | 309 | if self.normalize_feature: 310 | return ME.SparseTensor( 311 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 312 | coords_key=out.coords_key, 313 | coords_manager=out.coords_man) 314 | else: 315 | return out 316 | 317 | 318 | class SimpleNetIN2(SimpleNet2): 319 | NORM_TYPE = 'IN' 320 | 321 | 322 | class SimpleNetBN2(SimpleNet2): 323 | NORM_TYPE = 'BN' 324 | 325 | 326 | class SimpleNetBN2B(SimpleNet2): 327 | NORM_TYPE = 'BN' 328 | CHANNELS = [None, 32, 64, 128, 256] 329 | TR_CHANNELS = [None, 64, 64, 64, 64] 330 | 331 | 332 | class SimpleNetBN2C(SimpleNet2): 333 | NORM_TYPE = 'BN' 334 | CHANNELS = [None, 32, 64, 128, 256] 335 | TR_CHANNELS = [None, 32, 64, 64, 128] 336 | 337 | 338 | class SimpleNetBN2D(SimpleNet2): 339 | NORM_TYPE = 'BN' 340 | CHANNELS = [None, 32, 64, 128, 256] 341 | TR_CHANNELS = [None, 32, 64, 64, 128] 342 | 343 | 344 | class SimpleNetBN2E(SimpleNet2): 345 | NORM_TYPE = 'BN' 346 | CHANNELS = [None, 16, 32, 64, 128] 347 | TR_CHANNELS = [None, 16, 32, 32, 64] 348 | 349 | 350 | class SimpleNetIN2E(SimpleNetBN2E): 351 | NORM_TYPE = 'IN' 352 | 353 | 354 | class SimpleNet3(ME.MinkowskiNetwork): 355 | NORM_TYPE = None 356 | CHANNELS = [None, 32, 64, 128, 256, 512] 357 | TR_CHANNELS = [None, 32, 32, 64, 64, 128] 358 | 359 | # To use the model, must call initialize_coords before forward pass. 360 | # Once data is processed, call clear to reset the model before calling initialize_coords 361 | def __init__(self, 362 | in_channels=3, 363 | out_channels=32, 364 | bn_momentum=0.1, 365 | conv1_kernel_size=3, 366 | normalize_feature=False, 367 | D=3): 368 | ME.MinkowskiNetwork.__init__(self, D) 369 | NORM_TYPE = self.NORM_TYPE 370 | CHANNELS = self.CHANNELS 371 | TR_CHANNELS = self.TR_CHANNELS 372 | self.normalize_feature = normalize_feature 373 | self.conv1 = ME.MinkowskiConvolution( 374 | in_channels=in_channels, 375 | out_channels=CHANNELS[1], 376 | kernel_size=conv1_kernel_size, 377 | stride=1, 378 | dilation=1, 379 | has_bias=False, 380 | dimension=D) 381 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, dimension=D) 382 | 383 | self.conv2 = ME.MinkowskiConvolution( 384 | in_channels=CHANNELS[1], 385 | out_channels=CHANNELS[2], 386 | kernel_size=3, 387 | stride=2, 388 | dilation=1, 389 | has_bias=False, 390 | dimension=D) 391 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 392 | 393 | self.conv3 = ME.MinkowskiConvolution( 394 | in_channels=CHANNELS[2], 395 | out_channels=CHANNELS[3], 396 | kernel_size=3, 397 | stride=2, 398 | dilation=1, 399 | has_bias=False, 400 | dimension=D) 401 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 402 | 403 | self.conv4 = ME.MinkowskiConvolution( 404 | in_channels=CHANNELS[3], 405 | out_channels=CHANNELS[4], 406 | kernel_size=3, 407 | stride=2, 408 | dilation=1, 409 | has_bias=False, 410 | dimension=D) 411 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 412 | 413 | self.conv5 = ME.MinkowskiConvolution( 414 | in_channels=CHANNELS[4], 415 | out_channels=CHANNELS[5], 416 | kernel_size=3, 417 | stride=2, 418 | dilation=1, 419 | has_bias=False, 420 | dimension=D) 421 | self.norm5 = get_norm(NORM_TYPE, CHANNELS[5], bn_momentum=bn_momentum, dimension=D) 422 | 423 | self.conv5_tr = ME.MinkowskiConvolutionTranspose( 424 | in_channels=CHANNELS[5], 425 | out_channels=TR_CHANNELS[5], 426 | kernel_size=3, 427 | stride=2, 428 | dilation=1, 429 | has_bias=False, 430 | dimension=D) 431 | self.norm5_tr = get_norm( 432 | NORM_TYPE, TR_CHANNELS[5], bn_momentum=bn_momentum, dimension=D) 433 | 434 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 435 | in_channels=CHANNELS[4] + TR_CHANNELS[5], 436 | out_channels=TR_CHANNELS[4], 437 | kernel_size=3, 438 | stride=2, 439 | dilation=1, 440 | has_bias=False, 441 | dimension=D) 442 | self.norm4_tr = get_norm( 443 | NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, dimension=D) 444 | 445 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 446 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 447 | out_channels=TR_CHANNELS[3], 448 | kernel_size=3, 449 | stride=2, 450 | dilation=1, 451 | has_bias=False, 452 | dimension=D) 453 | self.norm3_tr = get_norm( 454 | NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, dimension=D) 455 | 456 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 457 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 458 | out_channels=TR_CHANNELS[2], 459 | kernel_size=3, 460 | stride=2, 461 | dilation=1, 462 | has_bias=False, 463 | dimension=D) 464 | self.norm2_tr = get_norm( 465 | NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, dimension=D) 466 | 467 | self.conv1_tr = ME.MinkowskiConvolution( 468 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 469 | out_channels=TR_CHANNELS[1], 470 | kernel_size=1, 471 | stride=1, 472 | dilation=1, 473 | has_bias=True, 474 | dimension=D) 475 | 476 | def forward(self, x): 477 | out_s1 = self.conv1(x) 478 | out_s1 = self.norm1(out_s1) 479 | out = MEF.relu(out_s1) 480 | 481 | out_s2 = self.conv2(out) 482 | out_s2 = self.norm2(out_s2) 483 | out = MEF.relu(out_s2) 484 | 485 | out_s4 = self.conv3(out) 486 | out_s4 = self.norm3(out_s4) 487 | out = MEF.relu(out_s4) 488 | 489 | out_s8 = self.conv4(out) 490 | out_s8 = self.norm4(out_s8) 491 | out = MEF.relu(out_s8) 492 | 493 | out_s16 = self.conv5(out) 494 | out_s16 = self.norm5(out_s16) 495 | out = MEF.relu(out_s16) 496 | 497 | out = self.conv5_tr(out) 498 | out = self.norm5_tr(out) 499 | out_s8_tr = MEF.relu(out) 500 | 501 | out = ME.cat(out_s8_tr, out_s8) 502 | 503 | out = self.conv4_tr(out) 504 | out = self.norm4_tr(out) 505 | out_s4_tr = MEF.relu(out) 506 | 507 | out = ME.cat(out_s4_tr, out_s4) 508 | 509 | out = self.conv3_tr(out) 510 | out = self.norm3_tr(out) 511 | out_s2_tr = MEF.relu(out) 512 | 513 | out = ME.cat(out_s2_tr, out_s2) 514 | 515 | out = self.conv2_tr(out) 516 | out = self.norm2_tr(out) 517 | out_s1_tr = MEF.relu(out) 518 | 519 | out = ME.cat(out_s1_tr, out_s1) 520 | out = self.conv1_tr(out) 521 | 522 | if self.normalize_feature: 523 | return ME.SparseTensor( 524 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 525 | coords_key=out.coords_key, 526 | coords_manager=out.coords_man) 527 | else: 528 | return out 529 | 530 | 531 | class SimpleNetIN3(SimpleNet3): 532 | NORM_TYPE = 'IN' 533 | 534 | 535 | class SimpleNetBN3(SimpleNet3): 536 | NORM_TYPE = 'BN' 537 | 538 | 539 | class SimpleNetBN3B(SimpleNet3): 540 | NORM_TYPE = 'BN' 541 | CHANNELS = [None, 32, 64, 128, 256, 512] 542 | TR_CHANNELS = [None, 32, 64, 64, 64, 128] 543 | 544 | 545 | class SimpleNetBN3C(SimpleNet3): 546 | NORM_TYPE = 'BN' 547 | CHANNELS = [None, 32, 64, 128, 256, 512] 548 | TR_CHANNELS = [None, 32, 32, 64, 128, 128] 549 | 550 | 551 | class SimpleNetBN3D(SimpleNet3): 552 | NORM_TYPE = 'BN' 553 | CHANNELS = [None, 32, 64, 128, 256, 512] 554 | TR_CHANNELS = [None, 32, 64, 64, 128, 128] 555 | 556 | 557 | class SimpleNetBN3E(SimpleNet3): 558 | NORM_TYPE = 'BN' 559 | CHANNELS = [None, 16, 32, 64, 128, 256] 560 | TR_CHANNELS = [None, 16, 32, 32, 64, 128] 561 | 562 | 563 | class SimpleNetIN3E(SimpleNetBN3E): 564 | NORM_TYPE = 'IN' 565 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ansi2html==1.8.0 2 | attrs==23.1.0 3 | certifi==2023.7.22 4 | charset-normalizer==3.1.0 5 | click==8.1.3 6 | cmake==3.26.4 7 | ConfigArgParse==1.5.3 8 | dash==2.11.0 9 | dash-core-components==2.0.0 10 | dash-html-components==2.0.0 11 | dash-table==5.0.0 12 | easydict==1.10 13 | fastjsonschema==2.17.1 14 | filelock==3.12.2 15 | Flask==2.2.5 16 | idna==3.4 17 | intel-openmp==2023.1.0 18 | itsdangerous==2.1.2 19 | Jinja2==3.1.2 20 | jsonschema==4.17.3 21 | jupyter_core==5.3.1 22 | lit==16.0.6 23 | MarkupSafe==2.1.3 24 | MinkowskiEngine==0.5.4 25 | mkl==2023.1.0 26 | mpmath==1.3.0 27 | nbformat==5.7.0 28 | nest-asyncio==1.5.6 29 | networkx==3.1 30 | numpy==1.25.0 31 | nvidia-cublas-cu11==11.10.3.66 32 | nvidia-cuda-cupti-cu11==11.7.101 33 | nvidia-cuda-nvrtc-cu11==11.7.99 34 | nvidia-cuda-runtime-cu11==11.7.99 35 | nvidia-cudnn-cu11==8.5.0.96 36 | nvidia-cufft-cu11==10.9.0.58 37 | nvidia-curand-cu11==10.2.10.91 38 | nvidia-cusolver-cu11==11.4.0.1 39 | nvidia-cusparse-cu11==11.7.4.91 40 | nvidia-nccl-cu11==2.14.3 41 | nvidia-nvtx-cu11==11.7.91 42 | open3d-cpu @ file:///home/ibrahim/tmp/Open3D/build/lib/python_package/pip_package/open3d_cpu-0.17.0%2B9238339b9-cp310-cp310-manylinux_2_37_x86_64.whl 43 | packaging==23.1 44 | Pillow==9.5.0 45 | platformdirs==3.8.0 46 | plotly==5.15.0 47 | pyrsistent==0.19.3 48 | requests==2.31.0 49 | retrying==1.3.4 50 | scipy==1.11.0 51 | six==1.16.0 52 | sympy==1.12 53 | tbb==2021.9.0 54 | tenacity==8.2.2 55 | torch==2.0.1 56 | torchaudio==2.0.2 57 | torchvision==0.15.2 58 | traitlets==5.9.0 59 | triton==2.0.0 60 | typing_extensions==4.6.3 61 | urllib3==2.0.3 62 | Werkzeug==2.2.3 63 | -------------------------------------------------------------------------------- /scripts/analyze_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | from matplotlib.patches import Rectangle 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import argparse 11 | 12 | PROPERTY_IDX_MAP = { 13 | 'Recall': 0, 14 | 'TE (m)': 1, 15 | 'RE (deg)': 2, 16 | 'log Time (s)': 3, 17 | 'Scene ID': 4 18 | } 19 | 20 | 21 | def analyze_by_pair(stats, rte_thresh, rre_thresh): 22 | ''' 23 | \input stats: (num_methods, num_pairs, num_pairwise_stats=5) 24 | \return valid mean_stats: (num_methods, 4) 25 | 4 properties: recall, rte, rre, time 26 | ''' 27 | num_methods, num_pairs, num_pairwise_stats = stats.shape 28 | pairwise_stats = np.zeros((num_methods, 4)) 29 | 30 | for m in range(num_methods): 31 | # Filter valid registrations by rte / rre thresholds 32 | mask_rte = stats[m, :, 1] < rte_thresh 33 | mask_rre = stats[m, :, 2] < rre_thresh 34 | mask_valid = mask_rte * mask_rre 35 | 36 | # Recall, RTE, RRE, Time 37 | pairwise_stats[m, 0] = mask_valid.mean() 38 | pairwise_stats[m, 1] = stats[m, mask_valid, 1].mean() 39 | pairwise_stats[m, 2] = stats[m, mask_valid, 2].mean() 40 | pairwise_stats[m, 3] = stats[m, mask_valid, 3].mean() 41 | 42 | return pairwise_stats 43 | 44 | 45 | def analyze_by_scene(stats, scene_id_list, rte_thresh=0.3, rre_thresh=15): 46 | ''' 47 | \input stats: (num_methods, num_pairs, num_pairwise_stats=5) 48 | \return scene_wise mean stats: (num_methods, num_scenes, 4) 49 | 4 properties: recall, rte, rre, time 50 | ''' 51 | num_methods, num_pairs, num_pairwise_stats = stats.shape 52 | num_scenes = len(scene_id_list) 53 | 54 | scene_wise_stats = np.zeros((num_methods, len(scene_id_list), 4)) 55 | 56 | for m in range(num_methods): 57 | # Filter valid registrations by rte / rre thresholds 58 | mask_rte = stats[m, :, 1] < rte_thresh 59 | mask_rre = stats[m, :, 2] < rre_thresh 60 | mask_valid = mask_rte * mask_rre 61 | 62 | for s in scene_id_list: 63 | mask_scene = stats[m, :, 4] == s 64 | 65 | # Valid registrations in the scene 66 | mask = mask_scene * mask_valid 67 | 68 | # Recall, RTE, RRE, Time 69 | scene_wise_stats[m, s, 0] = 0 if np.sum(mask_scene) == 0 else float( 70 | np.sum(mask)) / float(np.sum(mask_scene)) 71 | scene_wise_stats[m, s, 1] = stats[m, mask, 1].mean() 72 | scene_wise_stats[m, s, 2] = stats[m, mask, 2].mean() 73 | scene_wise_stats[m, s, 3] = stats[m, mask, 3].mean() 74 | 75 | return scene_wise_stats 76 | 77 | 78 | def plot_precision_recall_curves(stats, method_names, rte_precisions, rre_precisions, 79 | output_postfix, cmap): 80 | ''' 81 | \input stats: (num_methods, num_pairs, 5) 82 | \input method_names: (num_methods) string, shown as xticks 83 | ''' 84 | num_methods, num_pairs, _ = stats.shape 85 | rre_precision_curves = np.zeros((num_methods, len(rre_precisions))) 86 | rte_precision_curves = np.zeros((num_methods, len(rte_precisions))) 87 | 88 | for i, rre_thresh in enumerate(rre_precisions): 89 | pairwise_stats = analyze_by_pair(stats, rte_thresh=np.inf, rre_thresh=rre_thresh) 90 | rre_precision_curves[:, i] = pairwise_stats[:, 0] 91 | 92 | for i, rte_thresh in enumerate(rte_precisions): 93 | pairwise_stats = analyze_by_pair(stats, rte_thresh=rte_thresh, rre_thresh=np.inf) 94 | rte_precision_curves[:, i] = pairwise_stats[:, 0] 95 | 96 | fig = plt.figure(figsize=(10, 3)) 97 | ax1 = fig.add_subplot(1, 2, 1, aspect=3.0 / np.max(rte_precisions)) 98 | ax2 = fig.add_subplot(1, 2, 2, aspect=3.0 / np.max(rre_precisions)) 99 | 100 | for m, name in enumerate(method_names): 101 | alpha = rre_precision_curves[m].mean() 102 | alpha = 1.0 if alpha > 0 else 0.0 103 | ax1.plot(rre_precisions, rre_precision_curves[m], color=cmap[m], alpha=alpha) 104 | ax2.plot(rte_precisions, rte_precision_curves[m], color=cmap[m], alpha=alpha) 105 | 106 | ax1.set_ylabel('Recall') 107 | ax1.set_xlabel('Rotation (deg)') 108 | ax1.set_ylim((0.0, 1.0)) 109 | 110 | ax2.set_xlabel('Translation (m)') 111 | ax2.set_ylim((0.0, 1.0)) 112 | ax2.legend(method_names, loc='center left', bbox_to_anchor=(1, 0.5)) 113 | ax1.grid() 114 | ax2.grid() 115 | 116 | plt.tight_layout() 117 | plt.savefig('{}_{}.png'.format('precision_recall', output_postfix)) 118 | 119 | plt.close(fig) 120 | 121 | 122 | def plot_scene_wise_stats(scene_wise_stats, method_names, scene_names, property_name, 123 | ylim, output_postfix, cmap): 124 | ''' 125 | \input scene_wise_stats: (num_methods, num_scenes, 4) 126 | \input method_names: (num_methods) string, shown as xticks 127 | \input scene_names: (num_scenes) string, shown as legends 128 | \input property_name: string, shown as ylabel 129 | ''' 130 | num_methods, num_scenes, _ = scene_wise_stats.shape 131 | assert len(method_names) == num_methods 132 | assert len(scene_names) == num_scenes 133 | 134 | # Initialize figure 135 | fig = plt.figure(figsize=(14, 3)) 136 | ax = fig.add_subplot(1, 1, 1) 137 | 138 | # Add some paddings 139 | w = 1.0 / (num_methods + 2) 140 | 141 | # Rightmost bar 142 | x = np.arange(0, num_scenes) - 0.5 * w * num_methods 143 | 144 | for m in range(num_methods): 145 | m_stats = scene_wise_stats[m, :, PROPERTY_IDX_MAP[property_name]] 146 | valid = not (np.logical_and.reduce(np.isnan(m_stats)) 147 | or np.logical_and.reduce(m_stats == 0)) 148 | alpha = 1.0 if valid else 0.0 149 | ax.bar(x + m * w, m_stats, w, color=cmap[m], alpha=alpha) 150 | 151 | plt.ylim(ylim) 152 | plt.xlim((0 - w * num_methods, num_scenes)) 153 | plt.ylabel(property_name) 154 | plt.xticks(np.arange(0, num_scenes), tuple(scene_names)) 155 | ax.legend(method_names, loc='center left', bbox_to_anchor=(1, 0.5)) 156 | 157 | plt.tight_layout() 158 | plt.grid() 159 | plt.savefig('{}_{}.png'.format(property_name, output_postfix)) 160 | plt.close(fig) 161 | 162 | 163 | def plot_pareto_frontier(pairwise_stats, method_names, cmap): 164 | recalls = pairwise_stats[:, 0] 165 | times = 1.0 / pairwise_stats[:, 3] 166 | 167 | ind = np.argsort(times) 168 | 169 | offset = 0.05 170 | plt.rcParams.update({'font.size': 30}) 171 | 172 | fig = plt.figure(figsize=(20, 12)) 173 | ax = fig.add_subplot(111) 174 | ax.set_xlabel('Number of registrations per second (log scale)') 175 | ax.set_xscale('log') 176 | xmin = np.power(10, -2.2) 177 | xmax = np.power(10, 1.5) 178 | ax.set_xlim(xmin, xmax) 179 | 180 | ax.set_ylabel('Registration recall') 181 | ax.set_ylim(-offset, 1) 182 | ax.set_yticks(np.arange(0, 1, step=0.2)) 183 | 184 | plots = [None for m in ind] 185 | max_gain = -1 186 | for m in ind[::-1]: 187 | # 8, 9: our methods 188 | if (recalls[m] > max_gain): 189 | max_gain = recalls[m] 190 | ax.add_patch( 191 | Rectangle((0, -offset), 192 | times[m], 193 | recalls[m] + offset, 194 | facecolor=(0.94, 0.94, 0.94))) 195 | 196 | plot, = ax.plot(times[m], recalls[m], 'o', c=colors[m], markersize=30) 197 | plots[m] = plot 198 | 199 | ax.legend(plots, method_names, loc='center left', bbox_to_anchor=(1, 0.5)) 200 | plt.tight_layout() 201 | plt.savefig('frontier.png') 202 | 203 | 204 | if __name__ == '__main__': 205 | ''' 206 | Input .npz file to analyze: 207 | \prop npz['stats']: (num_methods, num_pairs, num_pairwise_stats=5) 208 | 5 pairwise stats properties consist of 209 | - \bool success: decided by evaluation thresholds, will be ignored in this script 210 | - \float rte: relative translation error (in cm) 211 | - \float rre: relative rotation error (in deg) 212 | - \float time: registration time for the pair (in ms) 213 | - \int scene_id: specific for 3DMatch test sets (8 scenes in total) 214 | 215 | \prop npz['names']: (num_methods) 216 | Corresponding method name stored in string 217 | ''' 218 | 219 | # Setup fonts 220 | from matplotlib import rc 221 | rc('font', **{'family': 'sans-serif', 'sans-serif': ['Helvetica']}) 222 | rc('text', usetex=False) 223 | 224 | # Parse arguments 225 | parser = argparse.ArgumentParser() 226 | parser.add_argument('npz', help='path to the npz file') 227 | parser.add_argument('--output_postfix', default='', help='postfix of the output') 228 | parser.add_argument('--end_method_index', 229 | default=1000, 230 | type=int, 231 | help='reserved only for making slides') 232 | args = parser.parse_args() 233 | 234 | # Load npz file with aformentioned format 235 | npz = np.load(args.npz) 236 | stats = npz['stats'] 237 | 238 | # Reserved only for making slides, will be skipped by default 239 | stats[args.end_method_index:, :, 1] = np.inf 240 | stats[args.end_method_index:, :, 2] = np.inf 241 | 242 | method_names = npz['names'] 243 | scene_names = [ 244 | 'Kitchen', 'Home1', 'Home2', 'Hotel1', 'Hotel2', 'Hotel3', 'Study', 'Lab' 245 | ] 246 | 247 | cmap = plt.get_cmap('tab20b') 248 | colors = [cmap(i) for i in np.linspace(0, 1, len(method_names))] 249 | colors.reverse() 250 | 251 | # Plot scene-wise bar charts 252 | scene_wise_stats = analyze_by_scene(stats, 253 | range(len(scene_names)), 254 | rte_thresh=0.3, 255 | rre_thresh=15) 256 | 257 | plot_scene_wise_stats(scene_wise_stats, method_names, scene_names, 'Recall', 258 | (0.0, 1.0), args.output_postfix, colors) 259 | plot_scene_wise_stats(scene_wise_stats, method_names, scene_names, 'TE (m)', 260 | (0.0, 0.3), args.output_postfix, colors) 261 | plot_scene_wise_stats(scene_wise_stats, method_names, scene_names, 'RE (deg)', 262 | (0.0, 15.0), args.output_postfix, colors) 263 | 264 | # Plot rte/rre - recall curves 265 | plot_precision_recall_curves(stats, 266 | method_names, 267 | rre_precisions=np.arange(0, 15, 0.05), 268 | rte_precisions=np.arange(0, 0.3, 0.005), 269 | output_postfix=args.output_postfix, 270 | cmap=colors) 271 | 272 | pairwise_stats = analyze_by_pair(stats, rte_thresh=0.3, rre_thresh=15) 273 | plot_pareto_frontier(pairwise_stats, method_names, cmap=colors) 274 | -------------------------------------------------------------------------------- /scripts/download_3dmatch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA_DIR=$1 4 | 5 | function download() { 6 | TMP_PATH="$DATA_DIR/tmp" 7 | echo "#################################" 8 | echo "Data Root Dir: ${DATA_DIR}" 9 | echo "Download Path: ${TMP_PATH}" 10 | echo "#################################" 11 | urls=( 12 | 'http://node2.chrischoy.org/data/datasets/registration/threedmatch.tgz' 13 | ) 14 | 15 | if [ ! -d "$TMP_PATH" ]; then 16 | echo ">> Create temporary directory: ${TMP_PATH}" 17 | mkdir -p "$TMP_PATH" 18 | fi 19 | cd "$TMP_PATH" 20 | 21 | echo ">> Start downloading" 22 | echo ${urls[@]} | xargs -n 1 -P 3 wget --no-check-certificate -q -c --show-progress $0 23 | 24 | echo ">> Unpack .zip file" 25 | for filename in *.tgz 26 | do 27 | tar -xvzf $filename -C ../ 28 | done 29 | 30 | echo ">> Clear tmp directory" 31 | cd .. 32 | rm -rf ./tmp 33 | 34 | echo "#################################" 35 | echo "Done!" 36 | echo "#################################" 37 | } 38 | 39 | function main() { 40 | echo $DATA_DIR 41 | if [ -z "$DATA_DIR" ]; then 42 | echo "DATA_DIR is required config!" 43 | else 44 | download 45 | fi 46 | } 47 | 48 | main; 49 | -------------------------------------------------------------------------------- /scripts/test_3dmatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | # Run with python -m scripts.test_3dmatch_refactor 8 | import os 9 | import sys 10 | import math 11 | import logging 12 | import open3d as o3d 13 | import numpy as np 14 | import time 15 | import torch 16 | import copy 17 | 18 | sys.path.append('.') 19 | import MinkowskiEngine as ME 20 | from config import get_config 21 | from model import load_model 22 | 23 | from dataloader.data_loaders import ThreeDMatchTrajectoryDataset 24 | from core.knn import find_knn_gpu 25 | from core.deep_global_registration import DeepGlobalRegistration 26 | 27 | from util.timer import Timer 28 | from util.pointcloud import make_open3d_point_cloud 29 | 30 | o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Warning) 31 | ch = logging.StreamHandler(sys.stdout) 32 | logging.getLogger().setLevel(logging.INFO) 33 | logging.basicConfig(format='%(asctime)s %(message)s', 34 | datefmt='%m/%d %H:%M:%S', 35 | handlers=[ch]) 36 | 37 | # Criteria 38 | def rte_rre(T_pred, T_gt, rte_thresh, rre_thresh, eps=1e-16): 39 | if T_pred is None: 40 | return np.array([0, np.inf, np.inf]) 41 | 42 | rte = np.linalg.norm(T_pred[:3, 3] - T_gt[:3, 3]) 43 | rre = np.arccos( 44 | np.clip((np.trace(T_pred[:3, :3].T @ T_gt[:3, :3]) - 1) / 2, -1 + eps, 45 | 1 - eps)) * 180 / math.pi 46 | return np.array([rte < rte_thresh and rre < rre_thresh, rte, rre]) 47 | 48 | 49 | def analyze_stats(stats, mask, method_names): 50 | mask = (mask > 0).squeeze(1) 51 | stats = stats[:, mask, :] 52 | 53 | print('Total result mean') 54 | for i, method_name in enumerate(method_names): 55 | print(method_name) 56 | print(stats[i].mean(0)) 57 | 58 | print('Total successful result mean') 59 | for i, method_name in enumerate(method_names): 60 | sel = stats[i][:, 0] > 0 61 | sel_stats = stats[i][sel] 62 | print(method_name) 63 | print(sel_stats.mean(0)) 64 | 65 | 66 | def create_pcd(xyz, color): 67 | # n x 3 68 | n = xyz.shape[0] 69 | pcd = o3d.geometry.PointCloud() 70 | pcd.points = o3d.utility.Vector3dVector(xyz) 71 | pcd.colors = o3d.utility.Vector3dVector(np.tile(color, (n, 1))) 72 | pcd.estimate_normals( 73 | search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)) 74 | return pcd 75 | 76 | 77 | def draw_geometries_flip(pcds): 78 | pcds_transform = [] 79 | flip_transform = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 80 | for pcd in pcds: 81 | pcd_temp = copy.deepcopy(pcd) 82 | pcd_temp.transform(flip_transform) 83 | pcds_transform.append(pcd_temp) 84 | o3d.visualization.draw_geometries(pcds_transform) 85 | 86 | 87 | def evaluate(methods, method_names, data_loader, config, debug=False): 88 | 89 | tot_num_data = len(data_loader.dataset) 90 | data_loader_iter = iter(data_loader) 91 | 92 | # Accumulate success, rre, rte, time, sid 93 | mask = np.zeros((tot_num_data, 1)).astype(int) 94 | stats = np.zeros((len(methods), tot_num_data, 5)) 95 | 96 | dataset = data_loader.dataset 97 | subset_names = open(dataset.DATA_FILES[dataset.phase]).read().split() 98 | 99 | for batch_idx in range(tot_num_data): 100 | batch = data_loader_iter.next() 101 | 102 | # Skip too sparse point clouds 103 | sname, xyz0, xyz1, trans = batch[0] 104 | 105 | sid = subset_names.index(sname) 106 | T_gt = np.linalg.inv(trans) 107 | 108 | for i, method in enumerate(methods): 109 | start = time.time() 110 | T = method.register(xyz0, xyz1) 111 | end = time.time() 112 | 113 | # Visualize 114 | if debug: 115 | print(method_names[i]) 116 | pcd0 = create_pcd(xyz0, np.array([1, 0.706, 0])) 117 | pcd1 = create_pcd(xyz1, np.array([0, 0.651, 0.929])) 118 | 119 | pcd0.transform(T) 120 | draw_geometries_flip([pcd0, pcd1]) 121 | pcd0.transform(np.linalg.inv(T)) 122 | 123 | stats[i, batch_idx, :3] = rte_rre(T, T_gt, config.success_rte_thresh, 124 | config.success_rre_thresh) 125 | stats[i, batch_idx, 3] = end - start 126 | stats[i, batch_idx, 4] = sid 127 | mask[batch_idx] = 1 128 | if stats[i, batch_idx, 0] == 0: 129 | print(f"{method_names[i]}: failed") 130 | 131 | if batch_idx % 10 == 9: 132 | print('Summary {} / {}'.format(batch_idx, tot_num_data)) 133 | analyze_stats(stats, mask, method_names) 134 | 135 | # Save results 136 | filename = f'3dmatch-stats_{method.__class__.__name__}' 137 | if os.path.isdir(config.out_dir): 138 | out_file = os.path.join(config.out_dir, filename) 139 | else: 140 | out_file = filename # save it on the current directory 141 | print(f'Saving the stats to {out_file}') 142 | np.savez(out_file, stats=stats, names=method_names) 143 | analyze_stats(stats, mask, method_names) 144 | 145 | # Analysis per scene 146 | for i, method in enumerate(methods): 147 | print(f'Scene-wise mean {method}') 148 | scene_vals = np.zeros((len(subset_names), 3)) 149 | for sid, sname in enumerate(subset_names): 150 | curr_scene = stats[i, :, 4] == sid 151 | scene_vals[sid] = (stats[i, curr_scene, :3]).mean(0) 152 | 153 | print('All scenes') 154 | print(scene_vals) 155 | print('Scene average') 156 | print(scene_vals.mean(0)) 157 | 158 | 159 | if __name__ == '__main__': 160 | config = get_config() 161 | print(config) 162 | 163 | dgr = DeepGlobalRegistration(config) 164 | 165 | methods = [dgr] 166 | method_names = ['DGR'] 167 | 168 | dset = ThreeDMatchTrajectoryDataset(phase='test', 169 | transform=None, 170 | random_scale=False, 171 | random_rotation=False, 172 | config=config) 173 | 174 | data_loader = torch.utils.data.DataLoader(dset, 175 | batch_size=1, 176 | shuffle=False, 177 | num_workers=1, 178 | collate_fn=lambda x: x, 179 | pin_memory=False, 180 | drop_last=True) 181 | 182 | evaluate(methods, method_names, data_loader, config, debug=False) 183 | -------------------------------------------------------------------------------- /scripts/test_kitti.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import os 8 | import sys 9 | import logging 10 | import argparse 11 | import numpy as np 12 | import open3d as o3d 13 | 14 | import torch 15 | 16 | from config import get_config 17 | 18 | from core.deep_global_registration import DeepGlobalRegistration 19 | 20 | from dataloader.kitti_loader import KITTINMPairDataset 21 | from dataloader.base_loader import CollationFunctionFactory 22 | from util.pointcloud import make_open3d_point_cloud, make_open3d_feature, pointcloud_to_spheres 23 | from util.timer import AverageMeter, Timer 24 | 25 | from scripts.test_3dmatch import rte_rre 26 | 27 | ch = logging.StreamHandler(sys.stdout) 28 | logging.getLogger().setLevel(logging.INFO) 29 | logging.basicConfig(format='%(asctime)s %(message)s', 30 | datefmt='%m/%d %H:%M:%S', 31 | handlers=[ch]) 32 | 33 | TE_THRESH = 0.6 # m 34 | RE_THRESH = 5 # deg 35 | VISUALIZE = False 36 | 37 | 38 | def visualize_pair(xyz0, xyz1, T, voxel_size): 39 | pcd0 = pointcloud_to_spheres(xyz0, 40 | voxel_size, 41 | np.array([0, 0, 1]), 42 | sphere_size=0.6) 43 | pcd1 = pointcloud_to_spheres(xyz1, 44 | voxel_size, 45 | np.array([0, 1, 0]), 46 | sphere_size=0.6) 47 | pcd0.transform(T) 48 | o3d.visualization.draw_geometries([pcd0, pcd1]) 49 | 50 | 51 | def analyze_stats(stats): 52 | print('Total result mean') 53 | print(stats.mean(0)) 54 | 55 | sel_stats = stats[stats[:, 0] > 0] 56 | print(sel_stats.mean(0)) 57 | 58 | 59 | def evaluate(config, data_loader, method): 60 | data_timer = Timer() 61 | 62 | test_iter = data_loader.__iter__() 63 | N = len(test_iter) 64 | 65 | stats = np.zeros((N, 5)) # bool succ, rte, rre, time, drive id 66 | 67 | for i in range(len(data_loader)): 68 | data_timer.tic() 69 | try: 70 | data_dict = test_iter.next() 71 | except ValueError as exc: 72 | pass 73 | data_timer.toc() 74 | 75 | drive = data_dict['extra_packages'][0]['drive'] 76 | xyz0, xyz1 = data_dict['pcd0'][0], data_dict['pcd1'][0] 77 | T_gt = data_dict['T_gt'][0].numpy() 78 | xyz0np, xyz1np = xyz0.numpy(), xyz1.numpy() 79 | 80 | T_pred = method.register(xyz0np, xyz1np) 81 | 82 | stats[i, :3] = rte_rre(T_pred, T_gt, TE_THRESH, RE_THRESH) 83 | stats[i, 3] = method.reg_timer.diff + method.feat_timer.diff 84 | stats[i, 4] = drive 85 | 86 | if stats[i, 0] == 0: 87 | logging.info(f"Failed with RTE: {stats[i, 1]}, RRE: {stats[i, 2]}") 88 | 89 | if i % 10 == 0: 90 | succ_rate, rte, rre, avg_time, _ = stats[:i + 1].mean(0) 91 | logging.info( 92 | f"{i} / {N}: Data time: {data_timer.avg}, Feat time: {method.feat_timer.avg}," 93 | + f" Reg time: {method.reg_timer.avg}, RTE: {rte}," + 94 | f" RRE: {rre}, Success: {succ_rate * 100} %") 95 | 96 | if VISUALIZE and i % 10 == 9: 97 | visualize_pair(xyz0, xyz1, T_pred, config.voxel_size) 98 | 99 | succ_rate, rte, rre, avg_time, _ = stats.mean(0) 100 | logging.info( 101 | f"Data time: {data_timer.avg}, Feat time: {method.feat_timer.avg}," + 102 | f" Reg time: {method.reg_timer.avg}, RTE: {rte}," + 103 | f" RRE: {rre}, Success: {succ_rate * 100} %") 104 | 105 | # Save results 106 | filename = f'kitti-stats_{method.__class__.__name__}' 107 | if config.out_filename is not None: 108 | filename += f'_{config.out_filename}' 109 | if isinstance(method, FCGFWrapper): 110 | filename += '_' + method.method 111 | if 'ransac' in method.method: 112 | filename += f'_{config.ransac_iter}' 113 | if os.path.isdir(config.out_dir): 114 | out_file = os.path.join(config.out_dir, filename) 115 | else: 116 | out_file = filename # save it on the current directory 117 | print(f'Saving the stats to {out_file}') 118 | np.savez(out_file, stats=stats) 119 | analyze_stats(stats) 120 | 121 | 122 | if __name__ == '__main__': 123 | config = get_config() 124 | 125 | dgr = DeepGlobalRegistration(config) 126 | 127 | dset = KITTINMPairDataset('test', 128 | transform=None, 129 | random_rotation=False, 130 | random_scale=False, 131 | config=config) 132 | 133 | data_loader = torch.utils.data.DataLoader( 134 | dset, 135 | batch_size=1, 136 | shuffle=False, 137 | num_workers=0, 138 | collate_fn=CollationFunctionFactory(concat_correspondences=False, 139 | collation_type='collate_pair'), 140 | pin_memory=False, 141 | drop_last=False) 142 | 143 | evaluate(config, data_loader, dgr) 144 | -------------------------------------------------------------------------------- /scripts/train_3dmatch.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/Experiment2" 6 | export DATASET=${DATASET:-ThreeDMatchPairDataset03} 7 | export THREED_MATCH_DIR=${THREED_MATCH_DIR} 8 | export MODEL=${MODEL:-ResUNetBN2C} 9 | export MODEL_N_OUT=${MODEL_N_OUT:-32} 10 | export FCGF_WEIGHTS=${FCGF_WEIGHTS:fcgf.pth} 11 | export INLIER_MODEL=${INLIER_MODEL:-ResUNetBNF} 12 | export OPTIMIZER=${OPTIMIZER:-SGD} 13 | export LR=${LR:-1e-1} 14 | export MAX_EPOCH=${MAX_EPOCH:-100} 15 | export BATCH_SIZE=${BATCH_SIZE:-8} 16 | export ITER_SIZE=${ITER_SIZE:-1} 17 | export VOXEL_SIZE=${VOXEL_SIZE:-0.05} 18 | export POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER=${POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER:-4} 19 | export CONV1_KERNEL_SIZE=${CONV1_KERNEL_SIZE:-7} 20 | export EXP_GAMMA=${EXP_GAMMA:-0.99} 21 | export RANDOM_SCALE=${RANDOM_SCALE:-True} 22 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 23 | export VERSION=$(git rev-parse HEAD) 24 | 25 | export OUT_DIR=${DATA_ROOT}/${DATASET}-v${VOXEL_SIZE}/${INLIER_MODEL}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}i${ITER_SIZE}-modelnout${MODEL_N_OUT}${PATH_POSTFIX}/${TIME} 26 | 27 | export PYTHONUNBUFFERED="True" 28 | 29 | echo $OUT_DIR 30 | 31 | mkdir -m 755 -p $OUT_DIR 32 | 33 | LOG=${OUT_DIR}/log_${TIME}.txt 34 | 35 | echo "Host: " $(hostname) | tee -a $LOG 36 | echo "Conda " $(which conda) | tee -a $LOG 37 | echo $(pwd) | tee -a $LOG 38 | echo "Version: " $VERSION | tee -a $LOG 39 | echo "Git diff" | tee -a $LOG 40 | echo "" | tee -a $LOG 41 | git diff | tee -a $LOG 42 | echo "" | tee -a $LOG 43 | nvidia-smi | tee -a $LOG 44 | 45 | # Training 46 | python train.py \ 47 | --weights ${FCGF_WEIGHTS} \ 48 | --dataset ${DATASET} \ 49 | --threed_match_dir ${THREED_MATCH_DIR} \ 50 | --feat_model ${MODEL} \ 51 | --feat_model_n_out ${MODEL_N_OUT} \ 52 | --feat_conv1_kernel_size ${CONV1_KERNEL_SIZE} \ 53 | --inlier_model ${INLIER_MODEL} \ 54 | --optimizer ${OPTIMIZER} \ 55 | --lr ${LR} \ 56 | --batch_size ${BATCH_SIZE} \ 57 | --val_batch_size ${BATCH_SIZE} \ 58 | --iter_size ${ITER_SIZE} \ 59 | --max_epoch ${MAX_EPOCH} \ 60 | --voxel_size ${VOXEL_SIZE} \ 61 | --out_dir ${OUT_DIR} \ 62 | --use_random_scale ${RANDOM_SCALE} \ 63 | --positive_pair_search_voxel_size_multiplier ${POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER} \ 64 | $MISC_ARGS 2>&1 | tee -a $LOG 65 | 66 | # Test 67 | python -m scripts.test_3dmatch \ 68 | $MISC_ARGS \ 69 | --threed_match_dir ${THREED_MATCH_DIR} \ 70 | --weights ${OUT_DIR}/best_val_checkpoint.pth \ 71 | 2>&1 | tee -a $LOG 72 | -------------------------------------------------------------------------------- /scripts/train_kitti.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/Experiment3" 6 | export DATASET=${DATASET:-KITTINMPairDataset} 7 | export KITTI_PATH=${KITTI_PATH} 8 | export MODEL=${MODEL:-ResUNetBN2C} 9 | export MODEL_N_OUT=${MODEL_N_OUT:-32} 10 | export FCGF_WEIGHTS=${FCGF_WEIGHTS} 11 | export INLIER_MODEL=${INLIER_MODEL:-ResUNetBN2C} 12 | export OPTIMIZER=${OPTIMIZER:-SGD} 13 | export LR=${LR:-1e-2} 14 | export MAX_EPOCH=${MAX_EPOCH:-100} 15 | export BATCH_SIZE=${BATCH_SIZE:-8} 16 | export ITER_SIZE=${ITER_SIZE:-1} 17 | export VOXEL_SIZE=${VOXEL_SIZE:-0.3} 18 | export POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER=${POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER:-4} 19 | export CONV1_KERNEL_SIZE=${CONV1_KERNEL_SIZE:-5} 20 | export EXP_GAMMA=${EXP_GAMMA:-0.99} 21 | export RANDOM_SCALE=${RANDOM_SCALE:-True} 22 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 23 | export VERSION=$(git rev-parse HEAD) 24 | 25 | export OUT_DIR=${DATA_ROOT}/${DATASET}-v${VOXEL_SIZE}/${INLIER_MODEL}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}i${ITER_SIZE}-modelnout${MODEL_N_OUT}${PATH_POSTFIX}/${TIME} 26 | 27 | export PYTHONUNBUFFERED="True" 28 | 29 | echo $OUT_DIR 30 | 31 | mkdir -m 755 -p $OUT_DIR 32 | 33 | LOG=${OUT_DIR}/log_${TIME}.txt 34 | 35 | echo "Host: " $(hostname) | tee -a $LOG 36 | echo "Conda " $(which conda) | tee -a $LOG 37 | echo $(pwd) | tee -a $LOG 38 | echo "Version: " $VERSION | tee -a $LOG 39 | echo "Git diff" | tee -a $LOG 40 | echo "" | tee -a $LOG 41 | git diff | tee -a $LOG 42 | echo "" | tee -a $LOG 43 | nvidia-smi | tee -a $LOG 44 | 45 | # Training 46 | python train.py \ 47 | --weights ${FCGF_WEIGHTS} \ 48 | --dataset ${DATASET} \ 49 | --feat_model ${MODEL} \ 50 | --feat_model_n_out ${MODEL_N_OUT} \ 51 | --feat_conv1_kernel_size ${CONV1_KERNEL_SIZE} \ 52 | --inlier_model ${INLIER_MODEL} \ 53 | --optimizer ${OPTIMIZER} \ 54 | --lr ${LR} \ 55 | --batch_size ${BATCH_SIZE} \ 56 | --val_batch_size ${BATCH_SIZE} \ 57 | --iter_size ${ITER_SIZE} \ 58 | --max_epoch ${MAX_EPOCH} \ 59 | --voxel_size ${VOXEL_SIZE} \ 60 | --out_dir ${OUT_DIR} \ 61 | --use_random_scale ${RANDOM_SCALE} \ 62 | --kitti_dir ${KITTI_PATH} \ 63 | --success_rte_thresh 2 \ 64 | --success_rre_thresh 5 \ 65 | --positive_pair_search_voxel_size_multiplier ${POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER} \ 66 | $MISC_ARGS 2>&1 | tee -a $LOG 67 | 68 | # Test 69 | python -m scripts.test_kitti \ 70 | --kitti_dir ${KITTI_PATH} \ 71 | --save_dir ${OUT_DIR} | tee -a $LOG 72 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import open3d as o3d # prevent loading error 8 | 9 | import sys 10 | import json 11 | import logging 12 | import torch 13 | from easydict import EasyDict as edict 14 | 15 | from config import get_config 16 | 17 | from dataloader.data_loaders import make_data_loader 18 | 19 | from core.trainer import WeightedProcrustesTrainer 20 | 21 | ch = logging.StreamHandler(sys.stdout) 22 | logging.getLogger().setLevel(logging.INFO) 23 | logging.basicConfig(format='%(asctime)s %(message)s', 24 | datefmt='%m/%d %H:%M:%S', 25 | handlers=[ch]) 26 | 27 | torch.manual_seed(0) 28 | torch.cuda.manual_seed(0) 29 | 30 | logging.basicConfig(level=logging.INFO, format="") 31 | 32 | 33 | def main(config, resume=False): 34 | train_loader = make_data_loader(config, 35 | config.train_phase, 36 | config.batch_size, 37 | num_workers=config.train_num_workers, 38 | shuffle=True) 39 | 40 | if config.test_valid: 41 | val_loader = make_data_loader(config, 42 | config.val_phase, 43 | config.val_batch_size, 44 | num_workers=config.val_num_workers, 45 | shuffle=True) 46 | else: 47 | val_loader = None 48 | 49 | trainer = WeightedProcrustesTrainer( 50 | config=config, 51 | data_loader=train_loader, 52 | val_data_loader=val_loader, 53 | ) 54 | 55 | trainer.train() 56 | 57 | 58 | if __name__ == "__main__": 59 | logger = logging.getLogger() 60 | config = get_config() 61 | 62 | dconfig = vars(config) 63 | if config.resume_dir: 64 | resume_config = json.load(open(config.resume_dir + '/config.json', 'r')) 65 | for k in dconfig: 66 | if k not in ['resume_dir'] and k in resume_config: 67 | dconfig[k] = resume_config[k] 68 | dconfig['resume'] = resume_config['out_dir'] + '/checkpoint.pth' 69 | 70 | logging.info('===> Configurations') 71 | for k in dconfig: 72 | logging.info(' {}: {}'.format(k, dconfig[k])) 73 | 74 | # Convert to dict 75 | config = edict(dconfig) 76 | main(config) 77 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | -------------------------------------------------------------------------------- /util/file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import os 8 | import re 9 | from os import listdir 10 | from os.path import isfile, isdir, join, splitext 11 | 12 | import numpy as np 13 | 14 | 15 | def read_txt(path): 16 | """Read txt file into lines. 17 | """ 18 | with open(path) as f: 19 | lines = f.readlines() 20 | lines = [x.strip() for x in lines] 21 | return lines 22 | 23 | 24 | def ensure_dir(path): 25 | if not os.path.exists(path): 26 | os.makedirs(path, mode=0o755) 27 | 28 | 29 | def sorted_alphanum(file_list_ordered): 30 | def convert(text): 31 | return int(text) if text.isdigit() else text 32 | 33 | def alphanum_key(key): 34 | return [convert(c) for c in re.split('([0-9]+)', key)] 35 | 36 | return sorted(file_list_ordered, key=alphanum_key) 37 | 38 | 39 | def get_file_list(path, extension=None): 40 | if extension is None: 41 | file_list = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 42 | else: 43 | file_list = [ 44 | join(path, f) for f in listdir(path) 45 | if isfile(join(path, f)) and splitext(f)[1] == extension 46 | ] 47 | file_list = sorted_alphanum(file_list) 48 | return file_list 49 | 50 | 51 | def get_file_list_specific(path, color_depth, extension=None): 52 | if extension is None: 53 | file_list = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 54 | else: 55 | file_list = [ 56 | join(path, f) for f in listdir(path) 57 | if isfile(join(path, f)) and color_depth in f and splitext(f)[1] == extension 58 | ] 59 | file_list = sorted_alphanum(file_list) 60 | return file_list 61 | 62 | 63 | def get_folder_list(path): 64 | folder_list = [join(path, f) for f in listdir(path) if isdir(join(path, f))] 65 | folder_list = sorted_alphanum(folder_list) 66 | return folder_list 67 | 68 | 69 | def read_trajectory(filename, dim=4): 70 | class CameraPose: 71 | def __init__(self, meta, mat): 72 | self.metadata = meta 73 | self.pose = mat 74 | 75 | def __str__(self): 76 | return 'metadata : ' + ' '.join(map(str, self.metadata)) + '\n' + \ 77 | "pose : " + "\n" + np.array_str(self.pose) 78 | 79 | traj = [] 80 | with open(filename, 'r') as f: 81 | metastr = f.readline() 82 | while metastr: 83 | metadata = list(map(int, metastr.split())) 84 | mat = np.zeros(shape=(dim, dim)) 85 | for i in range(dim): 86 | matstr = f.readline() 87 | mat[i, :] = np.fromstring(matstr, dtype=float, sep=' \t') 88 | traj.append(CameraPose(metadata, mat)) 89 | metastr = f.readline() 90 | return traj 91 | -------------------------------------------------------------------------------- /util/integration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import open3d as o3d 8 | import argparse 9 | import os, sys 10 | import numpy as np 11 | 12 | 13 | def read_rgbd_image(color_file, depth_file, max_depth=4.5): 14 | ''' 15 | \return RGBD image 16 | ''' 17 | color = o3d.io.read_image(color_file) 18 | depth = o3d.io.read_image(depth_file) 19 | 20 | # We assume depth scale is always 1000.0 21 | rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( 22 | color, depth, depth_trunc=max_depth, convert_rgb_to_intensity=False) 23 | return rgbd_image 24 | 25 | 26 | def read_pose(pose_file): 27 | ''' 28 | \return 4x4 np matrix 29 | ''' 30 | pose = np.loadtxt(pose_file) 31 | assert pose is not None 32 | return pose 33 | 34 | 35 | def read_intrinsics(intrinsic_file): 36 | ''' 37 | \return fx, fy, cx, cy 38 | ''' 39 | K = np.loadtxt(intrinsic_file) 40 | assert K is not None 41 | return K[0, 0], K[1, 1], K[0, 2], K[1, 2] 42 | 43 | 44 | def integrate_rgb_frames_for_fragment(color_files, 45 | depth_files, 46 | pose_files, 47 | seq_path, 48 | intrinsic, 49 | fragment_id, 50 | n_fragments, 51 | n_frames_per_fragment, 52 | voxel_length=0.008): 53 | volume = o3d.integration.ScalableTSDFVolume( 54 | voxel_length=voxel_length, 55 | sdf_trunc=0.04, 56 | color_type=o3d.integration.TSDFVolumeColorType.RGB8) 57 | 58 | start = fragment_id * n_frames_per_fragment 59 | end = min(start + n_frames_per_fragment, len(pose_files)) 60 | for i_abs in range(start, end): 61 | print("Fragment %03d / %03d :: integrate rgbd frame %d (%d of %d)." % 62 | (fragment_id, n_fragments - 1, i_abs, i_abs - start + 1, end - start)) 63 | 64 | rgbd = read_rgbd_image( 65 | os.path.join(seq_path, color_files[i_abs]), 66 | os.path.join(seq_path, depth_files[i_abs])) 67 | pose = read_pose(os.path.join(seq_path, pose_files[i_abs])) 68 | volume.integrate(rgbd, intrinsic, np.linalg.inv(pose)) 69 | 70 | mesh = volume.extract_triangle_mesh() 71 | return mesh 72 | 73 | 74 | def process_seq(seq_path, output_path, n_frames_per_fragment, display=False): 75 | files = os.listdir(seq_path) 76 | 77 | if 'intrinsics.txt' in files: 78 | fx, fy, cx, cy = read_intrinsics(os.path.join(seq_path, 'intrinsics.txt')) 79 | else: 80 | fx, fy, cx, cy = read_intrinsics(os.path.join(seq_path, '../camera-intrinsics.txt')) 81 | 82 | rgb_files = sorted(list(filter(lambda x: x.endswith('.color.png'), files))) 83 | depth_files = sorted(list(filter(lambda x: x.endswith('.depth.png'), files))) 84 | pose_files = sorted(list(filter(lambda x: x.endswith('.pose.txt'), files))) 85 | 86 | assert len(rgb_files) > 0 87 | assert len(rgb_files) == len(depth_files) 88 | assert len(rgb_files) == len(pose_files) 89 | 90 | # get width and height to prepare for intrinsics 91 | rgb_sample = o3d.io.read_image(os.path.join(seq_path, rgb_files[0])) 92 | width, height = rgb_sample.get_max_bound() 93 | intrinsic = o3d.camera.PinholeCameraIntrinsic(int(width), int(height), fx, fy, cx, cy) 94 | 95 | n_fragments = ((len(rgb_files) + n_frames_per_fragment - 1)) // n_frames_per_fragment 96 | 97 | for fragment_id in range(0, n_fragments): 98 | mesh = integrate_rgb_frames_for_fragment(rgb_files, depth_files, pose_files, 99 | seq_path, intrinsic, fragment_id, 100 | n_fragments, n_frames_per_fragment) 101 | if display: 102 | o3d.visualization.draw_geometries([mesh]) 103 | 104 | mesh_name = os.path.join(output_seq_path, 'fragment-{}.ply'.format(fragment_id)) 105 | o3d.io.write_triangle_mesh(mesh_name, mesh) 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser( 110 | description='RGB-D integration for 3DMatch raw dataset') 111 | 112 | parser.add_argument( 113 | 'dataset', help='path to dataset that contains colors, depths and poses') 114 | parser.add_argument('output', help='path to output fragments') 115 | 116 | args = parser.parse_args() 117 | 118 | scene_name = args.dataset.split('/')[-1] 119 | if not os.path.exists(args.output): 120 | os.makedirs(args.output) 121 | 122 | output_scene_path = os.path.join(args.output, scene_name) 123 | if os.path.exists(output_scene_path): 124 | choice = input( 125 | 'Path {} already exists, continue? (Y / N)'.format(output_scene_path)) 126 | if choice != 'Y' and choice != 'y': 127 | print('abort') 128 | exit 129 | else: 130 | os.makedirs(output_scene_path) 131 | 132 | seqs = list(filter(lambda x: x.startswith('seq'), os.listdir(args.dataset))) 133 | for seq in seqs: 134 | output_seq_path = os.path.join(output_scene_path, seq) 135 | if not os.path.exists(output_seq_path): 136 | os.makedirs(output_seq_path) 137 | process_seq( 138 | os.path.join(args.dataset, seq), 139 | output_seq_path, 140 | n_frames_per_fragment=50, 141 | display=False) 142 | -------------------------------------------------------------------------------- /util/pointcloud.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import copy 8 | import numpy as np 9 | import math 10 | 11 | import open3d as o3d 12 | from core.knn import find_knn_cpu 13 | 14 | 15 | def make_open3d_point_cloud(xyz, color=None): 16 | pcd = o3d.geometry.PointCloud() 17 | pcd.points = o3d.utility.Vector3dVector(xyz.cpu().detach().numpy()) 18 | if color is not None: 19 | if len(color) != len(xyz): 20 | color = np.tile(color, (len(xyz), 1)) 21 | pcd.colors = o3d.utility.Vector3dVector(color) 22 | return pcd 23 | 24 | 25 | def make_open3d_feature(data, dim, npts): 26 | feature = o3d.registration.Feature() 27 | feature.resize(dim, npts) 28 | feature.data = data.cpu().numpy().astype('d').transpose() 29 | return feature 30 | 31 | 32 | def make_open3d_feature_from_numpy(data): 33 | assert isinstance(data, np.ndarray) 34 | assert data.ndim == 2 35 | 36 | feature = o3d.registration.Feature() 37 | feature.resize(data.shape[1], data.shape[0]) 38 | feature.data = data.astype('d').transpose() 39 | return feature 40 | 41 | 42 | def pointcloud_to_spheres(pcd, voxel_size, color, sphere_size=0.6): 43 | spheres = o3d.geometry.TriangleMesh() 44 | s = o3d.geometry.TriangleMesh.create_sphere(radius=voxel_size * sphere_size) 45 | s.compute_vertex_normals() 46 | s.paint_uniform_color(color) 47 | if isinstance(pcd, o3d.geometry.PointCloud): 48 | pcd = np.array(pcd.points) 49 | for i, p in enumerate(pcd): 50 | si = copy.deepcopy(s) 51 | trans = np.identity(4) 52 | trans[:3, 3] = p 53 | si.transform(trans) 54 | # si.paint_uniform_color(pcd.colors[i]) 55 | spheres += si 56 | return spheres 57 | 58 | 59 | def prepare_single_pointcloud(pcd, voxel_size): 60 | pcd.estimate_normals(o3d.KDTreeSearchParamHybrid(radius=voxel_size * 2.0, max_nn=30)) 61 | return pcd 62 | 63 | 64 | def prepare_pointcloud(filename, voxel_size): 65 | pcd = o3d.io.read_point_cloud(filename) 66 | T = get_random_transformation(pcd) 67 | pcd.transform(T) 68 | pcd_down = pcd.voxel_down_sample(voxel_size) 69 | return pcd_down, T 70 | 71 | 72 | def compute_overlap_ratio(pcd0, pcd1, trans, voxel_size): 73 | pcd0_down = pcd0.voxel_down_sample(voxel_size) 74 | pcd1_down = pcd1.voxel_down_sample(voxel_size) 75 | matching01 = get_matching_indices(pcd0_down, pcd1_down, trans, voxel_size, 1) 76 | matching10 = get_matching_indices(pcd1_down, pcd0_down, np.linalg.inv(trans), 77 | voxel_size, 1) 78 | overlap0 = len(matching01) / len(pcd0_down.points) 79 | overlap1 = len(matching10) / len(pcd1_down.points) 80 | return max(overlap0, overlap1) 81 | 82 | 83 | def get_matching_indices(source, target, trans, search_voxel_size, K=None): 84 | source_copy = copy.deepcopy(source) 85 | target_copy = copy.deepcopy(target) 86 | source_copy.transform(trans) 87 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 88 | 89 | match_inds = [] 90 | for i, point in enumerate(source_copy.points): 91 | [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) 92 | if K is not None: 93 | idx = idx[:K] 94 | for j in idx: 95 | match_inds.append((i, j)) 96 | return match_inds 97 | 98 | 99 | def evaluate_feature(pcd0, pcd1, feat0, feat1, trans_gth, search_voxel_size): 100 | match_inds = get_matching_indices(pcd0, pcd1, trans_gth, search_voxel_size) 101 | pcd_tree = o3d.geometry.KDTreeFlann(feat1) 102 | dist = [] 103 | for ind in match_inds: 104 | k, idx, _ = pcd_tree.search_knn_vector_xd(feat0.data[:, ind[0]], 1) 105 | dist.append( 106 | np.clip(np.power(pcd1.points[ind[1]] - pcd1.points[idx[0]], 2), 107 | a_min=0.0, 108 | a_max=1.0)) 109 | return np.mean(dist) 110 | 111 | 112 | def valid_feat_ratio(pcd0, pcd1, feat0, feat1, trans_gth, thresh=0.1): 113 | pcd0_copy = copy.deepcopy(pcd0) 114 | pcd0_copy.transform(trans_gth) 115 | inds = find_knn_cpu(feat0, feat1) 116 | dist = np.sqrt(((np.array(pcd0_copy.points) - np.array(pcd1.points)[inds])**2).sum(1)) 117 | return np.mean(dist < thresh) 118 | 119 | 120 | def evaluate_feature_3dmatch(pcd0, pcd1, feat0, feat1, trans_gth, inlier_thresh=0.1): 121 | r"""Return the hit ratio (ratio of inlier correspondences and all correspondences). 122 | 123 | inliear_thresh is the inlier_threshold in meter. 124 | """ 125 | if len(pcd0.points) < len(pcd1.points): 126 | hit = valid_feat_ratio(pcd0, pcd1, feat0, feat1, trans_gth, inlier_thresh) 127 | else: 128 | hit = valid_feat_ratio(pcd1, pcd0, feat1, feat0, np.linalg.inv(trans_gth), 129 | inlier_thresh) 130 | return hit 131 | 132 | 133 | def get_matching_matrix(source, target, trans, voxel_size, debug_mode): 134 | source_copy = copy.deepcopy(source) 135 | target_copy = copy.deepcopy(target) 136 | source_copy.transform(trans) 137 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 138 | matching_matrix = np.zeros((len(source_copy.points), len(target_copy.points))) 139 | 140 | for i, point in enumerate(source_copy.points): 141 | [k, idx, _] = pcd_tree.search_radius_vector_3d(point, voxel_size * 1.5) 142 | if k >= 1: 143 | matching_matrix[i, idx[0]] = 1 # TODO: only the cloest? 144 | 145 | return matching_matrix 146 | 147 | 148 | def get_random_transformation(pcd_input): 149 | def rot_x(x): 150 | out = np.zeros((3, 3)) 151 | c = math.cos(x) 152 | s = math.sin(x) 153 | out[0, 0] = 1 154 | out[1, 1] = c 155 | out[1, 2] = -s 156 | out[2, 1] = s 157 | out[2, 2] = c 158 | return out 159 | 160 | def rot_y(x): 161 | out = np.zeros((3, 3)) 162 | c = math.cos(x) 163 | s = math.sin(x) 164 | out[0, 0] = c 165 | out[0, 2] = s 166 | out[1, 1] = 1 167 | out[2, 0] = -s 168 | out[2, 2] = c 169 | return out 170 | 171 | def rot_z(x): 172 | out = np.zeros((3, 3)) 173 | c = math.cos(x) 174 | s = math.sin(x) 175 | out[0, 0] = c 176 | out[0, 1] = -s 177 | out[1, 0] = s 178 | out[1, 1] = c 179 | out[2, 2] = 1 180 | return out 181 | 182 | pcd_output = copy.deepcopy(pcd_input) 183 | mean = np.mean(np.asarray(pcd_output.points), axis=0).transpose() 184 | xyz = np.random.uniform(0, 2 * math.pi, 3) 185 | R = np.dot(np.dot(rot_x(xyz[0]), rot_y(xyz[1])), rot_z(xyz[2])) 186 | T = np.zeros((4, 4)) 187 | T[:3, :3] = R 188 | T[:3, 3] = np.dot(-R, mean) 189 | T[3, 3] = 1 190 | return T 191 | 192 | 193 | def draw_registration_result(source, target, transformation): 194 | source_temp = copy.deepcopy(source) 195 | target_temp = copy.deepcopy(target) 196 | source_temp.paint_uniform_color([1, 0.706, 0]) 197 | target_temp.paint_uniform_color([0, 0.651, 0.929]) 198 | source_temp.transform(transformation) 199 | o3d.visualization.draw_geometries([source_temp, target_temp]) 200 | -------------------------------------------------------------------------------- /util/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu) and Wei Dong (weidong@andrew.cmu.edu) 2 | # 3 | # Please cite the following papers if you use any part of the code. 4 | # - Christopher Choy, Wei Dong, Vladlen Koltun, Deep Global Registration, CVPR 2020 5 | # - Christopher Choy, Jaesik Park, Vladlen Koltun, Fully Convolutional Geometric Features, ICCV 2019 6 | # - Christopher Choy, JunYoung Gwak, Silvio Savarese, 4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural Networks, CVPR 2019 7 | import time 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0.0 22 | self.sq_sum = 0.0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | if isinstance(val, np.ndarray): 27 | n = val.size 28 | val = val.mean() 29 | elif isinstance(val, torch.Tensor): 30 | n = val.nelement() 31 | val = val.mean().item() 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | self.sq_sum += val**2 * n 37 | self.var = self.sq_sum / self.count - self.avg ** 2 38 | 39 | 40 | class Timer(AverageMeter): 41 | """A simple timer.""" 42 | 43 | def tic(self): 44 | # using time.time instead of time.clock because time time.clock 45 | # does not normalize for multithreading 46 | self.start_time = time.time() 47 | 48 | def toc(self, average=True): 49 | self.diff = time.time() - self.start_time 50 | self.update(self.diff) 51 | if average: 52 | return self.avg 53 | else: 54 | return self.diff 55 | --------------------------------------------------------------------------------