├── .gitignore ├── LICENSE ├── README.md ├── assets ├── cloud_bin_21.pth ├── cloud_bin_34.pth ├── demo.png ├── dist_thresh.txt ├── fix_knn_feats.png ├── inlier_thresh.txt ├── results.png └── teaser_predator.jpg ├── common ├── colors.py ├── math │ ├── random.py │ ├── se3.py │ └── so3.py ├── math_torch │ └── se3.py ├── misc.py └── torch.py ├── configs ├── benchmarks │ ├── 3DLoMatch │ │ ├── 7-scenes-redkitchen │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── sun3d-hotel_uc-scan3 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── sun3d-hotel_umd-maryland_hotel1 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── sun3d-hotel_umd-maryland_hotel3 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ └── 3DMatch │ │ ├── 7-scenes-redkitchen │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1 │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30 │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ ├── sun3d-hotel_uc-scan3 │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ ├── sun3d-hotel_umd-maryland_hotel1 │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ ├── sun3d-hotel_umd-maryland_hotel3 │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2 │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika │ │ ├── gt.info │ │ ├── gt.log │ │ └── gt_overlap.log ├── indoor │ ├── 3DLoMatch.pkl │ ├── 3DMatch.pkl │ ├── train_3dmatch.txt │ ├── train_info.pkl │ ├── val_3dmatch.txt │ └── val_info.pkl ├── kitti │ ├── test_kitti.txt │ ├── train_kitti.txt │ └── val_kitti.txt ├── modelnet │ ├── modelnet40_all.txt │ ├── modelnet40_half1.txt │ └── modelnet40_half2.txt ├── models.py ├── test │ ├── indoor.yaml │ ├── kitti.yaml │ └── modelnet.yaml └── train │ ├── indoor.yaml │ ├── kitti.yaml │ └── modelnet.yaml ├── cpp_wrappers ├── compile_wrappers.sh ├── cpp_neighbors │ ├── build.bat │ ├── neighbors │ │ ├── neighbors.cpp │ │ └── neighbors.h │ ├── setup.py │ └── wrapper.cpp ├── cpp_subsampling │ ├── build.bat │ ├── grid_subsampling │ │ ├── grid_subsampling.cpp │ │ └── grid_subsampling.h │ ├── setup.py │ └── wrapper.cpp └── cpp_utils │ ├── cloud │ ├── cloud.cpp │ └── cloud.h │ └── nanoflann │ └── nanoflann.hpp ├── datasets ├── __init__.py ├── dataloader.py ├── indoor.py ├── kitti.py ├── modelnet.py └── transforms.py ├── kernels ├── dispositions │ └── k_015_center_3D.ply └── kernel_points.py ├── lib ├── __init__.py ├── benchmark.py ├── benchmark_utils.py ├── loss.py ├── ply.py ├── tester.py ├── timer.py ├── trainer.py └── utils.py ├── main.py ├── models ├── __init__.py ├── architectures.py ├── blocks.py └── gcn.py ├── requirements.txt └── scripts ├── cal_overlap.py ├── demo.py ├── download_data_weight.sh └── evaluate_predator.py /.gitignore: -------------------------------------------------------------------------------- 1 | snapshot/ 2 | scripts.py 3 | *.so 4 | *.zip 5 | data/ 6 | weights/ 7 | __pycache__/ 8 | build/ 9 | dump/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Shengyu Huang, Zan Gojcic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PREDATOR: Registration of 3D Point Clouds with Low Overlap (CVPR 2021, Oral) 2 | This repository represents the official implementation of the paper: 3 | 4 | ### [PREDATOR: Registration of 3D Point Clouds with Low Overlap](https://arxiv.org/abs/2011.13005) 5 | 6 | \*[Shengyu Huang](https://shengyuh.github.io), \*[Zan Gojcic](https://zgojcic.github.io/), [Mikhail Usvyatsov](https://aelphy.github.io), [Andreas Wieser](https://gseg.igp.ethz.ch/people/group-head/prof-dr--andreas-wieser.html), [Konrad Schindler](https://prs.igp.ethz.ch/group/people/person-detail.schindler.html)\ 7 | |[ETH Zurich](https://igp.ethz.ch/) | \* Equal contribution 8 | 9 | For implementation using MinkowskiEngine backbone, please check [this](https://github.com/ShengyuH/OverlapPredator.Mink) 10 | 11 | For more information, please see the [project website](https://overlappredator.github.io) 12 | 13 | ![Predator_teaser](assets/teaser_predator.jpg?raw=true) 14 | 15 | 16 | 17 | ### Contact 18 | If you have any questions, please let us know: 19 | - Shengyu Huang {shengyu.huang@geod.baug.ethz.ch} 20 | - Zan Gojcic {zan.gojcic@geod.baug.ethz.ch} 21 | 22 | ## News 23 | - 2021-08-09: We've updated arxiv version of our [paper](https://arxiv.org/abs/2011.13005) with improved performance! 24 | - 2021-06-02: Fix feature gathering bug in k-nn graph, please see improved performance in this [issue](https://github.com/overlappredator/OverlapPredator/issues/15). Stay tunned for updates on other experiments! 25 | - 2021-05-31: Check our video and poster on [project page](https://overlappredator.github.io)! 26 | - 2021-03-25: Camera ready is on arXiv! I also gave a talk on Predator(中文), you can find the recording here: [Bilibili](https://www.bilibili.com/video/BV1UK4y1U7Gs), [Youtube](https://www.youtube.com/watch?v=AZQGJa6R_4I&t=1563s) 27 | - 2021-02-28: MinkowskiEngine-based PREDATOR [release](https://github.com/ShengyuH/OverlapPredator.Mink.git) 28 | - 2020-11-30: Code and paper release 29 | 30 | 31 | ## Instructions 32 | This code has been tested on 33 | - Python 3.8.5, PyTorch 1.7.1, CUDA 11.2, gcc 9.3.0, GeForce RTX 3090/GeForce GTX 1080Ti 34 | 35 | **Note**: We observe random data loader crashes due to memory issues, if you observe similar issues, please consider reducing the number of workers or increasing CPU RAM. We now released a sparse convolution-based Predator, have a look [here](https://github.com/ShengyuH/OverlapPredator.Mink.git)! 36 | 37 | ### Requirements 38 | To create a virtual environment and install the required dependences please run: 39 | ```shell 40 | git clone https://github.com/overlappredator/OverlapPredator.git 41 | virtualenv predator; source predator/bin/activate 42 | cd OverlapPredator; pip install -r requirements.txt 43 | cd cpp_wrappers; sh compile_wrappers.sh; cd .. 44 | ``` 45 | in your working folder. 46 | 47 | ### Datasets and pretrained models 48 | For KITTI dataset, please follow the instruction on [KITTI Odometry website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) to download the KITTI odometry training set. 49 | 50 | We provide 51 | - preprocessed 3DMatch pairwise datasets (voxel-grid subsampled fragments together with their ground truth transformation matrices) 52 | - raw dense 3DMatch datasets 53 | - modelnet dataset 54 | - pretrained models on 3DMatch, KITTI and Modelnet 55 | 56 | The preprocessed data and models can be downloaded by running: 57 | ```shell 58 | sh scripts/download_data_weight.sh 59 | ``` 60 | 61 | To download raw dense 3DMatch data, please run: 62 | ```shell 63 | wget --no-check-certificate --show-progress https://share.phys.ethz.ch/~gsg/pairwise_reg/3dmatch.zip 64 | unzip 3dmatch.zip 65 | ``` 66 | 67 | The folder is organised as follows: 68 | 69 | - `3dmatch` 70 | - `train` 71 | - `7-scenes-chess` 72 | - `fragments` 73 | - `cloud_bin_*.ply` 74 | - ... 75 | - `poses` 76 | - `cloud_bin_*.txt` 77 | - ... 78 | - ... 79 | - `test` 80 | 81 | ### 3DMatch(Indoor) 82 | #### Train 83 | After creating the virtual environment and downloading the datasets, Predator can be trained using: 84 | ```shell 85 | python main.py configs/train/indoor.yaml 86 | ``` 87 | 88 | #### Evaluate 89 | For 3DMatch, to reproduce Table 2 in our main paper, we first extract features and overlap/matachability scores by running: 90 | ```shell 91 | python main.py configs/test/indoor.yaml 92 | ``` 93 | the features together with scores will be saved to ```snapshot/indoor/3DMatch```. The estimation of the transformation parameters using RANSAC can then be carried out using: 94 | ```shell 95 | for N_POINTS in 250 500 1000 2500 5000 96 | do 97 | python scripts/evaluate_predator.py --source_path snapshot/indoor/3DMatch --n_points $N_POINTS --benchmark 3DMatch --exp_dir snapshot/indoor/est_traj --sampling prob 98 | done 99 | ``` 100 | dependent on ```n_points``` used by RANSAC, this might take a few minutes. The final results are stored in ```snapshot/indoor/est_traj/{benchmark}_{n_points}_prob/result```. To evaluate PREDATOR on 3DLoMatch benchmark, please also change ```3DMatch``` to ```3DLoMatch``` in ```configs/test/indoor.yaml```. 101 | 102 | #### Demo 103 | We prepared a small demo, which demonstrates the whole Predator pipeline using two random fragments from the 3DMatch dataset. To carry out the demo, please run: 104 | ```shell 105 | python scripts/demo.py configs/test/indoor.yaml 106 | ``` 107 | 108 | The demo script will visualize input point clouds, inferred overlap regions, and point cloud aligned with the estimated transformation parameters: 109 | 110 | demo 111 | 112 | ### ModelNet(Synthetic) 113 | #### Train 114 | To train PREDATOR on ModelNet, please run: 115 | ``` 116 | python main.py configs/train/modelnet.yaml 117 | ``` 118 | 119 | We provide a small script to evaluate Predator on ModelNet test set, please run: 120 | ``` 121 | python main.py configs/test/modelnet.yaml 122 | ``` 123 | The rotation and translation errors could be better/worse than the reported ones due to randomness in RANSAC. 124 | 125 | ### KITTI(Outdoor) 126 | We provide a small script to evaluate Predator on KITTI test set, after configuring KITTI dataset, please run: 127 | ``` 128 | python main.py configs/test/kitti.yaml 129 | ``` 130 | the results will be saved to the log file. 131 | 132 | 133 | ### Custom dataset 134 | We have a few tips for train/test on custom dataset 135 | 136 | - If it's similar indoor scenes, please run ```demo.py``` first to check the generalisation ability before retraining 137 | - Remember to voxel-downsample the data in your data loader, see ```kitti.py``` for reference 138 | 139 | ### Citation 140 | If you find this code useful for your work or use it in your project, please consider citing: 141 | 142 | ```shell 143 | @InProceedings{Huang_2021_CVPR, 144 | author = {Huang, Shengyu and Gojcic, Zan and Usvyatsov, Mikhail and Wieser, Andreas and Schindler, Konrad}, 145 | title = {Predator: Registration of 3D Point Clouds With Low Overlap}, 146 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 147 | month = {June}, 148 | year = {2021}, 149 | pages = {4267-4276} 150 | } 151 | ``` 152 | 153 | ### Acknowledgments 154 | In this project we use (parts of) the official implementations of the followin works: 155 | 156 | - [FCGF](https://github.com/chrischoy/FCGF) (KITTI preprocessing) 157 | - [D3Feat](https://github.com/XuyangBai/D3Feat.pytorch) (KPConv backbone) 158 | - [3DSmoothNet](https://github.com/zgojcic/3DSmoothNet) (3DMatch preparation) 159 | - [MultiviewReg](https://github.com/zgojcic/3D_multiview_reg) (3DMatch benchmark) 160 | - [SuperGlue](https://github.com/magicleap/SuperGluePretrainedNetwork) (Transformer part) 161 | - [DGCNN](https://github.com/WangYueFt/dgcnn) (self-gnn) 162 | - [RPMNet](https://github.com/yewzijian/RPMNet) (ModelNet preprocessing and evaluation) 163 | 164 | We thank the respective authors for open sourcing their methods. We would also like to thank reviewers, especially reviewer 2 for his/her valuable inputs. 165 | -------------------------------------------------------------------------------- /assets/cloud_bin_21.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/assets/cloud_bin_21.pth -------------------------------------------------------------------------------- /assets/cloud_bin_34.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/assets/cloud_bin_34.pth -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/assets/demo.png -------------------------------------------------------------------------------- /assets/dist_thresh.txt: -------------------------------------------------------------------------------- 1 | method 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1 0.11 0.12 0.13 0.14 0.15 0.16 0.17 0.18 0.19 0.2 2 | 3DMatch 0.34988788 1.735477 6.3130245 13.2243185 20.68582 28.49041 34.01513 40.062992 45.687683 52.00012 55.463696 59.081406 61.986275 65.45967 68.43626 70.95725 73.55156 76.07139 77.85846 80.27285 3 | FPFH 0.32518435 3.0640497 7.408897 12.879427 17.654156 21.894594 24.814766 28.676664 32.31617 35.307827 38.895893 41.363243 44.120644 47.001804 49.11531 52.33542 55.02593 56.853123 59.70479 61.781036 4 | SpinImages 0.2650882 1.4736623 3.4342945 5.703008 8.022694 9.878324 12.34205 15.342176 18.735947 21.89705 24.729364 28.054333 31.125975 33.943203 36.965153 40.642887 43.33119 46.57485 49.260128 52.152588 5 | SHOT 0.47008017 2.6745644 5.1028447 7.4799767 9.238545 11.927085 16.089037 18.879076 22.094217 24.974531 27.351418 29.76674 32.573856 34.916573 37.086315 40.306644 43.138355 45.73636 48.786892 50.972153 6 | Predator 1.598646 43.764183 72.060312 84.277961 90.004648 92.64463 94.014723 94.990602 95.96624 96.601441 96.806433 97.099785 97.35597 97.35597 97.380674 97.380674 97.611186 97.90688 98.178721 98.238818 7 | FCGF 4.135307 50.302743 76.513382 87.786066 92.033676 93.6785 94.900559 96.311261 96.912816 97.353521 97.738293 98.016699 98.573519 98.633615 98.676423 98.731733 98.854555 98.854555 98.854555 98.99945 8 | 3DSN 6.925493 42.819659 65.908609 78.549499 85.100889 88.986995 91.870439 93.402519 94.398583 95.019629 95.537113 95.72003 95.900319 96.135917 96.358351 96.631008 96.831329 96.911342 97.031535 97.249845 9 | D3Feat 1.326839 31.021349 61.608349 75.575873 82.991841 88.585381 91.007284 92.794678 94.556689 95.631483 96.580489 96.768193 96.8176 96.920505 97.085318 97.128126 97.695815 97.816008 97.985607 97.985607 10 | -------------------------------------------------------------------------------- /assets/fix_knn_feats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/assets/fix_knn_feats.png -------------------------------------------------------------------------------- /assets/inlier_thresh.txt: -------------------------------------------------------------------------------- 1 | method 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.10 0.11 0.12 0.13 0.14 0.15 0.16 0.17 0.18 0.19 0.2 2 | 3DMatch 93.57423 84.989296 72.26616 60.18662 50.84514 40.988052 35.264496 29.492676 24.875406 20.927317 18.053923 14.7144 12.210866 10.735968 9.550746 8.107913 7.4702764 6.172435 5.2574573 4.3098807 3 | FPFH 84.33316 68.71096 54.144688 43.4266 35.885315 29.24774 24.669365 21.693983 18.398415 15.19658 13.052581 10.593815 9.197111 7.121167 6.0142927 5.459413 4.7025414 3.7748396 3.2582376 2.9224794 4 | SpinImages 80.370415 57.09011 41.317715 30.377916 22.666761 16.668486 12.008004 9.49888 6.9450083 5.2102103 4.6000204 3.9652426 3.6801224 2.8298047 2.4588833 1.6771933 1.4722013 1.3379945 1.1484778 0.9681894 5 | SHOT 74.7338 55.754692 41.421146 31.38755 23.82493 18.377901 14.077372 11.229745 9.498036 7.2400875 6.259093 5.4715004 4.264607 3.6964087 3.070859 2.418819 2.0778346 1.5370839 1.5123804 1.4522841 6 | Predator 98.859889 97.909077 97.500866 97.125458 96.601441 95.734758 94.609147 94.146195 93.172484 92.713764 92.183001 91.71345 91.044471 90.592342 90.185631 89.322503 88.736845 88.36471 87.145778 86.287232 7 | FCGF 99.526474 98.940171 98.226632 97.884006 97.353521 96.627542 96.318829 96.029037 95.339938 94.679256 93.99646 93.517994 92.618115 92.43327 92.128233 91.267739 90.514984 90.144248 89.87309 89.340131 8 | 3DSN 99.053496 97.806905 97.138431 96.063319 95.019629 93.43627 92.676226 91.676057 89.926433 88.093154 86.428362 84.396438 83.500374 82.442937 80.94519 78.946443 77.480061 75.891795 74.209464 72.884469 9 | D3Feat 98.452534 97.660423 97.020136 96.768193 95.631483 94.327522 92.314021 90.721366 89.756974 88.743718 88.126642 87.331697 85.993739 84.35534 83.309026 82.154434 80.440724 79.07386 77.657352 75.881642 10 | -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/assets/results.png -------------------------------------------------------------------------------- /assets/teaser_predator.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/assets/teaser_predator.jpg -------------------------------------------------------------------------------- /common/colors.py: -------------------------------------------------------------------------------- 1 | """Useful color codes""" 2 | ORANGE = [239, 124, 0] 3 | BLUE = [0, 61, 124] -------------------------------------------------------------------------------- /common/math/random.py: -------------------------------------------------------------------------------- 1 | """Functions for random sampling""" 2 | import numpy as np 3 | 4 | 5 | def uniform_2_sphere(num: int = None): 6 | """Uniform sampling on a 2-sphere 7 | 8 | Source: https://gist.github.com/andrewbolster/10274979 9 | 10 | Args: 11 | num: Number of vectors to sample (or None if single) 12 | 13 | Returns: 14 | Random Vector (np.ndarray) of size (num, 3) with norm 1. 15 | If num is None returned value will have size (3,) 16 | 17 | """ 18 | if num is not None: 19 | phi = np.random.uniform(0.0, 2 * np.pi, num) 20 | cos_theta = np.random.uniform(-1.0, 1.0, num) 21 | else: 22 | phi = np.random.uniform(0.0, 2 * np.pi) 23 | cos_theta = np.random.uniform(-1.0, 1.0) 24 | 25 | theta = np.arccos(cos_theta) 26 | x = np.sin(theta) * np.cos(phi) 27 | y = np.sin(theta) * np.sin(phi) 28 | z = np.cos(theta) 29 | 30 | return np.stack((x, y, z), axis=-1) 31 | 32 | 33 | if __name__ == '__main__': 34 | # Visualize sampling 35 | from vtk_visualizer.plot3d import plotxyz 36 | rand_2s = uniform_2_sphere(10000) 37 | plotxyz(rand_2s, block=True) 38 | -------------------------------------------------------------------------------- /common/math/se3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | 5 | def identity(): 6 | return np.eye(3, 4) 7 | 8 | 9 | def transform(g: np.ndarray, pts: np.ndarray): 10 | """ Applies the SE3 transform 11 | 12 | Args: 13 | g: SE3 transformation matrix of size ([B,] 3/4, 4) 14 | pts: Points to be transformed ([B,] N, 3) 15 | 16 | Returns: 17 | transformed points of size (N, 3) 18 | """ 19 | rot = g[..., :3, :3] # (3, 3) 20 | trans = g[..., :3, 3] # (3) 21 | 22 | transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :] 23 | return transformed 24 | 25 | 26 | def inverse(g: np.ndarray): 27 | """Returns the inverse of the SE3 transform 28 | 29 | Args: 30 | g: ([B,] 3/4, 4) transform 31 | 32 | Returns: 33 | ([B,] 3/4, 4) matrix containing the inverse 34 | 35 | """ 36 | rot = g[..., :3, :3] # (3, 3) 37 | trans = g[..., :3, 3] # (3) 38 | 39 | inv_rot = np.swapaxes(rot, -1, -2) 40 | inverse_transform = np.concatenate([inv_rot, inv_rot @ -trans[..., None]], axis=-1) 41 | if g.shape[-2] == 4: 42 | inverse_transform = np.concatenate([inverse_transform, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 43 | 44 | return inverse_transform 45 | 46 | 47 | def concatenate(a: np.ndarray, b: np.ndarray): 48 | """ Concatenate two SE3 transforms 49 | 50 | Args: 51 | a: First transform ([B,] 3/4, 4) 52 | b: Second transform ([B,] 3/4, 4) 53 | 54 | Returns: 55 | a*b ([B, ] 3/4, 4) 56 | 57 | """ 58 | 59 | r_a, t_a = a[..., :3, :3], a[..., :3, 3] 60 | r_b, t_b = b[..., :3, :3], b[..., :3, 3] 61 | 62 | r_ab = r_a @ r_b 63 | t_ab = r_a @ t_b[..., None] + t_a[..., None] 64 | 65 | concatenated = np.concatenate([r_ab, t_ab], axis=-1) 66 | 67 | if a.shape[-2] == 4: 68 | concatenated = np.concatenate([concatenated, [[0.0, 0.0, 0.0, 1.0]]], axis=-2) 69 | 70 | return concatenated 71 | 72 | 73 | def from_xyzquat(xyzquat): 74 | """Constructs SE3 matrix from x, y, z, qx, qy, qz, qw 75 | 76 | Args: 77 | xyzquat: np.array (7,) containing translation and quaterion 78 | 79 | Returns: 80 | SE3 matrix (4, 4) 81 | """ 82 | rot = Rotation.from_quat(xyzquat[3:]) 83 | trans = rot.apply(-xyzquat[:3]) 84 | transform = np.concatenate([rot.as_dcm(), trans[:, None]], axis=1) 85 | transform = np.concatenate([transform, [[0.0, 0.0, 0.0, 1.0]]], axis=0) 86 | 87 | return transform -------------------------------------------------------------------------------- /common/math/so3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rotation related functions for numpy arrays 3 | """ 4 | 5 | import numpy as np 6 | from scipy.spatial.transform import Rotation 7 | 8 | 9 | def dcm2euler(mats: np.ndarray, seq: str = 'zyx', degrees: bool = True): 10 | """Converts rotation matrix to euler angles 11 | 12 | Args: 13 | mats: (B, 3, 3) containing the B rotation matricecs 14 | seq: Sequence of euler rotations (default: 'zyx') 15 | degrees (bool): If true (default), will return in degrees instead of radians 16 | 17 | Returns: 18 | 19 | """ 20 | 21 | eulers = [] 22 | for i in range(mats.shape[0]): 23 | r = Rotation.from_dcm(mats[i]) 24 | eulers.append(r.as_euler(seq, degrees=degrees)) 25 | return np.stack(eulers) 26 | 27 | 28 | def transform(g: np.ndarray, pts: np.ndarray): 29 | """ Applies the SO3 transform 30 | 31 | Args: 32 | g: SO3 transformation matrix of size (3, 3) 33 | pts: Points to be transformed (N, 3) 34 | 35 | Returns: 36 | transformed points of size (N, 3) 37 | 38 | """ 39 | rot = g[:3, :3] # (3, 3) 40 | transformed = pts @ rot.transpose() 41 | return transformed 42 | -------------------------------------------------------------------------------- /common/math_torch/se3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rigid body transformation group 2 | """ 3 | import torch 4 | 5 | 6 | def identity(batch_size): 7 | return torch.eye(3, 4)[None, ...].repeat(batch_size, 1, 1) 8 | 9 | 10 | def inverse(g): 11 | """ Returns the inverse of the SE3 transform 12 | 13 | Args: 14 | g: (B, 3/4, 4) transform 15 | 16 | Returns: 17 | (B, 3, 4) matrix containing the inverse 18 | 19 | """ 20 | # Compute inverse 21 | rot = g[..., 0:3, 0:3] 22 | trans = g[..., 0:3, 3] 23 | inverse_transform = torch.cat([rot.transpose(-1, -2), rot.transpose(-1, -2) @ -trans[..., None]], dim=-1) 24 | 25 | return inverse_transform 26 | 27 | 28 | def concatenate(a, b): 29 | """Concatenate two SE3 transforms, 30 | i.e. return a@b (but note that our SE3 is represented as a 3x4 matrix) 31 | 32 | Args: 33 | a: (B, 3/4, 4) 34 | b: (B, 3/4, 4) 35 | 36 | Returns: 37 | (B, 3/4, 4) 38 | """ 39 | 40 | rot1 = a[..., :3, :3] 41 | trans1 = a[..., :3, 3] 42 | rot2 = b[..., :3, :3] 43 | trans2 = b[..., :3, 3] 44 | 45 | rot_cat = rot1 @ rot2 46 | trans_cat = rot1 @ trans2[..., None] + trans1[..., None] 47 | concatenated = torch.cat([rot_cat, trans_cat], dim=-1) 48 | 49 | return concatenated 50 | 51 | 52 | def transform(g, a, normals=None): 53 | """ Applies the SE3 transform 54 | 55 | Args: 56 | g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4) 57 | a: Points to be transformed (N, 3) or (B, N, 3) 58 | normals: (Optional). If provided, normals will be transformed 59 | 60 | Returns: 61 | transformed points of size (N, 3) or (B, N, 3) 62 | 63 | """ 64 | R = g[..., :3, :3] # (B, 3, 3) 65 | p = g[..., :3, 3] # (B, 3) 66 | 67 | if len(g.size()) == len(a.size()): 68 | b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :] 69 | else: 70 | raise NotImplementedError 71 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p # No batch. Not checked 72 | 73 | if normals is not None: 74 | rotated_normals = normals @ R.transpose(-1, -2) 75 | return b, rotated_normals 76 | 77 | else: 78 | return b 79 | -------------------------------------------------------------------------------- /common/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc utilities 3 | """ 4 | 5 | import argparse 6 | from datetime import datetime 7 | import logging 8 | import os 9 | import shutil 10 | import subprocess 11 | import sys 12 | 13 | import coloredlogs 14 | import git 15 | 16 | 17 | _logger = logging.getLogger() 18 | 19 | 20 | def print_info(opt, log_dir=None): 21 | """ Logs source code configuration 22 | """ 23 | _logger.info('Command: {}'.format(' '.join(sys.argv))) 24 | 25 | # Print commit ID 26 | try: 27 | repo = git.Repo(search_parent_directories=True) 28 | git_sha = repo.head.object.hexsha 29 | git_date = datetime.fromtimestamp(repo.head.object.committed_date).strftime('%Y-%m-%d') 30 | git_message = repo.head.object.message 31 | _logger.info('Source is from Commit {} ({}): {}'.format(git_sha[:8], git_date, git_message.strip())) 32 | 33 | # Also create diff file in the log directory 34 | if log_dir is not None: 35 | with open(os.path.join(log_dir, 'compareHead.diff'), 'w') as fid: 36 | subprocess.run(['git', 'diff'], stdout=fid) 37 | 38 | except git.exc.InvalidGitRepositoryError: 39 | pass 40 | 41 | # Arguments 42 | arg_str = ['{}: {}'.format(key, value) for key, value in vars(opt).items()] 43 | arg_str = ', '.join(arg_str) 44 | #_logger.info('Arguments: {}'.format(arg_str)) 45 | 46 | 47 | def prepare_logger(opt: argparse.Namespace, log_path: str = None): 48 | """Creates logging directory, and installs colorlogs 49 | 50 | Args: 51 | opt: Program arguments, should include --dev and --logdir flag. 52 | See get_parent_parser() 53 | log_path: Logging path (optional). This serves to overwrite the settings in 54 | argparse namespace 55 | 56 | Returns: 57 | logger (logging.Logger) 58 | log_path (str): Logging directory 59 | """ 60 | 61 | if log_path is None: 62 | if opt.dev: 63 | log_path = '../logdev' 64 | shutil.rmtree(log_path, ignore_errors=True) 65 | else: 66 | datetime_str = datetime.now().strftime('%y%m%d_%H%M%S') 67 | if opt.name is not None: 68 | log_path = os.path.join(opt.logdir, datetime_str + '_' + opt.name) 69 | else: 70 | log_path = os.path.join(opt.logdir, datetime_str) 71 | 72 | os.makedirs(log_path, exist_ok=True) 73 | logger = logging.getLogger() 74 | coloredlogs.install(level='INFO', logger=logger) 75 | file_handler = logging.FileHandler('{}/log.txt'.format(log_path)) 76 | log_formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s - %(message)s') 77 | file_handler.setFormatter(log_formatter) 78 | logger.addHandler(file_handler) 79 | print_info(opt, log_path) 80 | logger.info('Output and logs will be saved to {}'.format(log_path)) 81 | 82 | return logger, log_path 83 | -------------------------------------------------------------------------------- /common/torch.py: -------------------------------------------------------------------------------- 1 | """PyTorch related utility functions 2 | """ 3 | 4 | import logging 5 | import os 6 | import pdb 7 | import shutil 8 | import sys 9 | import time 10 | import traceback 11 | 12 | import numpy as np 13 | import torch 14 | from torch.optim.optimizer import Optimizer 15 | 16 | 17 | def dict_all_to_device(tensor_dict, device): 18 | """Sends everything into a certain device """ 19 | for k in tensor_dict: 20 | if isinstance(tensor_dict[k], torch.Tensor): 21 | tensor_dict[k] = tensor_dict[k].to(device) 22 | 23 | 24 | def to_numpy(tensor): 25 | """Wrapper around .detach().cpu().numpy() """ 26 | if isinstance(tensor, torch.Tensor): 27 | return tensor.detach().cpu().numpy() 28 | elif isinstance(tensor, np.ndarray): 29 | return tensor 30 | else: 31 | raise NotImplementedError 32 | 33 | 34 | class CheckPointManager(object): 35 | """Manager for saving/managing pytorch checkpoints. 36 | 37 | Provides functionality similar to tf.Saver such as 38 | max_to_keep and keep_checkpoint_every_n_hours 39 | """ 40 | def __init__(self, save_path: str = None, max_to_keep=5, keep_checkpoint_every_n_hours=10000.0): 41 | 42 | if max_to_keep <= 0: 43 | raise ValueError('max_to_keep must be at least 1') 44 | 45 | self._max_to_keep = max_to_keep 46 | self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours 47 | 48 | self._ckpt_dir = os.path.dirname(save_path) 49 | self._save_path = save_path + '-{}.pth' if save_path is not None else None 50 | self._logger = logging.getLogger(self.__class__.__name__) 51 | self._checkpoints_fname = os.path.join(self._ckpt_dir, 'checkpoints.txt') 52 | 53 | self._checkpoints_permanent = [] # Will not be deleted 54 | self._checkpoints_buffer = [] # Those which might still be deleted 55 | self._next_save_time = time.time() 56 | self._best_score = -float('inf') 57 | self._best_step = None 58 | 59 | os.makedirs(self._ckpt_dir, exist_ok=True) 60 | self._update_checkpoints_file() 61 | 62 | def _save_checkpoint(self, step, model, optimizer, score): 63 | save_name = self._save_path.format(step) 64 | state = {'state_dict': model.state_dict(), 65 | 'optimizer': optimizer.state_dict(), 66 | 'step': step} 67 | torch.save(state, save_name) 68 | self._logger.info('Saved checkpoint: {}'.format(save_name)) 69 | 70 | self._checkpoints_buffer.append((save_name, time.time())) 71 | 72 | if score > self._best_score: 73 | best_save_name = self._save_path.format('best') 74 | shutil.copyfile(save_name, best_save_name) 75 | self._best_score = score 76 | self._best_step = step 77 | self._logger.info('Checkpoint is current best, score={:.3g}'.format(self._best_score)) 78 | 79 | def _remove_old_checkpoints(self): 80 | while len(self._checkpoints_buffer) > self._max_to_keep: 81 | to_remove = self._checkpoints_buffer.pop(0) 82 | 83 | if to_remove[1] > self._next_save_time: 84 | self._checkpoints_permanent.append(to_remove) 85 | self._next_save_time = to_remove[1] + self._keep_checkpoint_every_n_hours * 3600 86 | else: 87 | os.remove(to_remove[0]) 88 | 89 | def _update_checkpoints_file(self): 90 | checkpoints = [os.path.basename(c[0]) for c in self._checkpoints_permanent + self._checkpoints_buffer] 91 | with open(self._checkpoints_fname, 'w') as fid: 92 | fid.write('\n'.join(checkpoints)) 93 | fid.write('\nBest step: {}'.format(self._best_step)) 94 | 95 | def save(self, model: torch.nn.Module, optimizer: Optimizer, step: int, score: float = 0.0): 96 | """Save model checkpoint to file 97 | 98 | Args: 99 | model: Torch model 100 | optimizer: Torch optimizer 101 | step (int): Step, model will be saved as model-[step].pth 102 | score (float, optional): To determine which model is the best 103 | """ 104 | if self._save_path is None: 105 | raise AssertionError('Checkpoint manager must be initialized with save path for save().') 106 | 107 | self._save_checkpoint(step, model, optimizer, score) 108 | self._remove_old_checkpoints() 109 | self._update_checkpoints_file() 110 | 111 | def load(self, save_path, model: torch.nn.Module = None, optimizer: Optimizer = None): 112 | """Loads saved model from file 113 | 114 | Args: 115 | save_path: Path to saved model (.pth). If a directory is provided instead, model-best.pth is used 116 | model: Torch model to restore weights to 117 | optimizer: Optimizer 118 | """ 119 | if os.path.isdir(save_path): 120 | save_path = os.path.join(save_path, 'model-best.pth') 121 | 122 | state = torch.load(save_path) 123 | 124 | step = 0 125 | if 'step' in state: 126 | step = state['step'] 127 | 128 | if 'state_dict' in state and model is not None: 129 | model.load_state_dict(state['state_dict']) 130 | 131 | if 'optimizer' in state and optimizer is not None: 132 | optimizer.load_state_dict(state['optimizer']) 133 | 134 | self._logger.info('Loaded models from {}'.format(save_path)) 135 | return step 136 | 137 | 138 | class TorchDebugger(torch.autograd.detect_anomaly): 139 | """Enters debugger when anomaly detected""" 140 | def __enter__(self) -> None: 141 | super().__enter__() 142 | 143 | def __exit__(self, type, value, trace): 144 | super().__exit__() 145 | if isinstance(value, RuntimeError): 146 | traceback.print_tb(trace) 147 | print(value) 148 | if sys.gettrace() is None: 149 | pdb.set_trace() 150 | -------------------------------------------------------------------------------- /configs/benchmarks/3DLoMatch/sun3d-hotel_umd-maryland_hotel3/gt_overlap.log: -------------------------------------------------------------------------------- 1 | 0,1,0.5325 2 | 0,2,0.0683 3 | 0,3,0.0000 4 | 0,4,0.0000 5 | 0,5,0.0000 6 | 0,6,0.0000 7 | 0,7,0.0000 8 | 0,8,0.0000 9 | 0,9,0.0000 10 | 0,10,0.0269 11 | 0,11,0.1804 12 | 0,12,0.3909 13 | 0,13,0.3089 14 | 0,14,0.0004 15 | 0,15,0.0000 16 | 0,16,0.0000 17 | 0,17,0.0000 18 | 0,18,0.0000 19 | 0,19,0.0000 20 | 0,20,0.0000 21 | 0,21,0.0000 22 | 0,22,0.0263 23 | 0,23,0.0105 24 | 0,24,0.0098 25 | 0,25,0.0002 26 | 0,26,0.0000 27 | 0,27,0.1104 28 | 0,28,0.0002 29 | 0,29,0.0000 30 | 0,30,0.0000 31 | 0,31,0.0000 32 | 0,32,0.0000 33 | 0,33,0.0000 34 | 0,34,0.0000 35 | 0,35,0.0000 36 | 0,36,0.0000 37 | 1,2,0.3132 38 | 1,3,0.0537 39 | 1,4,0.0000 40 | 1,5,0.0000 41 | 1,6,0.0000 42 | 1,7,0.0000 43 | 1,8,0.0000 44 | 1,9,0.0000 45 | 1,10,0.0000 46 | 1,11,0.1110 47 | 1,12,0.2444 48 | 1,13,0.1772 49 | 1,14,0.0000 50 | 1,15,0.0000 51 | 1,16,0.0000 52 | 1,17,0.0000 53 | 1,18,0.0000 54 | 1,19,0.0000 55 | 1,20,0.0000 56 | 1,21,0.0000 57 | 1,22,0.0000 58 | 1,23,0.0000 59 | 1,24,0.0000 60 | 1,25,0.0000 61 | 1,26,0.0000 62 | 1,27,0.0436 63 | 1,28,0.0000 64 | 1,29,0.0000 65 | 1,30,0.0000 66 | 1,31,0.0000 67 | 1,32,0.0000 68 | 1,33,0.0000 69 | 1,34,0.0000 70 | 1,35,0.0000 71 | 1,36,0.0000 72 | 2,3,0.4890 73 | 2,4,0.0686 74 | 2,5,0.0600 75 | 2,6,0.0000 76 | 2,7,0.0000 77 | 2,8,0.0000 78 | 2,9,0.0000 79 | 2,10,0.0000 80 | 2,11,0.0000 81 | 2,12,0.0000 82 | 2,13,0.0000 83 | 2,14,0.0000 84 | 2,15,0.0000 85 | 2,16,0.0000 86 | 2,17,0.0000 87 | 2,18,0.0000 88 | 2,19,0.0000 89 | 2,20,0.0000 90 | 2,21,0.0000 91 | 2,22,0.0000 92 | 2,23,0.0000 93 | 2,24,0.0000 94 | 2,25,0.0000 95 | 2,26,0.0000 96 | 2,27,0.0000 97 | 2,28,0.0000 98 | 2,29,0.0000 99 | 2,30,0.0000 100 | 2,31,0.0000 101 | 2,32,0.0000 102 | 2,33,0.0000 103 | 2,34,0.0000 104 | 2,35,0.0000 105 | 2,36,0.0000 106 | 3,4,0.5921 107 | 3,5,0.2586 108 | 3,6,0.0192 109 | 3,7,0.0000 110 | 3,8,0.0000 111 | 3,9,0.0000 112 | 3,10,0.0000 113 | 3,11,0.0000 114 | 3,12,0.0000 115 | 3,13,0.0000 116 | 3,14,0.0000 117 | 3,15,0.0000 118 | 3,16,0.0000 119 | 3,17,0.0000 120 | 3,18,0.0000 121 | 3,19,0.0000 122 | 3,20,0.0000 123 | 3,21,0.0000 124 | 3,22,0.0000 125 | 3,23,0.0000 126 | 3,24,0.0000 127 | 3,25,0.0000 128 | 3,26,0.0000 129 | 3,27,0.0000 130 | 3,28,0.0000 131 | 3,29,0.0000 132 | 3,30,0.0000 133 | 3,31,0.0000 134 | 3,32,0.0000 135 | 3,33,0.0000 136 | 3,34,0.0000 137 | 3,35,0.0000 138 | 3,36,0.0000 139 | 4,5,0.3325 140 | 4,6,0.0784 141 | 4,7,0.0000 142 | 4,8,0.0000 143 | 4,9,0.0000 144 | 4,10,0.0000 145 | 4,11,0.0000 146 | 4,12,0.0000 147 | 4,13,0.0000 148 | 4,14,0.0000 149 | 4,15,0.0000 150 | 4,16,0.0000 151 | 4,17,0.0000 152 | 4,18,0.0000 153 | 4,19,0.0000 154 | 4,20,0.0000 155 | 4,21,0.0000 156 | 4,22,0.0000 157 | 4,23,0.0000 158 | 4,24,0.0000 159 | 4,25,0.0000 160 | 4,26,0.0000 161 | 4,27,0.0000 162 | 4,28,0.0000 163 | 4,29,0.0000 164 | 4,30,0.0000 165 | 4,31,0.0000 166 | 4,32,0.0000 167 | 4,33,0.0000 168 | 4,34,0.0000 169 | 4,35,0.0000 170 | 4,36,0.0000 171 | 5,6,0.4993 172 | 5,7,0.1245 173 | 5,8,0.0000 174 | 5,9,0.0000 175 | 5,10,0.0000 176 | 5,11,0.0000 177 | 5,12,0.0000 178 | 5,13,0.0000 179 | 5,14,0.0000 180 | 5,15,0.0000 181 | 5,16,0.0000 182 | 5,17,0.0000 183 | 5,18,0.0000 184 | 5,19,0.0000 185 | 5,20,0.0000 186 | 5,21,0.0000 187 | 5,22,0.0000 188 | 5,23,0.0000 189 | 5,24,0.0000 190 | 5,25,0.0000 191 | 5,26,0.0000 192 | 5,27,0.0000 193 | 5,28,0.0000 194 | 5,29,0.0000 195 | 5,30,0.0000 196 | 5,31,0.0000 197 | 5,32,0.0000 198 | 5,33,0.0000 199 | 5,34,0.0000 200 | 5,35,0.0000 201 | 5,36,0.0000 202 | 6,7,0.4663 203 | 6,8,0.0432 204 | 6,9,0.0000 205 | 6,10,0.0000 206 | 6,11,0.0000 207 | 6,12,0.0000 208 | 6,13,0.0000 209 | 6,14,0.0000 210 | 6,15,0.0000 211 | 6,16,0.0000 212 | 6,17,0.0000 213 | 6,18,0.0000 214 | 6,19,0.0000 215 | 6,20,0.0000 216 | 6,21,0.0000 217 | 6,22,0.0000 218 | 6,23,0.0000 219 | 6,24,0.0000 220 | 6,25,0.0000 221 | 6,26,0.0000 222 | 6,27,0.0000 223 | 6,28,0.0000 224 | 6,29,0.0000 225 | 6,30,0.0000 226 | 6,31,0.0000 227 | 6,32,0.0000 228 | 6,33,0.0000 229 | 6,34,0.0000 230 | 6,35,0.0000 231 | 6,36,0.0000 232 | 7,8,0.3918 233 | 7,9,0.0272 234 | 7,10,0.0796 235 | 7,11,0.0000 236 | 7,12,0.0000 237 | 7,13,0.0000 238 | 7,14,0.0605 239 | 7,15,0.0651 240 | 7,16,0.0000 241 | 7,17,0.0000 242 | 7,18,0.0000 243 | 7,19,0.0000 244 | 7,20,0.0000 245 | 7,21,0.0000 246 | 7,22,0.0000 247 | 7,23,0.0000 248 | 7,24,0.0000 249 | 7,25,0.0000 250 | 7,26,0.0000 251 | 7,27,0.0000 252 | 7,28,0.0000 253 | 7,29,0.0000 254 | 7,30,0.0000 255 | 7,31,0.0000 256 | 7,32,0.0000 257 | 7,33,0.0000 258 | 7,34,0.0000 259 | 7,35,0.0000 260 | 7,36,0.0000 261 | 8,9,0.4565 262 | 8,10,0.3677 263 | 8,11,0.0000 264 | 8,12,0.0000 265 | 8,13,0.0000 266 | 8,14,0.1760 267 | 8,15,0.4993 268 | 8,16,0.2940 269 | 8,17,0.0577 270 | 8,18,0.0006 271 | 8,19,0.0000 272 | 8,20,0.0000 273 | 8,21,0.0000 274 | 8,22,0.0000 275 | 8,23,0.0000 276 | 8,24,0.0000 277 | 8,25,0.0000 278 | 8,26,0.0000 279 | 8,27,0.0000 280 | 8,28,0.0000 281 | 8,29,0.0000 282 | 8,30,0.0000 283 | 8,31,0.0000 284 | 8,32,0.0000 285 | 8,33,0.0000 286 | 8,34,0.0000 287 | 8,35,0.0000 288 | 8,36,0.0000 289 | 9,10,0.2301 290 | 9,11,0.0546 291 | 9,12,0.0000 292 | 9,13,0.0274 293 | 9,14,0.0741 294 | 9,15,0.2451 295 | 9,16,0.1903 296 | 9,17,0.0907 297 | 9,18,0.0404 298 | 9,19,0.0087 299 | 9,20,0.0000 300 | 9,21,0.0000 301 | 9,22,0.0000 302 | 9,23,0.0000 303 | 9,24,0.0000 304 | 9,25,0.0000 305 | 9,26,0.0000 306 | 9,27,0.0000 307 | 9,28,0.0000 308 | 9,29,0.0000 309 | 9,30,0.0000 310 | 9,31,0.0000 311 | 9,32,0.0000 312 | 9,33,0.0000 313 | 9,34,0.0000 314 | 9,35,0.0000 315 | 9,36,0.0000 316 | 10,11,0.2857 317 | 10,12,0.0900 318 | 10,13,0.1489 319 | 10,14,0.5227 320 | 10,15,0.5015 321 | 10,16,0.4256 322 | 10,17,0.2176 323 | 10,18,0.0341 324 | 10,19,0.0106 325 | 10,20,0.0000 326 | 10,21,0.0000 327 | 10,22,0.0000 328 | 10,23,0.0005 329 | 10,24,0.0018 330 | 10,25,0.0010 331 | 10,26,0.0000 332 | 10,27,0.0011 333 | 10,28,0.0007 334 | 10,29,0.0000 335 | 10,30,0.0000 336 | 10,31,0.0000 337 | 10,32,0.0000 338 | 10,33,0.0000 339 | 10,34,0.0000 340 | 10,35,0.0000 341 | 10,36,0.0000 342 | 11,12,0.6637 343 | 11,13,0.7383 344 | 11,14,0.2992 345 | 11,15,0.0000 346 | 11,16,0.0219 347 | 11,17,0.0121 348 | 11,18,0.0000 349 | 11,19,0.0000 350 | 11,20,0.0000 351 | 11,21,0.0006 352 | 11,22,0.0436 353 | 11,23,0.0238 354 | 11,24,0.0194 355 | 11,25,0.0047 356 | 11,26,0.0028 357 | 11,27,0.1267 358 | 11,28,0.0028 359 | 11,29,0.0000 360 | 11,30,0.0000 361 | 11,31,0.0000 362 | 11,32,0.0000 363 | 11,33,0.0000 364 | 11,34,0.0000 365 | 11,35,0.0000 366 | 11,36,0.0000 367 | 12,13,0.5171 368 | 12,14,0.0847 369 | 12,15,0.0000 370 | 12,16,0.0000 371 | 12,17,0.0000 372 | 12,18,0.0000 373 | 12,19,0.0000 374 | 12,20,0.0000 375 | 12,21,0.0008 376 | 12,22,0.0525 377 | 12,23,0.0422 378 | 12,24,0.0159 379 | 12,25,0.0030 380 | 12,26,0.0073 381 | 12,27,0.1169 382 | 12,28,0.0028 383 | 12,29,0.0093 384 | 12,30,0.0000 385 | 12,31,0.0000 386 | 12,32,0.0000 387 | 12,33,0.0000 388 | 12,34,0.0000 389 | 12,35,0.0000 390 | 12,36,0.0000 391 | 13,14,0.1723 392 | 13,15,0.0003 393 | 13,16,0.0000 394 | 13,17,0.0000 395 | 13,18,0.0000 396 | 13,19,0.0000 397 | 13,20,0.0000 398 | 13,21,0.0017 399 | 13,22,0.0478 400 | 13,23,0.0371 401 | 13,24,0.0210 402 | 13,25,0.0049 403 | 13,26,0.0000 404 | 13,27,0.1237 405 | 13,28,0.0030 406 | 13,29,0.0050 407 | 13,30,0.0000 408 | 13,31,0.0000 409 | 13,32,0.0000 410 | 13,33,0.0000 411 | 13,34,0.0000 412 | 13,35,0.0000 413 | 13,36,0.0000 414 | 14,15,0.3409 415 | 14,16,0.2692 416 | 14,17,0.2338 417 | 14,18,0.0061 418 | 14,19,0.0075 419 | 14,20,0.0079 420 | 14,21,0.0065 421 | 14,22,0.0011 422 | 14,23,0.0043 423 | 14,24,0.0063 424 | 14,25,0.0023 425 | 14,26,0.0000 426 | 14,27,0.0045 427 | 14,28,0.0027 428 | 14,29,0.0000 429 | 14,30,0.0000 430 | 14,31,0.0000 431 | 14,32,0.0000 432 | 14,33,0.0000 433 | 14,34,0.0000 434 | 14,35,0.0000 435 | 14,36,0.0002 436 | 15,16,0.6412 437 | 15,17,0.3784 438 | 15,18,0.0479 439 | 15,19,0.0572 440 | 15,20,0.0059 441 | 15,21,0.0000 442 | 15,22,0.0000 443 | 15,23,0.0000 444 | 15,24,0.0000 445 | 15,25,0.0000 446 | 15,26,0.0000 447 | 15,27,0.0000 448 | 15,28,0.0000 449 | 15,29,0.0000 450 | 15,30,0.0000 451 | 15,31,0.0000 452 | 15,32,0.0000 453 | 15,33,0.0000 454 | 15,34,0.0000 455 | 15,35,0.0000 456 | 15,36,0.0000 457 | 16,17,0.6605 458 | 16,18,0.1551 459 | 16,19,0.1564 460 | 16,20,0.0070 461 | 16,21,0.0037 462 | 16,22,0.0000 463 | 16,23,0.0000 464 | 16,24,0.0000 465 | 16,25,0.0000 466 | 16,26,0.0000 467 | 16,27,0.0000 468 | 16,28,0.0000 469 | 16,29,0.0000 470 | 16,30,0.0000 471 | 16,31,0.0000 472 | 16,32,0.0000 473 | 16,33,0.0000 474 | 16,34,0.0000 475 | 16,35,0.0003 476 | 16,36,0.0065 477 | 17,18,0.4449 478 | 17,19,0.5912 479 | 17,20,0.3402 480 | 17,21,0.0151 481 | 17,22,0.0000 482 | 17,23,0.0000 483 | 17,24,0.0000 484 | 17,25,0.0000 485 | 17,26,0.0000 486 | 17,27,0.0000 487 | 17,28,0.0000 488 | 17,29,0.0000 489 | 17,30,0.0000 490 | 17,31,0.0000 491 | 17,32,0.0000 492 | 17,33,0.0144 493 | 17,34,0.0610 494 | 17,35,0.0111 495 | 17,36,0.0755 496 | 18,19,0.5922 497 | 18,20,0.2836 498 | 18,21,0.0132 499 | 18,22,0.0000 500 | 18,23,0.0000 501 | 18,24,0.0000 502 | 18,25,0.0000 503 | 18,26,0.0000 504 | 18,27,0.0000 505 | 18,28,0.0000 506 | 18,29,0.0000 507 | 18,30,0.0000 508 | 18,31,0.0000 509 | 18,32,0.0000 510 | 18,33,0.0000 511 | 18,34,0.0055 512 | 18,35,0.0089 513 | 18,36,0.0382 514 | 19,20,0.4955 515 | 19,21,0.0496 516 | 19,22,0.0000 517 | 19,23,0.0000 518 | 19,24,0.0000 519 | 19,25,0.0000 520 | 19,26,0.0000 521 | 19,27,0.0000 522 | 19,28,0.0000 523 | 19,29,0.0000 524 | 19,30,0.0000 525 | 19,31,0.0000 526 | 19,32,0.0000 527 | 19,33,0.0274 528 | 19,34,0.0596 529 | 19,35,0.0140 530 | 19,36,0.0950 531 | 20,21,0.3810 532 | 20,22,0.0238 533 | 20,23,0.0040 534 | 20,24,0.0045 535 | 20,25,0.0010 536 | 20,26,0.0000 537 | 20,27,0.0045 538 | 20,28,0.0062 539 | 20,29,0.0000 540 | 20,30,0.0053 541 | 20,31,0.0576 542 | 20,32,0.0000 543 | 20,33,0.1121 544 | 20,34,0.1527 545 | 20,35,0.0143 546 | 20,36,0.1041 547 | 21,22,0.5275 548 | 21,23,0.3011 549 | 21,24,0.1952 550 | 21,25,0.0320 551 | 21,26,0.0016 552 | 21,27,0.0899 553 | 21,28,0.3441 554 | 21,29,0.2048 555 | 21,30,0.2257 556 | 21,31,0.2364 557 | 21,32,0.0316 558 | 21,33,0.0744 559 | 21,34,0.0670 560 | 21,35,0.0050 561 | 21,36,0.0172 562 | 22,23,0.7330 563 | 22,24,0.2303 564 | 22,25,0.0266 565 | 22,26,0.0267 566 | 22,27,0.1673 567 | 22,28,0.3191 568 | 22,29,0.4485 569 | 22,30,0.2956 570 | 22,31,0.1013 571 | 22,32,0.0161 572 | 22,33,0.0044 573 | 22,34,0.0085 574 | 22,35,0.0046 575 | 22,36,0.0075 576 | 23,24,0.2493 577 | 23,25,0.0259 578 | 23,26,0.0426 579 | 23,27,0.2199 580 | 23,28,0.3316 581 | 23,29,0.5661 582 | 23,30,0.3621 583 | 23,31,0.0670 584 | 23,32,0.0031 585 | 23,33,0.0000 586 | 23,34,0.0000 587 | 23,35,0.0000 588 | 23,36,0.0000 589 | 24,25,0.2683 590 | 24,26,0.0284 591 | 24,27,0.4720 592 | 24,28,0.6725 593 | 24,29,0.3175 594 | 24,30,0.1250 595 | 24,31,0.0131 596 | 24,32,0.0000 597 | 24,33,0.0000 598 | 24,34,0.0000 599 | 24,35,0.0000 600 | 24,36,0.0000 601 | 25,26,0.2499 602 | 25,27,0.4611 603 | 25,28,0.2101 604 | 25,29,0.0000 605 | 25,30,0.0000 606 | 25,31,0.0000 607 | 25,32,0.0000 608 | 25,33,0.0000 609 | 25,34,0.0000 610 | 25,35,0.0000 611 | 25,36,0.0000 612 | 26,27,0.7114 613 | 26,28,0.0033 614 | 26,29,0.0000 615 | 26,30,0.0000 616 | 26,31,0.0000 617 | 26,32,0.0000 618 | 26,33,0.0000 619 | 26,34,0.0000 620 | 26,35,0.0000 621 | 26,36,0.0000 622 | 27,28,0.1769 623 | 27,29,0.0335 624 | 27,30,0.0000 625 | 27,31,0.0000 626 | 27,32,0.0000 627 | 27,33,0.0000 628 | 27,34,0.0000 629 | 27,35,0.0000 630 | 27,36,0.0000 631 | 28,29,0.5445 632 | 28,30,0.3390 633 | 28,31,0.1712 634 | 28,32,0.0000 635 | 28,33,0.0000 636 | 28,34,0.0000 637 | 28,35,0.0000 638 | 28,36,0.0000 639 | 29,30,0.6915 640 | 29,31,0.1352 641 | 29,32,0.0507 642 | 29,33,0.0016 643 | 29,34,0.0000 644 | 29,35,0.0000 645 | 29,36,0.0000 646 | 30,31,0.4260 647 | 30,32,0.2778 648 | 30,33,0.1563 649 | 30,34,0.0450 650 | 30,35,0.0000 651 | 30,36,0.0080 652 | 31,32,0.4631 653 | 31,33,0.5335 654 | 31,34,0.3379 655 | 31,35,0.1129 656 | 31,36,0.2214 657 | 32,33,0.6403 658 | 32,34,0.3848 659 | 32,35,0.3697 660 | 32,36,0.3271 661 | 33,34,0.7574 662 | 33,35,0.3338 663 | 33,36,0.4225 664 | 34,35,0.4932 665 | 34,36,0.5212 666 | 35,36,0.5585 667 | -------------------------------------------------------------------------------- /configs/benchmarks/3DMatch/sun3d-hotel_umd-maryland_hotel3/gt_overlap.log: -------------------------------------------------------------------------------- 1 | 0,1,0.5325 2 | 0,2,0.0683 3 | 0,3,0.0000 4 | 0,4,0.0000 5 | 0,5,0.0000 6 | 0,6,0.0000 7 | 0,7,0.0000 8 | 0,8,0.0000 9 | 0,9,0.0000 10 | 0,10,0.0269 11 | 0,11,0.1804 12 | 0,12,0.3909 13 | 0,13,0.3089 14 | 0,14,0.0004 15 | 0,15,0.0000 16 | 0,16,0.0000 17 | 0,17,0.0000 18 | 0,18,0.0000 19 | 0,19,0.0000 20 | 0,20,0.0000 21 | 0,21,0.0000 22 | 0,22,0.0263 23 | 0,23,0.0105 24 | 0,24,0.0098 25 | 0,25,0.0002 26 | 0,26,0.0000 27 | 0,27,0.1104 28 | 0,28,0.0002 29 | 0,29,0.0000 30 | 0,30,0.0000 31 | 0,31,0.0000 32 | 0,32,0.0000 33 | 0,33,0.0000 34 | 0,34,0.0000 35 | 0,35,0.0000 36 | 0,36,0.0000 37 | 1,2,0.3132 38 | 1,3,0.0537 39 | 1,4,0.0000 40 | 1,5,0.0000 41 | 1,6,0.0000 42 | 1,7,0.0000 43 | 1,8,0.0000 44 | 1,9,0.0000 45 | 1,10,0.0000 46 | 1,11,0.1110 47 | 1,12,0.2444 48 | 1,13,0.1772 49 | 1,14,0.0000 50 | 1,15,0.0000 51 | 1,16,0.0000 52 | 1,17,0.0000 53 | 1,18,0.0000 54 | 1,19,0.0000 55 | 1,20,0.0000 56 | 1,21,0.0000 57 | 1,22,0.0000 58 | 1,23,0.0000 59 | 1,24,0.0000 60 | 1,25,0.0000 61 | 1,26,0.0000 62 | 1,27,0.0436 63 | 1,28,0.0000 64 | 1,29,0.0000 65 | 1,30,0.0000 66 | 1,31,0.0000 67 | 1,32,0.0000 68 | 1,33,0.0000 69 | 1,34,0.0000 70 | 1,35,0.0000 71 | 1,36,0.0000 72 | 2,3,0.4890 73 | 2,4,0.0686 74 | 2,5,0.0600 75 | 2,6,0.0000 76 | 2,7,0.0000 77 | 2,8,0.0000 78 | 2,9,0.0000 79 | 2,10,0.0000 80 | 2,11,0.0000 81 | 2,12,0.0000 82 | 2,13,0.0000 83 | 2,14,0.0000 84 | 2,15,0.0000 85 | 2,16,0.0000 86 | 2,17,0.0000 87 | 2,18,0.0000 88 | 2,19,0.0000 89 | 2,20,0.0000 90 | 2,21,0.0000 91 | 2,22,0.0000 92 | 2,23,0.0000 93 | 2,24,0.0000 94 | 2,25,0.0000 95 | 2,26,0.0000 96 | 2,27,0.0000 97 | 2,28,0.0000 98 | 2,29,0.0000 99 | 2,30,0.0000 100 | 2,31,0.0000 101 | 2,32,0.0000 102 | 2,33,0.0000 103 | 2,34,0.0000 104 | 2,35,0.0000 105 | 2,36,0.0000 106 | 3,4,0.5921 107 | 3,5,0.2586 108 | 3,6,0.0192 109 | 3,7,0.0000 110 | 3,8,0.0000 111 | 3,9,0.0000 112 | 3,10,0.0000 113 | 3,11,0.0000 114 | 3,12,0.0000 115 | 3,13,0.0000 116 | 3,14,0.0000 117 | 3,15,0.0000 118 | 3,16,0.0000 119 | 3,17,0.0000 120 | 3,18,0.0000 121 | 3,19,0.0000 122 | 3,20,0.0000 123 | 3,21,0.0000 124 | 3,22,0.0000 125 | 3,23,0.0000 126 | 3,24,0.0000 127 | 3,25,0.0000 128 | 3,26,0.0000 129 | 3,27,0.0000 130 | 3,28,0.0000 131 | 3,29,0.0000 132 | 3,30,0.0000 133 | 3,31,0.0000 134 | 3,32,0.0000 135 | 3,33,0.0000 136 | 3,34,0.0000 137 | 3,35,0.0000 138 | 3,36,0.0000 139 | 4,5,0.3325 140 | 4,6,0.0784 141 | 4,7,0.0000 142 | 4,8,0.0000 143 | 4,9,0.0000 144 | 4,10,0.0000 145 | 4,11,0.0000 146 | 4,12,0.0000 147 | 4,13,0.0000 148 | 4,14,0.0000 149 | 4,15,0.0000 150 | 4,16,0.0000 151 | 4,17,0.0000 152 | 4,18,0.0000 153 | 4,19,0.0000 154 | 4,20,0.0000 155 | 4,21,0.0000 156 | 4,22,0.0000 157 | 4,23,0.0000 158 | 4,24,0.0000 159 | 4,25,0.0000 160 | 4,26,0.0000 161 | 4,27,0.0000 162 | 4,28,0.0000 163 | 4,29,0.0000 164 | 4,30,0.0000 165 | 4,31,0.0000 166 | 4,32,0.0000 167 | 4,33,0.0000 168 | 4,34,0.0000 169 | 4,35,0.0000 170 | 4,36,0.0000 171 | 5,6,0.4993 172 | 5,7,0.1245 173 | 5,8,0.0000 174 | 5,9,0.0000 175 | 5,10,0.0000 176 | 5,11,0.0000 177 | 5,12,0.0000 178 | 5,13,0.0000 179 | 5,14,0.0000 180 | 5,15,0.0000 181 | 5,16,0.0000 182 | 5,17,0.0000 183 | 5,18,0.0000 184 | 5,19,0.0000 185 | 5,20,0.0000 186 | 5,21,0.0000 187 | 5,22,0.0000 188 | 5,23,0.0000 189 | 5,24,0.0000 190 | 5,25,0.0000 191 | 5,26,0.0000 192 | 5,27,0.0000 193 | 5,28,0.0000 194 | 5,29,0.0000 195 | 5,30,0.0000 196 | 5,31,0.0000 197 | 5,32,0.0000 198 | 5,33,0.0000 199 | 5,34,0.0000 200 | 5,35,0.0000 201 | 5,36,0.0000 202 | 6,7,0.4663 203 | 6,8,0.0432 204 | 6,9,0.0000 205 | 6,10,0.0000 206 | 6,11,0.0000 207 | 6,12,0.0000 208 | 6,13,0.0000 209 | 6,14,0.0000 210 | 6,15,0.0000 211 | 6,16,0.0000 212 | 6,17,0.0000 213 | 6,18,0.0000 214 | 6,19,0.0000 215 | 6,20,0.0000 216 | 6,21,0.0000 217 | 6,22,0.0000 218 | 6,23,0.0000 219 | 6,24,0.0000 220 | 6,25,0.0000 221 | 6,26,0.0000 222 | 6,27,0.0000 223 | 6,28,0.0000 224 | 6,29,0.0000 225 | 6,30,0.0000 226 | 6,31,0.0000 227 | 6,32,0.0000 228 | 6,33,0.0000 229 | 6,34,0.0000 230 | 6,35,0.0000 231 | 6,36,0.0000 232 | 7,8,0.3918 233 | 7,9,0.0272 234 | 7,10,0.0796 235 | 7,11,0.0000 236 | 7,12,0.0000 237 | 7,13,0.0000 238 | 7,14,0.0605 239 | 7,15,0.0651 240 | 7,16,0.0000 241 | 7,17,0.0000 242 | 7,18,0.0000 243 | 7,19,0.0000 244 | 7,20,0.0000 245 | 7,21,0.0000 246 | 7,22,0.0000 247 | 7,23,0.0000 248 | 7,24,0.0000 249 | 7,25,0.0000 250 | 7,26,0.0000 251 | 7,27,0.0000 252 | 7,28,0.0000 253 | 7,29,0.0000 254 | 7,30,0.0000 255 | 7,31,0.0000 256 | 7,32,0.0000 257 | 7,33,0.0000 258 | 7,34,0.0000 259 | 7,35,0.0000 260 | 7,36,0.0000 261 | 8,9,0.4565 262 | 8,10,0.3677 263 | 8,11,0.0000 264 | 8,12,0.0000 265 | 8,13,0.0000 266 | 8,14,0.1760 267 | 8,15,0.4993 268 | 8,16,0.2940 269 | 8,17,0.0577 270 | 8,18,0.0006 271 | 8,19,0.0000 272 | 8,20,0.0000 273 | 8,21,0.0000 274 | 8,22,0.0000 275 | 8,23,0.0000 276 | 8,24,0.0000 277 | 8,25,0.0000 278 | 8,26,0.0000 279 | 8,27,0.0000 280 | 8,28,0.0000 281 | 8,29,0.0000 282 | 8,30,0.0000 283 | 8,31,0.0000 284 | 8,32,0.0000 285 | 8,33,0.0000 286 | 8,34,0.0000 287 | 8,35,0.0000 288 | 8,36,0.0000 289 | 9,10,0.2301 290 | 9,11,0.0546 291 | 9,12,0.0000 292 | 9,13,0.0274 293 | 9,14,0.0741 294 | 9,15,0.2451 295 | 9,16,0.1903 296 | 9,17,0.0907 297 | 9,18,0.0404 298 | 9,19,0.0087 299 | 9,20,0.0000 300 | 9,21,0.0000 301 | 9,22,0.0000 302 | 9,23,0.0000 303 | 9,24,0.0000 304 | 9,25,0.0000 305 | 9,26,0.0000 306 | 9,27,0.0000 307 | 9,28,0.0000 308 | 9,29,0.0000 309 | 9,30,0.0000 310 | 9,31,0.0000 311 | 9,32,0.0000 312 | 9,33,0.0000 313 | 9,34,0.0000 314 | 9,35,0.0000 315 | 9,36,0.0000 316 | 10,11,0.2857 317 | 10,12,0.0900 318 | 10,13,0.1489 319 | 10,14,0.5227 320 | 10,15,0.5015 321 | 10,16,0.4256 322 | 10,17,0.2176 323 | 10,18,0.0341 324 | 10,19,0.0106 325 | 10,20,0.0000 326 | 10,21,0.0000 327 | 10,22,0.0000 328 | 10,23,0.0005 329 | 10,24,0.0018 330 | 10,25,0.0010 331 | 10,26,0.0000 332 | 10,27,0.0011 333 | 10,28,0.0007 334 | 10,29,0.0000 335 | 10,30,0.0000 336 | 10,31,0.0000 337 | 10,32,0.0000 338 | 10,33,0.0000 339 | 10,34,0.0000 340 | 10,35,0.0000 341 | 10,36,0.0000 342 | 11,12,0.6637 343 | 11,13,0.7383 344 | 11,14,0.2992 345 | 11,15,0.0000 346 | 11,16,0.0219 347 | 11,17,0.0121 348 | 11,18,0.0000 349 | 11,19,0.0000 350 | 11,20,0.0000 351 | 11,21,0.0006 352 | 11,22,0.0436 353 | 11,23,0.0238 354 | 11,24,0.0194 355 | 11,25,0.0047 356 | 11,26,0.0028 357 | 11,27,0.1267 358 | 11,28,0.0028 359 | 11,29,0.0000 360 | 11,30,0.0000 361 | 11,31,0.0000 362 | 11,32,0.0000 363 | 11,33,0.0000 364 | 11,34,0.0000 365 | 11,35,0.0000 366 | 11,36,0.0000 367 | 12,13,0.5171 368 | 12,14,0.0847 369 | 12,15,0.0000 370 | 12,16,0.0000 371 | 12,17,0.0000 372 | 12,18,0.0000 373 | 12,19,0.0000 374 | 12,20,0.0000 375 | 12,21,0.0008 376 | 12,22,0.0525 377 | 12,23,0.0422 378 | 12,24,0.0159 379 | 12,25,0.0030 380 | 12,26,0.0073 381 | 12,27,0.1169 382 | 12,28,0.0028 383 | 12,29,0.0093 384 | 12,30,0.0000 385 | 12,31,0.0000 386 | 12,32,0.0000 387 | 12,33,0.0000 388 | 12,34,0.0000 389 | 12,35,0.0000 390 | 12,36,0.0000 391 | 13,14,0.1723 392 | 13,15,0.0003 393 | 13,16,0.0000 394 | 13,17,0.0000 395 | 13,18,0.0000 396 | 13,19,0.0000 397 | 13,20,0.0000 398 | 13,21,0.0017 399 | 13,22,0.0478 400 | 13,23,0.0371 401 | 13,24,0.0210 402 | 13,25,0.0049 403 | 13,26,0.0000 404 | 13,27,0.1237 405 | 13,28,0.0030 406 | 13,29,0.0050 407 | 13,30,0.0000 408 | 13,31,0.0000 409 | 13,32,0.0000 410 | 13,33,0.0000 411 | 13,34,0.0000 412 | 13,35,0.0000 413 | 13,36,0.0000 414 | 14,15,0.3409 415 | 14,16,0.2692 416 | 14,17,0.2338 417 | 14,18,0.0061 418 | 14,19,0.0075 419 | 14,20,0.0079 420 | 14,21,0.0065 421 | 14,22,0.0011 422 | 14,23,0.0043 423 | 14,24,0.0063 424 | 14,25,0.0023 425 | 14,26,0.0000 426 | 14,27,0.0045 427 | 14,28,0.0027 428 | 14,29,0.0000 429 | 14,30,0.0000 430 | 14,31,0.0000 431 | 14,32,0.0000 432 | 14,33,0.0000 433 | 14,34,0.0000 434 | 14,35,0.0000 435 | 14,36,0.0002 436 | 15,16,0.6412 437 | 15,17,0.3784 438 | 15,18,0.0479 439 | 15,19,0.0572 440 | 15,20,0.0059 441 | 15,21,0.0000 442 | 15,22,0.0000 443 | 15,23,0.0000 444 | 15,24,0.0000 445 | 15,25,0.0000 446 | 15,26,0.0000 447 | 15,27,0.0000 448 | 15,28,0.0000 449 | 15,29,0.0000 450 | 15,30,0.0000 451 | 15,31,0.0000 452 | 15,32,0.0000 453 | 15,33,0.0000 454 | 15,34,0.0000 455 | 15,35,0.0000 456 | 15,36,0.0000 457 | 16,17,0.6605 458 | 16,18,0.1551 459 | 16,19,0.1564 460 | 16,20,0.0070 461 | 16,21,0.0037 462 | 16,22,0.0000 463 | 16,23,0.0000 464 | 16,24,0.0000 465 | 16,25,0.0000 466 | 16,26,0.0000 467 | 16,27,0.0000 468 | 16,28,0.0000 469 | 16,29,0.0000 470 | 16,30,0.0000 471 | 16,31,0.0000 472 | 16,32,0.0000 473 | 16,33,0.0000 474 | 16,34,0.0000 475 | 16,35,0.0003 476 | 16,36,0.0065 477 | 17,18,0.4449 478 | 17,19,0.5912 479 | 17,20,0.3402 480 | 17,21,0.0151 481 | 17,22,0.0000 482 | 17,23,0.0000 483 | 17,24,0.0000 484 | 17,25,0.0000 485 | 17,26,0.0000 486 | 17,27,0.0000 487 | 17,28,0.0000 488 | 17,29,0.0000 489 | 17,30,0.0000 490 | 17,31,0.0000 491 | 17,32,0.0000 492 | 17,33,0.0144 493 | 17,34,0.0610 494 | 17,35,0.0111 495 | 17,36,0.0755 496 | 18,19,0.5922 497 | 18,20,0.2836 498 | 18,21,0.0132 499 | 18,22,0.0000 500 | 18,23,0.0000 501 | 18,24,0.0000 502 | 18,25,0.0000 503 | 18,26,0.0000 504 | 18,27,0.0000 505 | 18,28,0.0000 506 | 18,29,0.0000 507 | 18,30,0.0000 508 | 18,31,0.0000 509 | 18,32,0.0000 510 | 18,33,0.0000 511 | 18,34,0.0055 512 | 18,35,0.0089 513 | 18,36,0.0382 514 | 19,20,0.4955 515 | 19,21,0.0496 516 | 19,22,0.0000 517 | 19,23,0.0000 518 | 19,24,0.0000 519 | 19,25,0.0000 520 | 19,26,0.0000 521 | 19,27,0.0000 522 | 19,28,0.0000 523 | 19,29,0.0000 524 | 19,30,0.0000 525 | 19,31,0.0000 526 | 19,32,0.0000 527 | 19,33,0.0274 528 | 19,34,0.0596 529 | 19,35,0.0140 530 | 19,36,0.0950 531 | 20,21,0.3810 532 | 20,22,0.0238 533 | 20,23,0.0040 534 | 20,24,0.0045 535 | 20,25,0.0010 536 | 20,26,0.0000 537 | 20,27,0.0045 538 | 20,28,0.0062 539 | 20,29,0.0000 540 | 20,30,0.0053 541 | 20,31,0.0576 542 | 20,32,0.0000 543 | 20,33,0.1121 544 | 20,34,0.1527 545 | 20,35,0.0143 546 | 20,36,0.1041 547 | 21,22,0.5275 548 | 21,23,0.3011 549 | 21,24,0.1952 550 | 21,25,0.0320 551 | 21,26,0.0016 552 | 21,27,0.0899 553 | 21,28,0.3441 554 | 21,29,0.2048 555 | 21,30,0.2257 556 | 21,31,0.2364 557 | 21,32,0.0316 558 | 21,33,0.0744 559 | 21,34,0.0670 560 | 21,35,0.0050 561 | 21,36,0.0172 562 | 22,23,0.7330 563 | 22,24,0.2303 564 | 22,25,0.0266 565 | 22,26,0.0267 566 | 22,27,0.1673 567 | 22,28,0.3191 568 | 22,29,0.4485 569 | 22,30,0.2956 570 | 22,31,0.1013 571 | 22,32,0.0161 572 | 22,33,0.0044 573 | 22,34,0.0085 574 | 22,35,0.0046 575 | 22,36,0.0075 576 | 23,24,0.2493 577 | 23,25,0.0259 578 | 23,26,0.0426 579 | 23,27,0.2199 580 | 23,28,0.3316 581 | 23,29,0.5661 582 | 23,30,0.3621 583 | 23,31,0.0670 584 | 23,32,0.0031 585 | 23,33,0.0000 586 | 23,34,0.0000 587 | 23,35,0.0000 588 | 23,36,0.0000 589 | 24,25,0.2683 590 | 24,26,0.0284 591 | 24,27,0.4720 592 | 24,28,0.6725 593 | 24,29,0.3175 594 | 24,30,0.1250 595 | 24,31,0.0131 596 | 24,32,0.0000 597 | 24,33,0.0000 598 | 24,34,0.0000 599 | 24,35,0.0000 600 | 24,36,0.0000 601 | 25,26,0.2499 602 | 25,27,0.4611 603 | 25,28,0.2101 604 | 25,29,0.0000 605 | 25,30,0.0000 606 | 25,31,0.0000 607 | 25,32,0.0000 608 | 25,33,0.0000 609 | 25,34,0.0000 610 | 25,35,0.0000 611 | 25,36,0.0000 612 | 26,27,0.7114 613 | 26,28,0.0033 614 | 26,29,0.0000 615 | 26,30,0.0000 616 | 26,31,0.0000 617 | 26,32,0.0000 618 | 26,33,0.0000 619 | 26,34,0.0000 620 | 26,35,0.0000 621 | 26,36,0.0000 622 | 27,28,0.1769 623 | 27,29,0.0335 624 | 27,30,0.0000 625 | 27,31,0.0000 626 | 27,32,0.0000 627 | 27,33,0.0000 628 | 27,34,0.0000 629 | 27,35,0.0000 630 | 27,36,0.0000 631 | 28,29,0.5445 632 | 28,30,0.3390 633 | 28,31,0.1712 634 | 28,32,0.0000 635 | 28,33,0.0000 636 | 28,34,0.0000 637 | 28,35,0.0000 638 | 28,36,0.0000 639 | 29,30,0.6915 640 | 29,31,0.1352 641 | 29,32,0.0507 642 | 29,33,0.0016 643 | 29,34,0.0000 644 | 29,35,0.0000 645 | 29,36,0.0000 646 | 30,31,0.4260 647 | 30,32,0.2778 648 | 30,33,0.1563 649 | 30,34,0.0450 650 | 30,35,0.0000 651 | 30,36,0.0080 652 | 31,32,0.4631 653 | 31,33,0.5335 654 | 31,34,0.3379 655 | 31,35,0.1129 656 | 31,36,0.2214 657 | 32,33,0.6403 658 | 32,34,0.3848 659 | 32,35,0.3697 660 | 32,36,0.3271 661 | 33,34,0.7574 662 | 33,35,0.3338 663 | 33,36,0.4225 664 | 34,35,0.4932 665 | 34,36,0.5212 666 | 35,36,0.5585 667 | -------------------------------------------------------------------------------- /configs/indoor/3DLoMatch.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/configs/indoor/3DLoMatch.pkl -------------------------------------------------------------------------------- /configs/indoor/3DMatch.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/configs/indoor/3DMatch.pkl -------------------------------------------------------------------------------- /configs/indoor/train_3dmatch.txt: -------------------------------------------------------------------------------- 1 | 7-scenes-chess 2 | 7-scenes-fire 3 | 7-scenes-office 4 | 7-scenes-pumpkin 5 | 7-scenes-stairs 6 | analysis-by-synthesis-apt1-kitchen 7 | analysis-by-synthesis-apt1-living 8 | analysis-by-synthesis-apt2-bed 9 | analysis-by-synthesis-apt2-kitchen 10 | analysis-by-synthesis-apt2-living 11 | analysis-by-synthesis-apt2-luke 12 | analysis-by-synthesis-office2-5a 13 | analysis-by-synthesis-office2-5b 14 | bundlefusion-apt0_1 15 | bundlefusion-apt0_2 16 | bundlefusion-apt0_3 17 | bundlefusion-apt0_4 18 | bundlefusion-apt1_1 19 | bundlefusion-apt1_2 20 | bundlefusion-apt1_3 21 | bundlefusion-apt1_4 22 | bundlefusion-apt2_1 23 | bundlefusion-apt2_2 24 | bundlefusion-copyroom_1 25 | bundlefusion-copyroom_2 26 | bundlefusion-office1_1 27 | bundlefusion-office1_2 28 | bundlefusion-office2 29 | bundlefusion-office3 30 | rgbd-scenes-v2-scene_01 31 | rgbd-scenes-v2-scene_02 32 | rgbd-scenes-v2-scene_03 33 | rgbd-scenes-v2-scene_04 34 | rgbd-scenes-v2-scene_05 35 | rgbd-scenes-v2-scene_06 36 | rgbd-scenes-v2-scene_07 37 | rgbd-scenes-v2-scene_08 38 | rgbd-scenes-v2-scene_09 39 | rgbd-scenes-v2-scene_11 40 | rgbd-scenes-v2-scene_12 41 | rgbd-scenes-v2-scene_13 42 | rgbd-scenes-v2-scene_14 43 | sun3d-brown_bm_1-brown_bm_1_1 44 | sun3d-brown_bm_1-brown_bm_1_2 45 | sun3d-brown_bm_1-brown_bm_1_3 46 | sun3d-brown_cogsci_1-brown_cogsci_1 47 | sun3d-brown_cs_2-brown_cs2_1 48 | sun3d-brown_cs_2-brown_cs2_2 49 | sun3d-brown_cs_3-brown_cs3 50 | sun3d-harvard_c3-hv_c3_1 51 | sun3d-harvard_c5-hv_c5_1 52 | sun3d-harvard_c6-hv_c6_1 53 | sun3d-harvard_c8-hv_c8_3 54 | sun3d-hotel_nips2012-nips_4_1 55 | sun3d-hotel_nips2012-nips_4_2 56 | sun3d-hotel_sf-scan1_1 57 | sun3d-hotel_sf-scan1_2 58 | sun3d-hotel_sf-scan1_3 59 | sun3d-hotel_sf-scan1_4 60 | sun3d-mit_32_d507-d507_2_1 61 | sun3d-mit_32_d507-d507_2_2 62 | sun3d-mit_46_ted_lab1-ted_lab_2_1 63 | sun3d-mit_46_ted_lab1-ted_lab_2_2 64 | sun3d-mit_46_ted_lab1-ted_lab_2_3 65 | sun3d-mit_46_ted_lab1-ted_lab_2_4 66 | sun3d-mit_76_417-76-417b_1 67 | sun3d-mit_76_417-76-417b_2_1 68 | sun3d-mit_76_417-76-417b_3 69 | sun3d-mit_76_417-76-417b_4 70 | sun3d-mit_76_417-76-417b_5 71 | sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika 72 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_1 73 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_2 74 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_3 75 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_4 76 | -------------------------------------------------------------------------------- /configs/indoor/train_info.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/configs/indoor/train_info.pkl -------------------------------------------------------------------------------- /configs/indoor/val_3dmatch.txt: -------------------------------------------------------------------------------- 1 | sun3d-brown_bm_4-brown_bm_4 2 | sun3d-harvard_c11-hv_c11_2 3 | 7-scenes-heads 4 | rgbd-scenes-v2-scene_10 5 | bundlefusion-office0_1 6 | bundlefusion-office0_2 7 | bundlefusion-office0_3 8 | analysis-by-synthesis-apt2-kitchen 9 | -------------------------------------------------------------------------------- /configs/indoor/val_info.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/configs/indoor/val_info.pkl -------------------------------------------------------------------------------- /configs/kitti/test_kitti.txt: -------------------------------------------------------------------------------- 1 | 8 2 | 9 3 | 10 -------------------------------------------------------------------------------- /configs/kitti/train_kitti.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 -------------------------------------------------------------------------------- /configs/kitti/val_kitti.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 7 -------------------------------------------------------------------------------- /configs/modelnet/modelnet40_all.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | bathtub 3 | bed 4 | bench 5 | bookshelf 6 | bottle 7 | bowl 8 | car 9 | chair 10 | cone 11 | cup 12 | curtain 13 | desk 14 | door 15 | dresser 16 | flower_pot 17 | glass_box 18 | guitar 19 | keyboard 20 | lamp 21 | laptop 22 | mantel 23 | monitor 24 | night_stand 25 | person 26 | piano 27 | plant 28 | radio 29 | range_hood 30 | sink 31 | sofa 32 | stairs 33 | stool 34 | table 35 | tent 36 | toilet 37 | tv_stand 38 | vase 39 | wardrobe 40 | xbox 41 | -------------------------------------------------------------------------------- /configs/modelnet/modelnet40_half1.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | bathtub 3 | bed 4 | bench 5 | bookshelf 6 | bottle 7 | bowl 8 | car 9 | chair 10 | cone 11 | cup 12 | curtain 13 | desk 14 | door 15 | dresser 16 | flower_pot 17 | glass_box 18 | guitar 19 | keyboard 20 | lamp 21 | -------------------------------------------------------------------------------- /configs/modelnet/modelnet40_half2.txt: -------------------------------------------------------------------------------- 1 | laptop 2 | mantel 3 | monitor 4 | night_stand 5 | person 6 | piano 7 | plant 8 | radio 9 | range_hood 10 | sink 11 | sofa 12 | stairs 13 | stool 14 | table 15 | tent 16 | toilet 17 | tv_stand 18 | vase 19 | wardrobe 20 | xbox 21 | -------------------------------------------------------------------------------- /configs/models.py: -------------------------------------------------------------------------------- 1 | architectures = dict() 2 | architectures['indoor'] = [ 3 | 'simple', 4 | 'resnetb', 5 | 'resnetb_strided', 6 | 'resnetb', 7 | 'resnetb', 8 | 'resnetb_strided', 9 | 'resnetb', 10 | 'resnetb', 11 | 'resnetb_strided', 12 | 'resnetb', 13 | 'resnetb', 14 | 'nearest_upsample', 15 | 'unary', 16 | 'nearest_upsample', 17 | 'unary', 18 | 'nearest_upsample', 19 | 'last_unary' 20 | ] 21 | 22 | architectures['kitti'] = [ 23 | 'simple', 24 | 'resnetb', 25 | 'resnetb_strided', 26 | 'resnetb', 27 | 'resnetb', 28 | 'resnetb_strided', 29 | 'resnetb', 30 | 'resnetb', 31 | 'resnetb_strided', 32 | 'resnetb', 33 | 'resnetb', 34 | 'nearest_upsample', 35 | 'unary', 36 | 'nearest_upsample', 37 | 'unary', 38 | 'nearest_upsample', 39 | 'last_unary' 40 | ] 41 | 42 | architectures['modelnet'] = [ 43 | 'simple', 44 | 'resnetb', 45 | 'resnetb', 46 | 'resnetb_strided', 47 | 'resnetb', 48 | 'resnetb', 49 | 'resnetb_strided', 50 | 'resnetb', 51 | 'resnetb', 52 | 'nearest_upsample', 53 | 'unary', 54 | 'unary', 55 | 'nearest_upsample', 56 | 'unary', 57 | 'last_unary' 58 | ] -------------------------------------------------------------------------------- /configs/test/indoor.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | exp_dir: indoor 3 | mode: test 4 | gpu_mode: True 5 | verbose: True 6 | verbose_freq: 1000 7 | snapshot_freq: 1 8 | pretrain: 'weights/indoor.pth' 9 | 10 | 11 | model: 12 | num_layers: 4 13 | in_points_dim: 3 14 | first_feats_dim: 128 15 | final_feats_dim: 32 16 | first_subsampling_dl: 0.025 17 | in_feats_dim: 1 18 | conv_radius: 2.5 19 | deform_radius: 5.0 20 | num_kernel_points: 15 21 | KP_extent: 2.0 22 | KP_influence: linear 23 | aggregation_mode: sum 24 | fixed_kernel_points: center 25 | use_batch_norm: True 26 | batch_norm_momentum: 0.02 27 | deformable: False 28 | modulated: False 29 | add_cross_score: True 30 | condition_feature: True 31 | 32 | overlap_attention_module: 33 | gnn_feats_dim: 256 34 | dgcnn_k: 10 35 | num_head: 4 36 | nets: ['self','cross','self'] 37 | 38 | loss: 39 | pos_margin: 0.1 40 | neg_margin: 1.4 41 | log_scale: 24 42 | pos_radius: 0.0375 43 | safe_radius: 0.1 44 | overlap_radius: 0.0375 45 | matchability_radius: 0.05 46 | w_circle_loss: 1.0 47 | w_overlap_loss: 1.0 48 | w_saliency_loss: 0.0 49 | max_points: 256 50 | 51 | optimiser: 52 | optimizer: SGD 53 | max_epoch: 40 54 | lr: 0.005 55 | weight_decay: 0.000001 56 | momentum: 0.98 57 | scheduler: ExpLR 58 | scheduler_gamma: 0.95 59 | scheduler_freq: 1 60 | iter_size: 1 61 | 62 | dataset: 63 | dataset: indoor 64 | benchmark: 3DMatch 65 | root: data/indoor 66 | batch_size: 1 67 | num_workers: 6 68 | augment_noise: 0.005 69 | train_info: configs/indoor/train_info.pkl 70 | val_info: configs/indoor/val_info.pkl 71 | 72 | demo: 73 | src_pcd: assets/cloud_bin_21.pth 74 | tgt_pcd: assets/cloud_bin_34.pth 75 | n_points: 1000 76 | 77 | -------------------------------------------------------------------------------- /configs/test/kitti.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | exp_dir: test 3 | mode: test 4 | gpu_mode: True 5 | verbose: True 6 | verbose_freq: 500 7 | snapshot_freq: 1 8 | pretrain: weights/kitti.pth 9 | 10 | model: 11 | num_layers: 4 12 | in_points_dim: 3 13 | first_feats_dim: 256 14 | final_feats_dim: 32 15 | first_subsampling_dl: 0.3 16 | in_feats_dim: 1 17 | conv_radius: 4.25 18 | deform_radius: 5.0 19 | num_kernel_points: 15 20 | KP_extent: 2.0 21 | KP_influence: linear 22 | aggregation_mode: sum 23 | fixed_kernel_points: center 24 | use_batch_norm: True 25 | batch_norm_momentum: 0.02 26 | deformable: False 27 | modulated: False 28 | add_cross_score: True 29 | condition_feature: True 30 | 31 | overlap_attention_module: 32 | gnn_feats_dim: 256 33 | dgcnn_k: 10 34 | num_head: 4 35 | nets: ['self','cross','self'] 36 | 37 | loss: 38 | pos_margin: 0.1 39 | neg_margin: 1.4 40 | log_scale: 40 41 | pos_radius: 0.21 42 | safe_radius: 0.75 43 | overlap_radius: 0.45 44 | matchability_radius: 0.3 45 | w_circle_loss: 1.0 46 | w_overlap_loss: 1.0 47 | w_saliency_loss: 0.0 48 | max_points: 512 49 | 50 | optimiser: 51 | optimizer: SGD 52 | max_epoch: 150 53 | lr: 0.05 54 | weight_decay: 0.000001 55 | momentum: 0.98 56 | scheduler: ExpLR 57 | scheduler_gamma: 0.95 58 | scheduler_freq: 1 59 | iter_size: 1 60 | 61 | dataset: 62 | dataset: kitti 63 | benchmark: odometryKITTI 64 | root: 65 | batch_size: 1 66 | num_workers: 6 67 | augment_noise: 0.01 68 | augment_shift_range: 2.0 69 | augment_scale_max: 1.2 70 | augment_scale_min: 0.8 71 | -------------------------------------------------------------------------------- /configs/test/modelnet.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | exp_dir: modelnet 3 | mode: test 4 | gpu_mode: True 5 | verbose: True 6 | verbose_freq: 1000 7 | snapshot_freq: 1 8 | pretrain: 'weights/modelnet.pth' 9 | 10 | 11 | model: 12 | num_layers: 3 13 | in_points_dim: 3 14 | first_feats_dim: 512 15 | final_feats_dim: 96 16 | first_subsampling_dl: 0.06 17 | in_feats_dim: 1 18 | conv_radius: 2.75 19 | deform_radius: 5.0 20 | num_kernel_points: 15 21 | KP_extent: 2.0 22 | KP_influence: linear 23 | aggregation_mode: sum 24 | fixed_kernel_points: center 25 | use_batch_norm: True 26 | batch_norm_momentum: 0.02 27 | deformable: False 28 | modulated: False 29 | add_cross_score: True 30 | condition_feature: True 31 | 32 | overlap_attention_module: 33 | gnn_feats_dim: 256 34 | dgcnn_k: 10 35 | num_head: 4 36 | nets: ['self','cross','self'] 37 | 38 | loss: 39 | pos_margin: 0.1 40 | neg_margin: 1.4 41 | log_scale: 64 42 | pos_radius: 0.018 43 | safe_radius: 0.06 44 | overlap_radius: 0.04 45 | matchability_radius: 0.04 46 | w_circle_loss: 1.0 47 | w_overlap_loss: 1.0 48 | w_saliency_loss: 0.0 49 | max_points: 384 50 | 51 | optimiser: 52 | optimizer: SGD 53 | max_epoch: 200 54 | lr: 0.01 55 | weight_decay: 0.000001 56 | momentum: 0.98 57 | scheduler: ExpLR 58 | scheduler_gamma: 0.99 59 | scheduler_freq: 1 60 | iter_size: 4 61 | 62 | dataset: 63 | dataset: modelnet 64 | benchmark: modelnet 65 | root: data/modelnet40_ply_hdf5_2048 66 | batch_size: 1 67 | num_workers: 4 68 | augment_noise: 0.005 69 | train_categoryfile: configs/modelnet/modelnet40_half1.txt 70 | val_categoryfile: configs/modelnet/modelnet40_half1.txt 71 | test_categoryfile: configs/modelnet/modelnet40_half2.txt 72 | partial: [0.7,0.7] # set to [0.5, 0.5] for ModelLoNet 73 | num_points: 1024 74 | noise_type: crop 75 | rot_mag: 45.0 76 | trans_mag: 0.5 77 | dataset_type: modelnet_hdf 78 | 79 | 80 | -------------------------------------------------------------------------------- /configs/train/indoor.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | exp_dir: indoor 3 | mode: train 4 | gpu_mode: True 5 | verbose: True 6 | verbose_freq: 1000 7 | snapshot_freq: 1 8 | pretrain: '' 9 | 10 | 11 | model: 12 | num_layers: 4 13 | in_points_dim: 3 14 | first_feats_dim: 128 15 | final_feats_dim: 32 16 | first_subsampling_dl: 0.025 17 | in_feats_dim: 1 18 | conv_radius: 2.5 19 | deform_radius: 5.0 20 | num_kernel_points: 15 21 | KP_extent: 2.0 22 | KP_influence: linear 23 | aggregation_mode: sum 24 | fixed_kernel_points: center 25 | use_batch_norm: True 26 | batch_norm_momentum: 0.02 27 | deformable: False 28 | modulated: False 29 | add_cross_score: True 30 | condition_feature: True 31 | 32 | overlap_attention_module: 33 | gnn_feats_dim: 256 34 | dgcnn_k: 10 35 | num_head: 4 36 | nets: ['self','cross','self'] 37 | 38 | loss: 39 | pos_margin: 0.1 40 | neg_margin: 1.4 41 | log_scale: 24 42 | pos_radius: 0.0375 43 | safe_radius: 0.1 44 | overlap_radius: 0.0375 45 | matchability_radius: 0.05 46 | w_circle_loss: 1.0 47 | w_overlap_loss: 1.0 48 | w_saliency_loss: 0.0 49 | max_points: 256 50 | 51 | optimiser: 52 | optimizer: SGD 53 | max_epoch: 40 54 | lr: 0.005 55 | weight_decay: 0.000001 56 | momentum: 0.98 57 | scheduler: ExpLR 58 | scheduler_gamma: 0.95 59 | scheduler_freq: 1 60 | iter_size: 1 61 | 62 | dataset: 63 | dataset: indoor 64 | benchmark: 3DMatch 65 | root: data/indoor 66 | batch_size: 1 67 | num_workers: 6 68 | augment_noise: 0.005 69 | train_info: configs/indoor/train_info.pkl 70 | val_info: configs/indoor/val_info.pkl 71 | 72 | demo: 73 | src_pcd: assets/cloud_bin_21.pth 74 | tgt_pcd: assets/cloud_bin_34.pth 75 | n_points: 1000 76 | 77 | -------------------------------------------------------------------------------- /configs/train/kitti.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | exp_dir: kitti 3 | mode: train 4 | gpu_mode: True 5 | verbose: True 6 | verbose_freq: 150 7 | snapshot_freq: 1 8 | pretrain: '' 9 | 10 | model: 11 | num_layers: 4 12 | in_points_dim: 3 13 | first_feats_dim: 256 14 | final_feats_dim: 32 15 | first_subsampling_dl: 0.3 16 | in_feats_dim: 1 17 | conv_radius: 4.25 18 | deform_radius: 5.0 19 | num_kernel_points: 15 20 | KP_extent: 2.0 21 | KP_influence: linear 22 | aggregation_mode: sum 23 | fixed_kernel_points: center 24 | use_batch_norm: True 25 | batch_norm_momentum: 0.02 26 | deformable: False 27 | modulated: False 28 | add_cross_score: True 29 | condition_feature: True 30 | 31 | overlap_attention_module: 32 | gnn_feats_dim: 256 33 | dgcnn_k: 10 34 | num_head: 4 35 | nets: ['self','cross','self'] 36 | 37 | loss: 38 | pos_margin: 0.1 39 | neg_margin: 1.4 40 | log_scale: 48 41 | pos_radius: 0.21 42 | safe_radius: 0.75 43 | overlap_radius: 0.45 44 | matchability_radius: 0.3 45 | w_circle_loss: 1.0 46 | w_overlap_loss: 1.0 47 | w_saliency_loss: 0.0 48 | max_points: 512 49 | 50 | optimiser: 51 | optimizer: SGD 52 | max_epoch: 150 53 | lr: 0.05 54 | weight_decay: 0.000001 55 | momentum: 0.98 56 | scheduler: ExpLR 57 | scheduler_gamma: 0.99 58 | scheduler_freq: 1 59 | iter_size: 1 60 | 61 | dataset: 62 | dataset: kitti 63 | benchmark: odometryKITTI 64 | root: 65 | batch_size: 1 66 | num_workers: 6 67 | augment_noise: 0.01 68 | augment_shift_range: 2.0 69 | augment_scale_max: 1.2 70 | augment_scale_min: 0.8 71 | -------------------------------------------------------------------------------- /configs/train/modelnet.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | exp_dir: modelnet 3 | mode: train 4 | gpu_mode: True 5 | verbose: True 6 | verbose_freq: 1000 7 | snapshot_freq: 1 8 | pretrain: '' 9 | 10 | model: 11 | num_layers: 3 12 | in_points_dim: 3 13 | first_feats_dim: 512 14 | final_feats_dim: 96 15 | first_subsampling_dl: 0.06 16 | in_feats_dim: 1 17 | conv_radius: 2.75 18 | deform_radius: 5.0 19 | num_kernel_points: 15 20 | KP_extent: 2.0 21 | KP_influence: linear 22 | aggregation_mode: sum 23 | fixed_kernel_points: center 24 | use_batch_norm: True 25 | batch_norm_momentum: 0.02 26 | deformable: False 27 | modulated: False 28 | add_cross_score: True 29 | condition_feature: True 30 | 31 | overlap_attention_module: 32 | gnn_feats_dim: 256 33 | dgcnn_k: 10 34 | num_head: 4 35 | nets: ['self','cross','self'] 36 | 37 | loss: 38 | pos_margin: 0.1 39 | neg_margin: 1.4 40 | log_scale: 64 41 | pos_radius: 0.018 42 | safe_radius: 0.06 43 | overlap_radius: 0.04 44 | matchability_radius: 0.04 45 | w_circle_loss: 1.0 46 | w_overlap_loss: 1.0 47 | w_saliency_loss: 0.0 48 | max_points: 384 49 | 50 | optimiser: 51 | optimizer: SGD 52 | max_epoch: 200 53 | lr: 0.01 54 | weight_decay: 0.000001 55 | momentum: 0.98 56 | scheduler: ExpLR 57 | scheduler_gamma: 0.99 58 | scheduler_freq: 1 59 | iter_size: 4 60 | 61 | dataset: 62 | dataset: modelnet 63 | benchmark: modelnet 64 | root: data/modelnet40_ply_hdf5_2048 65 | batch_size: 1 66 | num_workers: 4 67 | augment_noise: 0.005 68 | train_categoryfile: configs/modelnet/modelnet40_half1.txt 69 | val_categoryfile: configs/modelnet/modelnet40_half1.txt 70 | test_categoryfile: configs/modelnet/modelnet40_half2.txt 71 | partial: [0.7,0.7] 72 | num_points: 1024 73 | noise_type: crop 74 | rot_mag: 45.0 75 | trans_mag: 0.5 76 | dataset_type: modelnet_hdf 77 | 78 | 79 | -------------------------------------------------------------------------------- /cpp_wrappers/compile_wrappers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Compile cpp subsampling 4 | cd cpp_subsampling 5 | python3 setup.py build_ext --inplace 6 | cd .. 7 | 8 | # Compile cpp neighbors 9 | cd cpp_neighbors 10 | python3 setup.py build_ext --inplace 11 | cd .. -------------------------------------------------------------------------------- /cpp_wrappers/cpp_neighbors/build.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | py setup.py build_ext --inplace 3 | 4 | 5 | pause -------------------------------------------------------------------------------- /cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "neighbors.h" 3 | 4 | 5 | void brute_neighbors(vector& queries, vector& supports, vector& neighbors_indices, float radius, int verbose) 6 | { 7 | 8 | // Initialize variables 9 | // ****************** 10 | 11 | // square radius 12 | float r2 = radius * radius; 13 | 14 | // indices 15 | int i0 = 0; 16 | 17 | // Counting vector 18 | int max_count = 0; 19 | vector> tmp(queries.size()); 20 | 21 | // Search neigbors indices 22 | // *********************** 23 | 24 | for (auto& p0 : queries) 25 | { 26 | int i = 0; 27 | for (auto& p : supports) 28 | { 29 | if ((p0 - p).sq_norm() < r2) 30 | { 31 | tmp[i0].push_back(i); 32 | if (tmp[i0].size() > max_count) 33 | max_count = tmp[i0].size(); 34 | } 35 | i++; 36 | } 37 | i0++; 38 | } 39 | 40 | // Reserve the memory 41 | neighbors_indices.resize(queries.size() * max_count); 42 | i0 = 0; 43 | for (auto& inds : tmp) 44 | { 45 | for (int j = 0; j < max_count; j++) 46 | { 47 | if (j < inds.size()) 48 | neighbors_indices[i0 * max_count + j] = inds[j]; 49 | else 50 | neighbors_indices[i0 * max_count + j] = -1; 51 | } 52 | i0++; 53 | } 54 | 55 | return; 56 | } 57 | 58 | void ordered_neighbors(vector& queries, 59 | vector& supports, 60 | vector& neighbors_indices, 61 | float radius) 62 | { 63 | 64 | // Initialize variables 65 | // ****************** 66 | 67 | // square radius 68 | float r2 = radius * radius; 69 | 70 | // indices 71 | int i0 = 0; 72 | 73 | // Counting vector 74 | int max_count = 0; 75 | float d2; 76 | vector> tmp(queries.size()); 77 | vector> dists(queries.size()); 78 | 79 | // Search neigbors indices 80 | // *********************** 81 | 82 | for (auto& p0 : queries) 83 | { 84 | int i = 0; 85 | for (auto& p : supports) 86 | { 87 | d2 = (p0 - p).sq_norm(); 88 | if (d2 < r2) 89 | { 90 | // Find order of the new point 91 | auto it = std::upper_bound(dists[i0].begin(), dists[i0].end(), d2); 92 | int index = std::distance(dists[i0].begin(), it); 93 | 94 | // Insert element 95 | dists[i0].insert(it, d2); 96 | tmp[i0].insert(tmp[i0].begin() + index, i); 97 | 98 | // Update max count 99 | if (tmp[i0].size() > max_count) 100 | max_count = tmp[i0].size(); 101 | } 102 | i++; 103 | } 104 | i0++; 105 | } 106 | 107 | // Reserve the memory 108 | neighbors_indices.resize(queries.size() * max_count); 109 | i0 = 0; 110 | for (auto& inds : tmp) 111 | { 112 | for (int j = 0; j < max_count; j++) 113 | { 114 | if (j < inds.size()) 115 | neighbors_indices[i0 * max_count + j] = inds[j]; 116 | else 117 | neighbors_indices[i0 * max_count + j] = -1; 118 | } 119 | i0++; 120 | } 121 | 122 | return; 123 | } 124 | 125 | void batch_ordered_neighbors(vector& queries, 126 | vector& supports, 127 | vector& q_batches, 128 | vector& s_batches, 129 | vector& neighbors_indices, 130 | float radius) 131 | { 132 | 133 | // Initialize variables 134 | // ****************** 135 | 136 | // square radius 137 | float r2 = radius * radius; 138 | 139 | // indices 140 | int i0 = 0; 141 | 142 | // Counting vector 143 | int max_count = 0; 144 | float d2; 145 | vector> tmp(queries.size()); 146 | vector> dists(queries.size()); 147 | 148 | // batch index 149 | int b = 0; 150 | int sum_qb = 0; 151 | int sum_sb = 0; 152 | 153 | 154 | // Search neigbors indices 155 | // *********************** 156 | 157 | for (auto& p0 : queries) 158 | { 159 | // Check if we changed batch 160 | if (i0 == sum_qb + q_batches[b]) 161 | { 162 | sum_qb += q_batches[b]; 163 | sum_sb += s_batches[b]; 164 | b++; 165 | } 166 | 167 | // Loop only over the supports of current batch 168 | vector::iterator p_it; 169 | int i = 0; 170 | for(p_it = supports.begin() + sum_sb; p_it < supports.begin() + sum_sb + s_batches[b]; p_it++ ) 171 | { 172 | d2 = (p0 - *p_it).sq_norm(); 173 | if (d2 < r2) 174 | { 175 | // Find order of the new point 176 | auto it = std::upper_bound(dists[i0].begin(), dists[i0].end(), d2); 177 | int index = std::distance(dists[i0].begin(), it); 178 | 179 | // Insert element 180 | dists[i0].insert(it, d2); 181 | tmp[i0].insert(tmp[i0].begin() + index, sum_sb + i); 182 | 183 | // Update max count 184 | if (tmp[i0].size() > max_count) 185 | max_count = tmp[i0].size(); 186 | } 187 | i++; 188 | } 189 | i0++; 190 | } 191 | 192 | // Reserve the memory 193 | neighbors_indices.resize(queries.size() * max_count); 194 | i0 = 0; 195 | for (auto& inds : tmp) 196 | { 197 | for (int j = 0; j < max_count; j++) 198 | { 199 | if (j < inds.size()) 200 | neighbors_indices[i0 * max_count + j] = inds[j]; 201 | else 202 | neighbors_indices[i0 * max_count + j] = supports.size(); 203 | } 204 | i0++; 205 | } 206 | 207 | return; 208 | } 209 | 210 | 211 | void batch_nanoflann_neighbors(vector& queries, 212 | vector& supports, 213 | vector& q_batches, 214 | vector& s_batches, 215 | vector& neighbors_indices, 216 | float radius) 217 | { 218 | 219 | // Initialize variables 220 | // ****************** 221 | 222 | // indices 223 | int i0 = 0; 224 | 225 | // Square radius 226 | float r2 = radius * radius; 227 | 228 | // Counting vector 229 | int max_count = 0; 230 | float d2; 231 | vector>> all_inds_dists(queries.size()); 232 | 233 | // batch index 234 | int b = 0; 235 | int sum_qb = 0; 236 | int sum_sb = 0; 237 | 238 | // Nanoflann related variables 239 | // *************************** 240 | 241 | // CLoud variable 242 | PointCloud current_cloud; 243 | 244 | // Tree parameters 245 | nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */); 246 | 247 | // KDTree type definition 248 | typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Simple_Adaptor , 249 | PointCloud, 250 | 3 > my_kd_tree_t; 251 | 252 | // Pointer to trees 253 | my_kd_tree_t* index; 254 | 255 | // Build KDTree for the first batch element 256 | current_cloud.pts = vector(supports.begin() + sum_sb, supports.begin() + sum_sb + s_batches[b]); 257 | index = new my_kd_tree_t(3, current_cloud, tree_params); 258 | index->buildIndex(); 259 | 260 | 261 | // Search neigbors indices 262 | // *********************** 263 | 264 | // Search params 265 | nanoflann::SearchParams search_params; 266 | search_params.sorted = true; 267 | 268 | for (auto& p0 : queries) 269 | { 270 | 271 | // Check if we changed batch 272 | if (i0 == sum_qb + q_batches[b]) 273 | { 274 | sum_qb += q_batches[b]; 275 | sum_sb += s_batches[b]; 276 | b++; 277 | 278 | // Change the points 279 | current_cloud.pts.clear(); 280 | current_cloud.pts = vector(supports.begin() + sum_sb, supports.begin() + sum_sb + s_batches[b]); 281 | 282 | // Build KDTree of the current element of the batch 283 | delete index; 284 | index = new my_kd_tree_t(3, current_cloud, tree_params); 285 | index->buildIndex(); 286 | } 287 | 288 | // Initial guess of neighbors size 289 | all_inds_dists[i0].reserve(max_count); 290 | 291 | // Find neighbors 292 | float query_pt[3] = { p0.x, p0.y, p0.z}; 293 | size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params); 294 | 295 | // Update max count 296 | if (nMatches > max_count) 297 | max_count = nMatches; 298 | 299 | // Increment query idx 300 | i0++; 301 | } 302 | 303 | // Reserve the memory 304 | neighbors_indices.resize(queries.size() * max_count); 305 | i0 = 0; 306 | sum_sb = 0; 307 | sum_qb = 0; 308 | b = 0; 309 | for (auto& inds_dists : all_inds_dists) 310 | { 311 | // Check if we changed batch 312 | if (i0 == sum_qb + q_batches[b]) 313 | { 314 | sum_qb += q_batches[b]; 315 | sum_sb += s_batches[b]; 316 | b++; 317 | } 318 | 319 | for (int j = 0; j < max_count; j++) 320 | { 321 | if (j < inds_dists.size()) 322 | neighbors_indices[i0 * max_count + j] = inds_dists[j].first + sum_sb; 323 | else 324 | neighbors_indices[i0 * max_count + j] = supports.size(); 325 | } 326 | i0++; 327 | } 328 | 329 | delete index; 330 | 331 | return; 332 | } 333 | 334 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_neighbors/neighbors/neighbors.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "../../cpp_utils/cloud/cloud.h" 4 | #include "../../cpp_utils/nanoflann/nanoflann.hpp" 5 | 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | 12 | void ordered_neighbors(vector& queries, 13 | vector& supports, 14 | vector& neighbors_indices, 15 | float radius); 16 | 17 | void batch_ordered_neighbors(vector& queries, 18 | vector& supports, 19 | vector& q_batches, 20 | vector& s_batches, 21 | vector& neighbors_indices, 22 | float radius); 23 | 24 | void batch_nanoflann_neighbors(vector& queries, 25 | vector& supports, 26 | vector& q_batches, 27 | vector& s_batches, 28 | vector& neighbors_indices, 29 | float radius); 30 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_neighbors/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | import numpy.distutils.misc_util 3 | 4 | # Adding OpenCV to project 5 | # ************************ 6 | 7 | # Adding sources of the project 8 | # ***************************** 9 | 10 | SOURCES = ["../cpp_utils/cloud/cloud.cpp", 11 | "neighbors/neighbors.cpp", 12 | "wrapper.cpp"] 13 | 14 | module = Extension(name="radius_neighbors", 15 | sources=SOURCES, 16 | extra_compile_args=['-std=c++11', 17 | '-D_GLIBCXX_USE_CXX11_ABI=0']) 18 | 19 | 20 | setup(ext_modules=[module], include_dirs=numpy.distutils.misc_util.get_numpy_include_dirs()) 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_neighbors/wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "neighbors/neighbors.h" 4 | #include 5 | 6 | 7 | 8 | // docstrings for our module 9 | // ************************* 10 | 11 | static char module_docstring[] = "This module provides two methods to compute radius neighbors from pointclouds or batch of pointclouds"; 12 | 13 | static char batch_query_docstring[] = "Method to get radius neighbors in a batch of stacked pointclouds"; 14 | 15 | 16 | // Declare the functions 17 | // ********************* 18 | 19 | static PyObject *batch_neighbors(PyObject *self, PyObject *args, PyObject *keywds); 20 | 21 | 22 | // Specify the members of the module 23 | // ********************************* 24 | 25 | static PyMethodDef module_methods[] = 26 | { 27 | { "batch_query", (PyCFunction)batch_neighbors, METH_VARARGS | METH_KEYWORDS, batch_query_docstring }, 28 | {NULL, NULL, 0, NULL} 29 | }; 30 | 31 | 32 | // Initialize the module 33 | // ********************* 34 | 35 | static struct PyModuleDef moduledef = 36 | { 37 | PyModuleDef_HEAD_INIT, 38 | "radius_neighbors", // m_name 39 | module_docstring, // m_doc 40 | -1, // m_size 41 | module_methods, // m_methods 42 | NULL, // m_reload 43 | NULL, // m_traverse 44 | NULL, // m_clear 45 | NULL, // m_free 46 | }; 47 | 48 | PyMODINIT_FUNC PyInit_radius_neighbors(void) 49 | { 50 | import_array(); 51 | return PyModule_Create(&moduledef); 52 | } 53 | 54 | 55 | // Definition of the batch_subsample method 56 | // ********************************** 57 | 58 | static PyObject* batch_neighbors(PyObject* self, PyObject* args, PyObject* keywds) 59 | { 60 | 61 | // Manage inputs 62 | // ************* 63 | 64 | // Args containers 65 | PyObject* queries_obj = NULL; 66 | PyObject* supports_obj = NULL; 67 | PyObject* q_batches_obj = NULL; 68 | PyObject* s_batches_obj = NULL; 69 | 70 | // Keywords containers 71 | static char* kwlist[] = { "queries", "supports", "q_batches", "s_batches", "radius", NULL }; 72 | float radius = 0.1; 73 | 74 | // Parse the input 75 | if (!PyArg_ParseTupleAndKeywords(args, keywds, "OOOO|$f", kwlist, &queries_obj, &supports_obj, &q_batches_obj, &s_batches_obj, &radius)) 76 | { 77 | PyErr_SetString(PyExc_RuntimeError, "Error parsing arguments"); 78 | return NULL; 79 | } 80 | 81 | 82 | // Interpret the input objects as numpy arrays. 83 | PyObject* queries_array = PyArray_FROM_OTF(queries_obj, NPY_FLOAT, NPY_IN_ARRAY); 84 | PyObject* supports_array = PyArray_FROM_OTF(supports_obj, NPY_FLOAT, NPY_IN_ARRAY); 85 | PyObject* q_batches_array = PyArray_FROM_OTF(q_batches_obj, NPY_INT, NPY_IN_ARRAY); 86 | PyObject* s_batches_array = PyArray_FROM_OTF(s_batches_obj, NPY_INT, NPY_IN_ARRAY); 87 | 88 | // Verify data was load correctly. 89 | if (queries_array == NULL) 90 | { 91 | Py_XDECREF(queries_array); 92 | Py_XDECREF(supports_array); 93 | Py_XDECREF(q_batches_array); 94 | Py_XDECREF(s_batches_array); 95 | PyErr_SetString(PyExc_RuntimeError, "Error converting query points to numpy arrays of type float32"); 96 | return NULL; 97 | } 98 | if (supports_array == NULL) 99 | { 100 | Py_XDECREF(queries_array); 101 | Py_XDECREF(supports_array); 102 | Py_XDECREF(q_batches_array); 103 | Py_XDECREF(s_batches_array); 104 | PyErr_SetString(PyExc_RuntimeError, "Error converting support points to numpy arrays of type float32"); 105 | return NULL; 106 | } 107 | if (q_batches_array == NULL) 108 | { 109 | Py_XDECREF(queries_array); 110 | Py_XDECREF(supports_array); 111 | Py_XDECREF(q_batches_array); 112 | Py_XDECREF(s_batches_array); 113 | PyErr_SetString(PyExc_RuntimeError, "Error converting query batches to numpy arrays of type int32"); 114 | return NULL; 115 | } 116 | if (s_batches_array == NULL) 117 | { 118 | Py_XDECREF(queries_array); 119 | Py_XDECREF(supports_array); 120 | Py_XDECREF(q_batches_array); 121 | Py_XDECREF(s_batches_array); 122 | PyErr_SetString(PyExc_RuntimeError, "Error converting support batches to numpy arrays of type int32"); 123 | return NULL; 124 | } 125 | 126 | // Check that the input array respect the dims 127 | if ((int)PyArray_NDIM(queries_array) != 2 || (int)PyArray_DIM(queries_array, 1) != 3) 128 | { 129 | Py_XDECREF(queries_array); 130 | Py_XDECREF(supports_array); 131 | Py_XDECREF(q_batches_array); 132 | Py_XDECREF(s_batches_array); 133 | PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : query.shape is not (N, 3)"); 134 | return NULL; 135 | } 136 | if ((int)PyArray_NDIM(supports_array) != 2 || (int)PyArray_DIM(supports_array, 1) != 3) 137 | { 138 | Py_XDECREF(queries_array); 139 | Py_XDECREF(supports_array); 140 | Py_XDECREF(q_batches_array); 141 | Py_XDECREF(s_batches_array); 142 | PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : support.shape is not (N, 3)"); 143 | return NULL; 144 | } 145 | if ((int)PyArray_NDIM(q_batches_array) > 1) 146 | { 147 | Py_XDECREF(queries_array); 148 | Py_XDECREF(supports_array); 149 | Py_XDECREF(q_batches_array); 150 | Py_XDECREF(s_batches_array); 151 | PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : queries_batches.shape is not (B,) "); 152 | return NULL; 153 | } 154 | if ((int)PyArray_NDIM(s_batches_array) > 1) 155 | { 156 | Py_XDECREF(queries_array); 157 | Py_XDECREF(supports_array); 158 | Py_XDECREF(q_batches_array); 159 | Py_XDECREF(s_batches_array); 160 | PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : supports_batches.shape is not (B,) "); 161 | return NULL; 162 | } 163 | if ((int)PyArray_DIM(q_batches_array, 0) != (int)PyArray_DIM(s_batches_array, 0)) 164 | { 165 | Py_XDECREF(queries_array); 166 | Py_XDECREF(supports_array); 167 | Py_XDECREF(q_batches_array); 168 | Py_XDECREF(s_batches_array); 169 | PyErr_SetString(PyExc_RuntimeError, "Wrong number of batch elements: different for queries and supports "); 170 | return NULL; 171 | } 172 | 173 | // Number of points 174 | int Nq = (int)PyArray_DIM(queries_array, 0); 175 | int Ns= (int)PyArray_DIM(supports_array, 0); 176 | 177 | // Number of batches 178 | int Nb = (int)PyArray_DIM(q_batches_array, 0); 179 | 180 | // Call the C++ function 181 | // ********************* 182 | 183 | // Convert PyArray to Cloud C++ class 184 | vector queries; 185 | vector supports; 186 | vector q_batches; 187 | vector s_batches; 188 | queries = vector((PointXYZ*)PyArray_DATA(queries_array), (PointXYZ*)PyArray_DATA(queries_array) + Nq); 189 | supports = vector((PointXYZ*)PyArray_DATA(supports_array), (PointXYZ*)PyArray_DATA(supports_array) + Ns); 190 | q_batches = vector((int*)PyArray_DATA(q_batches_array), (int*)PyArray_DATA(q_batches_array) + Nb); 191 | s_batches = vector((int*)PyArray_DATA(s_batches_array), (int*)PyArray_DATA(s_batches_array) + Nb); 192 | 193 | // Create result containers 194 | vector neighbors_indices; 195 | 196 | // Compute results 197 | //batch_ordered_neighbors(queries, supports, q_batches, s_batches, neighbors_indices, radius); 198 | batch_nanoflann_neighbors(queries, supports, q_batches, s_batches, neighbors_indices, radius); 199 | 200 | // Check result 201 | if (neighbors_indices.size() < 1) 202 | { 203 | PyErr_SetString(PyExc_RuntimeError, "Error"); 204 | return NULL; 205 | } 206 | 207 | // Manage outputs 208 | // ************** 209 | 210 | // Maximal number of neighbors 211 | int max_neighbors = neighbors_indices.size() / Nq; 212 | 213 | // Dimension of output containers 214 | npy_intp* neighbors_dims = new npy_intp[2]; 215 | neighbors_dims[0] = Nq; 216 | neighbors_dims[1] = max_neighbors; 217 | 218 | // Create output array 219 | PyObject* res_obj = PyArray_SimpleNew(2, neighbors_dims, NPY_INT); 220 | PyObject* ret = NULL; 221 | 222 | // Fill output array with values 223 | size_t size_in_bytes = Nq * max_neighbors * sizeof(int); 224 | memcpy(PyArray_DATA(res_obj), neighbors_indices.data(), size_in_bytes); 225 | 226 | // Merge results 227 | ret = Py_BuildValue("N", res_obj); 228 | 229 | // Clean up 230 | // ******** 231 | 232 | Py_XDECREF(queries_array); 233 | Py_XDECREF(supports_array); 234 | Py_XDECREF(q_batches_array); 235 | Py_XDECREF(s_batches_array); 236 | 237 | return ret; 238 | } 239 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_subsampling/build.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | py setup.py build_ext --inplace 3 | 4 | 5 | pause -------------------------------------------------------------------------------- /cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "grid_subsampling.h" 3 | 4 | 5 | void grid_subsampling(vector& original_points, 6 | vector& subsampled_points, 7 | vector& original_features, 8 | vector& subsampled_features, 9 | vector& original_classes, 10 | vector& subsampled_classes, 11 | float sampleDl, 12 | int verbose) { 13 | 14 | // Initialize variables 15 | // ****************** 16 | 17 | // Number of points in the cloud 18 | size_t N = original_points.size(); 19 | 20 | // Dimension of the features 21 | size_t fdim = original_features.size() / N; 22 | size_t ldim = original_classes.size() / N; 23 | 24 | // Limits of the cloud 25 | PointXYZ minCorner = min_point(original_points); 26 | PointXYZ maxCorner = max_point(original_points); 27 | PointXYZ originCorner = floor(minCorner * (1/sampleDl)) * sampleDl; 28 | 29 | // Dimensions of the grid 30 | size_t sampleNX = (size_t)floor((maxCorner.x - originCorner.x) / sampleDl) + 1; 31 | size_t sampleNY = (size_t)floor((maxCorner.y - originCorner.y) / sampleDl) + 1; 32 | //size_t sampleNZ = (size_t)floor((maxCorner.z - originCorner.z) / sampleDl) + 1; 33 | 34 | // Check if features and classes need to be processed 35 | bool use_feature = original_features.size() > 0; 36 | bool use_classes = original_classes.size() > 0; 37 | 38 | 39 | // Create the sampled map 40 | // ********************** 41 | 42 | // Verbose parameters 43 | int i = 0; 44 | int nDisp = N / 100; 45 | 46 | // Initialize variables 47 | size_t iX, iY, iZ, mapIdx; 48 | unordered_map data; 49 | 50 | for (auto& p : original_points) 51 | { 52 | // Position of point in sample map 53 | iX = (size_t)floor((p.x - originCorner.x) / sampleDl); 54 | iY = (size_t)floor((p.y - originCorner.y) / sampleDl); 55 | iZ = (size_t)floor((p.z - originCorner.z) / sampleDl); 56 | mapIdx = iX + sampleNX*iY + sampleNX*sampleNY*iZ; 57 | 58 | // If not already created, create key 59 | if (data.count(mapIdx) < 1) 60 | data.emplace(mapIdx, SampledData(fdim, ldim)); 61 | 62 | // Fill the sample map 63 | if (use_feature && use_classes) 64 | data[mapIdx].update_all(p, original_features.begin() + i * fdim, original_classes.begin() + i * ldim); 65 | else if (use_feature) 66 | data[mapIdx].update_features(p, original_features.begin() + i * fdim); 67 | else if (use_classes) 68 | data[mapIdx].update_classes(p, original_classes.begin() + i * ldim); 69 | else 70 | data[mapIdx].update_points(p); 71 | 72 | // Display 73 | i++; 74 | if (verbose > 1 && i%nDisp == 0) 75 | std::cout << "\rSampled Map : " << std::setw(3) << i / nDisp << "%"; 76 | 77 | } 78 | 79 | // Divide for barycentre and transfer to a vector 80 | subsampled_points.reserve(data.size()); 81 | if (use_feature) 82 | subsampled_features.reserve(data.size() * fdim); 83 | if (use_classes) 84 | subsampled_classes.reserve(data.size() * ldim); 85 | for (auto& v : data) 86 | { 87 | subsampled_points.push_back(v.second.point * (1.0 / v.second.count)); 88 | if (use_feature) 89 | { 90 | float count = (float)v.second.count; 91 | transform(v.second.features.begin(), 92 | v.second.features.end(), 93 | v.second.features.begin(), 94 | [count](float f) { return f / count;}); 95 | subsampled_features.insert(subsampled_features.end(),v.second.features.begin(),v.second.features.end()); 96 | } 97 | if (use_classes) 98 | { 99 | for (int i = 0; i < ldim; i++) 100 | subsampled_classes.push_back(max_element(v.second.labels[i].begin(), v.second.labels[i].end(), 101 | [](const pair&a, const pair&b){return a.second < b.second;})->first); 102 | } 103 | } 104 | 105 | return; 106 | } 107 | 108 | 109 | void batch_grid_subsampling(vector& original_points, 110 | vector& subsampled_points, 111 | vector& original_features, 112 | vector& subsampled_features, 113 | vector& original_classes, 114 | vector& subsampled_classes, 115 | vector& original_batches, 116 | vector& subsampled_batches, 117 | float sampleDl, 118 | int max_p) 119 | { 120 | // Initialize variables 121 | // ****************** 122 | 123 | int b = 0; 124 | int sum_b = 0; 125 | 126 | // Number of points in the cloud 127 | size_t N = original_points.size(); 128 | 129 | // Dimension of the features 130 | size_t fdim = original_features.size() / N; 131 | size_t ldim = original_classes.size() / N; 132 | 133 | // Handle max_p = 0 134 | if (max_p < 1) 135 | max_p = N; 136 | 137 | // Loop over batches 138 | // ***************** 139 | 140 | for (b = 0; b < original_batches.size(); b++) 141 | { 142 | 143 | // Extract batch points features and labels 144 | vector b_o_points = vector(original_points.begin () + sum_b, 145 | original_points.begin () + sum_b + original_batches[b]); 146 | 147 | vector b_o_features; 148 | if (original_features.size() > 0) 149 | { 150 | b_o_features = vector(original_features.begin () + sum_b * fdim, 151 | original_features.begin () + (sum_b + original_batches[b]) * fdim); 152 | } 153 | 154 | vector b_o_classes; 155 | if (original_classes.size() > 0) 156 | { 157 | b_o_classes = vector(original_classes.begin () + sum_b * ldim, 158 | original_classes.begin () + sum_b + original_batches[b] * ldim); 159 | } 160 | 161 | 162 | // Create result containers 163 | vector b_s_points; 164 | vector b_s_features; 165 | vector b_s_classes; 166 | 167 | // Compute subsampling on current batch 168 | grid_subsampling(b_o_points, 169 | b_s_points, 170 | b_o_features, 171 | b_s_features, 172 | b_o_classes, 173 | b_s_classes, 174 | sampleDl, 175 | 0); 176 | 177 | // Stack batches points features and labels 178 | // **************************************** 179 | 180 | // If too many points remove some 181 | if (b_s_points.size() <= max_p) 182 | { 183 | subsampled_points.insert(subsampled_points.end(), b_s_points.begin(), b_s_points.end()); 184 | 185 | if (original_features.size() > 0) 186 | subsampled_features.insert(subsampled_features.end(), b_s_features.begin(), b_s_features.end()); 187 | 188 | if (original_classes.size() > 0) 189 | subsampled_classes.insert(subsampled_classes.end(), b_s_classes.begin(), b_s_classes.end()); 190 | 191 | subsampled_batches.push_back(b_s_points.size()); 192 | } 193 | else 194 | { 195 | subsampled_points.insert(subsampled_points.end(), b_s_points.begin(), b_s_points.begin() + max_p); 196 | 197 | if (original_features.size() > 0) 198 | subsampled_features.insert(subsampled_features.end(), b_s_features.begin(), b_s_features.begin() + max_p * fdim); 199 | 200 | if (original_classes.size() > 0) 201 | subsampled_classes.insert(subsampled_classes.end(), b_s_classes.begin(), b_s_classes.begin() + max_p * ldim); 202 | 203 | subsampled_batches.push_back(max_p); 204 | } 205 | 206 | // Stack new batch lengths 207 | sum_b += original_batches[b]; 208 | } 209 | 210 | return; 211 | } 212 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "../../cpp_utils/cloud/cloud.h" 4 | 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | 10 | class SampledData 11 | { 12 | public: 13 | 14 | // Elements 15 | // ******** 16 | 17 | int count; 18 | PointXYZ point; 19 | vector features; 20 | vector> labels; 21 | 22 | 23 | // Methods 24 | // ******* 25 | 26 | // Constructor 27 | SampledData() 28 | { 29 | count = 0; 30 | point = PointXYZ(); 31 | } 32 | 33 | SampledData(const size_t fdim, const size_t ldim) 34 | { 35 | count = 0; 36 | point = PointXYZ(); 37 | features = vector(fdim); 38 | labels = vector>(ldim); 39 | } 40 | 41 | // Method Update 42 | void update_all(const PointXYZ p, vector::iterator f_begin, vector::iterator l_begin) 43 | { 44 | count += 1; 45 | point += p; 46 | transform (features.begin(), features.end(), f_begin, features.begin(), plus()); 47 | int i = 0; 48 | for(vector::iterator it = l_begin; it != l_begin + labels.size(); ++it) 49 | { 50 | labels[i][*it] += 1; 51 | i++; 52 | } 53 | return; 54 | } 55 | void update_features(const PointXYZ p, vector::iterator f_begin) 56 | { 57 | count += 1; 58 | point += p; 59 | transform (features.begin(), features.end(), f_begin, features.begin(), plus()); 60 | return; 61 | } 62 | void update_classes(const PointXYZ p, vector::iterator l_begin) 63 | { 64 | count += 1; 65 | point += p; 66 | int i = 0; 67 | for(vector::iterator it = l_begin; it != l_begin + labels.size(); ++it) 68 | { 69 | labels[i][*it] += 1; 70 | i++; 71 | } 72 | return; 73 | } 74 | void update_points(const PointXYZ p) 75 | { 76 | count += 1; 77 | point += p; 78 | return; 79 | } 80 | }; 81 | 82 | void grid_subsampling(vector& original_points, 83 | vector& subsampled_points, 84 | vector& original_features, 85 | vector& subsampled_features, 86 | vector& original_classes, 87 | vector& subsampled_classes, 88 | float sampleDl, 89 | int verbose); 90 | 91 | void batch_grid_subsampling(vector& original_points, 92 | vector& subsampled_points, 93 | vector& original_features, 94 | vector& subsampled_features, 95 | vector& original_classes, 96 | vector& subsampled_classes, 97 | vector& original_batches, 98 | vector& subsampled_batches, 99 | float sampleDl, 100 | int max_p); 101 | 102 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_subsampling/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup, Extension 2 | import numpy.distutils.misc_util 3 | 4 | # Adding OpenCV to project 5 | # ************************ 6 | 7 | # Adding sources of the project 8 | # ***************************** 9 | 10 | SOURCES = ["../cpp_utils/cloud/cloud.cpp", 11 | "grid_subsampling/grid_subsampling.cpp", 12 | "wrapper.cpp"] 13 | 14 | module = Extension(name="grid_subsampling", 15 | sources=SOURCES, 16 | extra_compile_args=['-std=c++11', 17 | '-D_GLIBCXX_USE_CXX11_ABI=0']) 18 | 19 | 20 | setup(ext_modules=[module], include_dirs=numpy.distutils.misc_util.get_numpy_include_dirs()) 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /cpp_wrappers/cpp_utils/cloud/cloud.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // 3 | // 0==========================0 4 | // | Local feature test | 5 | // 0==========================0 6 | // 7 | // version 1.0 : 8 | // > 9 | // 10 | //--------------------------------------------------- 11 | // 12 | // Cloud source : 13 | // Define usefull Functions/Methods 14 | // 15 | //---------------------------------------------------- 16 | // 17 | // Hugues THOMAS - 10/02/2017 18 | // 19 | 20 | 21 | #include "cloud.h" 22 | 23 | 24 | // Getters 25 | // ******* 26 | 27 | PointXYZ max_point(std::vector points) 28 | { 29 | // Initialize limits 30 | PointXYZ maxP(points[0]); 31 | 32 | // Loop over all points 33 | for (auto p : points) 34 | { 35 | if (p.x > maxP.x) 36 | maxP.x = p.x; 37 | 38 | if (p.y > maxP.y) 39 | maxP.y = p.y; 40 | 41 | if (p.z > maxP.z) 42 | maxP.z = p.z; 43 | } 44 | 45 | return maxP; 46 | } 47 | 48 | PointXYZ min_point(std::vector points) 49 | { 50 | // Initialize limits 51 | PointXYZ minP(points[0]); 52 | 53 | // Loop over all points 54 | for (auto p : points) 55 | { 56 | if (p.x < minP.x) 57 | minP.x = p.x; 58 | 59 | if (p.y < minP.y) 60 | minP.y = p.y; 61 | 62 | if (p.z < minP.z) 63 | minP.z = p.z; 64 | } 65 | 66 | return minP; 67 | } -------------------------------------------------------------------------------- /cpp_wrappers/cpp_utils/cloud/cloud.h: -------------------------------------------------------------------------------- 1 | // 2 | // 3 | // 0==========================0 4 | // | Local feature test | 5 | // 0==========================0 6 | // 7 | // version 1.0 : 8 | // > 9 | // 10 | //--------------------------------------------------- 11 | // 12 | // Cloud header 13 | // 14 | //---------------------------------------------------- 15 | // 16 | // Hugues THOMAS - 10/02/2017 17 | // 18 | 19 | 20 | # pragma once 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | #include 32 | 33 | 34 | 35 | 36 | // Point class 37 | // *********** 38 | 39 | 40 | class PointXYZ 41 | { 42 | public: 43 | 44 | // Elements 45 | // ******** 46 | 47 | float x, y, z; 48 | 49 | 50 | // Methods 51 | // ******* 52 | 53 | // Constructor 54 | PointXYZ() { x = 0; y = 0; z = 0; } 55 | PointXYZ(float x0, float y0, float z0) { x = x0; y = y0; z = z0; } 56 | 57 | // array type accessor 58 | float operator [] (int i) const 59 | { 60 | if (i == 0) return x; 61 | else if (i == 1) return y; 62 | else return z; 63 | } 64 | 65 | // opperations 66 | float dot(const PointXYZ P) const 67 | { 68 | return x * P.x + y * P.y + z * P.z; 69 | } 70 | 71 | float sq_norm() 72 | { 73 | return x*x + y*y + z*z; 74 | } 75 | 76 | PointXYZ cross(const PointXYZ P) const 77 | { 78 | return PointXYZ(y*P.z - z*P.y, z*P.x - x*P.z, x*P.y - y*P.x); 79 | } 80 | 81 | PointXYZ& operator+=(const PointXYZ& P) 82 | { 83 | x += P.x; 84 | y += P.y; 85 | z += P.z; 86 | return *this; 87 | } 88 | 89 | PointXYZ& operator-=(const PointXYZ& P) 90 | { 91 | x -= P.x; 92 | y -= P.y; 93 | z -= P.z; 94 | return *this; 95 | } 96 | 97 | PointXYZ& operator*=(const float& a) 98 | { 99 | x *= a; 100 | y *= a; 101 | z *= a; 102 | return *this; 103 | } 104 | }; 105 | 106 | 107 | // Point Opperations 108 | // ***************** 109 | 110 | inline PointXYZ operator + (const PointXYZ A, const PointXYZ B) 111 | { 112 | return PointXYZ(A.x + B.x, A.y + B.y, A.z + B.z); 113 | } 114 | 115 | inline PointXYZ operator - (const PointXYZ A, const PointXYZ B) 116 | { 117 | return PointXYZ(A.x - B.x, A.y - B.y, A.z - B.z); 118 | } 119 | 120 | inline PointXYZ operator * (const PointXYZ P, const float a) 121 | { 122 | return PointXYZ(P.x * a, P.y * a, P.z * a); 123 | } 124 | 125 | inline PointXYZ operator * (const float a, const PointXYZ P) 126 | { 127 | return PointXYZ(P.x * a, P.y * a, P.z * a); 128 | } 129 | 130 | inline std::ostream& operator << (std::ostream& os, const PointXYZ P) 131 | { 132 | return os << "[" << P.x << ", " << P.y << ", " << P.z << "]"; 133 | } 134 | 135 | inline bool operator == (const PointXYZ A, const PointXYZ B) 136 | { 137 | return A.x == B.x && A.y == B.y && A.z == B.z; 138 | } 139 | 140 | inline PointXYZ floor(const PointXYZ P) 141 | { 142 | return PointXYZ(std::floor(P.x), std::floor(P.y), std::floor(P.z)); 143 | } 144 | 145 | 146 | PointXYZ max_point(std::vector points); 147 | PointXYZ min_point(std::vector points); 148 | 149 | 150 | struct PointCloud 151 | { 152 | 153 | std::vector pts; 154 | 155 | // Must return the number of data points 156 | inline size_t kdtree_get_point_count() const { return pts.size(); } 157 | 158 | // Returns the dim'th component of the idx'th point in the class: 159 | // Since this is inlined and the "dim" argument is typically an immediate value, the 160 | // "if/else's" are actually solved at compile time. 161 | inline float kdtree_get_pt(const size_t idx, const size_t dim) const 162 | { 163 | if (dim == 0) return pts[idx].x; 164 | else if (dim == 1) return pts[idx].y; 165 | else return pts[idx].z; 166 | } 167 | 168 | // Optional bounding-box computation: return false to default to a standard bbox computation loop. 169 | // Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again. 170 | // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds) 171 | template 172 | bool kdtree_get_bbox(BBOX& /* bb */) const { return false; } 173 | 174 | }; 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/indoor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Shengyu Huang 3 | Last modified: 30.11.2020 4 | """ 5 | 6 | import os,sys,glob,torch 7 | import numpy as np 8 | from scipy.spatial.transform import Rotation 9 | from torch.utils.data import Dataset 10 | import open3d as o3d 11 | from lib.benchmark_utils import to_o3d_pcd, to_tsfm, get_correspondences 12 | 13 | 14 | class IndoorDataset(Dataset): 15 | """ 16 | Load subsampled coordinates, relative rotation and translation 17 | Output(torch.Tensor): 18 | src_pcd: [N,3] 19 | tgt_pcd: [M,3] 20 | rot: [3,3] 21 | trans: [3,1] 22 | """ 23 | def __init__(self,infos,config,data_augmentation=True): 24 | super(IndoorDataset,self).__init__() 25 | self.infos = infos 26 | self.base_dir = config.root 27 | self.overlap_radius = config.overlap_radius 28 | self.data_augmentation=data_augmentation 29 | self.config = config 30 | 31 | self.rot_factor=1. 32 | self.augment_noise = config.augment_noise 33 | self.max_points = 30000 34 | 35 | def __len__(self): 36 | return len(self.infos['rot']) 37 | 38 | def __getitem__(self,item): 39 | # get transformation 40 | rot=self.infos['rot'][item] 41 | trans=self.infos['trans'][item] 42 | 43 | # get pointcloud 44 | src_path=os.path.join(self.base_dir,self.infos['src'][item]) 45 | tgt_path=os.path.join(self.base_dir,self.infos['tgt'][item]) 46 | src_pcd = torch.load(src_path) 47 | tgt_pcd = torch.load(tgt_path) 48 | 49 | # if we get too many points, we do some downsampling 50 | if(src_pcd.shape[0] > self.max_points): 51 | idx = np.random.permutation(src_pcd.shape[0])[:self.max_points] 52 | src_pcd = src_pcd[idx] 53 | if(tgt_pcd.shape[0] > self.max_points): 54 | idx = np.random.permutation(tgt_pcd.shape[0])[:self.max_points] 55 | tgt_pcd = tgt_pcd[idx] 56 | 57 | # add gaussian noise 58 | if self.data_augmentation: 59 | # rotate the point cloud 60 | euler_ab=np.random.rand(3)*np.pi*2/self.rot_factor # anglez, angley, anglex 61 | rot_ab= Rotation.from_euler('zyx', euler_ab).as_matrix() 62 | if(np.random.rand(1)[0]>0.5): 63 | src_pcd=np.matmul(rot_ab,src_pcd.T).T 64 | rot=np.matmul(rot,rot_ab.T) 65 | else: 66 | tgt_pcd=np.matmul(rot_ab,tgt_pcd.T).T 67 | rot=np.matmul(rot_ab,rot) 68 | trans=np.matmul(rot_ab,trans) 69 | 70 | src_pcd += (np.random.rand(src_pcd.shape[0],3) - 0.5) * self.augment_noise 71 | tgt_pcd += (np.random.rand(tgt_pcd.shape[0],3) - 0.5) * self.augment_noise 72 | 73 | if(trans.ndim==1): 74 | trans=trans[:,None] 75 | 76 | # get correspondence at fine level 77 | tsfm = to_tsfm(rot, trans) 78 | correspondences = get_correspondences(to_o3d_pcd(src_pcd), to_o3d_pcd(tgt_pcd), tsfm,self.overlap_radius) 79 | 80 | src_feats=np.ones_like(src_pcd[:,:1]).astype(np.float32) 81 | tgt_feats=np.ones_like(tgt_pcd[:,:1]).astype(np.float32) 82 | rot = rot.astype(np.float32) 83 | trans = trans.astype(np.float32) 84 | 85 | return src_pcd,tgt_pcd,src_feats,tgt_feats,rot,trans, correspondences, src_pcd, tgt_pcd, torch.ones(1) -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | # Basic libs 2 | import os, time, glob, random, pickle, copy, torch 3 | import numpy as np 4 | import open3d 5 | from scipy.spatial.transform import Rotation 6 | 7 | # Dataset parent class 8 | from torch.utils.data import Dataset 9 | from lib.benchmark_utils import to_tsfm, to_o3d_pcd, get_correspondences 10 | 11 | 12 | class KITTIDataset(Dataset): 13 | """ 14 | We follow D3Feat to add data augmentation part. 15 | We first voxelize the pcd and get matches 16 | Then we apply data augmentation to pcds. KPConv runs over processed pcds, but later for loss computation, we use pcds before data augmentation 17 | """ 18 | DATA_FILES = { 19 | 'train': './configs/kitti/train_kitti.txt', 20 | 'val': './configs/kitti/val_kitti.txt', 21 | 'test': './configs/kitti/test_kitti.txt' 22 | } 23 | def __init__(self,config,split,data_augmentation=True): 24 | super(KITTIDataset,self).__init__() 25 | self.config = config 26 | self.root = os.path.join(config.root,'dataset') 27 | self.icp_path = os.path.join(config.root,'icp') 28 | if not os.path.exists(self.icp_path): 29 | os.makedirs(self.icp_path) 30 | self.voxel_size = config.first_subsampling_dl 31 | self.matching_search_voxel_size = config.overlap_radius 32 | self.data_augmentation = data_augmentation 33 | self.augment_noise = config.augment_noise 34 | self.IS_ODOMETRY = True 35 | self.max_corr = config.max_points 36 | self.augment_shift_range = config.augment_shift_range 37 | self.augment_scale_max = config.augment_scale_max 38 | self.augment_scale_min = config.augment_scale_min 39 | 40 | # Initiate containers 41 | self.files = [] 42 | self.kitti_icp_cache = {} 43 | self.kitti_cache = {} 44 | self.prepare_kitti_ply(split) 45 | self.split = split 46 | 47 | 48 | def prepare_kitti_ply(self, split): 49 | assert split in ['train','val','test'] 50 | 51 | subset_names = open(self.DATA_FILES[split]).read().split() 52 | for dirname in subset_names: 53 | drive_id = int(dirname) 54 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id) 55 | assert len(fnames) > 0, f"Make sure that the path {self.root} has data {dirname}" 56 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 57 | 58 | # get one-to-one distance by comparing the translation vector 59 | all_odo = self.get_video_odometry(drive_id, return_all=True) 60 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 61 | Ts = all_pos[:, :3, 3] 62 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3)) ** 2 63 | pdist = np.sqrt(pdist.sum(-1)) 64 | 65 | ###################################### 66 | # D3Feat script to generate test pairs 67 | more_than_10 = pdist > 10 68 | curr_time = inames[0] 69 | while curr_time in inames: 70 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0] 71 | if len(next_time) == 0: 72 | curr_time += 1 73 | else: 74 | next_time = next_time[0] + curr_time - 1 75 | 76 | if next_time in inames: 77 | self.files.append((drive_id, curr_time, next_time)) 78 | curr_time = next_time + 1 79 | 80 | # remove bad pairs 81 | if split=='test': 82 | self.files.remove((8, 15, 58)) 83 | print(f'Num_{split}: {len(self.files)}') 84 | 85 | 86 | def __len__(self): 87 | return len(self.files) 88 | 89 | 90 | def __getitem__(self, idx): 91 | drive = self.files[idx][0] 92 | t0, t1 = self.files[idx][1], self.files[idx][2] 93 | all_odometry = self.get_video_odometry(drive, [t0, t1]) 94 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry] 95 | fname0 = self._get_velodyne_fn(drive, t0) 96 | fname1 = self._get_velodyne_fn(drive, t1) 97 | 98 | # XYZ and reflectance 99 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4) 100 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4) 101 | 102 | xyz0 = xyzr0[:, :3] 103 | xyz1 = xyzr1[:, :3] 104 | 105 | # use ICP to refine the ground_truth pose, for ICP we don't voxllize the point clouds 106 | key = '%d_%d_%d' % (drive, t0, t1) 107 | filename = self.icp_path + '/' + key + '.npy' 108 | if key not in self.kitti_icp_cache: 109 | if not os.path.exists(filename): 110 | print('missing ICP files, recompute it') 111 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T) 112 | @ np.linalg.inv(self.velo2cam)).T 113 | xyz0_t = self.apply_transform(xyz0, M) 114 | pcd0 = to_o3d_pcd(xyz0_t) 115 | pcd1 = to_o3d_pcd(xyz1) 116 | reg = open3d.registration.registration_icp(pcd0, pcd1, 0.2, np.eye(4), 117 | open3d.registration.TransformationEstimationPointToPoint(), 118 | open3d.registration.ICPConvergenceCriteria(max_iteration=200)) 119 | pcd0.transform(reg.transformation) 120 | M2 = M @ reg.transformation 121 | np.save(filename, M2) 122 | else: 123 | M2 = np.load(filename) 124 | self.kitti_icp_cache[key] = M2 125 | else: 126 | M2 = self.kitti_icp_cache[key] 127 | 128 | 129 | # refined pose is denoted as trans 130 | tsfm = M2 131 | rot = tsfm[:3,:3] 132 | trans = tsfm[:3,3][:,None] 133 | 134 | # voxelize the point clouds here 135 | pcd0 = to_o3d_pcd(xyz0) 136 | pcd1 = to_o3d_pcd(xyz1) 137 | pcd0 = pcd0.voxel_down_sample(self.voxel_size) 138 | pcd1 = pcd1.voxel_down_sample(self.voxel_size) 139 | src_pcd = np.array(pcd0.points) 140 | tgt_pcd = np.array(pcd1.points) 141 | 142 | # Get matches 143 | matching_inds = get_correspondences(pcd0, pcd1, tsfm, self.matching_search_voxel_size) 144 | if(matching_inds.size(0) < self.max_corr and self.split == 'train'): 145 | return self.__getitem__(np.random.choice(len(self.files),1)[0]) 146 | 147 | src_feats=np.ones_like(src_pcd[:,:1]).astype(np.float32) 148 | tgt_feats=np.ones_like(tgt_pcd[:,:1]).astype(np.float32) 149 | 150 | rot = rot.astype(np.float32) 151 | trans = trans.astype(np.float32) 152 | 153 | # add data augmentation 154 | src_pcd_input = copy.deepcopy(src_pcd) 155 | tgt_pcd_input = copy.deepcopy(tgt_pcd) 156 | if(self.data_augmentation): 157 | # add gaussian noise 158 | src_pcd_input += (np.random.rand(src_pcd_input.shape[0],3) - 0.5) * self.augment_noise 159 | tgt_pcd_input += (np.random.rand(tgt_pcd_input.shape[0],3) - 0.5) * self.augment_noise 160 | 161 | # rotate the point cloud 162 | euler_ab=np.random.rand(3)*np.pi*2 # anglez, angley, anglex 163 | rot_ab= Rotation.from_euler('zyx', euler_ab).as_matrix() 164 | if(np.random.rand(1)[0]>0.5): 165 | src_pcd_input = np.dot(rot_ab, src_pcd_input.T).T 166 | else: 167 | tgt_pcd_input = np.dot(rot_ab, tgt_pcd_input.T).T 168 | 169 | # scale the pcd 170 | scale = self.augment_scale_min + (self.augment_scale_max - self.augment_scale_min) * random.random() 171 | src_pcd_input = src_pcd_input * scale 172 | tgt_pcd_input = tgt_pcd_input * scale 173 | 174 | # shift the pcd 175 | shift_src = np.random.uniform(-self.augment_shift_range, self.augment_shift_range, 3) 176 | shift_tgt = np.random.uniform(-self.augment_shift_range, self.augment_shift_range, 3) 177 | 178 | src_pcd_input = src_pcd_input + shift_src 179 | tgt_pcd_input = tgt_pcd_input + shift_tgt 180 | 181 | 182 | return src_pcd_input, tgt_pcd_input, src_feats, tgt_feats, rot, trans, matching_inds, src_pcd, tgt_pcd, torch.ones(1) 183 | 184 | 185 | def apply_transform(self, pts, trans): 186 | R = trans[:3, :3] 187 | T = trans[:3, 3] 188 | pts = pts @ R.T + T 189 | return pts 190 | 191 | @property 192 | def velo2cam(self): 193 | try: 194 | velo2cam = self._velo2cam 195 | except AttributeError: 196 | R = np.array([ 197 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 198 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 199 | ]).reshape(3, 3) 200 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 201 | velo2cam = np.hstack([R, T]) 202 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 203 | return self._velo2cam 204 | 205 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False): 206 | if self.IS_ODOMETRY: 207 | data_path = self.root + '/poses/%02d.txt' % drive 208 | if data_path not in self.kitti_cache: 209 | self.kitti_cache[data_path] = np.genfromtxt(data_path) 210 | if return_all: 211 | return self.kitti_cache[data_path] 212 | else: 213 | return self.kitti_cache[data_path][indices] 214 | 215 | def odometry_to_positions(self, odometry): 216 | if self.IS_ODOMETRY: 217 | T_w_cam0 = odometry.reshape(3, 4) 218 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 219 | return T_w_cam0 220 | 221 | def _get_velodyne_fn(self, drive, t): 222 | if self.IS_ODOMETRY: 223 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t) 224 | return fname 225 | 226 | def get_position_transform(self, pos0, pos1, invert=False): 227 | T0 = self.pos_transform(pos0) 228 | T1 = self.pos_transform(pos1) 229 | return (np.dot(T1, np.linalg.inv(T0)).T if not invert else np.dot( 230 | np.linalg.inv(T1), T0).T) 231 | -------------------------------------------------------------------------------- /datasets/modelnet.py: -------------------------------------------------------------------------------- 1 | """Data loader 2 | """ 3 | import argparse, os, torch, h5py, torchvision 4 | from typing import List 5 | 6 | import numpy as np 7 | import open3d as o3d 8 | from torch.utils.data import Dataset 9 | 10 | import datasets.transforms as Transforms 11 | import common.math.se3 as se3 12 | from lib.benchmark_utils import get_correspondences, to_o3d_pcd, to_tsfm 13 | 14 | 15 | def get_train_datasets(args: argparse.Namespace): 16 | train_categories, val_categories = None, None 17 | if args.train_categoryfile: 18 | train_categories = [line.rstrip('\n') for line in open(args.train_categoryfile)] 19 | train_categories.sort() 20 | if args.val_categoryfile: 21 | val_categories = [line.rstrip('\n') for line in open(args.val_categoryfile)] 22 | val_categories.sort() 23 | 24 | train_transforms, val_transforms = get_transforms(args.noise_type, args.rot_mag, args.trans_mag, 25 | args.num_points, args.partial) 26 | train_transforms = torchvision.transforms.Compose(train_transforms) 27 | val_transforms = torchvision.transforms.Compose(val_transforms) 28 | 29 | if args.dataset_type == 'modelnet_hdf': 30 | train_data = ModelNetHdf(args, args.root, subset='train', categories=train_categories, 31 | transform=train_transforms) 32 | val_data = ModelNetHdf(args, args.root, subset='test', categories=val_categories, 33 | transform=val_transforms) 34 | else: 35 | raise NotImplementedError 36 | 37 | return train_data, val_data 38 | 39 | 40 | def get_test_datasets(args: argparse.Namespace): 41 | test_categories = None 42 | if args.test_categoryfile: 43 | test_categories = [line.rstrip('\n') for line in open(args.test_categoryfile)] 44 | test_categories.sort() 45 | 46 | _, test_transforms = get_transforms(args.noise_type, args.rot_mag, args.trans_mag, 47 | args.num_points, args.partial) 48 | test_transforms = torchvision.transforms.Compose(test_transforms) 49 | 50 | if args.dataset_type == 'modelnet_hdf': 51 | test_data = ModelNetHdf(args, args.root, subset='test', categories=test_categories, 52 | transform=test_transforms) 53 | else: 54 | raise NotImplementedError 55 | 56 | return test_data 57 | 58 | 59 | def get_transforms(noise_type: str, 60 | rot_mag: float = 45.0, trans_mag: float = 0.5, 61 | num_points: int = 1024, partial_p_keep: List = None): 62 | """Get the list of transformation to be used for training or evaluating RegNet 63 | 64 | Args: 65 | noise_type: Either 'clean', 'jitter', 'crop'. 66 | Depending on the option, some of the subsequent arguments may be ignored. 67 | rot_mag: Magnitude of rotation perturbation to apply to source, in degrees. 68 | Default: 45.0 (same as Deep Closest Point) 69 | trans_mag: Magnitude of translation perturbation to apply to source. 70 | Default: 0.5 (same as Deep Closest Point) 71 | num_points: Number of points to uniformly resample to. 72 | Note that this is with respect to the full point cloud. The number of 73 | points will be proportionally less if cropped 74 | partial_p_keep: Proportion to keep during cropping, [src_p, ref_p] 75 | Default: [0.7, 0.7], i.e. Crop both source and reference to ~70% 76 | 77 | Returns: 78 | train_transforms, test_transforms: Both contain list of transformations to be applied 79 | """ 80 | 81 | partial_p_keep = partial_p_keep if partial_p_keep is not None else [0.7, 0.7] 82 | 83 | if noise_type == "clean": 84 | # 1-1 correspondence for each point (resample first before splitting), no noise 85 | train_transforms = [Transforms.Resampler(num_points), 86 | Transforms.SplitSourceRef(), 87 | Transforms.RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 88 | Transforms.ShufflePoints()] 89 | 90 | test_transforms = [Transforms.SetDeterministic(), 91 | Transforms.FixedResampler(num_points), 92 | Transforms.SplitSourceRef(), 93 | Transforms.RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 94 | Transforms.ShufflePoints()] 95 | 96 | elif noise_type == "jitter": 97 | # Points randomly sampled (might not have perfect correspondence), gaussian noise to position 98 | train_transforms = [Transforms.SplitSourceRef(), 99 | Transforms.RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 100 | Transforms.Resampler(num_points), 101 | Transforms.RandomJitter(), 102 | Transforms.ShufflePoints()] 103 | 104 | test_transforms = [Transforms.SetDeterministic(), 105 | Transforms.SplitSourceRef(), 106 | Transforms.RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 107 | Transforms.Resampler(num_points), 108 | Transforms.RandomJitter(), 109 | Transforms.ShufflePoints()] 110 | 111 | elif noise_type == "crop": 112 | # Both source and reference point clouds cropped, plus same noise in "jitter" 113 | train_transforms = [Transforms.SplitSourceRef(), 114 | Transforms.RandomCrop(partial_p_keep), 115 | Transforms.RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 116 | Transforms.Resampler(num_points), 117 | Transforms.RandomJitter(), 118 | Transforms.ShufflePoints()] 119 | 120 | test_transforms = [Transforms.SetDeterministic(), 121 | Transforms.SplitSourceRef(), 122 | Transforms.RandomCrop(partial_p_keep), 123 | Transforms.RandomTransformSE3_euler(rot_mag=rot_mag, trans_mag=trans_mag), 124 | Transforms.Resampler(num_points), 125 | Transforms.RandomJitter(), 126 | Transforms.ShufflePoints()] 127 | else: 128 | raise NotImplementedError 129 | 130 | return train_transforms, test_transforms 131 | 132 | 133 | class ModelNetHdf(Dataset): 134 | def __init__(self, args, root: str, subset: str = 'train', categories: List = None, transform=None): 135 | """ModelNet40 dataset from PointNet. 136 | Automatically downloads the dataset if not available 137 | 138 | Args: 139 | root (str): Folder containing processed dataset 140 | subset (str): Dataset subset, either 'train' or 'test' 141 | categories (list): Categories to use 142 | transform (callable, optional): Optional transform to be applied 143 | on a sample. 144 | """ 145 | self.config = args 146 | self._root = root 147 | self.n_in_feats = args.in_feats_dim 148 | self.overlap_radius = args.overlap_radius 149 | 150 | 151 | if not os.path.exists(os.path.join(root)): 152 | self._download_dataset(root) 153 | 154 | with open(os.path.join(root, 'shape_names.txt')) as fid: 155 | self._classes = [l.strip() for l in fid] 156 | self._category2idx = {e[1]: e[0] for e in enumerate(self._classes)} 157 | self._idx2category = self._classes 158 | 159 | with open(os.path.join(root, '{}_files.txt'.format(subset))) as fid: 160 | h5_filelist = [line.strip() for line in fid] 161 | h5_filelist = [x.replace('data/modelnet40_ply_hdf5_2048/', '') for x in h5_filelist] 162 | h5_filelist = [os.path.join(self._root, f) for f in h5_filelist] 163 | 164 | if categories is not None: 165 | categories_idx = [self._category2idx[c] for c in categories] 166 | self._classes = categories 167 | else: 168 | categories_idx = None 169 | 170 | self._data, self._labels = self._read_h5_files(h5_filelist, categories_idx) 171 | # self._data, self._labels = self._data[:32], self._labels[:32, ...] 172 | self._transform = transform 173 | 174 | def __getitem__(self, item): 175 | sample = {'points': self._data[item, :, :], 'label': self._labels[item], 'idx': np.array(item, dtype=np.int32)} 176 | 177 | if self._transform: 178 | sample = self._transform(sample) 179 | # transform to our format 180 | src_pcd = sample['points_src'][:,:3] 181 | tgt_pcd = sample['points_ref'][:,:3] 182 | rot = sample['transform_gt'][:,:3] 183 | trans = sample['transform_gt'][:,3][:,None] 184 | matching_inds = get_correspondences(to_o3d_pcd(src_pcd), to_o3d_pcd(tgt_pcd),to_tsfm(rot,trans),self.overlap_radius) 185 | 186 | if(self.n_in_feats == 1): 187 | src_feats=np.ones_like(src_pcd[:,:1]).astype(np.float32) 188 | tgt_feats=np.ones_like(tgt_pcd[:,:1]).astype(np.float32) 189 | elif(self.n_in_feats == 3): 190 | src_feats = src_pcd.astype(np.float32) 191 | tgt_feats = tgt_pcd.astype(np.float32) 192 | 193 | for k,v in sample.items(): 194 | if k not in ['deterministic','label','idx']: 195 | sample[k] = torch.from_numpy(v).unsqueeze(0) 196 | 197 | return src_pcd,tgt_pcd,src_feats,tgt_feats,rot,trans, matching_inds, src_pcd, tgt_pcd, sample 198 | 199 | def __len__(self): 200 | return self._data.shape[0] 201 | 202 | @property 203 | def classes(self): 204 | return self._classes 205 | 206 | @staticmethod 207 | def _read_h5_files(fnames, categories): 208 | 209 | all_data = [] 210 | all_labels = [] 211 | 212 | for fname in fnames: 213 | f = h5py.File(fname, mode='r') 214 | data = np.concatenate([f['data'][:], f['normal'][:]], axis=-1) 215 | labels = f['label'][:].flatten().astype(np.int64) 216 | 217 | if categories is not None: # Filter out unwanted categories 218 | mask = np.isin(labels, categories).flatten() 219 | data = data[mask, ...] 220 | labels = labels[mask, ...] 221 | 222 | all_data.append(data) 223 | all_labels.append(labels) 224 | 225 | all_data = np.concatenate(all_data, axis=0) 226 | all_labels = np.concatenate(all_labels, axis=0) 227 | return all_data, all_labels 228 | 229 | @staticmethod 230 | def _download_dataset(root: str): 231 | os.makedirs(root, exist_ok=True) 232 | 233 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 234 | zipfile = os.path.basename(www) 235 | os.system('wget {}'.format(www)) 236 | os.system('unzip {} -d .'.format(zipfile)) 237 | os.system('mv {} {}'.format(zipfile[:-4], os.path.dirname(root))) 238 | os.system('rm {}'.format(zipfile)) 239 | 240 | def to_category(self, i): 241 | return self._idx2category[i] 242 | -------------------------------------------------------------------------------- /kernels/dispositions/k_015_center_3D.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/kernels/dispositions/k_015_center_3D.ply -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/lib/__init__.py -------------------------------------------------------------------------------- /lib/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scripts for pairwise registration using different sampling methods 3 | 4 | Author: Shengyu Huang 5 | Last modified: 30.11.2020 6 | """ 7 | 8 | import os,re,sys,json,yaml,random, glob, argparse, torch, pickle 9 | from tqdm import tqdm 10 | import numpy as np 11 | from scipy.spatial.transform import Rotation 12 | import open3d as o3d 13 | from lib.benchmark import read_trajectory, read_pairs, read_trajectory_info, write_trajectory 14 | 15 | _EPS = 1e-7 # To prevent division by zero 16 | 17 | 18 | def fmr_wrt_distance(data,split,inlier_ratio_threshold=0.05): 19 | """ 20 | calculate feature match recall wrt distance threshold 21 | """ 22 | fmr_wrt_distance =[] 23 | for distance_threshold in range(1,21): 24 | inlier_ratios =[] 25 | distance_threshold /=100.0 26 | for idx in range(data.shape[0]): 27 | inlier_ratio = (data[idx] < distance_threshold).mean() 28 | inlier_ratios.append(inlier_ratio) 29 | fmr = 0 30 | for ele in split: 31 | fmr += (np.array(inlier_ratios[ele[0]:ele[1]]) > inlier_ratio_threshold).mean() 32 | fmr /= 8 33 | fmr_wrt_distance.append(fmr*100) 34 | return fmr_wrt_distance 35 | 36 | def fmr_wrt_inlier_ratio(data, split, distance_threshold=0.1): 37 | """ 38 | calculate feature match recall wrt inlier ratio threshold 39 | """ 40 | fmr_wrt_inlier =[] 41 | for inlier_ratio_threshold in range(1,21): 42 | inlier_ratios =[] 43 | inlier_ratio_threshold /=100.0 44 | for idx in range(data.shape[0]): 45 | inlier_ratio = (data[idx] < distance_threshold).mean() 46 | inlier_ratios.append(inlier_ratio) 47 | 48 | fmr = 0 49 | for ele in split: 50 | fmr += (np.array(inlier_ratios[ele[0]:ele[1]]) > inlier_ratio_threshold).mean() 51 | fmr /= 8 52 | fmr_wrt_inlier.append(fmr*100) 53 | 54 | return fmr_wrt_inlier 55 | 56 | 57 | def write_est_trajectory(gt_folder, exp_dir, tsfm_est): 58 | """ 59 | Write the estimated trajectories 60 | """ 61 | scene_names=sorted(os.listdir(gt_folder)) 62 | count=0 63 | for scene_name in scene_names: 64 | gt_pairs, gt_traj = read_trajectory(os.path.join(gt_folder,scene_name,'gt.log')) 65 | est_traj = [] 66 | for i in range(len(gt_pairs)): 67 | est_traj.append(tsfm_est[count]) 68 | count+=1 69 | 70 | # write the trajectory 71 | c_directory=os.path.join(exp_dir,scene_name) 72 | os.makedirs(c_directory) 73 | write_trajectory(np.array(est_traj),gt_pairs,os.path.join(c_directory,'est.log')) 74 | 75 | 76 | def to_tensor(array): 77 | """ 78 | Convert array to tensor 79 | """ 80 | if(not isinstance(array,torch.Tensor)): 81 | return torch.from_numpy(array).float() 82 | else: 83 | return array 84 | 85 | def to_array(tensor): 86 | """ 87 | Conver tensor to array 88 | """ 89 | if(not isinstance(tensor,np.ndarray)): 90 | if(tensor.device == torch.device('cpu')): 91 | return tensor.numpy() 92 | else: 93 | return tensor.cpu().numpy() 94 | else: 95 | return tensor 96 | 97 | def to_tsfm(rot,trans): 98 | tsfm = np.eye(4) 99 | tsfm[:3,:3]=rot 100 | tsfm[:3,3]=trans.flatten() 101 | return tsfm 102 | 103 | def to_o3d_pcd(xyz): 104 | """ 105 | Convert tensor/array to open3d PointCloud 106 | xyz: [N, 3] 107 | """ 108 | pcd = o3d.geometry.PointCloud() 109 | pcd.points = o3d.utility.Vector3dVector(to_array(xyz)) 110 | return pcd 111 | 112 | def to_o3d_feats(embedding): 113 | """ 114 | Convert tensor/array to open3d features 115 | embedding: [N, 3] 116 | """ 117 | feats = o3d.registration.Feature() 118 | feats.data = to_array(embedding).T 119 | return feats 120 | 121 | def get_correspondences(src_pcd, tgt_pcd, trans, search_voxel_size, K=None): 122 | src_pcd.transform(trans) 123 | pcd_tree = o3d.geometry.KDTreeFlann(tgt_pcd) 124 | 125 | correspondences = [] 126 | for i, point in enumerate(src_pcd.points): 127 | [count, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) 128 | if K is not None: 129 | idx = idx[:K] 130 | for j in idx: 131 | correspondences.append([i, j]) 132 | 133 | correspondences = np.array(correspondences) 134 | correspondences = torch.from_numpy(correspondences) 135 | return correspondences 136 | 137 | def get_blue(): 138 | """ 139 | Get color blue for rendering 140 | """ 141 | return [0, 0.651, 0.929] 142 | 143 | def get_yellow(): 144 | """ 145 | Get color yellow for rendering 146 | """ 147 | return [1, 0.706, 0] 148 | 149 | def random_sample(pcd, feats, N): 150 | """ 151 | Do random sampling to get exact N points and associated features 152 | pcd: [N,3] 153 | feats: [N,C] 154 | """ 155 | if(isinstance(pcd,torch.Tensor)): 156 | n1 = pcd.size(0) 157 | elif(isinstance(pcd, np.ndarray)): 158 | n1 = pcd.shape[0] 159 | 160 | if n1 == N: 161 | return pcd, feats 162 | 163 | if n1 > N: 164 | choice = np.random.permutation(n1)[:N] 165 | else: 166 | choice = np.random.choice(n1, N) 167 | 168 | return pcd[choice], feats[choice] 169 | 170 | def get_angle_deviation(R_pred,R_gt): 171 | """ 172 | Calculate the angle deviation between two rotaion matrice 173 | The rotation error is between [0,180] 174 | Input: 175 | R_pred: [B,3,3] 176 | R_gt : [B,3,3] 177 | Return: 178 | degs: [B] 179 | """ 180 | R=np.matmul(R_pred,R_gt.transpose(0,2,1)) 181 | tr=np.trace(R,0,1,2) 182 | rads=np.arccos(np.clip((tr-1)/2,-1,1)) # clip to valid range 183 | degs=rads/np.pi*180 184 | 185 | return degs 186 | 187 | def ransac_pose_estimation(src_pcd, tgt_pcd, src_feat, tgt_feat, mutual = False, distance_threshold = 0.05, ransac_n = 3): 188 | """ 189 | RANSAC pose estimation with two checkers 190 | We follow D3Feat to set ransac_n = 3 for 3DMatch and ransac_n = 4 for KITTI. 191 | For 3DMatch dataset, we observe significant improvement after changing ransac_n from 4 to 3. 192 | """ 193 | if(mutual): 194 | if(torch.cuda.device_count()>=1): 195 | device = torch.device('cuda') 196 | else: 197 | device = torch.device('cpu') 198 | src_feat, tgt_feat = to_tensor(src_feat), to_tensor(tgt_feat) 199 | scores = torch.matmul(src_feat.to(device), tgt_feat.transpose(0,1).to(device)).cpu() 200 | selection = mutual_selection(scores[None,:,:])[0] 201 | row_sel, col_sel = np.where(selection) 202 | corrs = o3d.utility.Vector2iVector(np.array([row_sel,col_sel]).T) 203 | src_pcd = to_o3d_pcd(src_pcd) 204 | tgt_pcd = to_o3d_pcd(tgt_pcd) 205 | result_ransac = o3d.registration.registration_ransac_based_on_correspondence( 206 | source=src_pcd, target=tgt_pcd,corres=corrs, 207 | max_correspondence_distance=distance_threshold, 208 | estimation_method=o3d.registration.TransformationEstimationPointToPoint(False), 209 | ransac_n=4, 210 | criteria=o3d.registration.RANSACConvergenceCriteria(50000, 1000)) 211 | else: 212 | src_pcd = to_o3d_pcd(src_pcd) 213 | tgt_pcd = to_o3d_pcd(tgt_pcd) 214 | src_feats = to_o3d_feats(src_feat) 215 | tgt_feats = to_o3d_feats(tgt_feat) 216 | 217 | result_ransac = o3d.registration.registration_ransac_based_on_feature_matching( 218 | src_pcd, tgt_pcd, src_feats, tgt_feats,distance_threshold, 219 | o3d.registration.TransformationEstimationPointToPoint(False), ransac_n, 220 | [o3d.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 221 | o3d.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)], 222 | o3d.registration.RANSACConvergenceCriteria(50000, 1000)) 223 | 224 | return result_ransac.transformation 225 | 226 | def get_inlier_ratio(src_pcd, tgt_pcd, src_feat, tgt_feat, rot, trans, inlier_distance_threshold = 0.1): 227 | """ 228 | Compute inlier ratios with and without mutual check, return both 229 | """ 230 | src_pcd = to_tensor(src_pcd) 231 | tgt_pcd = to_tensor(tgt_pcd) 232 | src_feat = to_tensor(src_feat) 233 | tgt_feat = to_tensor(tgt_feat) 234 | rot, trans = to_tensor(rot), to_tensor(trans) 235 | 236 | results =dict() 237 | results['w']=dict() 238 | results['wo']=dict() 239 | 240 | if(torch.cuda.device_count()>=1): 241 | device = torch.device('cuda') 242 | else: 243 | device = torch.device('cpu') 244 | 245 | src_pcd = (torch.matmul(rot, src_pcd.transpose(0,1)) + trans).transpose(0,1) 246 | scores = torch.matmul(src_feat.to(device), tgt_feat.transpose(0,1).to(device)).cpu() 247 | 248 | ######################################## 249 | # 1. calculate inlier ratios wo mutual check 250 | _, idx = scores.max(-1) 251 | dist = torch.norm(src_pcd- tgt_pcd[idx],dim=1) 252 | results['wo']['distance'] = dist.numpy() 253 | 254 | c_inlier_ratio = (dist < inlier_distance_threshold).float().mean() 255 | results['wo']['inlier_ratio'] = c_inlier_ratio 256 | 257 | ######################################## 258 | # 2. calculate inlier ratios w mutual check 259 | selection = mutual_selection(scores[None,:,:])[0] 260 | row_sel, col_sel = np.where(selection) 261 | dist = torch.norm(src_pcd[row_sel]- tgt_pcd[col_sel],dim=1) 262 | results['w']['distance'] = dist.numpy() 263 | 264 | c_inlier_ratio = (dist < inlier_distance_threshold).float().mean() 265 | results['w']['inlier_ratio'] = c_inlier_ratio 266 | 267 | return results 268 | 269 | 270 | def mutual_selection(score_mat): 271 | """ 272 | Return a {0,1} matrix, the element is 1 if and only if it's maximum along both row and column 273 | 274 | Args: np.array() 275 | score_mat: [B,N,N] 276 | Return: 277 | mutuals: [B,N,N] 278 | """ 279 | score_mat=to_array(score_mat) 280 | if(score_mat.ndim==2): 281 | score_mat=score_mat[None,:,:] 282 | 283 | mutuals=np.zeros_like(score_mat) 284 | for i in range(score_mat.shape[0]): # loop through the batch 285 | c_mat=score_mat[i] 286 | flag_row=np.zeros_like(c_mat) 287 | flag_column=np.zeros_like(c_mat) 288 | 289 | max_along_row=np.argmax(c_mat,1)[:,None] 290 | max_along_column=np.argmax(c_mat,0)[None,:] 291 | np.put_along_axis(flag_row,max_along_row,1,1) 292 | np.put_along_axis(flag_column,max_along_column,1,0) 293 | mutuals[i]=(flag_row.astype(np.bool)) & (flag_column.astype(np.bool)) 294 | return mutuals.astype(np.bool) 295 | 296 | 297 | def get_scene_split(whichbenchmark): 298 | """ 299 | Just to check how many valid fragments each scene has 300 | """ 301 | assert whichbenchmark in ['3DMatch','3DLoMatch'] 302 | folder = f'configs/benchmarks/{whichbenchmark}/*/gt.log' 303 | 304 | scene_files=sorted(glob.glob(folder)) 305 | split=[] 306 | count=0 307 | for eachfile in scene_files: 308 | gt_pairs, gt_traj = read_trajectory(eachfile) 309 | split.append([count,count+len(gt_pairs)]) 310 | count+=len(gt_pairs) 311 | return split 312 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss functions 3 | 4 | Author: Shengyu Huang 5 | Last modified: 30.11.2020 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from lib.utils import square_distance 13 | from sklearn.metrics import precision_recall_fscore_support 14 | 15 | class MetricLoss(nn.Module): 16 | """ 17 | We evaluate both contrastive loss and circle loss 18 | """ 19 | def __init__(self,configs,log_scale=16, pos_optimal=0.1, neg_optimal=1.4): 20 | super(MetricLoss,self).__init__() 21 | self.log_scale = log_scale 22 | self.pos_optimal = pos_optimal 23 | self.neg_optimal = neg_optimal 24 | 25 | self.pos_margin = configs.pos_margin 26 | self.neg_margin = configs.neg_margin 27 | self.max_points = configs.max_points 28 | 29 | self.safe_radius = configs.safe_radius 30 | self.matchability_radius = configs.matchability_radius 31 | self.pos_radius = configs.pos_radius # just to take care of the numeric precision 32 | 33 | def get_circle_loss(self, coords_dist, feats_dist): 34 | """ 35 | Modified from: https://github.com/XuyangBai/D3Feat.pytorch 36 | """ 37 | pos_mask = coords_dist < self.pos_radius 38 | neg_mask = coords_dist > self.safe_radius 39 | 40 | ## get anchors that have both positive and negative pairs 41 | row_sel = ((pos_mask.sum(-1)>0) * (neg_mask.sum(-1)>0)).detach() 42 | col_sel = ((pos_mask.sum(-2)>0) * (neg_mask.sum(-2)>0)).detach() 43 | 44 | # get alpha for both positive and negative pairs 45 | pos_weight = feats_dist - 1e5 * (~pos_mask).float() # mask the non-positive 46 | pos_weight = (pos_weight - self.pos_optimal) # mask the uninformative positive 47 | pos_weight = torch.max(torch.zeros_like(pos_weight), pos_weight).detach() 48 | 49 | neg_weight = feats_dist + 1e5 * (~neg_mask).float() # mask the non-negative 50 | neg_weight = (self.neg_optimal - neg_weight) # mask the uninformative negative 51 | neg_weight = torch.max(torch.zeros_like(neg_weight),neg_weight).detach() 52 | 53 | lse_pos_row = torch.logsumexp(self.log_scale * (feats_dist - self.pos_margin) * pos_weight,dim=-1) 54 | lse_pos_col = torch.logsumexp(self.log_scale * (feats_dist - self.pos_margin) * pos_weight,dim=-2) 55 | 56 | lse_neg_row = torch.logsumexp(self.log_scale * (self.neg_margin - feats_dist) * neg_weight,dim=-1) 57 | lse_neg_col = torch.logsumexp(self.log_scale * (self.neg_margin - feats_dist) * neg_weight,dim=-2) 58 | 59 | loss_row = F.softplus(lse_pos_row + lse_neg_row)/self.log_scale 60 | loss_col = F.softplus(lse_pos_col + lse_neg_col)/self.log_scale 61 | 62 | circle_loss = (loss_row[row_sel].mean() + loss_col[col_sel].mean()) / 2 63 | 64 | return circle_loss 65 | 66 | def get_recall(self,coords_dist,feats_dist): 67 | """ 68 | Get feature match recall, divided by number of true inliers 69 | """ 70 | pos_mask = coords_dist < self.pos_radius 71 | n_gt_pos = (pos_mask.sum(-1)>0).float().sum()+1e-12 72 | _, sel_idx = torch.min(feats_dist, -1) 73 | sel_dist = torch.gather(coords_dist,dim=-1,index=sel_idx[:,None])[pos_mask.sum(-1)>0] 74 | n_pred_pos = (sel_dist < self.pos_radius).float().sum() 75 | recall = n_pred_pos / n_gt_pos 76 | return recall 77 | 78 | def get_weighted_bce_loss(self, prediction, gt): 79 | loss = nn.BCELoss(reduction='none') 80 | 81 | class_loss = loss(prediction, gt) 82 | 83 | weights = torch.ones_like(gt) 84 | w_negative = gt.sum()/gt.size(0) 85 | w_positive = 1 - w_negative 86 | 87 | weights[gt >= 0.5] = w_positive 88 | weights[gt < 0.5] = w_negative 89 | w_class_loss = torch.mean(weights * class_loss) 90 | 91 | ####################################### 92 | # get classification precision and recall 93 | predicted_labels = prediction.detach().cpu().round().numpy() 94 | cls_precision, cls_recall, _, _ = precision_recall_fscore_support(gt.cpu().numpy(),predicted_labels, average='binary') 95 | 96 | return w_class_loss, cls_precision, cls_recall 97 | 98 | 99 | def forward(self, src_pcd, tgt_pcd, src_feats, tgt_feats, correspondence, rot, trans,scores_overlap,scores_saliency): 100 | """ 101 | Circle loss for metric learning, here we feed the positive pairs only 102 | Input: 103 | src_pcd: [N, 3] 104 | tgt_pcd: [M, 3] 105 | rot: [3, 3] 106 | trans: [3, 1] 107 | src_feats: [N, C] 108 | tgt_feats: [M, C] 109 | """ 110 | src_pcd = (torch.matmul(rot,src_pcd.transpose(0,1))+trans).transpose(0,1) 111 | stats=dict() 112 | 113 | src_idx = list(set(correspondence[:,0].int().tolist())) 114 | tgt_idx = list(set(correspondence[:,1].int().tolist())) 115 | 116 | ####################### 117 | # get BCE loss for overlap, here the ground truth label is obtained from correspondence information 118 | src_gt = torch.zeros(src_pcd.size(0)) 119 | src_gt[src_idx]=1. 120 | tgt_gt = torch.zeros(tgt_pcd.size(0)) 121 | tgt_gt[tgt_idx]=1. 122 | gt_labels = torch.cat((src_gt, tgt_gt)).to(torch.device('cuda')) 123 | 124 | class_loss, cls_precision, cls_recall = self.get_weighted_bce_loss(scores_overlap, gt_labels) 125 | stats['overlap_loss'] = class_loss 126 | stats['overlap_recall'] = cls_recall 127 | stats['overlap_precision'] = cls_precision 128 | 129 | ####################### 130 | # get BCE loss for saliency part, here we only supervise points in the overlap region 131 | src_feats_sel, src_pcd_sel = src_feats[src_idx], src_pcd[src_idx] 132 | tgt_feats_sel, tgt_pcd_sel = tgt_feats[tgt_idx], tgt_pcd[tgt_idx] 133 | scores = torch.matmul(src_feats_sel, tgt_feats_sel.transpose(0,1)) 134 | _, idx = scores.max(1) 135 | distance_1 = torch.norm(src_pcd_sel - tgt_pcd_sel[idx], p=2, dim=1) 136 | _, idx = scores.max(0) 137 | distance_2 = torch.norm(tgt_pcd_sel - src_pcd_sel[idx], p=2, dim=1) 138 | 139 | gt_labels = torch.cat(((distance_1 self.max_points): 156 | choice = np.random.permutation(correspondence.size(0))[:self.max_points] 157 | correspondence = correspondence[choice] 158 | src_idx = correspondence[:,0] 159 | tgt_idx = correspondence[:,1] 160 | src_pcd, tgt_pcd = src_pcd[src_idx], tgt_pcd[tgt_idx] 161 | src_feats, tgt_feats = src_feats[src_idx], tgt_feats[tgt_idx] 162 | 163 | ####################### 164 | # get L2 distance between source / target point cloud 165 | coords_dist = torch.sqrt(square_distance(src_pcd[None,:,:], tgt_pcd[None,:,:]).squeeze(0)) 166 | feats_dist = torch.sqrt(square_distance(src_feats[None,:,:], tgt_feats[None,:,:],normalised=True)).squeeze(0) 167 | 168 | ############################## 169 | # get FMR and circle loss 170 | ############################## 171 | recall = self.get_recall(coords_dist, feats_dist) 172 | circle_loss = self.get_circle_loss(coords_dist, feats_dist) 173 | 174 | stats['circle_loss']= circle_loss 175 | stats['recall']=recall 176 | 177 | return stats 178 | -------------------------------------------------------------------------------- /lib/ply.py: -------------------------------------------------------------------------------- 1 | # 2 | # 3 | # 0===============================0 4 | # | PLY files reader/writer | 5 | # 0===============================0 6 | # 7 | # 8 | # ---------------------------------------------------------------------------------------------------------------------- 9 | # 10 | # function to read/write .ply files 11 | # 12 | # ---------------------------------------------------------------------------------------------------------------------- 13 | # 14 | # Hugues THOMAS - 10/02/2017 15 | # 16 | 17 | 18 | # ---------------------------------------------------------------------------------------------------------------------- 19 | # 20 | # Imports and global variables 21 | # \**********************************/ 22 | # 23 | 24 | # Basic libs 25 | import numpy as np 26 | import sys 27 | 28 | # Define PLY types 29 | ply_dtypes = dict([ 30 | (b'int8', 'i1'), 31 | (b'char', 'i1'), 32 | (b'uint8', 'u1'), 33 | (b'uchar', 'u1'), 34 | (b'int16', 'i2'), 35 | (b'short', 'i2'), 36 | (b'uint16', 'u2'), 37 | (b'ushort', 'u2'), 38 | (b'int32', 'i4'), 39 | (b'int', 'i4'), 40 | (b'uint32', 'u4'), 41 | (b'uint', 'u4'), 42 | (b'float32', 'f4'), 43 | (b'float', 'f4'), 44 | (b'float64', 'f8'), 45 | (b'double', 'f8') 46 | ]) 47 | 48 | # Numpy reader format 49 | valid_formats = {'ascii': '', 'binary_big_endian': '>', 50 | 'binary_little_endian': '<'} 51 | 52 | 53 | # ---------------------------------------------------------------------------------------------------------------------- 54 | # 55 | # Functions 56 | # \***************/ 57 | # 58 | 59 | 60 | def parse_header(plyfile, ext): 61 | # Variables 62 | line = [] 63 | properties = [] 64 | num_points = None 65 | 66 | while b'end_header' not in line and line != b'': 67 | line = plyfile.readline() 68 | 69 | if b'element' in line: 70 | line = line.split() 71 | num_points = int(line[2]) 72 | 73 | elif b'property' in line: 74 | line = line.split() 75 | properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 76 | 77 | return num_points, properties 78 | 79 | 80 | def parse_mesh_header(plyfile, ext): 81 | # Variables 82 | line = [] 83 | vertex_properties = [] 84 | num_points = None 85 | num_faces = None 86 | current_element = None 87 | 88 | while b'end_header' not in line and line != b'': 89 | line = plyfile.readline() 90 | 91 | # Find point element 92 | if b'element vertex' in line: 93 | current_element = 'vertex' 94 | line = line.split() 95 | num_points = int(line[2]) 96 | 97 | elif b'element face' in line: 98 | current_element = 'face' 99 | line = line.split() 100 | num_faces = int(line[2]) 101 | 102 | elif b'property' in line: 103 | if current_element == 'vertex': 104 | line = line.split() 105 | vertex_properties.append((line[2].decode(), ext + ply_dtypes[line[1]])) 106 | elif current_element == 'vertex': 107 | if not line.startswith('property list uchar int'): 108 | raise ValueError('Unsupported faces property : ' + line) 109 | 110 | return num_points, num_faces, vertex_properties 111 | 112 | 113 | def read_ply(filename, triangular_mesh=False): 114 | """ 115 | Read ".ply" files 116 | 117 | Parameters 118 | ---------- 119 | filename : string 120 | the name of the file to read. 121 | 122 | Returns 123 | ------- 124 | result : array 125 | data stored in the file 126 | 127 | Examples 128 | -------- 129 | Store data in file 130 | 131 | >>> points = np.random.rand(5, 3) 132 | >>> values = np.random.randint(2, size=10) 133 | >>> write_ply('example.ply', [points, values], ['x', 'y', 'z', 'values']) 134 | 135 | Read the file 136 | 137 | >>> data = read_ply('example.ply') 138 | >>> values = data['values'] 139 | array([0, 0, 1, 1, 0]) 140 | 141 | >>> points = np.vstack((data['x'], data['y'], data['z'])).T 142 | array([[ 0.466 0.595 0.324] 143 | [ 0.538 0.407 0.654] 144 | [ 0.850 0.018 0.988] 145 | [ 0.395 0.394 0.363] 146 | [ 0.873 0.996 0.092]]) 147 | 148 | """ 149 | 150 | with open(filename, 'rb') as plyfile: 151 | 152 | # Check if the file start with ply 153 | if b'ply' not in plyfile.readline(): 154 | raise ValueError('The file does not start whith the word ply') 155 | 156 | # get binary_little/big or ascii 157 | fmt = plyfile.readline().split()[1].decode() 158 | if fmt == "ascii": 159 | raise ValueError('The file is not binary') 160 | 161 | # get extension for building the numpy dtypes 162 | ext = valid_formats[fmt] 163 | 164 | # PointCloud reader vs mesh reader 165 | if triangular_mesh: 166 | 167 | # Parse header 168 | num_points, num_faces, properties = parse_mesh_header(plyfile, ext) 169 | 170 | # Get point data 171 | vertex_data = np.fromfile(plyfile, dtype=properties, count=num_points) 172 | 173 | # Get face data 174 | face_properties = [('k', ext + 'u1'), 175 | ('v1', ext + 'i4'), 176 | ('v2', ext + 'i4'), 177 | ('v3', ext + 'i4')] 178 | faces_data = np.fromfile(plyfile, dtype=face_properties, count=num_faces) 179 | 180 | # Return vertex data and concatenated faces 181 | faces = np.vstack((faces_data['v1'], faces_data['v2'], faces_data['v3'])).T 182 | data = [vertex_data, faces] 183 | 184 | else: 185 | 186 | # Parse header 187 | num_points, properties = parse_header(plyfile, ext) 188 | 189 | # Get data 190 | data = np.fromfile(plyfile, dtype=properties, count=num_points) 191 | 192 | return data 193 | 194 | 195 | def header_properties(field_list, field_names): 196 | # List of lines to write 197 | lines = [] 198 | 199 | # First line describing element vertex 200 | lines.append('element vertex %d' % field_list[0].shape[0]) 201 | 202 | # Properties lines 203 | i = 0 204 | for fields in field_list: 205 | for field in fields.T: 206 | lines.append('property %s %s' % (field.dtype.name, field_names[i])) 207 | i += 1 208 | 209 | return lines 210 | 211 | 212 | def write_ply(filename, field_list, field_names, triangular_faces=None): 213 | """ 214 | Write ".ply" files 215 | 216 | Parameters 217 | ---------- 218 | filename : string 219 | the name of the file to which the data is saved. A '.ply' extension will be appended to the 220 | file name if it does no already have one. 221 | 222 | field_list : list, tuple, numpy array 223 | the fields to be saved in the ply file. Either a numpy array, a list of numpy arrays or a 224 | tuple of numpy arrays. Each 1D numpy array and each column of 2D numpy arrays are considered 225 | as one field. 226 | 227 | field_names : list 228 | the name of each fields as a list of strings. Has to be the same length as the number of 229 | fields. 230 | 231 | Examples 232 | -------- 233 | >>> points = np.random.rand(10, 3) 234 | >>> write_ply('example1.ply', points, ['x', 'y', 'z']) 235 | 236 | >>> values = np.random.randint(2, size=10) 237 | >>> write_ply('example2.ply', [points, values], ['x', 'y', 'z', 'values']) 238 | 239 | >>> colors = np.random.randint(255, size=(10,3), dtype=np.uint8) 240 | >>> field_names = ['x', 'y', 'z', 'red', 'green', 'blue', values'] 241 | >>> write_ply('example3.ply', [points, colors, values], field_names) 242 | 243 | """ 244 | 245 | # Format list input to the right form 246 | field_list = list(field_list) if (type(field_list) == list or type(field_list) == tuple) else list((field_list,)) 247 | for i, field in enumerate(field_list): 248 | if field.ndim < 2: 249 | field_list[i] = field.reshape(-1, 1) 250 | if field.ndim > 2: 251 | print('fields have more than 2 dimensions') 252 | return False 253 | 254 | # check all fields have the same number of data 255 | n_points = [field.shape[0] for field in field_list] 256 | if not np.all(np.equal(n_points, n_points[0])): 257 | print('wrong field dimensions') 258 | return False 259 | 260 | # Check if field_names and field_list have same nb of column 261 | n_fields = np.sum([field.shape[1] for field in field_list]) 262 | if (n_fields != len(field_names)): 263 | print('wrong number of field names') 264 | return False 265 | 266 | # Add extension if not there 267 | if not filename.endswith('.ply'): 268 | filename += '.ply' 269 | 270 | # open in text mode to write the header 271 | with open(filename, 'w') as plyfile: 272 | 273 | # First magical word 274 | header = ['ply'] 275 | 276 | # Encoding format 277 | header.append('format binary_' + sys.byteorder + '_endian 1.0') 278 | 279 | # Points properties description 280 | header.extend(header_properties(field_list, field_names)) 281 | 282 | # Add faces if needded 283 | if triangular_faces is not None: 284 | header.append('element face {:d}'.format(triangular_faces.shape[0])) 285 | header.append('property list uchar int vertex_indices') 286 | 287 | # End of header 288 | header.append('end_header') 289 | 290 | # Write all lines 291 | for line in header: 292 | plyfile.write("%s\n" % line) 293 | 294 | # open in binary/append to use tofile 295 | with open(filename, 'ab') as plyfile: 296 | 297 | # Create a structured array 298 | i = 0 299 | type_list = [] 300 | for fields in field_list: 301 | for field in fields.T: 302 | type_list += [(field_names[i], field.dtype.str)] 303 | i += 1 304 | data = np.empty(field_list[0].shape[0], dtype=type_list) 305 | i = 0 306 | for fields in field_list: 307 | for field in fields.T: 308 | data[field_names[i]] = field 309 | i += 1 310 | 311 | data.tofile(plyfile) 312 | 313 | if triangular_faces is not None: 314 | triangular_faces = triangular_faces.astype(np.int32) 315 | type_list = [('k', 'uint8')] + [(str(ind), 'int32') for ind in range(3)] 316 | data = np.empty(triangular_faces.shape[0], dtype=type_list) 317 | data['k'] = np.full((triangular_faces.shape[0],), 3, dtype=np.uint8) 318 | data['0'] = triangular_faces[:, 0] 319 | data['1'] = triangular_faces[:, 1] 320 | data['2'] = triangular_faces[:, 2] 321 | data.tofile(plyfile) 322 | 323 | return True 324 | 325 | 326 | def describe_element(name, df): 327 | """ Takes the columns of the dataframe and builds a ply-like description 328 | 329 | Parameters 330 | ---------- 331 | name: str 332 | df: pandas DataFrame 333 | 334 | Returns 335 | ------- 336 | element: list[str] 337 | """ 338 | property_formats = {'f': 'float', 'u': 'uchar', 'i': 'int'} 339 | element = ['element ' + name + ' ' + str(len(df))] 340 | 341 | if name == 'face': 342 | element.append("property list uchar int points_indices") 343 | 344 | else: 345 | for i in range(len(df.columns)): 346 | # get first letter of dtype to infer format 347 | f = property_formats[str(df.dtypes[i])[0]] 348 | element.append('property ' + f + ' ' + df.columns.values[i]) 349 | 350 | return element 351 | 352 | -------------------------------------------------------------------------------- /lib/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0.0 14 | self.sq_sum = 0.0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | self.sq_sum += val ** 2 * n 23 | self.var = self.sq_sum / self.count - self.avg ** 2 24 | 25 | 26 | class Timer(object): 27 | """A simple timer.""" 28 | 29 | def __init__(self): 30 | self.total_time = 0. 31 | self.calls = 0 32 | self.start_time = 0. 33 | self.diff = 0. 34 | self.avg = 0. 35 | 36 | def reset(self): 37 | self.total_time = 0 38 | self.calls = 0 39 | self.start_time = 0 40 | self.diff = 0 41 | self.avg = 0 42 | 43 | def tic(self): 44 | # using time.time instead of time.clock because time time.clock 45 | # does not normalize for multithreading 46 | self.start_time = time.time() 47 | 48 | def toc(self, average=True): 49 | self.diff = time.time() - self.start_time 50 | self.total_time += self.diff 51 | self.calls += 1 52 | self.avg = self.total_time / self.calls 53 | if average: 54 | return self.avg 55 | else: 56 | return self.diff 57 | -------------------------------------------------------------------------------- /lib/trainer.py: -------------------------------------------------------------------------------- 1 | import time, os, torch,copy 2 | import numpy as np 3 | import torch.nn as nn 4 | from tensorboardX import SummaryWriter 5 | from lib.timer import Timer, AverageMeter 6 | from lib.utils import Logger,validate_gradient 7 | 8 | from tqdm import tqdm 9 | import torch.nn.functional as F 10 | import gc 11 | 12 | 13 | class Trainer(object): 14 | def __init__(self, args): 15 | self.config = args 16 | # parameters 17 | self.start_epoch = 1 18 | self.max_epoch = args.max_epoch 19 | self.save_dir = args.save_dir 20 | self.device = args.device 21 | self.verbose = args.verbose 22 | self.max_points = args.max_points 23 | 24 | self.model = args.model.to(self.device) 25 | self.optimizer = args.optimizer 26 | self.scheduler = args.scheduler 27 | self.scheduler_freq = args.scheduler_freq 28 | self.snapshot_freq = args.snapshot_freq 29 | self.snapshot_dir = args.snapshot_dir 30 | self.benchmark = args.benchmark 31 | self.iter_size = args.iter_size 32 | self.verbose_freq= args.verbose_freq 33 | 34 | self.w_circle_loss = args.w_circle_loss 35 | self.w_overlap_loss = args.w_overlap_loss 36 | self.w_saliency_loss = args.w_saliency_loss 37 | self.desc_loss = args.desc_loss 38 | 39 | self.best_loss = 1e5 40 | self.best_recall = -1e5 41 | self.writer = SummaryWriter(log_dir=args.tboard_dir) 42 | self.logger = Logger(args.snapshot_dir) 43 | self.logger.write(f'#parameters {sum([x.nelement() for x in self.model.parameters()])/1000000.} M\n') 44 | 45 | 46 | if (args.pretrain !=''): 47 | self._load_pretrain(args.pretrain) 48 | 49 | self.loader =dict() 50 | self.loader['train']=args.train_loader 51 | self.loader['val']=args.val_loader 52 | self.loader['test'] = args.test_loader 53 | 54 | with open(f'{args.snapshot_dir}/model','w') as f: 55 | f.write(str(self.model)) 56 | f.close() 57 | 58 | def _snapshot(self, epoch, name=None): 59 | state = { 60 | 'epoch': epoch, 61 | 'state_dict': self.model.state_dict(), 62 | 'optimizer': self.optimizer.state_dict(), 63 | 'scheduler': self.scheduler.state_dict(), 64 | 'best_loss': self.best_loss, 65 | 'best_recall': self.best_recall 66 | } 67 | if name is None: 68 | filename = os.path.join(self.save_dir, f'model_{epoch}.pth') 69 | else: 70 | filename = os.path.join(self.save_dir, f'model_{name}.pth') 71 | self.logger.write(f"Save model to {filename}\n") 72 | torch.save(state, filename) 73 | 74 | def _load_pretrain(self, resume): 75 | if os.path.isfile(resume): 76 | state = torch.load(resume) 77 | self.model.load_state_dict(state['state_dict']) 78 | self.start_epoch = state['epoch'] 79 | self.scheduler.load_state_dict(state['scheduler']) 80 | self.optimizer.load_state_dict(state['optimizer']) 81 | self.best_loss = state['best_loss'] 82 | self.best_recall = state['best_recall'] 83 | 84 | self.logger.write(f'Successfully load pretrained model from {resume}!\n') 85 | self.logger.write(f'Current best loss {self.best_loss}\n') 86 | self.logger.write(f'Current best recall {self.best_recall}\n') 87 | else: 88 | raise ValueError(f"=> no checkpoint found at '{resume}'") 89 | 90 | def _get_lr(self, group=0): 91 | return self.optimizer.param_groups[group]['lr'] 92 | 93 | def stats_dict(self): 94 | stats=dict() 95 | stats['circle_loss']=0. 96 | stats['recall']=0. # feature match recall, divided by number of ground truth pairs 97 | stats['saliency_loss'] = 0. 98 | stats['saliency_recall'] = 0. 99 | stats['saliency_precision'] = 0. 100 | stats['overlap_loss'] = 0. 101 | stats['overlap_recall']=0. 102 | stats['overlap_precision']=0. 103 | return stats 104 | 105 | def stats_meter(self): 106 | meters=dict() 107 | stats=self.stats_dict() 108 | for key,_ in stats.items(): 109 | meters[key]=AverageMeter() 110 | return meters 111 | 112 | 113 | def inference_one_batch(self, inputs, phase): 114 | assert phase in ['train','val','test'] 115 | ################################## 116 | # training 117 | if(phase == 'train'): 118 | self.model.train() 119 | ############################################### 120 | # forward pass 121 | feats, scores_overlap, scores_saliency = self.model(inputs) #[N1, C1], [N2, C2] 122 | pcd = inputs['points'][0] 123 | len_src = inputs['stack_lengths'][0][0] 124 | c_rot, c_trans = inputs['rot'], inputs['trans'] 125 | correspondence = inputs['correspondences'] 126 | 127 | src_pcd, tgt_pcd = inputs['src_pcd_raw'], inputs['tgt_pcd_raw'] 128 | src_feats, tgt_feats = feats[:len_src], feats[len_src:] 129 | 130 | ################################################### 131 | # get loss 132 | stats= self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,correspondence, c_rot, c_trans, scores_overlap, scores_saliency) 133 | 134 | c_loss = stats['circle_loss'] * self.w_circle_loss + stats['overlap_loss'] * self.w_overlap_loss + stats['saliency_loss'] * self.w_saliency_loss 135 | 136 | c_loss.backward() 137 | 138 | else: 139 | self.model.eval() 140 | with torch.no_grad(): 141 | ############################################### 142 | # forward pass 143 | feats, scores_overlap, scores_saliency = self.model(inputs) #[N1, C1], [N2, C2] 144 | pcd = inputs['points'][0] 145 | len_src = inputs['stack_lengths'][0][0] 146 | c_rot, c_trans = inputs['rot'], inputs['trans'] 147 | correspondence = inputs['correspondences'] 148 | 149 | src_pcd, tgt_pcd = inputs['src_pcd_raw'], inputs['tgt_pcd_raw'] 150 | src_feats, tgt_feats = feats[:len_src], feats[len_src:] 151 | 152 | ################################################### 153 | # get loss 154 | stats= self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,correspondence, c_rot, c_trans, scores_overlap, scores_saliency) 155 | 156 | 157 | ################################## 158 | # detach the gradients for loss terms 159 | stats['circle_loss'] = float(stats['circle_loss'].detach()) 160 | stats['overlap_loss'] = float(stats['overlap_loss'].detach()) 161 | stats['saliency_loss'] = float(stats['saliency_loss'].detach()) 162 | 163 | return stats 164 | 165 | 166 | def inference_one_epoch(self,epoch, phase): 167 | gc.collect() 168 | assert phase in ['train','val','test'] 169 | 170 | # init stats meter 171 | stats_meter = self.stats_meter() 172 | 173 | num_iter = int(len(self.loader[phase].dataset) // self.loader[phase].batch_size) 174 | c_loader_iter = self.loader[phase].__iter__() 175 | 176 | self.optimizer.zero_grad() 177 | for c_iter in tqdm(range(num_iter)): # loop through this epoch 178 | ################################## 179 | # load inputs to device. 180 | inputs = c_loader_iter.next() 181 | for k, v in inputs.items(): 182 | if type(v) == list: 183 | inputs[k] = [item.to(self.device) for item in v] 184 | elif type(v) == dict: 185 | pass 186 | else: 187 | inputs[k] = v.to(self.device) 188 | try: 189 | ################################## 190 | # forward pass 191 | # with torch.autograd.detect_anomaly(): 192 | stats = self.inference_one_batch(inputs, phase) 193 | 194 | ################################################### 195 | # run optimisation 196 | if((c_iter+1) % self.iter_size == 0 and phase == 'train'): 197 | gradient_valid = validate_gradient(self.model) 198 | if(gradient_valid): 199 | self.optimizer.step() 200 | else: 201 | self.logger.write('gradient not valid\n') 202 | self.optimizer.zero_grad() 203 | 204 | ################################ 205 | # update to stats_meter 206 | for key,value in stats.items(): 207 | stats_meter[key].update(value) 208 | except Exception as inst: 209 | print(inst) 210 | 211 | torch.cuda.empty_cache() 212 | 213 | if (c_iter + 1) % self.verbose_freq == 0 and self.verbose: 214 | curr_iter = num_iter * (epoch - 1) + c_iter 215 | for key, value in stats_meter.items(): 216 | self.writer.add_scalar(f'{phase}/{key}', value.avg, curr_iter) 217 | 218 | message = f'{phase} Epoch: {epoch} [{c_iter+1:4d}/{num_iter}]' 219 | for key,value in stats_meter.items(): 220 | message += f'{key}: {value.avg:.2f}\t' 221 | 222 | self.logger.write(message + '\n') 223 | 224 | message = f'{phase} Epoch: {epoch}' 225 | for key,value in stats_meter.items(): 226 | message += f'{key}: {value.avg:.2f}\t' 227 | self.logger.write(message+'\n') 228 | 229 | return stats_meter 230 | 231 | 232 | def train(self): 233 | print('start training...') 234 | for epoch in range(self.start_epoch, self.max_epoch): 235 | self.inference_one_epoch(epoch,'train') 236 | self.scheduler.step() 237 | 238 | stats_meter = self.inference_one_epoch(epoch,'val') 239 | 240 | if stats_meter['circle_loss'].avg < self.best_loss: 241 | self.best_loss = stats_meter['circle_loss'].avg 242 | self._snapshot(epoch,'best_loss') 243 | if stats_meter['recall'].avg > self.best_recall: 244 | self.best_recall = stats_meter['recall'].avg 245 | self._snapshot(epoch,'best_recall') 246 | 247 | # we only add saliency loss when we get descent point-wise features 248 | if(stats_meter['recall'].avg>0.3): 249 | self.w_saliency_loss = 1. 250 | else: 251 | self.w_saliency_loss = 0. 252 | 253 | # finish all epoch 254 | print("Training finish!") 255 | 256 | 257 | def eval(self): 258 | print('Start to evaluate on validation datasets...') 259 | stats_meter = self.inference_one_epoch(0,'val') 260 | 261 | for key, value in stats_meter.items(): 262 | print(key, value.avg) 263 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | General utility functions 3 | 4 | Author: Shengyu Huang 5 | Last modified: 30.11.2020 6 | """ 7 | 8 | import os,re,sys,json,yaml,random, argparse, torch, pickle 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import numpy as np 13 | from scipy.spatial.transform import Rotation 14 | 15 | from sklearn.neighbors import NearestNeighbors 16 | from scipy.spatial.distance import minkowski 17 | _EPS = 1e-7 # To prevent division by zero 18 | 19 | 20 | class Logger: 21 | def __init__(self, path): 22 | self.path = path 23 | self.fw = open(self.path+'/log','a') 24 | 25 | def write(self, text): 26 | self.fw.write(text) 27 | self.fw.flush() 28 | 29 | def close(self): 30 | self.fw.close() 31 | 32 | def save_obj(obj, path ): 33 | """ 34 | save a dictionary to a pickle file 35 | """ 36 | with open(path, 'wb') as f: 37 | pickle.dump(obj, f) 38 | 39 | def load_obj(path): 40 | """ 41 | read a dictionary from a pickle file 42 | """ 43 | with open(path, 'rb') as f: 44 | return pickle.load(f) 45 | 46 | def load_config(path): 47 | """ 48 | Loads config file: 49 | 50 | Args: 51 | path (str): path to the config file 52 | 53 | Returns: 54 | config (dict): dictionary of the configuration parameters, merge sub_dicts 55 | 56 | """ 57 | with open(path,'r') as f: 58 | cfg = yaml.safe_load(f) 59 | 60 | config = dict() 61 | for key, value in cfg.items(): 62 | for k,v in value.items(): 63 | config[k] = v 64 | 65 | return config 66 | 67 | 68 | def setup_seed(seed): 69 | """ 70 | fix random seed for deterministic training 71 | """ 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | np.random.seed(seed) 75 | random.seed(seed) 76 | torch.backends.cudnn.deterministic = True 77 | 78 | def square_distance(src, dst, normalised = False): 79 | """ 80 | Calculate Euclid distance between each two points. 81 | Args: 82 | src: source points, [B, N, C] 83 | dst: target points, [B, M, C] 84 | Returns: 85 | dist: per-point square distance, [B, N, M] 86 | """ 87 | B, N, _ = src.shape 88 | _, M, _ = dst.shape 89 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 90 | if(normalised): 91 | dist += 2 92 | else: 93 | dist += torch.sum(src ** 2, dim=-1)[:, :, None] 94 | dist += torch.sum(dst ** 2, dim=-1)[:, None, :] 95 | 96 | dist = torch.clamp(dist, min=1e-12, max=None) 97 | return dist 98 | 99 | 100 | def validate_gradient(model): 101 | """ 102 | Confirm all the gradients are non-nan and non-inf 103 | """ 104 | for name, param in model.named_parameters(): 105 | if param.grad is not None: 106 | if torch.any(torch.isnan(param.grad)): 107 | return False 108 | if torch.any(torch.isinf(param.grad)): 109 | return False 110 | return True 111 | 112 | 113 | def natural_key(string_): 114 | """ 115 | Sort strings by numbers in the name 116 | """ 117 | return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_)] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, torch, time, shutil, json,glob, argparse, shutil 2 | import numpy as np 3 | from easydict import EasyDict as edict 4 | 5 | from datasets.dataloader import get_dataloader, get_datasets 6 | from models.architectures import KPFCNN 7 | from lib.utils import setup_seed, load_config 8 | from lib.tester import get_trainer 9 | from lib.loss import MetricLoss 10 | from configs.models import architectures 11 | 12 | from torch import optim 13 | from torch import nn 14 | setup_seed(0) 15 | 16 | 17 | if __name__ == '__main__': 18 | # load configs 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('config', type=str, help= 'Path to the config file.') 21 | args = parser.parse_args() 22 | config = load_config(args.config) 23 | config['snapshot_dir'] = 'snapshot/%s' % config['exp_dir'] 24 | config['tboard_dir'] = 'snapshot/%s/tensorboard' % config['exp_dir'] 25 | config['save_dir'] = 'snapshot/%s/checkpoints' % config['exp_dir'] 26 | config = edict(config) 27 | 28 | os.makedirs(config.snapshot_dir, exist_ok=True) 29 | os.makedirs(config.save_dir, exist_ok=True) 30 | os.makedirs(config.tboard_dir, exist_ok=True) 31 | json.dump( 32 | config, 33 | open(os.path.join(config.snapshot_dir, 'config.json'), 'w'), 34 | indent=4, 35 | ) 36 | if config.gpu_mode: 37 | config.device = torch.device('cuda') 38 | else: 39 | config.device = torch.device('cpu') 40 | 41 | # backup the files 42 | os.system(f'cp -r models {config.snapshot_dir}') 43 | os.system(f'cp -r datasets {config.snapshot_dir}') 44 | os.system(f'cp -r lib {config.snapshot_dir}') 45 | shutil.copy2('main.py',config.snapshot_dir) 46 | 47 | 48 | # model initialization 49 | config.architecture = architectures[config.dataset] 50 | config.model = KPFCNN(config) 51 | 52 | # create optimizer 53 | if config.optimizer == 'SGD': 54 | config.optimizer = optim.SGD( 55 | config.model.parameters(), 56 | lr=config.lr, 57 | momentum=config.momentum, 58 | weight_decay=config.weight_decay, 59 | ) 60 | elif config.optimizer == 'ADAM': 61 | config.optimizer = optim.Adam( 62 | config.model.parameters(), 63 | lr=config.lr, 64 | betas=(0.9, 0.999), 65 | weight_decay=config.weight_decay, 66 | ) 67 | 68 | # create learning rate scheduler 69 | config.scheduler = optim.lr_scheduler.ExponentialLR( 70 | config.optimizer, 71 | gamma=config.scheduler_gamma, 72 | ) 73 | 74 | # create dataset and dataloader 75 | train_set, val_set, benchmark_set = get_datasets(config) 76 | config.train_loader, neighborhood_limits = get_dataloader(dataset=train_set, 77 | batch_size=config.batch_size, 78 | shuffle=True, 79 | num_workers=config.num_workers, 80 | ) 81 | config.val_loader, _ = get_dataloader(dataset=val_set, 82 | batch_size=config.batch_size, 83 | shuffle=False, 84 | num_workers=1, 85 | neighborhood_limits=neighborhood_limits 86 | ) 87 | config.test_loader, _ = get_dataloader(dataset=benchmark_set, 88 | batch_size=config.batch_size, 89 | shuffle=False, 90 | num_workers=1, 91 | neighborhood_limits=neighborhood_limits) 92 | 93 | # create evaluation metrics 94 | config.desc_loss = MetricLoss(config) 95 | trainer = get_trainer(config) 96 | if(config.mode=='train'): 97 | trainer.train() 98 | elif(config.mode =='val'): 99 | trainer.eval() 100 | else: 101 | trainer.test() -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prs-eth/OverlapPredator/8c78f125fc58d62ad7d149adf1fed43ed54937e4/models/__init__.py -------------------------------------------------------------------------------- /models/architectures.py: -------------------------------------------------------------------------------- 1 | from models.blocks import * 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from models.gcn import GCN 5 | from lib.utils import square_distance 6 | 7 | 8 | class KPFCNN(nn.Module): 9 | 10 | def __init__(self, config): 11 | super(KPFCNN, self).__init__() 12 | 13 | ############ 14 | # Parameters 15 | ############ 16 | # Current radius of convolution and feature dimension 17 | layer = 0 18 | r = config.first_subsampling_dl * config.conv_radius 19 | in_dim = config.in_feats_dim 20 | out_dim = config.first_feats_dim 21 | self.K = config.num_kernel_points 22 | self.epsilon = torch.nn.Parameter(torch.tensor(-5.0)) 23 | self.final_feats_dim = config.final_feats_dim 24 | self.condition = config.condition_feature 25 | self.add_cross_overlap = config.add_cross_score 26 | 27 | ##################### 28 | # List Encoder blocks 29 | ##################### 30 | # Save all block operations in a list of modules 31 | self.encoder_blocks = nn.ModuleList() 32 | self.encoder_skip_dims = [] 33 | self.encoder_skips = [] 34 | 35 | # Loop over consecutive blocks 36 | for block_i, block in enumerate(config.architecture): 37 | 38 | # Check equivariance 39 | if ('equivariant' in block) and (not out_dim % 3 == 0): 40 | raise ValueError('Equivariant block but features dimension is not a factor of 3') 41 | 42 | # Detect change to next layer for skip connection 43 | if np.any([tmp in block for tmp in ['pool', 'strided', 'upsample', 'global']]): 44 | self.encoder_skips.append(block_i) 45 | self.encoder_skip_dims.append(in_dim) 46 | 47 | # Detect upsampling block to stop 48 | if 'upsample' in block: 49 | break 50 | 51 | # Apply the good block function defining tf ops 52 | self.encoder_blocks.append(block_decider(block, 53 | r, 54 | in_dim, 55 | out_dim, 56 | layer, 57 | config)) 58 | 59 | # Update dimension of input from output 60 | if 'simple' in block: 61 | in_dim = out_dim // 2 62 | else: 63 | in_dim = out_dim 64 | 65 | # Detect change to a subsampled layer 66 | if 'pool' in block or 'strided' in block: 67 | # Update radius and feature dimension for next layer 68 | layer += 1 69 | r *= 2 70 | out_dim *= 2 71 | 72 | ##################### 73 | # bottleneck layer and GNN part 74 | ##################### 75 | gnn_feats_dim = config.gnn_feats_dim 76 | self.bottle = nn.Conv1d(in_dim, gnn_feats_dim,kernel_size=1,bias=True) 77 | k=config.dgcnn_k 78 | num_head = config.num_head 79 | self.gnn = GCN(num_head,gnn_feats_dim, k, config.nets) 80 | self.proj_gnn = nn.Conv1d(gnn_feats_dim,gnn_feats_dim,kernel_size=1, bias=True) 81 | self.proj_score = nn.Conv1d(gnn_feats_dim,1,kernel_size=1,bias=True) 82 | 83 | 84 | ##################### 85 | # List Decoder blocks 86 | ##################### 87 | if self.add_cross_overlap: 88 | out_dim = gnn_feats_dim + 2 89 | else: 90 | out_dim = gnn_feats_dim + 1 91 | 92 | # Save all block operations in a list of modules 93 | self.decoder_blocks = nn.ModuleList() 94 | self.decoder_concats = [] 95 | 96 | # Find first upsampling block 97 | start_i = 0 98 | for block_i, block in enumerate(config.architecture): 99 | if 'upsample' in block: 100 | start_i = block_i 101 | break 102 | 103 | # Loop over consecutive blocks 104 | for block_i, block in enumerate(config.architecture[start_i:]): 105 | 106 | # Add dimension of skip connection concat 107 | if block_i > 0 and 'upsample' in config.architecture[start_i + block_i - 1]: 108 | in_dim += self.encoder_skip_dims[layer] 109 | self.decoder_concats.append(block_i) 110 | 111 | # Apply the good block function defining tf ops 112 | self.decoder_blocks.append(block_decider(block, 113 | r, 114 | in_dim, 115 | out_dim, 116 | layer, 117 | config)) 118 | 119 | # Update dimension of input from output 120 | in_dim = out_dim 121 | 122 | # Detect change to a subsampled layer 123 | if 'upsample' in block: 124 | # Update radius and feature dimension for next layer 125 | layer -= 1 126 | r *= 0.5 127 | out_dim = out_dim // 2 128 | return 129 | 130 | def regular_score(self,score): 131 | score = torch.where(torch.isnan(score), torch.zeros_like(score), score) 132 | score = torch.where(torch.isinf(score), torch.zeros_like(score), score) 133 | return score 134 | 135 | 136 | def forward(self, batch): 137 | # Get input features 138 | x = batch['features'].clone().detach() 139 | len_src_c = batch['stack_lengths'][-1][0] 140 | len_src_f = batch['stack_lengths'][0][0] 141 | pcd_c = batch['points'][-1] 142 | pcd_f = batch['points'][0] 143 | src_pcd_c, tgt_pcd_c = pcd_c[:len_src_c], pcd_c[len_src_c:] 144 | 145 | sigmoid = nn.Sigmoid() 146 | ################################# 147 | # 1. joint encoder part 148 | skip_x = [] 149 | for block_i, block_op in enumerate(self.encoder_blocks): 150 | if block_i in self.encoder_skips: 151 | skip_x.append(x) 152 | x = block_op(x, batch) 153 | 154 | ################################# 155 | # 2. project the bottleneck features 156 | feats_c = x.transpose(0,1).unsqueeze(0) #[1, C, N] 157 | feats_c = self.bottle(feats_c) #[1, C, N] 158 | unconditioned_feats = feats_c.transpose(1,2).squeeze(0) 159 | 160 | ################################# 161 | # 3. apply GNN to communicate the features and get overlap score 162 | src_feats_c, tgt_feats_c = feats_c[:,:,:len_src_c], feats_c[:,:,len_src_c:] 163 | src_feats_c, tgt_feats_c= self.gnn(src_pcd_c.unsqueeze(0).transpose(1,2), tgt_pcd_c.unsqueeze(0).transpose(1,2),src_feats_c, tgt_feats_c) 164 | feats_c = torch.cat([src_feats_c, tgt_feats_c], dim=-1) 165 | 166 | feats_c = self.proj_gnn(feats_c) 167 | scores_c = self.proj_score(feats_c) 168 | 169 | feats_gnn_norm = F.normalize(feats_c, p=2, dim=1).squeeze(0).transpose(0,1) #[N, C] 170 | feats_gnn_raw = feats_c.squeeze(0).transpose(0,1) 171 | scores_c_raw = scores_c.squeeze(0).transpose(0,1) #[N, 1] 172 | 173 | #################################### 174 | # 4. decoder part 175 | src_feats_gnn, tgt_feats_gnn = feats_gnn_norm[:len_src_c], feats_gnn_norm[len_src_c:] 176 | inner_products = torch.matmul(src_feats_gnn, tgt_feats_gnn.transpose(0,1)) 177 | 178 | src_scores_c, tgt_scores_c = scores_c_raw[:len_src_c], scores_c_raw[len_src_c:] 179 | 180 | temperature = torch.exp(self.epsilon) + 0.03 181 | s1 = torch.matmul(F.softmax(inner_products / temperature ,dim=1) ,tgt_scores_c) 182 | s2 = torch.matmul(F.softmax(inner_products.transpose(0,1) / temperature,dim=1),src_scores_c) 183 | scores_saliency = torch.cat((s1,s2),dim=0) 184 | 185 | if(self.condition and self.add_cross_overlap): 186 | x = torch.cat([scores_c_raw,scores_saliency,feats_gnn_raw], dim=1) 187 | elif(self.condition and not self.add_cross_overlap): 188 | x = torch.cat([scores_c_raw,feats_gnn_raw], dim=1) 189 | elif(not self.condition and self.add_cross_overlap): 190 | x = torch.cat([scores_c_raw, scores_saliency, unconditioned_feats], dim = 1) 191 | elif(not self.condition and not self.add_cross_overlap): 192 | x = torch.cat([scores_c_raw, unconditioned_feats], dim = 1) 193 | 194 | for block_i, block_op in enumerate(self.decoder_blocks): 195 | if block_i in self.decoder_concats: 196 | x = torch.cat([x, skip_x.pop()], dim=1) 197 | x = block_op(x, batch) 198 | feats_f = x[:,:self.final_feats_dim] 199 | scores_overlap = x[:,self.final_feats_dim] 200 | scores_saliency = x[:,self.final_feats_dim+1] 201 | 202 | # safe guard our score 203 | scores_overlap = torch.clamp(sigmoid(scores_overlap.view(-1)),min=0,max=1) 204 | scores_saliency = torch.clamp(sigmoid(scores_saliency.view(-1)),min=0,max=1) 205 | scores_overlap = self.regular_score(scores_overlap) 206 | scores_saliency = self.regular_score(scores_saliency) 207 | 208 | # normalise point-wise features 209 | feats_f = F.normalize(feats_f, p=2, dim=1) 210 | 211 | return feats_f, scores_overlap, scores_saliency 212 | -------------------------------------------------------------------------------- /models/gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from copy import deepcopy 5 | import torch.utils.checkpoint as checkpoint 6 | from lib.utils import square_distance 7 | 8 | 9 | def get_graph_feature(coords, feats, k=10): 10 | """ 11 | Apply KNN search based on coordinates, then concatenate the features to the centroid features 12 | Input: 13 | X: [B, 3, N] 14 | feats: [B, C, N] 15 | Return: 16 | feats_cat: [B, 2C, N, k] 17 | """ 18 | # apply KNN search to build neighborhood 19 | B, C, N = feats.size() 20 | dist = square_distance(coords.transpose(1,2), coords.transpose(1,2)) 21 | 22 | idx = dist.topk(k=k+1, dim=-1, largest=False, sorted=True)[1] #[B, N, K+1], here we ignore the smallest element as it's the query itself 23 | idx = idx[:,:,1:] #[B, N, K] 24 | 25 | idx = idx.unsqueeze(1).repeat(1,C,1,1) #[B, C, N, K] 26 | all_feats = feats.unsqueeze(2).repeat(1, 1, N, 1) # [B, C, N, N] 27 | 28 | neighbor_feats = torch.gather(all_feats, dim=-1,index=idx) #[B, C, N, K] 29 | 30 | # concatenate the features with centroid 31 | feats = feats.unsqueeze(-1).repeat(1,1,1,k) 32 | 33 | feats_cat = torch.cat((feats, neighbor_feats-feats),dim=1) 34 | 35 | return feats_cat 36 | 37 | 38 | 39 | class SelfAttention(nn.Module): 40 | def __init__(self,feature_dim,k=10): 41 | super(SelfAttention, self).__init__() 42 | self.conv1 = nn.Conv2d(feature_dim*2, feature_dim, kernel_size=1, bias=False) 43 | self.in1 = nn.InstanceNorm2d(feature_dim) 44 | 45 | self.conv2 = nn.Conv2d(feature_dim*2, feature_dim * 2, kernel_size=1, bias=False) 46 | self.in2 = nn.InstanceNorm2d(feature_dim * 2) 47 | 48 | self.conv3 = nn.Conv2d(feature_dim * 4, feature_dim, kernel_size=1, bias=False) 49 | self.in3 = nn.InstanceNorm2d(feature_dim) 50 | 51 | self.k = k 52 | 53 | def forward(self, coords, features): 54 | """ 55 | Here we take coordinats and features, feature aggregation are guided by coordinates 56 | Input: 57 | coords: [B, 3, N] 58 | feats: [B, C, N] 59 | Output: 60 | feats: [B, C, N] 61 | """ 62 | B, C, N = features.size() 63 | 64 | x0 = features.unsqueeze(-1) #[B, C, N, 1] 65 | 66 | x1 = get_graph_feature(coords, x0.squeeze(-1), self.k) 67 | x1 = F.leaky_relu(self.in1(self.conv1(x1)), negative_slope=0.2) 68 | x1 = x1.max(dim=-1,keepdim=True)[0] 69 | 70 | x2 = get_graph_feature(coords, x1.squeeze(-1), self.k) 71 | x2 = F.leaky_relu(self.in2(self.conv2(x2)), negative_slope=0.2) 72 | x2 = x2.max(dim=-1, keepdim=True)[0] 73 | 74 | x3 = torch.cat((x0,x1,x2),dim=1) 75 | x3 = F.leaky_relu(self.in3(self.conv3(x3)), negative_slope=0.2).view(B, -1, N) 76 | 77 | return x3 78 | 79 | 80 | def MLP(channels: list, do_bn=True): 81 | """ Multi-layer perceptron """ 82 | n = len(channels) 83 | layers = [] 84 | for i in range(1, n): 85 | layers.append( 86 | nn.Conv1d(channels[i - 1], channels[i], kernel_size=1, bias=True)) 87 | if i < (n-1): 88 | if do_bn: 89 | layers.append(nn.InstanceNorm1d(channels[i])) 90 | layers.append(nn.ReLU()) 91 | return nn.Sequential(*layers) 92 | 93 | 94 | def attention(query, key, value): 95 | dim = query.shape[1] 96 | scores = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5 97 | prob = torch.nn.functional.softmax(scores, dim=-1) 98 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 99 | 100 | 101 | class MultiHeadedAttention(nn.Module): 102 | """ Multi-head attention to increase model expressivitiy """ 103 | def __init__(self, num_heads: int, d_model: int): 104 | super().__init__() 105 | assert d_model % num_heads == 0 106 | self.dim = d_model // num_heads 107 | self.num_heads = num_heads 108 | self.merge = nn.Conv1d(d_model, d_model, kernel_size=1) 109 | self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) 110 | 111 | def forward(self, query, key, value): 112 | batch_dim = query.size(0) 113 | query, key, value = [l(x).view(batch_dim, self.dim, self.num_heads, -1) 114 | for l, x in zip(self.proj, (query, key, value))] 115 | x, _ = attention(query, key, value) 116 | return self.merge(x.contiguous().view(batch_dim, self.dim*self.num_heads, -1)) 117 | 118 | 119 | class AttentionalPropagation(nn.Module): 120 | def __init__(self, feature_dim: int, num_heads: int): 121 | super().__init__() 122 | self.attn = MultiHeadedAttention(num_heads, feature_dim) 123 | self.mlp = MLP([feature_dim*2, feature_dim*2, feature_dim]) 124 | nn.init.constant_(self.mlp[-1].bias, 0.0) 125 | 126 | def forward(self, x, source): 127 | message = self.attn(x, source, source) 128 | return self.mlp(torch.cat([x, message], dim=1)) 129 | 130 | 131 | class GCN(nn.Module): 132 | """ 133 | Alternate between self-attention and cross-attention 134 | Input: 135 | coords: [B, 3, N] 136 | feats: [B, C, N] 137 | Output: 138 | feats: [B, C, N] 139 | """ 140 | def __init__(self, num_head: int, feature_dim: int, k: int, layer_names: list): 141 | super().__init__() 142 | self.layers=[] 143 | for atten_type in layer_names: 144 | if atten_type == 'cross': 145 | self.layers.append(AttentionalPropagation(feature_dim,num_head)) 146 | elif atten_type == 'self': 147 | self.layers.append(SelfAttention(feature_dim, k)) 148 | self.layers = nn.ModuleList(self.layers) 149 | self.names = layer_names 150 | 151 | def forward(self, coords0, coords1, desc0, desc1): 152 | for layer, name in zip(self.layers, self.names): 153 | if name == 'cross': 154 | # desc0 = desc0 + checkpoint.checkpoint(layer, desc0, desc1) 155 | # desc1 = desc1 + checkpoint.checkpoint(layer, desc1, desc0) 156 | desc0 = desc0 + layer(desc0, desc1) 157 | desc1 = desc1 + layer(desc1, desc0) 158 | elif name == 'self': 159 | desc0 = layer(coords0, desc0) 160 | desc1 = layer(coords1, desc1) 161 | return desc0, desc1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.3 2 | numpy==1.19.4 3 | torch == 1.7.1 4 | torchvision==0.8.2 5 | torchaudio==0.7.2 6 | nibabel==3.2.1 7 | tqdm==4.38.0 8 | open3d==0.10.0.0 9 | easydict==1.9 10 | scipy==1.5.4 11 | coloredlogs==15.0 12 | PyYAML==5.4.1 13 | scikit_learn==0.24.1 14 | tensorboardX==2.1 15 | vtk_visualizer==0.9.6 16 | nibabel==3.2.1 17 | h5py==3.2.1 18 | coloredlogs==15.0 19 | gitpython==3.1.17 -------------------------------------------------------------------------------- /scripts/cal_overlap.py: -------------------------------------------------------------------------------- 1 | """ 2 | We use this script to calculate the overlap ratios for all the train/test fragment pairs 3 | """ 4 | import os,sys,glob 5 | import open3d as o3d 6 | from lib.utils import natural_key 7 | import numpy as np 8 | from tqdm import tqdm 9 | import multiprocessing as mp 10 | 11 | def determine_epsilon(): 12 | """ 13 | We follow Learning Compact Geomtric Features to compute this hyperparameter, which unfortunately we didn't use later. 14 | """ 15 | base_dir='../dataset/3DMatch/test/*/03_Transformed/*.ply' 16 | files=sorted(glob.glob(base_dir),key=natural_key) 17 | etas=[] 18 | for eachfile in files: 19 | pcd=o3d.io.read_point_cloud(eachfile) 20 | pcd=pcd.voxel_down_sample(0.025) 21 | pcd_tree = o3d.geometry.KDTreeFlann(pcd) 22 | distances=[] 23 | for i, point in enumerate(pcd.points): 24 | [count,vec1, vec2] = pcd_tree.search_knn_vector_3d(point,2) 25 | distances.append(np.sqrt(vec2[1])) 26 | etai=np.median(distances) 27 | etas.append(etai) 28 | return np.median(etas) 29 | 30 | 31 | def get_overlap_ratio(source,target,threshold=0.03): 32 | """ 33 | We compute overlap ratio from source point cloud to target point cloud 34 | """ 35 | pcd_tree = o3d.geometry.KDTreeFlann(target) 36 | 37 | match_count=0 38 | for i, point in enumerate(source.points): 39 | [count, _, _] = pcd_tree.search_radius_vector_3d(point, threshold) 40 | if(count!=0): 41 | match_count+=1 42 | 43 | overlap_ratio = match_count / len(source.points) 44 | return overlap_ratio 45 | 46 | def cal_overlap_per_scene(c_folder): 47 | base_dir=os.path.join(c_folder,'03_Transformed') 48 | fragments=sorted(glob.glob(base_dir+'/*.ply'),key=natural_key) 49 | n_fragments=len(fragments) 50 | 51 | with open(f'{c_folder}/overlaps_ours.txt','w') as f: 52 | for i in tqdm(range(n_fragments-1)): 53 | for j in range(i+1,n_fragments): 54 | path1,path2=fragments[i],fragments[j] 55 | 56 | # load, downsample and transform 57 | pcd1=o3d.io.read_point_cloud(path1) 58 | pcd2=o3d.io.read_point_cloud(path2) 59 | pcd1=pcd1.voxel_down_sample(0.01) 60 | pcd2=pcd2.voxel_down_sample(0.01) 61 | 62 | # calculate overlap 63 | c_overlap = get_overlap_ratio(pcd1,pcd2) 64 | f.write(f'{i},{j},{c_overlap:.4f}\n') 65 | f.close() 66 | 67 | if __name__=='__main__': 68 | base_dir='your data folder' 69 | scenes = sorted(glob.glob(base_dir)) 70 | 71 | p = mp.Pool(processes=mp.cpu_count()) 72 | p.map(cal_overlap_mat,scenes) 73 | p.close() 74 | p.join() -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scripts for pairwise registration demo 3 | 4 | Author: Shengyu Huang 5 | Last modified: 22.02.2021 6 | """ 7 | import os, torch, time, shutil, json,glob,sys,copy, argparse 8 | import numpy as np 9 | from easydict import EasyDict as edict 10 | from torch.utils.data import Dataset 11 | from torch import optim, nn 12 | import open3d as o3d 13 | 14 | cwd = os.getcwd() 15 | sys.path.append(cwd) 16 | from datasets.indoor import IndoorDataset 17 | from datasets.dataloader import get_dataloader 18 | from models.architectures import KPFCNN 19 | from lib.utils import load_obj, setup_seed,natural_key, load_config 20 | from lib.benchmark_utils import ransac_pose_estimation, to_o3d_pcd, get_blue, get_yellow, to_tensor 21 | from lib.trainer import Trainer 22 | from lib.loss import MetricLoss 23 | import shutil 24 | setup_seed(0) 25 | 26 | 27 | class ThreeDMatchDemo(Dataset): 28 | """ 29 | Load subsampled coordinates, relative rotation and translation 30 | Output(torch.Tensor): 31 | src_pcd: [N,3] 32 | tgt_pcd: [M,3] 33 | rot: [3,3] 34 | trans: [3,1] 35 | """ 36 | def __init__(self,config, src_path, tgt_path): 37 | super(ThreeDMatchDemo,self).__init__() 38 | self.config = config 39 | self.src_path = src_path 40 | self.tgt_path = tgt_path 41 | 42 | def __len__(self): 43 | return 1 44 | 45 | def __getitem__(self,item): 46 | # get pointcloud 47 | src_pcd = torch.load(self.src_path).astype(np.float32) 48 | tgt_pcd = torch.load(self.tgt_path).astype(np.float32) 49 | 50 | 51 | #src_pcd = o3d.io.read_point_cloud(self.src_path) 52 | #tgt_pcd = o3d.io.read_point_cloud(self.tgt_path) 53 | #src_pcd = src_pcd.voxel_down_sample(0.025) 54 | #tgt_pcd = tgt_pcd.voxel_down_sample(0.025) 55 | #src_pcd = np.array(src_pcd.points).astype(np.float32) 56 | #tgt_pcd = np.array(tgt_pcd.points).astype(np.float32) 57 | 58 | 59 | src_feats=np.ones_like(src_pcd[:,:1]).astype(np.float32) 60 | tgt_feats=np.ones_like(tgt_pcd[:,:1]).astype(np.float32) 61 | 62 | # fake the ground truth information 63 | rot = np.eye(3).astype(np.float32) 64 | trans = np.ones((3,1)).astype(np.float32) 65 | correspondences = torch.ones(1,2).long() 66 | 67 | return src_pcd,tgt_pcd,src_feats,tgt_feats,rot,trans, correspondences, src_pcd, tgt_pcd, torch.ones(1) 68 | 69 | def lighter(color, percent): 70 | '''assumes color is rgb between (0, 0, 0) and (1,1,1)''' 71 | color = np.array(color) 72 | white = np.array([1, 1, 1]) 73 | vector = white-color 74 | return color + vector * percent 75 | 76 | 77 | def draw_registration_result(src_raw, tgt_raw, src_overlap, tgt_overlap, src_saliency, tgt_saliency, tsfm): 78 | ######################################## 79 | # 1. input point cloud 80 | src_pcd_before = to_o3d_pcd(src_raw) 81 | tgt_pcd_before = to_o3d_pcd(tgt_raw) 82 | src_pcd_before.paint_uniform_color(get_yellow()) 83 | tgt_pcd_before.paint_uniform_color(get_blue()) 84 | src_pcd_before.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=50)) 85 | tgt_pcd_before.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=50)) 86 | 87 | ######################################## 88 | # 2. overlap colors 89 | rot, trans = to_tensor(tsfm[:3,:3]), to_tensor(tsfm[:3,3][:,None]) 90 | src_overlap = src_overlap[:,None].repeat(1,3).numpy() 91 | tgt_overlap = tgt_overlap[:,None].repeat(1,3).numpy() 92 | src_overlap_color = lighter(get_yellow(), 1 - src_overlap) 93 | tgt_overlap_color = lighter(get_blue(), 1 - tgt_overlap) 94 | src_pcd_overlap = copy.deepcopy(src_pcd_before) 95 | src_pcd_overlap.transform(tsfm) 96 | tgt_pcd_overlap = copy.deepcopy(tgt_pcd_before) 97 | src_pcd_overlap.colors = o3d.utility.Vector3dVector(src_overlap_color) 98 | tgt_pcd_overlap.colors = o3d.utility.Vector3dVector(tgt_overlap_color) 99 | 100 | ######################################## 101 | # 3. draw registrations 102 | src_pcd_after = copy.deepcopy(src_pcd_before) 103 | src_pcd_after.transform(tsfm) 104 | 105 | vis1 = o3d.visualization.Visualizer() 106 | vis1.create_window(window_name='Input', width=960, height=540, left=0, top=0) 107 | vis1.add_geometry(src_pcd_before) 108 | vis1.add_geometry(tgt_pcd_before) 109 | 110 | vis2 = o3d.visualization.Visualizer() 111 | vis2.create_window(window_name='Inferred overlap region', width=960, height=540, left=0, top=600) 112 | vis2.add_geometry(src_pcd_overlap) 113 | vis2.add_geometry(tgt_pcd_overlap) 114 | 115 | vis3 = o3d.visualization.Visualizer() 116 | vis3.create_window(window_name ='Our registration', width=960, height=540, left=960, top=0) 117 | vis3.add_geometry(src_pcd_after) 118 | vis3.add_geometry(tgt_pcd_before) 119 | 120 | while True: 121 | vis1.update_geometry(src_pcd_before) 122 | vis3.update_geometry(tgt_pcd_before) 123 | if not vis1.poll_events(): 124 | break 125 | vis1.update_renderer() 126 | 127 | vis2.update_geometry(src_pcd_overlap) 128 | vis2.update_geometry(tgt_pcd_overlap) 129 | if not vis2.poll_events(): 130 | break 131 | vis2.update_renderer() 132 | 133 | vis3.update_geometry(src_pcd_after) 134 | vis3.update_geometry(tgt_pcd_before) 135 | if not vis3.poll_events(): 136 | break 137 | vis3.update_renderer() 138 | 139 | vis1.destroy_window() 140 | vis2.destroy_window() 141 | vis3.destroy_window() 142 | 143 | 144 | def main(config, demo_loader): 145 | config.model.eval() 146 | c_loader_iter = demo_loader.__iter__() 147 | with torch.no_grad(): 148 | inputs = c_loader_iter.next() 149 | ################################## 150 | # load inputs to device. 151 | for k, v in inputs.items(): 152 | if type(v) == list: 153 | inputs[k] = [item.to(config.device) for item in v] 154 | else: 155 | inputs[k] = v.to(config.device) 156 | 157 | ############################################### 158 | # forward pass 159 | feats, scores_overlap, scores_saliency = config.model(inputs) #[N1, C1], [N2, C2] 160 | pcd = inputs['points'][0] 161 | len_src = inputs['stack_lengths'][0][0] 162 | c_rot, c_trans = inputs['rot'], inputs['trans'] 163 | correspondence = inputs['correspondences'] 164 | 165 | src_pcd, tgt_pcd = pcd[:len_src], pcd[len_src:] 166 | src_raw = copy.deepcopy(src_pcd) 167 | tgt_raw = copy.deepcopy(tgt_pcd) 168 | src_feats, tgt_feats = feats[:len_src].detach().cpu(), feats[len_src:].detach().cpu() 169 | src_overlap, src_saliency = scores_overlap[:len_src].detach().cpu(), scores_saliency[:len_src].detach().cpu() 170 | tgt_overlap, tgt_saliency = scores_overlap[len_src:].detach().cpu(), scores_saliency[len_src:].detach().cpu() 171 | 172 | ######################################## 173 | # do probabilistic sampling guided by the score 174 | src_scores = src_overlap * src_saliency 175 | tgt_scores = tgt_overlap * tgt_saliency 176 | 177 | if(src_pcd.size(0) > config.n_points): 178 | idx = np.arange(src_pcd.size(0)) 179 | probs = (src_scores / src_scores.sum()).numpy().flatten() 180 | idx = np.random.choice(idx, size= config.n_points, replace=False, p=probs) 181 | src_pcd, src_feats = src_pcd[idx], src_feats[idx] 182 | if(tgt_pcd.size(0) > config.n_points): 183 | idx = np.arange(tgt_pcd.size(0)) 184 | probs = (tgt_scores / tgt_scores.sum()).numpy().flatten() 185 | idx = np.random.choice(idx, size= config.n_points, replace=False, p=probs) 186 | tgt_pcd, tgt_feats = tgt_pcd[idx], tgt_feats[idx] 187 | 188 | ######################################## 189 | # run ransac and draw registration 190 | tsfm = ransac_pose_estimation(src_pcd, tgt_pcd, src_feats, tgt_feats, mutual=False) 191 | draw_registration_result(src_raw, tgt_raw, src_overlap, tgt_overlap, src_saliency, tgt_saliency, tsfm) 192 | 193 | 194 | if __name__ == '__main__': 195 | # load configs 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('config', type=str, help= 'Path to the config file.') 198 | args = parser.parse_args() 199 | config = load_config(args.config) 200 | config = edict(config) 201 | if config.gpu_mode: 202 | config.device = torch.device('cuda') 203 | else: 204 | config.device = torch.device('cpu') 205 | 206 | # model initialization 207 | config.architecture = [ 208 | 'simple', 209 | 'resnetb', 210 | ] 211 | for i in range(config.num_layers-1): 212 | config.architecture.append('resnetb_strided') 213 | config.architecture.append('resnetb') 214 | config.architecture.append('resnetb') 215 | for i in range(config.num_layers-2): 216 | config.architecture.append('nearest_upsample') 217 | config.architecture.append('unary') 218 | config.architecture.append('nearest_upsample') 219 | config.architecture.append('last_unary') 220 | config.model = KPFCNN(config).to(config.device) 221 | 222 | # create dataset and dataloader 223 | info_train = load_obj(config.train_info) 224 | train_set = IndoorDataset(info_train,config,data_augmentation=True) 225 | demo_set = ThreeDMatchDemo(config, config.src_pcd, config.tgt_pcd) 226 | 227 | _, neighborhood_limits = get_dataloader(dataset=train_set, 228 | batch_size=config.batch_size, 229 | shuffle=True, 230 | num_workers=config.num_workers, 231 | ) 232 | demo_loader, _ = get_dataloader(dataset=demo_set, 233 | batch_size=config.batch_size, 234 | shuffle=False, 235 | num_workers=1, 236 | neighborhood_limits=neighborhood_limits) 237 | 238 | # load pretrained weights 239 | assert config.pretrain != None 240 | state = torch.load(config.pretrain) 241 | config.model.load_state_dict(state['state_dict']) 242 | 243 | # do pose estimation 244 | main(config, demo_loader) 245 | -------------------------------------------------------------------------------- /scripts/download_data_weight.sh: -------------------------------------------------------------------------------- 1 | wget --no-check-certificate --show-progress https://share.phys.ethz.ch/~gseg/Predator/data.zip 2 | wget --no-check-certificate --show-progress https://share.phys.ethz.ch/~gseg/Predator/weights.zip 3 | unzip data.zip 4 | unzip weights.zip 5 | rm data.zip 6 | rm weights.zip 7 | -------------------------------------------------------------------------------- /scripts/evaluate_predator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scripts for pairwise registration with RANSAC and our probabilistic sampling 3 | 4 | Author: Shengyu Huang 5 | Last modified: 30.11.2020 6 | """ 7 | 8 | import torch, os, sys, glob 9 | cwd = os.getcwd() 10 | sys.path.append(cwd) 11 | from tqdm import tqdm 12 | import numpy as np 13 | from lib.utils import load_obj, natural_key,setup_seed 14 | from lib.benchmark_utils import ransac_pose_estimation, get_inlier_ratio, get_scene_split, write_est_trajectory 15 | import open3d as o3d 16 | from lib.benchmark import read_trajectory, write_trajectory, benchmark 17 | import argparse 18 | setup_seed(0) 19 | 20 | def sample_interest_points(method, scores, N): 21 | """ 22 | We can do random sampling, probabilistic sampling, or top-k sampling 23 | """ 24 | assert method in ['prob','topk', 'random'] 25 | n = scores.size(0) 26 | if n < N: 27 | choice = np.random.choice(n, N) 28 | else: 29 | if method == 'random': 30 | choice = np.random.permutation(n)[:N] 31 | elif method =='topk': 32 | choice = torch.topk(scores, N, dim=0)[1] 33 | elif method =='prob': 34 | idx = np.arange(n) 35 | probs = (scores / scores.sum()).numpy().flatten() 36 | choice = np.random.choice(idx, size= N, replace=False, p=probs) 37 | 38 | return choice 39 | 40 | 41 | 42 | def benchmark_predator(feats_scores,n_points,exp_dir,whichbenchmark,sample_method,ransac_with_mutual=False, inlier_ratio_threshold = 0.05): 43 | gt_folder = f'configs/benchmarks/{whichbenchmark}' 44 | exp_dir = f'{exp_dir}/{whichbenchmark}_{n_points}_{sample_method}' 45 | if(not os.path.exists(exp_dir)): 46 | os.makedirs(exp_dir) 47 | print(exp_dir) 48 | 49 | results = dict() 50 | results['w_mutual'] = {'inlier_ratios':[], 'distances':[]} 51 | results['wo_mutual'] = {'inlier_ratios':[], 'distances':[]} 52 | tsfm_est = [] 53 | for eachfile in tqdm(feats_scores): 54 | ######################################## 55 | # 1. take the input point clouds 56 | data = torch.load(eachfile) 57 | len_src = data['len_src'] 58 | pcd = data['pcd'] 59 | feats = data['feats'] 60 | rot, trans = data['rot'], data['trans'] 61 | saliency, overlap = data['saliency'], data['overlaps'] 62 | 63 | src_pcd = pcd[:len_src] 64 | tgt_pcd = pcd[len_src:] 65 | src_feats = feats[:len_src] 66 | tgt_feats = feats[len_src:] 67 | src_overlap, src_saliency = overlap[:len_src], saliency[:len_src] 68 | tgt_overlap, tgt_saliency = overlap[len_src:], saliency[len_src:] 69 | 70 | ######################################## 71 | # 2. do probabilistic sampling guided by the score 72 | src_scores = src_overlap * src_saliency 73 | tgt_scores = tgt_overlap * tgt_saliency 74 | 75 | 76 | src_idx = sample_interest_points(sample_method, src_scores,n_points) 77 | tgt_idx = sample_interest_points(sample_method, tgt_scores,n_points) 78 | 79 | src_pcd, src_feats = src_pcd[src_idx], src_feats[src_idx] 80 | tgt_pcd, tgt_feats = tgt_pcd[tgt_idx], tgt_feats[tgt_idx] 81 | 82 | ######################################## 83 | # 3. run ransac 84 | tsfm_est.append(ransac_pose_estimation(src_pcd, tgt_pcd, src_feats, tgt_feats, mutual=ransac_with_mutual)) 85 | 86 | ######################################## 87 | # 4. calculate inlier ratios 88 | inlier_ratio_results = get_inlier_ratio(src_pcd, tgt_pcd, src_feats, tgt_feats, rot, trans) 89 | 90 | results['w_mutual']['inlier_ratios'].append(inlier_ratio_results['w']['inlier_ratio']) 91 | results['w_mutual']['distances'].append(inlier_ratio_results['w']['distance']) 92 | results['wo_mutual']['inlier_ratios'].append(inlier_ratio_results['wo']['inlier_ratio']) 93 | results['wo_mutual']['distances'].append(inlier_ratio_results['wo']['distance']) 94 | 95 | tsfm_est = np.array(tsfm_est) 96 | 97 | ######################################## 98 | # wirte the estimated trajectories 99 | write_est_trajectory(gt_folder, exp_dir, tsfm_est) 100 | 101 | ######################################## 102 | # evaluate the results, here FMR and Inlier ratios are all average twice 103 | benchmark(exp_dir, gt_folder) 104 | split = get_scene_split(whichbenchmark) 105 | 106 | for key in['w_mutual','wo_mutual']: 107 | inliers =[] 108 | fmrs = [] 109 | 110 | for ele in split: 111 | c_inliers = results[key]['inlier_ratios'][ele[0]:ele[1]] 112 | inliers.append(np.mean(c_inliers)) 113 | fmrs.append((np.array(c_inliers) > inlier_ratio_threshold).mean()) 114 | 115 | with open(os.path.join(exp_dir,'result'),'a') as f: 116 | f.write(f'Inlier ratio {key}: {np.mean(inliers):.3f} : +- {np.std(inliers):.3f}\n') 117 | f.write(f'Feature match recall {key}: {np.mean(fmrs):.3f} : +- {np.std(fmrs):.3f}\n') 118 | f.close() 119 | 120 | 121 | if __name__=='__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument( 124 | '--source_path', default=None, type=str, help='path to precomputed features and scores') 125 | parser.add_argument( 126 | '--benchmark', default='3DLoMatch', type=str, help='[3DMatch, 3DLoMatch]') 127 | parser.add_argument( 128 | '--n_points', default=1000, type=int, help='number of points used by RANSAC') 129 | parser.add_argument( 130 | '--exp_dir', default='est_traj', type=str, help='export final results') 131 | parser.add_argument( 132 | '--sampling', default='prob', type = str, help='interest point sampling') 133 | args = parser.parse_args() 134 | 135 | feats_scores = sorted(glob.glob(f'{args.source_path}/*.pth'), key=natural_key) 136 | 137 | benchmark_predator(feats_scores, args.n_points, args.exp_dir, args.benchmark, args.sampling) 138 | --------------------------------------------------------------------------------