├── LICENSE ├── README.md ├── assets ├── 3dmatch.png ├── cast.jpg ├── kitti.png └── nuscenes.png ├── calibrate.py ├── config ├── 3dmatch.json ├── eth.json ├── kitti.json └── nuscenes.json ├── data ├── 3dmatch_list │ ├── benchmark │ │ ├── 3DLoMatch │ │ │ ├── 7-scenes-redkitchen │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_uc-scan3 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel1 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel3 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ └── 3DMatch │ │ │ ├── 7-scenes-redkitchen │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_uc-scan3 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel1 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel3 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ ├── train.pkl │ └── val.pkl ├── eth_data.py ├── gen_kitti_data.py ├── indoor_data.py ├── kitti_data.py ├── kitti_list │ ├── test.txt │ ├── train.txt │ └── val.txt ├── nuscenes_data.py └── nuscenes_list │ ├── test.txt │ ├── train.txt │ └── val.txt ├── demo_3dmatch.py ├── demo_outdoor.py ├── engine ├── __init__.py ├── evaluator.py ├── losses.py ├── summary_board.py └── trainer.py ├── evaluate_IR_FMR.py ├── evaluate_RR.py ├── evaluate_eth.py ├── finetune.py ├── models ├── cast │ ├── __init__.py │ ├── cast.py │ ├── consistency.py │ ├── correspondence.py │ └── spot_attention.py ├── kpconv │ ├── __init__.py │ ├── backbone.py │ ├── dispositions │ │ └── k_015_center_3D.ply │ ├── kernel_points.py │ ├── kpconv.py │ └── modules.py ├── models │ ├── cast.py │ └── cast_eth.py ├── transformer │ ├── __init__.py │ ├── conditional_transformer.py │ ├── linear_transformer.py │ ├── output_layer.py │ ├── pe_transformer.py │ ├── positional_encoding.py │ ├── rpe_transformer.py │ └── vanilla_transformer.py └── utils.py ├── requirements.txt └── trainval.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Renlang Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Consistency-Aware Spot-Guided Transformer for Versatile and Hierarchical Point Cloud Registration 2 | 3 | Official PyTorch implementation of the paper [''A Consistency-Aware Spot-Guided Transformer for Versatile and Hierarchical Point Cloud Registration''](https://arxiv.org/abs/2410.10295) accepted by NeurIPS 2024 as poster. 4 | 5 | ### 1. Introduction 6 | 7 | We present a novel consistency-aware spot-guided Transformer to achieve compact, consistent coarse matching and efficient, accurate pose estimation for point cloud registration. At the coarse matching stage, our consistency-aware self-attention enhances the feature representations with sparse sampling from the geometric compatibility graph.Additionally, our spot-guided cross-attention leverages local consistency to guide the cross-attention to confident spots without interfering with relevant areas. Based on these semi-dense and consistent coarse correspondences, a lightweight and scalable sparse-to-dense fine matching module empowered by local attention can achieve accurate pose estimation without optimal transport or hypothesis-and-selection pipelines. Our method has showcased *state-of-the-art* accuracy, robustness, and efficiency for point cloud registration across different 3D sensors and scenarios. 8 | 9 | ![](assets/cast.jpg) 10 | 11 | 12 | ### 2. Installation 13 | 14 | Please use the following command for installation. 15 | 16 | ```bash 17 | # It is recommended to create a new environment 18 | conda create -n cast python==3.7 19 | conda activate cast 20 | 21 | # Install packages and other dependencies 22 | pip install -r requirements.txt 23 | 24 | # If you are using CUDA 11.2 or newer, you can install `torch==1.7.1+cu110` or `torch=1.9.0+cu111` 25 | pip install torch==1.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 26 | 27 | # Install pytorch3d (feel free to download it to other directories) 28 | conda install openblas-devel -c anaconda 29 | wget https://github.com/facebookresearch/pytorch3d/archive/refs/tags/v0.6.2.zip 30 | mv v0.6.2.zip pytorch3d-0.6.2.zip 31 | unzip pytorch3d-0.6.2.zip 32 | cd pytorch3d-0.6.2 33 | pip install -e . 34 | cd .. 35 | 36 | # Install MinkowskiEngine (feel free to download it to other directories) 37 | git clone https://github.com/NVIDIA/MinkowskiEngine 38 | cd MinkowskiEngine 39 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas 40 | 41 | # Download pre-trained weights from release v1.0.0 42 | ``` 43 | 44 | Code has been tested with Ubuntu 20.04, GCC 9.4.0, Python 3.7, PyTorch 1.9.0, CUDA 11.2 and PyTorch3D 0.6.2. 45 | 46 | 47 | 48 | ### 3. KITTI odometry 49 | 50 | #### Data preparation 51 | 52 | Download the data from the [KITTI official website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php). The data should be organized as follows: 53 | - `KITTI` 54 | - `velodyne` (point clouds) 55 | - `sequences` 56 | - `00` 57 | - `velodyne` 58 | - `000000.bin` 59 | - ... 60 | - ... 61 | - `results` (poses) 62 | - `00.txt` 63 | - ... 64 | - `sequences` (sensor calibration and time stamps) 65 | - `00` 66 | - `calib.txt` 67 | - `times.txt` 68 | - ... 69 | 70 | Please note that we have already generated the information of pairwise point clouds via ``./data/gen_kitti_data.py``, which is stored in ``./data/kitti_list``. Feel free to use it directly or re-generate the information by yourselves. 71 | 72 | #### Training 73 | After modifying the ```data.root``` item to your dataset path in ```./config/kitti.json```, you can use the following command for training. 74 | ```bash 75 | python trainval.py --mode train --config ./config/kitti.json 76 | ``` 77 | 78 | #### Testing 79 | After modifying the ```data.root``` item to your dataset path in ```./config/kitti.json```, you can use the following command for testing. 80 | ```bash 81 | python trainval.py --mode test --config ./config/kitti.json --load_pretrained cast-epoch-39 82 | ``` 83 | 84 | #### Qualitative results 85 | You can use the following command for visualization: 86 | ```bash 87 | # visualize the keypoints 88 | python demo_outdoor.py --dataset kitti --mode keypts --load_pretrained cast-epoch-39 --split train --id 0 89 | # visualize the keypoint correspondences 90 | python demo_outdoor.py --dataset kitti --mode corr --load_pretrained cast-epoch-39 --split train --id 0 91 | # visualize the aligned point clouds after pose estimation 92 | python demo_outdoor.py --dataset kitti --mode reg --load_pretrained cast-epoch-39 --split train --id 0 93 | ``` 94 | ![](assets/kitti.png) 95 | 96 | 97 | ### 4. nuScenes 98 | 99 | #### Data preparation 100 | 101 | Download the data from the [nuScenes official website](https://www.nuscenes.org/nuscenes#download). The data should be organized as follows: 102 | - `nuscenes` 103 | - `samples` 104 | - `LIDAR_TOP` 105 | - `n008-2018-05-21-11-06-59-0400__LIDAR_TOP__1526915243047392.pcd.bin` 106 | - ... 107 | 108 | #### Training 109 | After modifying the ```data.root``` item to your dataset path in ```./config/nuscenes.json```, you can use the following command for training. 110 | ```bash 111 | python trainval.py --mode train --config ./config/nuscenes.json 112 | ``` 113 | 114 | #### Testing 115 | After modifying the ```data.root``` item to your dataset path in ```./config/nuscenes.json```, you can use the following command for testing. 116 | ```bash 117 | python trainval.py --mode test --config ./config/nuscenes.json --load_pretrained cast-epoch-03-26000 118 | ``` 119 | 120 | #### Qualitative results 121 | ```bash 122 | # visualize the keypoints 123 | python demo_outdoor.py --dataset nuscenes --mode keypts --load_pretrained cast-epoch-03-26000 --split train --id 0 124 | # visualize the keypoint correspondences 125 | python demo_outdoor.py --dataset nuscenes --mode corr --load_pretrained cast-epoch-03-26000 --split train --id 0 126 | # visualize the aligned point clouds after pose estimation 127 | python demo_outdoor.py --dataset nuscenes --mode reg --load_pretrained cast-epoch-03-26000 --split train --id 0 128 | ``` 129 | ![](assets/nuscenes.png) 130 | 131 | 132 | ### 5. 3DMatch and 3DLoMatch 133 | 134 | #### Data preparation 135 | 136 | The dataset can be downloaded from [PREDATOR](https://github.com/prs-eth/OverlapPredator) (by running the following commands): 137 | ```bash 138 | wget --no-check-certificate --show-progress https://share.phys.ethz.ch/~gsg/pairwise_reg/3dmatch.zip 139 | unzip 3dmatch.zip 140 | ``` 141 | The data should be organized as follows: 142 | - `3dmatch` 143 | - `train` 144 | - `7-scenes-chess` 145 | - `fragments` 146 | - `cloud_bin_*.ply` 147 | - ... 148 | - `poses` 149 | - `cloud_bin_*.txt` 150 | - ... 151 | - ... 152 | - `test` 153 | - `7-scenes-redkitchen` 154 | - `fragments` 155 | - `cloud_bin_*.ply` 156 | - ... 157 | - `poses` 158 | - `cloud_bin_*.txt` 159 | - ... 160 | - ... 161 | 162 | #### Training 163 | After modifying the ```data.root``` item to your dataset path in ```./config/3dmatch.json```, you can use the following command for training. 164 | ```bash 165 | python trainval.py --mode train --config ./config/3dmatch.json 166 | ``` 167 | 168 | #### Testing 169 | After modifying the ```data.root``` item to your dataset path in ```./config/3dmatch.json```, you can use the following command for testing. 170 | ```bash 171 | # evaluate the registration recall (CAST+RANSAC) 172 | ## for 3DMatch benchmark 173 | python evaluate_RR.py --benchmark 3DMatch --config ./config/3dmatch.json --load_pretrained cast-epoch-05 --ransac 174 | ## for 3DLoMatch benchmark 175 | python evaluate_RR.py --benchmark 3DLoMatch -config ./config/3dmatch.json --load_pretrained cast-epoch-05 --ransac 176 | 177 | # evaluate the registration recall (CAST) 178 | ## for 3DMatch benchmark 179 | python evaluate_RR.py --benchmark 3DMatch --config ./config/3dmatch.json --load_pretrained cast-epoch-05 180 | ## for 3DLoMatch benchmark 181 | python evaluate_RR.py --benchmark 3DLoMatch --config ./config/3dmatch.json --load_pretrained cast-epoch-05 182 | 183 | # evaluate IR, FMR, PIR, and PMR 184 | ## for 3DMatch benchmark 185 | python evaluate_IR_FMR.py --benchmark 3DMatch --config ./config/3dmatch.json --load_pretrained cast-epoch-05 186 | ## for 3DLoMatch benchmark 187 | python evaluate_IR_FMR.py --benchmark 3DLoMatch --config ./config/3dmatch.json --load_pretrained cast-epoch-05 188 | ``` 189 | 190 | 191 | #### Qualitative results 192 | You can use the following command for visualization: 193 | ```bash 194 | python demo_3dmatch.py --split test --benchmark 3DMatch --id 0 195 | ``` 196 | ![](assets/3dmatch.png) 197 | 198 | 199 | 200 | ### 6. Generalization and Adaptation to ETH 201 | 202 | #### Data preparation 203 | 204 | This dataset can be downloaded [here](https://share.phys.ethz.ch/~gsg/3DSmoothNet/data/ETH.rar), which is organized as follows after unzipping it: 205 | - `ETH` 206 | - `gazebo_summer` 207 | - `gt.log` 208 | - `overlapMatrix.csv` 209 | - `Hokuyo_0.ply` 210 | - `Hokuyo_?.ply`... 211 | - `gazebo_winter` 212 | - `wood_autmn` 213 | - `wood_summer` 214 | #### Testing 215 | After modifying the ```data.root``` item to your dataset path in ```./config/eth.json```, you can use the following command for testing. 216 | ```bash 217 | python evaluate_eth.py 218 | ``` 219 | #### Training (Unsupervised Domain Adaptation) 220 | You can use the following command for tuning the network in an unsupervised manner. 221 | ```bash 222 | python finetune.py 223 | ``` 224 | 225 | ## Acknowledgements 226 | We sincerely thank the area chair for the appreciation of our work, and we sincerely thank all of the reviewers for their constructive reviews and valuable suggestions. Meanwhile, we would like to thank the authors of [D3Feat](https://github.com/XuyangBai/D3Feat.pytorch), [PREDATOR](https://github.com/prs-eth/OverlapPredator), [RSKDD](https://github.com/lufan11223/RSKDD-Net), [PointDSC](https://github.com/XuyangBai/PointDSC), [HRegNet](https://github.com/ispc-lab/HRegNet2), [GeoTransformer](https://github.com/qinzheng93/GeoTransformer), and [REGTR](https://github.com/yewzijian/RegTR) for making their source codes public. 227 | -------------------------------------------------------------------------------- /assets/3dmatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/assets/3dmatch.png -------------------------------------------------------------------------------- /assets/cast.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/assets/cast.jpg -------------------------------------------------------------------------------- /assets/kitti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/assets/kitti.png -------------------------------------------------------------------------------- /assets/nuscenes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/assets/nuscenes.png -------------------------------------------------------------------------------- /calibrate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from munch import Munch 5 | from data.indoor_data import IndoorDataset 6 | from data.kitti_data import KittiDataset 7 | from data.nuscenes_data import NuscenesDataset 8 | from models.utils import grid_subsample_gpu, radius_search_gpu 9 | 10 | 11 | def collate_fn(points, lengths, num_stages, voxel_size, radius, max_neighbor_limits): 12 | neighbors_list = list() 13 | 14 | for i in range(num_stages): 15 | neighbors_list.append(radius_search_gpu(points, points, lengths, lengths, radius, max_neighbor_limits[i])) 16 | radius, voxel_size = radius * 2., voxel_size * 2. 17 | if i == num_stages - 1: break 18 | points, lengths = grid_subsample_gpu(points, lengths, voxel_size) 19 | 20 | return neighbors_list 21 | 22 | 23 | def calibrate_neighbors_stack_mode( 24 | dataset, num_stages, voxel_size, search_radius, keep_ratio=0.8, sample_threshold=2000 25 | ): 26 | # Compute higher bound of neighbors number in a neighborhood 27 | hist_n = int(np.ceil(4 / 3 * np.pi * (search_radius / voxel_size + 1) ** 3)) 28 | neighbor_hists = np.zeros((num_stages, hist_n), dtype=np.int32) 29 | max_neighbor_limits = [hist_n] * num_stages 30 | 31 | # Get histogram of neighborhood sizes i in 1 epoch max. 32 | for i in range(len(dataset)): 33 | data_dict = dataset[i] 34 | points = torch.cat([data_dict[0], data_dict[1]], dim=0) 35 | lengths = torch.LongTensor([data_dict[0].shape[0], data_dict[1].shape[0]]) 36 | data_dict = collate_fn(points, lengths, num_stages, voxel_size, search_radius, max_neighbor_limits) 37 | counts = [np.sum(neighbors.numpy() < neighbors.shape[0], axis=1) for neighbors in data_dict] 38 | hists = [np.bincount(c, minlength=hist_n)[:hist_n] for c in counts] 39 | neighbor_hists += np.vstack(hists) 40 | if np.min(np.sum(neighbor_hists, axis=1)) > sample_threshold: break 41 | 42 | cum_sum = np.cumsum(neighbor_hists.T, axis=0) 43 | return np.sum(cum_sum < (keep_ratio * cum_sum[hist_n - 1, :]), axis=0) 44 | 45 | 46 | info = Munch() 47 | info.indoor = Munch() 48 | info.indoor.root = '/home/jacko/Downloads/3dmatch/' 49 | info.indoor.data_list = './data/3dmatch_list/' 50 | info.kitti = Munch() 51 | info.kitti.root = '/media/jacko/SSD/KITTI/velodyne/sequences/' 52 | info.kitti.data_list = './data/kitti_list/' 53 | info.nuscenes = Munch() 54 | info.nuscenes.root = '/media/jacko/SSD/nuscenes/' 55 | info.nuscenes.data_list = './data/nuscenes_list/' 56 | 57 | seed = 1 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | 62 | dataset = KittiDataset(info.kitti.root, 'train', 30000, 0.3, info.kitti.data_list, 0.5) 63 | print('Create a data loader for KITTI with %d samples.'%len(dataset)) 64 | neighbor_limits = calibrate_neighbors_stack_mode(dataset, 5, 0.3, 0.3 * 3.0) 65 | print('Calibrate neighbors for KITTI: ', neighbor_limits) 66 | 67 | dataset = NuscenesDataset(info.nuscenes.root, 'train', 30000, 0.3, info.nuscenes.data_list, 0.5) 68 | print('Create a data loader for NuScenes with %d samples.'%len(dataset)) 69 | neighbor_limits = calibrate_neighbors_stack_mode(dataset, 5, 0.3, 0.3 * 3.25) 70 | print('Calibrate neighbors for NuScenes: ', neighbor_limits) 71 | 72 | dataset = IndoorDataset(info.indoor.root, 'train', None, 0.03, info.indoor.data_list, 0.5) 73 | print('Create a data loader for 3DMatch with %d samples.'%len(dataset)) 74 | neighbor_limits = calibrate_neighbors_stack_mode(dataset, 4, 0.03, 0.025 * 2.5) 75 | print('Calibrate neighbors for 3DMatch: ', neighbor_limits) 76 | 77 | ''' 78 | Outputs: 79 | Create a data loader for KITTI with 1358 samples. 80 | Calibrate neighbors for KITTI: [35 35 36 37 41] 81 | Create a data loader for NuScenes with 26696 samples. 82 | Calibrate neighbors for NuScenes: [19 26 36 45 46] 83 | Create a data loader for 3DMatch with 20642 samples. 84 | Calibrate neighbors for 3DMatch: [26 21 22 25] 85 | ''' -------------------------------------------------------------------------------- /config/3dmatch.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 1, 3 | "log_steps": 200, 4 | "save_steps": 5000, 5 | "dataset": "3dmatch", 6 | "runname": "cast", 7 | "log_dir": "./logs/3dmatch/", 8 | "snapshot_dir": "./ckpt/3dmatch/", 9 | "data_list": "./data/3dmatch_list/", 10 | 11 | "data": { 12 | "root": "/home/jacko/Downloads/3dmatch/", 13 | "num_workers": 4, 14 | "npoints": 30000, 15 | "voxel_size": 0.025, 16 | "augment": 0.5 17 | }, 18 | 19 | "model": { 20 | "kpconv_layers": 4, 21 | "voxel_size": 0.025, 22 | "input_dim": 1, 23 | "init_dim": 64, 24 | "kernel_size": 15, 25 | "init_radius": 0.0625, 26 | "init_sigma": 0.05, 27 | "neighbor_limits": [26, 21, 22, 25], 28 | 29 | "k": 12, 30 | "spots": 4, 31 | "down_k": 4, 32 | "spot_k": 12, 33 | "dense_neighbors": 6, 34 | "input_dim_f": 512, 35 | "input_dim_c": 1024, 36 | "output_dim": 256, 37 | "hidden_dim": 128, 38 | "desc_dim": 32, 39 | "num_heads": 4, 40 | "dropout": null, 41 | "activation_fn": "relu", 42 | "sigma_d": 0.2, 43 | "sigma_a": 15, 44 | "angle_k": 3, 45 | "reduction_a": "max", 46 | "sigma_c": 0.15, 47 | "seed_threshold": 0.3, 48 | "seed_num": 48, 49 | "blocks": 3, 50 | 51 | "sigma_r": 0.075, 52 | "use_overlap_head": true, 53 | "overlap_threshold": 0.1, 54 | "keypoint_node_threshold": 0.1, 55 | "local_matching_radius": 0.15, 56 | "dual_normalization": true, 57 | 58 | "patch_k": 16, 59 | "num_neighbors": 4, 60 | "learnable_matcher": true, 61 | "filter_layers": 3, 62 | "filter_sigma_d": 0.1, 63 | 64 | "ransac": true, 65 | "ransac_filter": 0.1 66 | }, 67 | 68 | "optim": { 69 | "lr": 1e-4, 70 | "step_size": 5, 71 | "weight_decay": 1e-4, 72 | "gamma": 0.9, 73 | "max_epoch": 6, 74 | "clip_grad_norm": 0.5 75 | }, 76 | 77 | "loss": { 78 | "positive_margin": 0.1, 79 | "negative_margin": 1.4, 80 | "positive_optimal": 0.1, 81 | "negative_optimal": 1.4, 82 | "log_scale": 24, 83 | "positive_overlap": 0.1, 84 | 85 | "r_p": 0.05, 86 | "r_n": 0.06, 87 | 88 | "weight_det_loss": 1.0, 89 | "weight_spot_loss": 0.1, 90 | "weight_feat_loss": 0.5, 91 | "weight_desc_loss": 1.0, 92 | "weight_overlap_loss": 1.0, 93 | "weight_corr_loss": 10.0, 94 | "weight_trans_loss": 5.0, 95 | "weight_rot_loss": 20.0, 96 | "pretrain_feat_epochs": 1 97 | }, 98 | 99 | "eval": { 100 | "hit_threshold": 0.05, 101 | "acceptance_overlap": 0.0, 102 | "rmse_threshold": 0.2, 103 | "inlier_distance_threshold": 0.1, 104 | "inlier_ratio_threshold": 0.05 105 | } 106 | } -------------------------------------------------------------------------------- /config/eth.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 1, 3 | "log_steps": 200, 4 | "save_steps": 0, 5 | "dataset": "eth", 6 | "runname": "cast", 7 | "log_dir": "./logs/eth/", 8 | "snapshot_dir": "./ckpt/eth/", 9 | 10 | "data": { 11 | "root": "/media/jacko/SSD/ETH/", 12 | "num_workers": 4, 13 | "npoints": 30000, 14 | "voxel_size": 0.3, 15 | "augment": 1.0 16 | }, 17 | 18 | "model": { 19 | "kpconv_layers": 5, 20 | "voxel_size": 0.2, 21 | "input_dim": 1, 22 | "init_dim": 64, 23 | "kernel_size": 15, 24 | "init_radius": 0.9, 25 | "init_sigma": 0.6, 26 | "neighbor_limits": [73, 72, 80, 85, 81], 27 | 28 | "k": 12, 29 | "spots": 4, 30 | "down_k": 4, 31 | "spot_k": 12, 32 | "dense_neighbors": 6, 33 | "input_dim_f": 512, 34 | "input_dim_c": 2048, 35 | "output_dim": 256, 36 | "hidden_dim": 128, 37 | "desc_dim": 32, 38 | "num_heads": 4, 39 | "dropout": null, 40 | "activation_fn": "relu", 41 | "sigma_d": 4.8, 42 | "sigma_a": 15, 43 | "angle_k": 3, 44 | "reduction_a": "max", 45 | "sigma_c": 1.8, 46 | "seed_threshold": 0.3, 47 | "seed_num": 48, 48 | "blocks": 3, 49 | 50 | "sigma_r": 0.8, 51 | "use_overlap_head": false, 52 | "overlap_threshold": 0.2, 53 | "keypoint_node_threshold": 1.5, 54 | "local_matching_radius": 0.75, 55 | "dual_normalization": true, 56 | 57 | "patch_k": 24, 58 | "num_neighbors": 4, 59 | "learnable_matcher": true, 60 | "filter_layers": 3, 61 | "filter_sigma_d": 1.0, 62 | "ransac_filter": 0.6, 63 | "ransac": true 64 | }, 65 | 66 | "optim": { 67 | "lr": 1e-4, 68 | "step_size": 5, 69 | "weight_decay": 1e-4, 70 | "gamma": 0.9, 71 | "max_epoch": 2, 72 | "clip_grad_norm": 0.5 73 | }, 74 | 75 | "loss": { 76 | "positive_margin": 0.1, 77 | "negative_margin": 1.4, 78 | "positive_optimal": 0.1, 79 | "negative_optimal": 1.4, 80 | "log_scale": 40, 81 | "positive_overlap": 0.2, 82 | 83 | "r_p": 0.45, 84 | "r_n": 0.6, 85 | 86 | "weight_det_loss": 1.0, 87 | "weight_spot_loss": 0.2, 88 | "weight_feat_loss": 0.2, 89 | "weight_desc_loss": 1.0, 90 | "weight_overlap_loss": 1.0, 91 | "weight_corr_loss": 1.0, 92 | "weight_trans_loss": 5.0, 93 | "weight_rot_loss": 20.0, 94 | "pretrain_feat_epochs": 10 95 | }, 96 | 97 | "eval": { 98 | "hit_threshold": 0.3, 99 | "acceptance_overlap": 0.0, 100 | "rre_threshold": 5.0, 101 | "rte_threshold": 2.0 102 | } 103 | } -------------------------------------------------------------------------------- /config/kitti.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 1, 3 | "log_steps": 200, 4 | "save_steps": 0, 5 | "dataset": "kitti", 6 | "runname": "cast", 7 | "log_dir": "./logs/kitti/", 8 | "snapshot_dir": "./ckpt/kitti/", 9 | "data_list": "./data/kitti_list/", 10 | 11 | "data": { 12 | "root": "/media/jacko/SSD/KITTI/velodyne/sequences/", 13 | "num_workers": 4, 14 | "npoints": 30000, 15 | "voxel_size": 0.3, 16 | "augment": 0.5 17 | }, 18 | 19 | "model": { 20 | "kpconv_layers": 5, 21 | "voxel_size": 0.3, 22 | "input_dim": 1, 23 | "init_dim": 64, 24 | "kernel_size": 15, 25 | "init_radius": 0.9, 26 | "init_sigma": 0.6, 27 | "neighbor_limits": [35, 35, 36, 37, 41], 28 | 29 | "k": 12, 30 | "spots": 4, 31 | "down_k": 4, 32 | "spot_k": 12, 33 | "dense_neighbors": 6, 34 | "input_dim_f": 512, 35 | "input_dim_c": 2048, 36 | "output_dim": 256, 37 | "hidden_dim": 128, 38 | "desc_dim": 32, 39 | "num_heads": 4, 40 | "dropout": null, 41 | "activation_fn": "relu", 42 | "sigma_d": 4.8, 43 | "sigma_a": 15, 44 | "angle_k": 3, 45 | "reduction_a": "max", 46 | "sigma_c": 1.8, 47 | "seed_threshold": 0.3, 48 | "seed_num": 48, 49 | "blocks": 3, 50 | 51 | "sigma_r": 1.2, 52 | "use_overlap_head": false, 53 | "overlap_threshold": 0.2, 54 | "keypoint_node_threshold": 1.8, 55 | "local_matching_radius": 0.75, 56 | "dual_normalization": true, 57 | 58 | "patch_k": 24, 59 | "num_neighbors": 4, 60 | "learnable_matcher": true, 61 | "filter_layers": 3, 62 | "filter_sigma_d": 1.0, 63 | "ransac": false 64 | }, 65 | 66 | "optim": { 67 | "lr": 1e-4, 68 | "step_size": 5, 69 | "weight_decay": 1e-4, 70 | "gamma": 0.9, 71 | "max_epoch": 40, 72 | "clip_grad_norm": 0.5 73 | }, 74 | 75 | "loss": { 76 | "positive_margin": 0.1, 77 | "negative_margin": 1.4, 78 | "positive_optimal": 0.1, 79 | "negative_optimal": 1.4, 80 | "log_scale": 40, 81 | "positive_overlap": 0.2, 82 | 83 | "r_p": 0.45, 84 | "r_n": 0.6, 85 | 86 | "weight_det_loss": 1.0, 87 | "weight_spot_loss": 0.2, 88 | "weight_feat_loss": 0.2, 89 | "weight_desc_loss": 1.0, 90 | "weight_overlap_loss": 1.0, 91 | "weight_corr_loss": 1.0, 92 | "weight_trans_loss": 5.0, 93 | "weight_rot_loss": 20.0, 94 | "pretrain_feat_epochs": 10 95 | }, 96 | 97 | "eval": { 98 | "hit_threshold": 0.3, 99 | "acceptance_overlap": 0.0, 100 | "rre_threshold": 5.0, 101 | "rte_threshold": 2.0 102 | } 103 | } -------------------------------------------------------------------------------- /config/nuscenes.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 1, 3 | "log_steps": 200, 4 | "save_steps": 2000, 5 | "dataset": "nuscenes", 6 | "runname": "cast", 7 | "log_dir": "./logs/nuscenes/", 8 | "snapshot_dir": "./ckpt/nuscenes/", 9 | "data_list": "./data/nuscenes_list/", 10 | 11 | "data": { 12 | "root": "/media/jacko/SSD//nuscenes/", 13 | "num_workers": 4, 14 | "npoints": 20000, 15 | "voxel_size": 0.3, 16 | "augment": 0.5 17 | }, 18 | 19 | "model": { 20 | "kpconv_layers": 5, 21 | "voxel_size": 0.3, 22 | "input_dim": 1, 23 | "init_dim": 64, 24 | "kernel_size": 15, 25 | "init_radius": 0.975, 26 | "init_sigma": 0.6, 27 | "neighbor_limits": [19, 26, 36, 45, 46], 28 | 29 | "k": 12, 30 | "spots": 4, 31 | "down_k": 4, 32 | "spot_k": 12, 33 | "dense_neighbors": 6, 34 | "input_dim_f": 512, 35 | "input_dim_c": 2048, 36 | "output_dim": 256, 37 | "hidden_dim": 128, 38 | "desc_dim": 32, 39 | "num_heads": 4, 40 | "dropout": null, 41 | "activation_fn": "relu", 42 | "sigma_c": 1.8, 43 | "seed_threshold": 0.3, 44 | "seed_num": 48, 45 | "blocks": 3, 46 | 47 | "sigma_r": 1.2, 48 | "use_overlap_head": false, 49 | "overlap_threshold": 0.2, 50 | "keypoint_node_threshold": 1.8, 51 | "local_matching_radius": 1.0, 52 | "dual_normalization": true, 53 | 54 | "patch_k": 24, 55 | "num_neighbors": 4, 56 | "learnable_matcher": true, 57 | "filter_layers": 3, 58 | "filter_sigma_d": 1.0, 59 | "ransac": false 60 | }, 61 | 62 | "optim": { 63 | "lr": 1e-4, 64 | "step_size": 5, 65 | "weight_decay": 1e-4, 66 | "gamma": 0.9, 67 | "max_epoch": 3, 68 | "clip_grad_norm": 0.5 69 | }, 70 | 71 | "loss": { 72 | "positive_margin": 0.1, 73 | "negative_margin": 1.4, 74 | "positive_optimal": 0.1, 75 | "negative_optimal": 1.4, 76 | "log_scale": 40, 77 | "positive_overlap": 0.2, 78 | 79 | "r_p": 0.45, 80 | "r_n": 0.6, 81 | 82 | "weight_det_loss": 1.0, 83 | "weight_spot_loss": 0.1, 84 | "weight_feat_loss": 0.2, 85 | "weight_desc_loss": 1.0, 86 | "weight_overlap_loss": 1.0, 87 | "weight_corr_loss": 1.0, 88 | "weight_trans_loss": 5.0, 89 | "weight_rot_loss": 20.0, 90 | "pretrain_feat_epochs": 1 91 | }, 92 | 93 | "eval": { 94 | "hit_threshold": 0.3, 95 | "acceptance_overlap": 0.0, 96 | "rre_threshold": 5.0, 97 | "rte_threshold": 2.0 98 | } 99 | } -------------------------------------------------------------------------------- /data/3dmatch_list/benchmark/3DLoMatch/sun3d-hotel_umd-maryland_hotel3/gt_overlap.log: -------------------------------------------------------------------------------- 1 | 0,1,0.5325 2 | 0,2,0.0683 3 | 0,3,0.0000 4 | 0,4,0.0000 5 | 0,5,0.0000 6 | 0,6,0.0000 7 | 0,7,0.0000 8 | 0,8,0.0000 9 | 0,9,0.0000 10 | 0,10,0.0269 11 | 0,11,0.1804 12 | 0,12,0.3909 13 | 0,13,0.3089 14 | 0,14,0.0004 15 | 0,15,0.0000 16 | 0,16,0.0000 17 | 0,17,0.0000 18 | 0,18,0.0000 19 | 0,19,0.0000 20 | 0,20,0.0000 21 | 0,21,0.0000 22 | 0,22,0.0263 23 | 0,23,0.0105 24 | 0,24,0.0098 25 | 0,25,0.0002 26 | 0,26,0.0000 27 | 0,27,0.1104 28 | 0,28,0.0002 29 | 0,29,0.0000 30 | 0,30,0.0000 31 | 0,31,0.0000 32 | 0,32,0.0000 33 | 0,33,0.0000 34 | 0,34,0.0000 35 | 0,35,0.0000 36 | 0,36,0.0000 37 | 1,2,0.3132 38 | 1,3,0.0537 39 | 1,4,0.0000 40 | 1,5,0.0000 41 | 1,6,0.0000 42 | 1,7,0.0000 43 | 1,8,0.0000 44 | 1,9,0.0000 45 | 1,10,0.0000 46 | 1,11,0.1110 47 | 1,12,0.2444 48 | 1,13,0.1772 49 | 1,14,0.0000 50 | 1,15,0.0000 51 | 1,16,0.0000 52 | 1,17,0.0000 53 | 1,18,0.0000 54 | 1,19,0.0000 55 | 1,20,0.0000 56 | 1,21,0.0000 57 | 1,22,0.0000 58 | 1,23,0.0000 59 | 1,24,0.0000 60 | 1,25,0.0000 61 | 1,26,0.0000 62 | 1,27,0.0436 63 | 1,28,0.0000 64 | 1,29,0.0000 65 | 1,30,0.0000 66 | 1,31,0.0000 67 | 1,32,0.0000 68 | 1,33,0.0000 69 | 1,34,0.0000 70 | 1,35,0.0000 71 | 1,36,0.0000 72 | 2,3,0.4890 73 | 2,4,0.0686 74 | 2,5,0.0600 75 | 2,6,0.0000 76 | 2,7,0.0000 77 | 2,8,0.0000 78 | 2,9,0.0000 79 | 2,10,0.0000 80 | 2,11,0.0000 81 | 2,12,0.0000 82 | 2,13,0.0000 83 | 2,14,0.0000 84 | 2,15,0.0000 85 | 2,16,0.0000 86 | 2,17,0.0000 87 | 2,18,0.0000 88 | 2,19,0.0000 89 | 2,20,0.0000 90 | 2,21,0.0000 91 | 2,22,0.0000 92 | 2,23,0.0000 93 | 2,24,0.0000 94 | 2,25,0.0000 95 | 2,26,0.0000 96 | 2,27,0.0000 97 | 2,28,0.0000 98 | 2,29,0.0000 99 | 2,30,0.0000 100 | 2,31,0.0000 101 | 2,32,0.0000 102 | 2,33,0.0000 103 | 2,34,0.0000 104 | 2,35,0.0000 105 | 2,36,0.0000 106 | 3,4,0.5921 107 | 3,5,0.2586 108 | 3,6,0.0192 109 | 3,7,0.0000 110 | 3,8,0.0000 111 | 3,9,0.0000 112 | 3,10,0.0000 113 | 3,11,0.0000 114 | 3,12,0.0000 115 | 3,13,0.0000 116 | 3,14,0.0000 117 | 3,15,0.0000 118 | 3,16,0.0000 119 | 3,17,0.0000 120 | 3,18,0.0000 121 | 3,19,0.0000 122 | 3,20,0.0000 123 | 3,21,0.0000 124 | 3,22,0.0000 125 | 3,23,0.0000 126 | 3,24,0.0000 127 | 3,25,0.0000 128 | 3,26,0.0000 129 | 3,27,0.0000 130 | 3,28,0.0000 131 | 3,29,0.0000 132 | 3,30,0.0000 133 | 3,31,0.0000 134 | 3,32,0.0000 135 | 3,33,0.0000 136 | 3,34,0.0000 137 | 3,35,0.0000 138 | 3,36,0.0000 139 | 4,5,0.3325 140 | 4,6,0.0784 141 | 4,7,0.0000 142 | 4,8,0.0000 143 | 4,9,0.0000 144 | 4,10,0.0000 145 | 4,11,0.0000 146 | 4,12,0.0000 147 | 4,13,0.0000 148 | 4,14,0.0000 149 | 4,15,0.0000 150 | 4,16,0.0000 151 | 4,17,0.0000 152 | 4,18,0.0000 153 | 4,19,0.0000 154 | 4,20,0.0000 155 | 4,21,0.0000 156 | 4,22,0.0000 157 | 4,23,0.0000 158 | 4,24,0.0000 159 | 4,25,0.0000 160 | 4,26,0.0000 161 | 4,27,0.0000 162 | 4,28,0.0000 163 | 4,29,0.0000 164 | 4,30,0.0000 165 | 4,31,0.0000 166 | 4,32,0.0000 167 | 4,33,0.0000 168 | 4,34,0.0000 169 | 4,35,0.0000 170 | 4,36,0.0000 171 | 5,6,0.4993 172 | 5,7,0.1245 173 | 5,8,0.0000 174 | 5,9,0.0000 175 | 5,10,0.0000 176 | 5,11,0.0000 177 | 5,12,0.0000 178 | 5,13,0.0000 179 | 5,14,0.0000 180 | 5,15,0.0000 181 | 5,16,0.0000 182 | 5,17,0.0000 183 | 5,18,0.0000 184 | 5,19,0.0000 185 | 5,20,0.0000 186 | 5,21,0.0000 187 | 5,22,0.0000 188 | 5,23,0.0000 189 | 5,24,0.0000 190 | 5,25,0.0000 191 | 5,26,0.0000 192 | 5,27,0.0000 193 | 5,28,0.0000 194 | 5,29,0.0000 195 | 5,30,0.0000 196 | 5,31,0.0000 197 | 5,32,0.0000 198 | 5,33,0.0000 199 | 5,34,0.0000 200 | 5,35,0.0000 201 | 5,36,0.0000 202 | 6,7,0.4663 203 | 6,8,0.0432 204 | 6,9,0.0000 205 | 6,10,0.0000 206 | 6,11,0.0000 207 | 6,12,0.0000 208 | 6,13,0.0000 209 | 6,14,0.0000 210 | 6,15,0.0000 211 | 6,16,0.0000 212 | 6,17,0.0000 213 | 6,18,0.0000 214 | 6,19,0.0000 215 | 6,20,0.0000 216 | 6,21,0.0000 217 | 6,22,0.0000 218 | 6,23,0.0000 219 | 6,24,0.0000 220 | 6,25,0.0000 221 | 6,26,0.0000 222 | 6,27,0.0000 223 | 6,28,0.0000 224 | 6,29,0.0000 225 | 6,30,0.0000 226 | 6,31,0.0000 227 | 6,32,0.0000 228 | 6,33,0.0000 229 | 6,34,0.0000 230 | 6,35,0.0000 231 | 6,36,0.0000 232 | 7,8,0.3918 233 | 7,9,0.0272 234 | 7,10,0.0796 235 | 7,11,0.0000 236 | 7,12,0.0000 237 | 7,13,0.0000 238 | 7,14,0.0605 239 | 7,15,0.0651 240 | 7,16,0.0000 241 | 7,17,0.0000 242 | 7,18,0.0000 243 | 7,19,0.0000 244 | 7,20,0.0000 245 | 7,21,0.0000 246 | 7,22,0.0000 247 | 7,23,0.0000 248 | 7,24,0.0000 249 | 7,25,0.0000 250 | 7,26,0.0000 251 | 7,27,0.0000 252 | 7,28,0.0000 253 | 7,29,0.0000 254 | 7,30,0.0000 255 | 7,31,0.0000 256 | 7,32,0.0000 257 | 7,33,0.0000 258 | 7,34,0.0000 259 | 7,35,0.0000 260 | 7,36,0.0000 261 | 8,9,0.4565 262 | 8,10,0.3677 263 | 8,11,0.0000 264 | 8,12,0.0000 265 | 8,13,0.0000 266 | 8,14,0.1760 267 | 8,15,0.4993 268 | 8,16,0.2940 269 | 8,17,0.0577 270 | 8,18,0.0006 271 | 8,19,0.0000 272 | 8,20,0.0000 273 | 8,21,0.0000 274 | 8,22,0.0000 275 | 8,23,0.0000 276 | 8,24,0.0000 277 | 8,25,0.0000 278 | 8,26,0.0000 279 | 8,27,0.0000 280 | 8,28,0.0000 281 | 8,29,0.0000 282 | 8,30,0.0000 283 | 8,31,0.0000 284 | 8,32,0.0000 285 | 8,33,0.0000 286 | 8,34,0.0000 287 | 8,35,0.0000 288 | 8,36,0.0000 289 | 9,10,0.2301 290 | 9,11,0.0546 291 | 9,12,0.0000 292 | 9,13,0.0274 293 | 9,14,0.0741 294 | 9,15,0.2451 295 | 9,16,0.1903 296 | 9,17,0.0907 297 | 9,18,0.0404 298 | 9,19,0.0087 299 | 9,20,0.0000 300 | 9,21,0.0000 301 | 9,22,0.0000 302 | 9,23,0.0000 303 | 9,24,0.0000 304 | 9,25,0.0000 305 | 9,26,0.0000 306 | 9,27,0.0000 307 | 9,28,0.0000 308 | 9,29,0.0000 309 | 9,30,0.0000 310 | 9,31,0.0000 311 | 9,32,0.0000 312 | 9,33,0.0000 313 | 9,34,0.0000 314 | 9,35,0.0000 315 | 9,36,0.0000 316 | 10,11,0.2857 317 | 10,12,0.0900 318 | 10,13,0.1489 319 | 10,14,0.5227 320 | 10,15,0.5015 321 | 10,16,0.4256 322 | 10,17,0.2176 323 | 10,18,0.0341 324 | 10,19,0.0106 325 | 10,20,0.0000 326 | 10,21,0.0000 327 | 10,22,0.0000 328 | 10,23,0.0005 329 | 10,24,0.0018 330 | 10,25,0.0010 331 | 10,26,0.0000 332 | 10,27,0.0011 333 | 10,28,0.0007 334 | 10,29,0.0000 335 | 10,30,0.0000 336 | 10,31,0.0000 337 | 10,32,0.0000 338 | 10,33,0.0000 339 | 10,34,0.0000 340 | 10,35,0.0000 341 | 10,36,0.0000 342 | 11,12,0.6637 343 | 11,13,0.7383 344 | 11,14,0.2992 345 | 11,15,0.0000 346 | 11,16,0.0219 347 | 11,17,0.0121 348 | 11,18,0.0000 349 | 11,19,0.0000 350 | 11,20,0.0000 351 | 11,21,0.0006 352 | 11,22,0.0436 353 | 11,23,0.0238 354 | 11,24,0.0194 355 | 11,25,0.0047 356 | 11,26,0.0028 357 | 11,27,0.1267 358 | 11,28,0.0028 359 | 11,29,0.0000 360 | 11,30,0.0000 361 | 11,31,0.0000 362 | 11,32,0.0000 363 | 11,33,0.0000 364 | 11,34,0.0000 365 | 11,35,0.0000 366 | 11,36,0.0000 367 | 12,13,0.5171 368 | 12,14,0.0847 369 | 12,15,0.0000 370 | 12,16,0.0000 371 | 12,17,0.0000 372 | 12,18,0.0000 373 | 12,19,0.0000 374 | 12,20,0.0000 375 | 12,21,0.0008 376 | 12,22,0.0525 377 | 12,23,0.0422 378 | 12,24,0.0159 379 | 12,25,0.0030 380 | 12,26,0.0073 381 | 12,27,0.1169 382 | 12,28,0.0028 383 | 12,29,0.0093 384 | 12,30,0.0000 385 | 12,31,0.0000 386 | 12,32,0.0000 387 | 12,33,0.0000 388 | 12,34,0.0000 389 | 12,35,0.0000 390 | 12,36,0.0000 391 | 13,14,0.1723 392 | 13,15,0.0003 393 | 13,16,0.0000 394 | 13,17,0.0000 395 | 13,18,0.0000 396 | 13,19,0.0000 397 | 13,20,0.0000 398 | 13,21,0.0017 399 | 13,22,0.0478 400 | 13,23,0.0371 401 | 13,24,0.0210 402 | 13,25,0.0049 403 | 13,26,0.0000 404 | 13,27,0.1237 405 | 13,28,0.0030 406 | 13,29,0.0050 407 | 13,30,0.0000 408 | 13,31,0.0000 409 | 13,32,0.0000 410 | 13,33,0.0000 411 | 13,34,0.0000 412 | 13,35,0.0000 413 | 13,36,0.0000 414 | 14,15,0.3409 415 | 14,16,0.2692 416 | 14,17,0.2338 417 | 14,18,0.0061 418 | 14,19,0.0075 419 | 14,20,0.0079 420 | 14,21,0.0065 421 | 14,22,0.0011 422 | 14,23,0.0043 423 | 14,24,0.0063 424 | 14,25,0.0023 425 | 14,26,0.0000 426 | 14,27,0.0045 427 | 14,28,0.0027 428 | 14,29,0.0000 429 | 14,30,0.0000 430 | 14,31,0.0000 431 | 14,32,0.0000 432 | 14,33,0.0000 433 | 14,34,0.0000 434 | 14,35,0.0000 435 | 14,36,0.0002 436 | 15,16,0.6412 437 | 15,17,0.3784 438 | 15,18,0.0479 439 | 15,19,0.0572 440 | 15,20,0.0059 441 | 15,21,0.0000 442 | 15,22,0.0000 443 | 15,23,0.0000 444 | 15,24,0.0000 445 | 15,25,0.0000 446 | 15,26,0.0000 447 | 15,27,0.0000 448 | 15,28,0.0000 449 | 15,29,0.0000 450 | 15,30,0.0000 451 | 15,31,0.0000 452 | 15,32,0.0000 453 | 15,33,0.0000 454 | 15,34,0.0000 455 | 15,35,0.0000 456 | 15,36,0.0000 457 | 16,17,0.6605 458 | 16,18,0.1551 459 | 16,19,0.1564 460 | 16,20,0.0070 461 | 16,21,0.0037 462 | 16,22,0.0000 463 | 16,23,0.0000 464 | 16,24,0.0000 465 | 16,25,0.0000 466 | 16,26,0.0000 467 | 16,27,0.0000 468 | 16,28,0.0000 469 | 16,29,0.0000 470 | 16,30,0.0000 471 | 16,31,0.0000 472 | 16,32,0.0000 473 | 16,33,0.0000 474 | 16,34,0.0000 475 | 16,35,0.0003 476 | 16,36,0.0065 477 | 17,18,0.4449 478 | 17,19,0.5912 479 | 17,20,0.3402 480 | 17,21,0.0151 481 | 17,22,0.0000 482 | 17,23,0.0000 483 | 17,24,0.0000 484 | 17,25,0.0000 485 | 17,26,0.0000 486 | 17,27,0.0000 487 | 17,28,0.0000 488 | 17,29,0.0000 489 | 17,30,0.0000 490 | 17,31,0.0000 491 | 17,32,0.0000 492 | 17,33,0.0144 493 | 17,34,0.0610 494 | 17,35,0.0111 495 | 17,36,0.0755 496 | 18,19,0.5922 497 | 18,20,0.2836 498 | 18,21,0.0132 499 | 18,22,0.0000 500 | 18,23,0.0000 501 | 18,24,0.0000 502 | 18,25,0.0000 503 | 18,26,0.0000 504 | 18,27,0.0000 505 | 18,28,0.0000 506 | 18,29,0.0000 507 | 18,30,0.0000 508 | 18,31,0.0000 509 | 18,32,0.0000 510 | 18,33,0.0000 511 | 18,34,0.0055 512 | 18,35,0.0089 513 | 18,36,0.0382 514 | 19,20,0.4955 515 | 19,21,0.0496 516 | 19,22,0.0000 517 | 19,23,0.0000 518 | 19,24,0.0000 519 | 19,25,0.0000 520 | 19,26,0.0000 521 | 19,27,0.0000 522 | 19,28,0.0000 523 | 19,29,0.0000 524 | 19,30,0.0000 525 | 19,31,0.0000 526 | 19,32,0.0000 527 | 19,33,0.0274 528 | 19,34,0.0596 529 | 19,35,0.0140 530 | 19,36,0.0950 531 | 20,21,0.3810 532 | 20,22,0.0238 533 | 20,23,0.0040 534 | 20,24,0.0045 535 | 20,25,0.0010 536 | 20,26,0.0000 537 | 20,27,0.0045 538 | 20,28,0.0062 539 | 20,29,0.0000 540 | 20,30,0.0053 541 | 20,31,0.0576 542 | 20,32,0.0000 543 | 20,33,0.1121 544 | 20,34,0.1527 545 | 20,35,0.0143 546 | 20,36,0.1041 547 | 21,22,0.5275 548 | 21,23,0.3011 549 | 21,24,0.1952 550 | 21,25,0.0320 551 | 21,26,0.0016 552 | 21,27,0.0899 553 | 21,28,0.3441 554 | 21,29,0.2048 555 | 21,30,0.2257 556 | 21,31,0.2364 557 | 21,32,0.0316 558 | 21,33,0.0744 559 | 21,34,0.0670 560 | 21,35,0.0050 561 | 21,36,0.0172 562 | 22,23,0.7330 563 | 22,24,0.2303 564 | 22,25,0.0266 565 | 22,26,0.0267 566 | 22,27,0.1673 567 | 22,28,0.3191 568 | 22,29,0.4485 569 | 22,30,0.2956 570 | 22,31,0.1013 571 | 22,32,0.0161 572 | 22,33,0.0044 573 | 22,34,0.0085 574 | 22,35,0.0046 575 | 22,36,0.0075 576 | 23,24,0.2493 577 | 23,25,0.0259 578 | 23,26,0.0426 579 | 23,27,0.2199 580 | 23,28,0.3316 581 | 23,29,0.5661 582 | 23,30,0.3621 583 | 23,31,0.0670 584 | 23,32,0.0031 585 | 23,33,0.0000 586 | 23,34,0.0000 587 | 23,35,0.0000 588 | 23,36,0.0000 589 | 24,25,0.2683 590 | 24,26,0.0284 591 | 24,27,0.4720 592 | 24,28,0.6725 593 | 24,29,0.3175 594 | 24,30,0.1250 595 | 24,31,0.0131 596 | 24,32,0.0000 597 | 24,33,0.0000 598 | 24,34,0.0000 599 | 24,35,0.0000 600 | 24,36,0.0000 601 | 25,26,0.2499 602 | 25,27,0.4611 603 | 25,28,0.2101 604 | 25,29,0.0000 605 | 25,30,0.0000 606 | 25,31,0.0000 607 | 25,32,0.0000 608 | 25,33,0.0000 609 | 25,34,0.0000 610 | 25,35,0.0000 611 | 25,36,0.0000 612 | 26,27,0.7114 613 | 26,28,0.0033 614 | 26,29,0.0000 615 | 26,30,0.0000 616 | 26,31,0.0000 617 | 26,32,0.0000 618 | 26,33,0.0000 619 | 26,34,0.0000 620 | 26,35,0.0000 621 | 26,36,0.0000 622 | 27,28,0.1769 623 | 27,29,0.0335 624 | 27,30,0.0000 625 | 27,31,0.0000 626 | 27,32,0.0000 627 | 27,33,0.0000 628 | 27,34,0.0000 629 | 27,35,0.0000 630 | 27,36,0.0000 631 | 28,29,0.5445 632 | 28,30,0.3390 633 | 28,31,0.1712 634 | 28,32,0.0000 635 | 28,33,0.0000 636 | 28,34,0.0000 637 | 28,35,0.0000 638 | 28,36,0.0000 639 | 29,30,0.6915 640 | 29,31,0.1352 641 | 29,32,0.0507 642 | 29,33,0.0016 643 | 29,34,0.0000 644 | 29,35,0.0000 645 | 29,36,0.0000 646 | 30,31,0.4260 647 | 30,32,0.2778 648 | 30,33,0.1563 649 | 30,34,0.0450 650 | 30,35,0.0000 651 | 30,36,0.0080 652 | 31,32,0.4631 653 | 31,33,0.5335 654 | 31,34,0.3379 655 | 31,35,0.1129 656 | 31,36,0.2214 657 | 32,33,0.6403 658 | 32,34,0.3848 659 | 32,35,0.3697 660 | 32,36,0.3271 661 | 33,34,0.7574 662 | 33,35,0.3338 663 | 33,36,0.4225 664 | 34,35,0.4932 665 | 34,36,0.5212 666 | 35,36,0.5585 667 | -------------------------------------------------------------------------------- /data/3dmatch_list/benchmark/3DMatch/sun3d-hotel_umd-maryland_hotel3/gt_overlap.log: -------------------------------------------------------------------------------- 1 | 0,1,0.5325 2 | 0,2,0.0683 3 | 0,3,0.0000 4 | 0,4,0.0000 5 | 0,5,0.0000 6 | 0,6,0.0000 7 | 0,7,0.0000 8 | 0,8,0.0000 9 | 0,9,0.0000 10 | 0,10,0.0269 11 | 0,11,0.1804 12 | 0,12,0.3909 13 | 0,13,0.3089 14 | 0,14,0.0004 15 | 0,15,0.0000 16 | 0,16,0.0000 17 | 0,17,0.0000 18 | 0,18,0.0000 19 | 0,19,0.0000 20 | 0,20,0.0000 21 | 0,21,0.0000 22 | 0,22,0.0263 23 | 0,23,0.0105 24 | 0,24,0.0098 25 | 0,25,0.0002 26 | 0,26,0.0000 27 | 0,27,0.1104 28 | 0,28,0.0002 29 | 0,29,0.0000 30 | 0,30,0.0000 31 | 0,31,0.0000 32 | 0,32,0.0000 33 | 0,33,0.0000 34 | 0,34,0.0000 35 | 0,35,0.0000 36 | 0,36,0.0000 37 | 1,2,0.3132 38 | 1,3,0.0537 39 | 1,4,0.0000 40 | 1,5,0.0000 41 | 1,6,0.0000 42 | 1,7,0.0000 43 | 1,8,0.0000 44 | 1,9,0.0000 45 | 1,10,0.0000 46 | 1,11,0.1110 47 | 1,12,0.2444 48 | 1,13,0.1772 49 | 1,14,0.0000 50 | 1,15,0.0000 51 | 1,16,0.0000 52 | 1,17,0.0000 53 | 1,18,0.0000 54 | 1,19,0.0000 55 | 1,20,0.0000 56 | 1,21,0.0000 57 | 1,22,0.0000 58 | 1,23,0.0000 59 | 1,24,0.0000 60 | 1,25,0.0000 61 | 1,26,0.0000 62 | 1,27,0.0436 63 | 1,28,0.0000 64 | 1,29,0.0000 65 | 1,30,0.0000 66 | 1,31,0.0000 67 | 1,32,0.0000 68 | 1,33,0.0000 69 | 1,34,0.0000 70 | 1,35,0.0000 71 | 1,36,0.0000 72 | 2,3,0.4890 73 | 2,4,0.0686 74 | 2,5,0.0600 75 | 2,6,0.0000 76 | 2,7,0.0000 77 | 2,8,0.0000 78 | 2,9,0.0000 79 | 2,10,0.0000 80 | 2,11,0.0000 81 | 2,12,0.0000 82 | 2,13,0.0000 83 | 2,14,0.0000 84 | 2,15,0.0000 85 | 2,16,0.0000 86 | 2,17,0.0000 87 | 2,18,0.0000 88 | 2,19,0.0000 89 | 2,20,0.0000 90 | 2,21,0.0000 91 | 2,22,0.0000 92 | 2,23,0.0000 93 | 2,24,0.0000 94 | 2,25,0.0000 95 | 2,26,0.0000 96 | 2,27,0.0000 97 | 2,28,0.0000 98 | 2,29,0.0000 99 | 2,30,0.0000 100 | 2,31,0.0000 101 | 2,32,0.0000 102 | 2,33,0.0000 103 | 2,34,0.0000 104 | 2,35,0.0000 105 | 2,36,0.0000 106 | 3,4,0.5921 107 | 3,5,0.2586 108 | 3,6,0.0192 109 | 3,7,0.0000 110 | 3,8,0.0000 111 | 3,9,0.0000 112 | 3,10,0.0000 113 | 3,11,0.0000 114 | 3,12,0.0000 115 | 3,13,0.0000 116 | 3,14,0.0000 117 | 3,15,0.0000 118 | 3,16,0.0000 119 | 3,17,0.0000 120 | 3,18,0.0000 121 | 3,19,0.0000 122 | 3,20,0.0000 123 | 3,21,0.0000 124 | 3,22,0.0000 125 | 3,23,0.0000 126 | 3,24,0.0000 127 | 3,25,0.0000 128 | 3,26,0.0000 129 | 3,27,0.0000 130 | 3,28,0.0000 131 | 3,29,0.0000 132 | 3,30,0.0000 133 | 3,31,0.0000 134 | 3,32,0.0000 135 | 3,33,0.0000 136 | 3,34,0.0000 137 | 3,35,0.0000 138 | 3,36,0.0000 139 | 4,5,0.3325 140 | 4,6,0.0784 141 | 4,7,0.0000 142 | 4,8,0.0000 143 | 4,9,0.0000 144 | 4,10,0.0000 145 | 4,11,0.0000 146 | 4,12,0.0000 147 | 4,13,0.0000 148 | 4,14,0.0000 149 | 4,15,0.0000 150 | 4,16,0.0000 151 | 4,17,0.0000 152 | 4,18,0.0000 153 | 4,19,0.0000 154 | 4,20,0.0000 155 | 4,21,0.0000 156 | 4,22,0.0000 157 | 4,23,0.0000 158 | 4,24,0.0000 159 | 4,25,0.0000 160 | 4,26,0.0000 161 | 4,27,0.0000 162 | 4,28,0.0000 163 | 4,29,0.0000 164 | 4,30,0.0000 165 | 4,31,0.0000 166 | 4,32,0.0000 167 | 4,33,0.0000 168 | 4,34,0.0000 169 | 4,35,0.0000 170 | 4,36,0.0000 171 | 5,6,0.4993 172 | 5,7,0.1245 173 | 5,8,0.0000 174 | 5,9,0.0000 175 | 5,10,0.0000 176 | 5,11,0.0000 177 | 5,12,0.0000 178 | 5,13,0.0000 179 | 5,14,0.0000 180 | 5,15,0.0000 181 | 5,16,0.0000 182 | 5,17,0.0000 183 | 5,18,0.0000 184 | 5,19,0.0000 185 | 5,20,0.0000 186 | 5,21,0.0000 187 | 5,22,0.0000 188 | 5,23,0.0000 189 | 5,24,0.0000 190 | 5,25,0.0000 191 | 5,26,0.0000 192 | 5,27,0.0000 193 | 5,28,0.0000 194 | 5,29,0.0000 195 | 5,30,0.0000 196 | 5,31,0.0000 197 | 5,32,0.0000 198 | 5,33,0.0000 199 | 5,34,0.0000 200 | 5,35,0.0000 201 | 5,36,0.0000 202 | 6,7,0.4663 203 | 6,8,0.0432 204 | 6,9,0.0000 205 | 6,10,0.0000 206 | 6,11,0.0000 207 | 6,12,0.0000 208 | 6,13,0.0000 209 | 6,14,0.0000 210 | 6,15,0.0000 211 | 6,16,0.0000 212 | 6,17,0.0000 213 | 6,18,0.0000 214 | 6,19,0.0000 215 | 6,20,0.0000 216 | 6,21,0.0000 217 | 6,22,0.0000 218 | 6,23,0.0000 219 | 6,24,0.0000 220 | 6,25,0.0000 221 | 6,26,0.0000 222 | 6,27,0.0000 223 | 6,28,0.0000 224 | 6,29,0.0000 225 | 6,30,0.0000 226 | 6,31,0.0000 227 | 6,32,0.0000 228 | 6,33,0.0000 229 | 6,34,0.0000 230 | 6,35,0.0000 231 | 6,36,0.0000 232 | 7,8,0.3918 233 | 7,9,0.0272 234 | 7,10,0.0796 235 | 7,11,0.0000 236 | 7,12,0.0000 237 | 7,13,0.0000 238 | 7,14,0.0605 239 | 7,15,0.0651 240 | 7,16,0.0000 241 | 7,17,0.0000 242 | 7,18,0.0000 243 | 7,19,0.0000 244 | 7,20,0.0000 245 | 7,21,0.0000 246 | 7,22,0.0000 247 | 7,23,0.0000 248 | 7,24,0.0000 249 | 7,25,0.0000 250 | 7,26,0.0000 251 | 7,27,0.0000 252 | 7,28,0.0000 253 | 7,29,0.0000 254 | 7,30,0.0000 255 | 7,31,0.0000 256 | 7,32,0.0000 257 | 7,33,0.0000 258 | 7,34,0.0000 259 | 7,35,0.0000 260 | 7,36,0.0000 261 | 8,9,0.4565 262 | 8,10,0.3677 263 | 8,11,0.0000 264 | 8,12,0.0000 265 | 8,13,0.0000 266 | 8,14,0.1760 267 | 8,15,0.4993 268 | 8,16,0.2940 269 | 8,17,0.0577 270 | 8,18,0.0006 271 | 8,19,0.0000 272 | 8,20,0.0000 273 | 8,21,0.0000 274 | 8,22,0.0000 275 | 8,23,0.0000 276 | 8,24,0.0000 277 | 8,25,0.0000 278 | 8,26,0.0000 279 | 8,27,0.0000 280 | 8,28,0.0000 281 | 8,29,0.0000 282 | 8,30,0.0000 283 | 8,31,0.0000 284 | 8,32,0.0000 285 | 8,33,0.0000 286 | 8,34,0.0000 287 | 8,35,0.0000 288 | 8,36,0.0000 289 | 9,10,0.2301 290 | 9,11,0.0546 291 | 9,12,0.0000 292 | 9,13,0.0274 293 | 9,14,0.0741 294 | 9,15,0.2451 295 | 9,16,0.1903 296 | 9,17,0.0907 297 | 9,18,0.0404 298 | 9,19,0.0087 299 | 9,20,0.0000 300 | 9,21,0.0000 301 | 9,22,0.0000 302 | 9,23,0.0000 303 | 9,24,0.0000 304 | 9,25,0.0000 305 | 9,26,0.0000 306 | 9,27,0.0000 307 | 9,28,0.0000 308 | 9,29,0.0000 309 | 9,30,0.0000 310 | 9,31,0.0000 311 | 9,32,0.0000 312 | 9,33,0.0000 313 | 9,34,0.0000 314 | 9,35,0.0000 315 | 9,36,0.0000 316 | 10,11,0.2857 317 | 10,12,0.0900 318 | 10,13,0.1489 319 | 10,14,0.5227 320 | 10,15,0.5015 321 | 10,16,0.4256 322 | 10,17,0.2176 323 | 10,18,0.0341 324 | 10,19,0.0106 325 | 10,20,0.0000 326 | 10,21,0.0000 327 | 10,22,0.0000 328 | 10,23,0.0005 329 | 10,24,0.0018 330 | 10,25,0.0010 331 | 10,26,0.0000 332 | 10,27,0.0011 333 | 10,28,0.0007 334 | 10,29,0.0000 335 | 10,30,0.0000 336 | 10,31,0.0000 337 | 10,32,0.0000 338 | 10,33,0.0000 339 | 10,34,0.0000 340 | 10,35,0.0000 341 | 10,36,0.0000 342 | 11,12,0.6637 343 | 11,13,0.7383 344 | 11,14,0.2992 345 | 11,15,0.0000 346 | 11,16,0.0219 347 | 11,17,0.0121 348 | 11,18,0.0000 349 | 11,19,0.0000 350 | 11,20,0.0000 351 | 11,21,0.0006 352 | 11,22,0.0436 353 | 11,23,0.0238 354 | 11,24,0.0194 355 | 11,25,0.0047 356 | 11,26,0.0028 357 | 11,27,0.1267 358 | 11,28,0.0028 359 | 11,29,0.0000 360 | 11,30,0.0000 361 | 11,31,0.0000 362 | 11,32,0.0000 363 | 11,33,0.0000 364 | 11,34,0.0000 365 | 11,35,0.0000 366 | 11,36,0.0000 367 | 12,13,0.5171 368 | 12,14,0.0847 369 | 12,15,0.0000 370 | 12,16,0.0000 371 | 12,17,0.0000 372 | 12,18,0.0000 373 | 12,19,0.0000 374 | 12,20,0.0000 375 | 12,21,0.0008 376 | 12,22,0.0525 377 | 12,23,0.0422 378 | 12,24,0.0159 379 | 12,25,0.0030 380 | 12,26,0.0073 381 | 12,27,0.1169 382 | 12,28,0.0028 383 | 12,29,0.0093 384 | 12,30,0.0000 385 | 12,31,0.0000 386 | 12,32,0.0000 387 | 12,33,0.0000 388 | 12,34,0.0000 389 | 12,35,0.0000 390 | 12,36,0.0000 391 | 13,14,0.1723 392 | 13,15,0.0003 393 | 13,16,0.0000 394 | 13,17,0.0000 395 | 13,18,0.0000 396 | 13,19,0.0000 397 | 13,20,0.0000 398 | 13,21,0.0017 399 | 13,22,0.0478 400 | 13,23,0.0371 401 | 13,24,0.0210 402 | 13,25,0.0049 403 | 13,26,0.0000 404 | 13,27,0.1237 405 | 13,28,0.0030 406 | 13,29,0.0050 407 | 13,30,0.0000 408 | 13,31,0.0000 409 | 13,32,0.0000 410 | 13,33,0.0000 411 | 13,34,0.0000 412 | 13,35,0.0000 413 | 13,36,0.0000 414 | 14,15,0.3409 415 | 14,16,0.2692 416 | 14,17,0.2338 417 | 14,18,0.0061 418 | 14,19,0.0075 419 | 14,20,0.0079 420 | 14,21,0.0065 421 | 14,22,0.0011 422 | 14,23,0.0043 423 | 14,24,0.0063 424 | 14,25,0.0023 425 | 14,26,0.0000 426 | 14,27,0.0045 427 | 14,28,0.0027 428 | 14,29,0.0000 429 | 14,30,0.0000 430 | 14,31,0.0000 431 | 14,32,0.0000 432 | 14,33,0.0000 433 | 14,34,0.0000 434 | 14,35,0.0000 435 | 14,36,0.0002 436 | 15,16,0.6412 437 | 15,17,0.3784 438 | 15,18,0.0479 439 | 15,19,0.0572 440 | 15,20,0.0059 441 | 15,21,0.0000 442 | 15,22,0.0000 443 | 15,23,0.0000 444 | 15,24,0.0000 445 | 15,25,0.0000 446 | 15,26,0.0000 447 | 15,27,0.0000 448 | 15,28,0.0000 449 | 15,29,0.0000 450 | 15,30,0.0000 451 | 15,31,0.0000 452 | 15,32,0.0000 453 | 15,33,0.0000 454 | 15,34,0.0000 455 | 15,35,0.0000 456 | 15,36,0.0000 457 | 16,17,0.6605 458 | 16,18,0.1551 459 | 16,19,0.1564 460 | 16,20,0.0070 461 | 16,21,0.0037 462 | 16,22,0.0000 463 | 16,23,0.0000 464 | 16,24,0.0000 465 | 16,25,0.0000 466 | 16,26,0.0000 467 | 16,27,0.0000 468 | 16,28,0.0000 469 | 16,29,0.0000 470 | 16,30,0.0000 471 | 16,31,0.0000 472 | 16,32,0.0000 473 | 16,33,0.0000 474 | 16,34,0.0000 475 | 16,35,0.0003 476 | 16,36,0.0065 477 | 17,18,0.4449 478 | 17,19,0.5912 479 | 17,20,0.3402 480 | 17,21,0.0151 481 | 17,22,0.0000 482 | 17,23,0.0000 483 | 17,24,0.0000 484 | 17,25,0.0000 485 | 17,26,0.0000 486 | 17,27,0.0000 487 | 17,28,0.0000 488 | 17,29,0.0000 489 | 17,30,0.0000 490 | 17,31,0.0000 491 | 17,32,0.0000 492 | 17,33,0.0144 493 | 17,34,0.0610 494 | 17,35,0.0111 495 | 17,36,0.0755 496 | 18,19,0.5922 497 | 18,20,0.2836 498 | 18,21,0.0132 499 | 18,22,0.0000 500 | 18,23,0.0000 501 | 18,24,0.0000 502 | 18,25,0.0000 503 | 18,26,0.0000 504 | 18,27,0.0000 505 | 18,28,0.0000 506 | 18,29,0.0000 507 | 18,30,0.0000 508 | 18,31,0.0000 509 | 18,32,0.0000 510 | 18,33,0.0000 511 | 18,34,0.0055 512 | 18,35,0.0089 513 | 18,36,0.0382 514 | 19,20,0.4955 515 | 19,21,0.0496 516 | 19,22,0.0000 517 | 19,23,0.0000 518 | 19,24,0.0000 519 | 19,25,0.0000 520 | 19,26,0.0000 521 | 19,27,0.0000 522 | 19,28,0.0000 523 | 19,29,0.0000 524 | 19,30,0.0000 525 | 19,31,0.0000 526 | 19,32,0.0000 527 | 19,33,0.0274 528 | 19,34,0.0596 529 | 19,35,0.0140 530 | 19,36,0.0950 531 | 20,21,0.3810 532 | 20,22,0.0238 533 | 20,23,0.0040 534 | 20,24,0.0045 535 | 20,25,0.0010 536 | 20,26,0.0000 537 | 20,27,0.0045 538 | 20,28,0.0062 539 | 20,29,0.0000 540 | 20,30,0.0053 541 | 20,31,0.0576 542 | 20,32,0.0000 543 | 20,33,0.1121 544 | 20,34,0.1527 545 | 20,35,0.0143 546 | 20,36,0.1041 547 | 21,22,0.5275 548 | 21,23,0.3011 549 | 21,24,0.1952 550 | 21,25,0.0320 551 | 21,26,0.0016 552 | 21,27,0.0899 553 | 21,28,0.3441 554 | 21,29,0.2048 555 | 21,30,0.2257 556 | 21,31,0.2364 557 | 21,32,0.0316 558 | 21,33,0.0744 559 | 21,34,0.0670 560 | 21,35,0.0050 561 | 21,36,0.0172 562 | 22,23,0.7330 563 | 22,24,0.2303 564 | 22,25,0.0266 565 | 22,26,0.0267 566 | 22,27,0.1673 567 | 22,28,0.3191 568 | 22,29,0.4485 569 | 22,30,0.2956 570 | 22,31,0.1013 571 | 22,32,0.0161 572 | 22,33,0.0044 573 | 22,34,0.0085 574 | 22,35,0.0046 575 | 22,36,0.0075 576 | 23,24,0.2493 577 | 23,25,0.0259 578 | 23,26,0.0426 579 | 23,27,0.2199 580 | 23,28,0.3316 581 | 23,29,0.5661 582 | 23,30,0.3621 583 | 23,31,0.0670 584 | 23,32,0.0031 585 | 23,33,0.0000 586 | 23,34,0.0000 587 | 23,35,0.0000 588 | 23,36,0.0000 589 | 24,25,0.2683 590 | 24,26,0.0284 591 | 24,27,0.4720 592 | 24,28,0.6725 593 | 24,29,0.3175 594 | 24,30,0.1250 595 | 24,31,0.0131 596 | 24,32,0.0000 597 | 24,33,0.0000 598 | 24,34,0.0000 599 | 24,35,0.0000 600 | 24,36,0.0000 601 | 25,26,0.2499 602 | 25,27,0.4611 603 | 25,28,0.2101 604 | 25,29,0.0000 605 | 25,30,0.0000 606 | 25,31,0.0000 607 | 25,32,0.0000 608 | 25,33,0.0000 609 | 25,34,0.0000 610 | 25,35,0.0000 611 | 25,36,0.0000 612 | 26,27,0.7114 613 | 26,28,0.0033 614 | 26,29,0.0000 615 | 26,30,0.0000 616 | 26,31,0.0000 617 | 26,32,0.0000 618 | 26,33,0.0000 619 | 26,34,0.0000 620 | 26,35,0.0000 621 | 26,36,0.0000 622 | 27,28,0.1769 623 | 27,29,0.0335 624 | 27,30,0.0000 625 | 27,31,0.0000 626 | 27,32,0.0000 627 | 27,33,0.0000 628 | 27,34,0.0000 629 | 27,35,0.0000 630 | 27,36,0.0000 631 | 28,29,0.5445 632 | 28,30,0.3390 633 | 28,31,0.1712 634 | 28,32,0.0000 635 | 28,33,0.0000 636 | 28,34,0.0000 637 | 28,35,0.0000 638 | 28,36,0.0000 639 | 29,30,0.6915 640 | 29,31,0.1352 641 | 29,32,0.0507 642 | 29,33,0.0016 643 | 29,34,0.0000 644 | 29,35,0.0000 645 | 29,36,0.0000 646 | 30,31,0.4260 647 | 30,32,0.2778 648 | 30,33,0.1563 649 | 30,34,0.0450 650 | 30,35,0.0000 651 | 30,36,0.0080 652 | 31,32,0.4631 653 | 31,33,0.5335 654 | 31,34,0.3379 655 | 31,35,0.1129 656 | 31,36,0.2214 657 | 32,33,0.6403 658 | 32,34,0.3848 659 | 32,35,0.3697 660 | 32,36,0.3271 661 | 33,34,0.7574 662 | 33,35,0.3338 663 | 33,36,0.4225 664 | 34,35,0.4932 665 | 34,36,0.5212 666 | 35,36,0.5585 667 | -------------------------------------------------------------------------------- /data/3dmatch_list/train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/data/3dmatch_list/train.pkl -------------------------------------------------------------------------------- /data/3dmatch_list/val.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/data/3dmatch_list/val.pkl -------------------------------------------------------------------------------- /data/eth_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import open3d as o3d 4 | import torch 5 | from torch.utils.data import Dataset 6 | from models.utils import generate_rand_rotm 7 | 8 | 9 | def read_eth_bin_voxel(filename, npoints=None, voxel_size=None) -> np.ndarray: 10 | scan = np.array(o3d.io.read_point_cloud(filename).voxel_down_sample(voxel_size=voxel_size).points) 11 | scan = scan.astype('float32') 12 | if npoints is None: 13 | return scan 14 | 15 | if scan.shape[0] >= npoints: 16 | sample_idx = np.random.choice(scan.shape[0], npoints, replace=False) 17 | scan = scan[sample_idx, :] 18 | return scan 19 | 20 | 21 | class ETHDataset(Dataset): 22 | def __init__(self, root, npoints, voxel_size, augment=0.0): 23 | super(ETHDataset, self).__init__() 24 | self.root = root 25 | self.npoints = npoints 26 | self.voxel_size = voxel_size 27 | self.augment = augment 28 | self.dataset = self.make_dataset() 29 | 30 | def make_dataset(self): 31 | scenes = [ 32 | 'gazebo_summer', 33 | 'gazebo_winter', 34 | 'wood_autmn', 35 | 'wood_summer', 36 | ] 37 | dataset = [] 38 | 39 | for seq in scenes: 40 | folder = os.path.join(self.root, seq) 41 | pairs, trans = self.read_transformation_log(folder) 42 | for pair, rela_pose in zip(pairs, trans): 43 | src_fn = os.path.join(folder, 'Hokuyo_%d.ply'%pair[0]) 44 | dst_fn = os.path.join(folder, 'Hokuyo_%d.ply'%pair[1]) 45 | data_dict = {'points1': src_fn, 'points2': dst_fn, 'Tr': rela_pose} 46 | dataset.append(data_dict) 47 | 48 | return dataset 49 | 50 | def __getitem__(self, index): 51 | data_dict = self.dataset[index] 52 | src_points = read_eth_bin_voxel(data_dict['points1'], self.npoints, self.voxel_size) 53 | dst_points = read_eth_bin_voxel(data_dict['points2'], self.npoints, self.voxel_size) 54 | Tr = data_dict['Tr'] 55 | 56 | if self.augment > 0: 57 | cross = np.random.uniform(low=dst_points.min(0), high=dst_points.max(0)) 58 | if cross[0] < 0.1 and cross[0] > 0: cross[0] = 0.1 59 | if cross[0] > -0.1 and cross[0] < 0: cross[0] = -0.1 60 | if cross[1] < 0.1 and cross[1] > 0: cross[1] = 0.1 61 | if cross[1] > -0.1 and cross[1] < 0: cross[1] = -0.1 62 | crop = dst_points[:, 0] / cross[0] + dst_points[:, 1] / cross[1] 63 | src_points = dst_points[crop < 1] 64 | Tr = np.eye(4, dtype=np.float32) 65 | 66 | if np.random.rand() < self.augment: 67 | print(src_points.shape) 68 | aug_T = np.eye(4, dtype=np.float32) 69 | aug_T[:2, 3] = np.random.random([2]) * 2. - 1. 70 | aug_T[:3,:3] = generate_rand_rotm(0., 0.) 71 | src_points = src_points @ aug_T[:3,:3].T + aug_T[:3,3:].T # dst_points 72 | Tr = aug_T 73 | 74 | src_points = torch.from_numpy(src_points) 75 | dst_points = torch.from_numpy(dst_points) 76 | Tr = torch.from_numpy(Tr) 77 | return src_points, dst_points, Tr 78 | 79 | def __len__(self): 80 | return len(self.dataset) 81 | 82 | def read_transformation_log(self, seq:str): 83 | with open(os.path.join(seq, 'gt.log')) as f: 84 | lines = f.readlines() 85 | lines = [line.strip() for line in lines] 86 | pairs, trans = list(), list() 87 | num_pairs = len(lines) // 5 88 | for i in range(num_pairs): 89 | line_id = i * 5 90 | pairs.append(lines[line_id].split()) 91 | pairs[-1] = [int(pairs[-1][0]), int(pairs[-1][1])] 92 | transform = list() 93 | for j in range(1, 5): 94 | transform.append(lines[line_id + j].split()) 95 | trans.append(np.array(transform, dtype=np.float32)) 96 | return np.array(pairs), np.array(trans).astype(np.float32) -------------------------------------------------------------------------------- /data/gen_kitti_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from D3Feat, which can generate 3 | point cloud pairs for kitti and refine the ground-truth transformations. 4 | If you want to regenerate these results, please modify the directories below. 5 | """ 6 | import os 7 | import glob 8 | import numpy as np 9 | import open3d as o3d 10 | 11 | 12 | class KITTIDataset: 13 | DATA_FILES = { 14 | 'train': ['00', '01', '02', '03', '04', '05'], 15 | 'val': ['06', '07'], 16 | 'test': ['08', '09', '10'] 17 | } 18 | def __init__(self): 19 | self.root = '/media/jacko/SSD/KITTI/' 20 | self.save = './kitti_list/' 21 | 22 | self.gt = [self.read_groundtruth(i) for i in range(11)] 23 | self.files = {'train': [], 'val': [], 'test': []} 24 | 25 | self.prepare_kitti_ply('test') 26 | self.prepare_kitti_ply('train') 27 | self.prepare_kitti_ply('val') 28 | 29 | self.refine_poses('test') 30 | self.refine_poses('train') 31 | self.refine_poses('val') 32 | 33 | 34 | def prepare_kitti_ply(self, split='train'): 35 | for dirname in self.DATA_FILES[split]: 36 | drive_id = int(dirname) 37 | fnames = glob.glob(self.root + '/velodyne/sequences/%02d/velodyne/*.bin' % drive_id) 38 | assert len(fnames) > 0, f"Make sure that the path {self.root} has data {dirname}" 39 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 40 | 41 | all_odo = np.genfromtxt(self.root + '/results/%02d.txt' % drive_id) 42 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 43 | Ts = all_pos[:, :3, 3] 44 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3)) ** 2 45 | pdist = np.sqrt(pdist.sum(-1)) 46 | more_than_10 = pdist > 10 47 | curr_time = inames[0] 48 | while curr_time in inames: 49 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0] 50 | if len(next_time) == 0: 51 | curr_time += 1 52 | else: 53 | next_time = next_time[0] + curr_time - 1 54 | 55 | if next_time in inames: 56 | self.files[split].append((drive_id, curr_time, next_time)) 57 | curr_time = next_time + 1 58 | 59 | if split == 'train': 60 | self.num_train = len(self.files[split]) 61 | print("Num_train", self.num_train) 62 | elif split == 'val': 63 | self.num_val = len(self.files[split]) 64 | print("Num_val", self.num_val) 65 | else: 66 | # pair (8, 15, 58) is wrong. 67 | self.files[split].remove((8, 15, 58)) 68 | self.num_test = len(self.files[split]) 69 | print("Num_test", self.num_test) 70 | 71 | def odometry_to_positions(self, odometry): 72 | T_w_cam0 = odometry.reshape(3, 4) 73 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 74 | return T_w_cam0 75 | 76 | def read_groundtruth(self, seq): 77 | gt = np.genfromtxt(self.root + '/results/%02d.txt'%seq).reshape([-1, 3, 4]) 78 | gt = np.concatenate([gt,np.repeat(np.array([[[0,0,0,1.]]]),gt.shape[0],axis=0)],axis=1) 79 | # these transformations are under the left camera's coordinate system 80 | calibf = open(self.root + '/sequences/%02d/calib.txt'%seq) 81 | for t in calibf.readlines(): # [P0, P1, P2, P3, Tr] 82 | if t[0]=='T': t = t[4:]; break # Tr (camera0->LiDAR) 83 | calibf.close(); calib = np.eye(4, 4) 84 | calib[:-1, :] = np.array([float(c) for c in t.split(' ')]).reshape([3, 4]) 85 | return np.linalg.inv(calib) @ gt @ calib 86 | 87 | def refine_poses(self, split='train'): 88 | file = open(self.save + split + '.txt', 'w') 89 | for data_dict in self.files[split]: 90 | folder = os.path.join(self.root + '/velodyne/sequences/%02d/velodyne/'%data_dict[0]) 91 | src_fn = os.path.join(folder, '%06d.bin'%data_dict[1]) 92 | dst_fn = os.path.join(folder, '%06d.bin'%data_dict[2]) 93 | 94 | ply1 = o3d.geometry.PointCloud() 95 | ply2 = o3d.geometry.PointCloud() 96 | xyz1 = np.fromfile(src_fn, dtype=np.float32, count=-1).reshape([-1,4])[:, :3] 97 | xyz2 = np.fromfile(dst_fn, dtype=np.float32, count=-1).reshape([-1,4])[:, :3] 98 | 99 | ply1.points = o3d.utility.Vector3dVector(xyz1) 100 | ply2.points = o3d.utility.Vector3dVector(xyz2) 101 | ply1 = ply1.voxel_down_sample(voxel_size=0.05) 102 | ply2 = ply2.voxel_down_sample(voxel_size=0.05) 103 | 104 | trans = np.linalg.inv(self.gt[data_dict[0]][data_dict[1]]) 105 | trans = trans @ self.gt[data_dict[0]][data_dict[2]] 106 | t = o3d.pipelines.registration.registration_icp( 107 | ply2, ply1, 0.2, trans, # refine the transformation matrix via ICP 108 | o3d.pipelines.registration.TransformationEstimationPointToPoint(), 109 | o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=500)) 110 | trans = np.array(t.transformation, dtype=np.float32)[:3].reshape(-1) 111 | item = '%d %d %d '%(data_dict[0], data_dict[1], data_dict[2]) 112 | item = item + ' '.join(str(k) for k in trans) + '\n' 113 | file.write(item) 114 | print(item) 115 | file.close() 116 | 117 | 118 | if __name__ == '__main__': 119 | KITTIDataset() 120 | -------------------------------------------------------------------------------- /data/indoor_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import open3d as o3d 5 | from scipy.spatial.transform import Rotation as R 6 | 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | def read_3dmatch_bin_voxel(filename, npoints=None, voxel_size=None) -> np.ndarray: 12 | scan = np.array(o3d.io.read_point_cloud(filename).voxel_down_sample(voxel_size=voxel_size).points) 13 | scan = scan.astype('float32') 14 | if npoints is None: 15 | return scan 16 | 17 | if scan.shape[0] >= npoints: 18 | sample_idx = np.random.choice(scan.shape[0], npoints, replace=False) 19 | scan = scan[sample_idx, :] 20 | return scan 21 | 22 | 23 | class IndoorDataset(Dataset): 24 | def __init__(self, root, seqs, npoints, voxel_size, data_list, augment=0.0): 25 | super(IndoorDataset, self).__init__() 26 | self.root = root 27 | self.seqs = seqs 28 | self.npoints = npoints 29 | self.voxel_size = voxel_size 30 | self.augment = augment 31 | self.data_list = data_list 32 | self.dataset = self.make_dataset() 33 | 34 | def make_dataset(self): 35 | last_row = np.zeros((1,4), dtype=np.float32) 36 | last_row[:,3] = 1.0 37 | dataset = [] 38 | 39 | fn_pair_poses = os.path.join(self.data_list, self.seqs + '.pkl') 40 | fn_pair_poses = open(fn_pair_poses, 'rb') 41 | metadata = pickle.load(fn_pair_poses) 42 | fn_pair_poses.close() 43 | 44 | for i in range(len(metadata)): 45 | folder = os.path.join(self.root, 'train', metadata[i]['seq'], 'fragments') 46 | src_fn = os.path.join(folder, 'cloud_bin_%d.ply'%metadata[i]['ref_id']) 47 | dst_fn = os.path.join(folder, 'cloud_bin_%d.ply'%metadata[i]['src_id']) 48 | rela_pose = metadata[i]['transform'].astype(np.float32) 49 | rela_pose = np.concatenate([rela_pose, last_row], axis = 0) 50 | data_dict = {'points1': src_fn, 'points2': dst_fn, 'Tr': rela_pose} 51 | dataset.append(data_dict) 52 | 53 | return dataset 54 | 55 | def __getitem__(self, index): 56 | data_dict = self.dataset[index] 57 | src_points = read_3dmatch_bin_voxel(data_dict['points1'], self.npoints, self.voxel_size) 58 | dst_points = read_3dmatch_bin_voxel(data_dict['points2'], self.npoints, self.voxel_size) 59 | Tr = data_dict['Tr'] 60 | 61 | if np.random.rand() < self.augment: 62 | aug_T = np.eye(4, dtype=np.float32) 63 | aug_T[:3,:3] = self.sample_random_rotation() 64 | dst_points = dst_points @ aug_T[:3,:3] 65 | Tr = Tr @ aug_T 66 | 67 | src_points = torch.from_numpy(src_points) 68 | dst_points = torch.from_numpy(dst_points) 69 | Tr = torch.from_numpy(Tr) 70 | return src_points, dst_points, Tr 71 | 72 | def sample_random_rotation(self, pitch_scale=np.pi/3., roll_scale=np.pi/4.): 73 | roll = np.random.uniform(-roll_scale, roll_scale) 74 | pitch = np.random.uniform(-pitch_scale, pitch_scale) 75 | r = R.from_euler('xyz', [roll, pitch, 0.], degrees=False) 76 | return r.as_matrix() 77 | 78 | def __len__(self): 79 | return len(self.dataset) 80 | 81 | 82 | class IndoorTestDataset(Dataset): 83 | def __init__(self, root, seqs, npoints, voxel_size, data_list, non_consecutive=False): 84 | super(IndoorTestDataset, self).__init__() 85 | self.root = root 86 | self.seqs = seqs 87 | self.npoints = npoints 88 | self.data_list = data_list 89 | self.voxel_size = voxel_size 90 | self.non_consecutive = non_consecutive 91 | self.dataset = self.make_dataset() 92 | 93 | def make_dataset(self): 94 | dataset = [] 95 | benchmark = os.path.join(self.data_list, 'benchmark', self.seqs) 96 | scenes = os.listdir(benchmark) 97 | scenes.sort() 98 | 99 | for seq in scenes: 100 | metaseq = os.path.join(benchmark, seq) 101 | pairs, trans = self.read_transformation_log(metaseq) 102 | pairs, covs = self.read_covariance_log(metaseq) 103 | for pair, rela_pose, cov in zip(pairs, trans, covs): 104 | folder = os.path.join(self.root, 'test', seq, 'fragments') 105 | src_fn = os.path.join(folder, 'cloud_bin_%d.ply'%pair[0]) 106 | dst_fn = os.path.join(folder, 'cloud_bin_%d.ply'%pair[1]) 107 | data_dict = {'points1': src_fn, 'points2': dst_fn, 'Tr': rela_pose, 'Cov': cov} 108 | dataset.append(data_dict) 109 | 110 | return dataset 111 | 112 | def __getitem__(self, index): 113 | data_dict = self.dataset[index] 114 | src_points = read_3dmatch_bin_voxel(data_dict['points1'], self.npoints, self.voxel_size) 115 | dst_points = read_3dmatch_bin_voxel(data_dict['points2'], self.npoints, self.voxel_size) 116 | Tr, Cov = data_dict['Tr'], data_dict['Cov'] 117 | 118 | src_points = torch.from_numpy(src_points) 119 | dst_points = torch.from_numpy(dst_points) 120 | Tr = torch.from_numpy(Tr) 121 | Cov = torch.from_numpy(Cov) 122 | return src_points, dst_points, Tr, Cov 123 | 124 | def read_transformation_log(self, seq:str): 125 | with open(os.path.join(seq, 'gt.log')) as f: 126 | lines = f.readlines() 127 | lines = [line.strip() for line in lines] 128 | pairs, trans = list(), list() 129 | num_pairs = len(lines) // 5 130 | for i in range(num_pairs): 131 | line_id = i * 5 132 | if self.non_consecutive: 133 | item = lines[line_id].split() 134 | if int(item[1]) - int(item[0]) == 1: continue 135 | pairs.append(lines[line_id].split()) 136 | pairs[-1] = [int(pairs[-1][0]), int(pairs[-1][1])] 137 | transform = list() 138 | for j in range(1, 5): 139 | transform.append(lines[line_id + j].split()) 140 | trans.append(np.array(transform, dtype=np.float32)) 141 | return np.array(pairs), np.array(trans).astype(np.float32) 142 | 143 | def read_covariance_log(self, seq:str): 144 | with open(os.path.join(seq, 'gt.info')) as f: 145 | lines = f.readlines() 146 | lines = [line.strip() for line in lines] 147 | pairs, cov = list(), list() 148 | num_pairs = len(lines) // 7 149 | for i in range(num_pairs): 150 | line_id = i * 7 151 | if self.non_consecutive: 152 | item = lines[line_id].split() 153 | if int(item[1]) - int(item[0]) == 1: continue 154 | pairs.append(lines[line_id].split()) 155 | pairs[-1] = [int(pairs[-1][0]), int(pairs[-1][1])] 156 | covariance = list() 157 | for j in range(1, 7): 158 | covariance.append(lines[line_id + j].split()) 159 | cov.append(np.array(covariance, dtype=np.float32)) 160 | return np.array(pairs), np.array(cov).astype(np.float32) 161 | 162 | def __len__(self): 163 | return len(self.dataset) -------------------------------------------------------------------------------- /data/kitti_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | import os 6 | import numpy as np 7 | import MinkowskiEngine as ME 8 | 9 | from models.utils import generate_rand_rotm 10 | 11 | 12 | def read_kitti_bin_voxel(filename, npoints=None, voxel_size=None) -> np.ndarray: 13 | scan = np.fromfile(filename, dtype=np.float32, count=-1).reshape([-1,4]) 14 | scan = scan[:,:3] 15 | 16 | if voxel_size is not None: 17 | _, sel = ME.utils.sparse_quantize(scan / voxel_size, return_index=True) 18 | scan = scan[sel] 19 | if npoints is None: 20 | return scan.astype('float32') 21 | 22 | dist = np.linalg.norm(scan, ord=2, axis=1) 23 | N = scan.shape[0] 24 | if N >= npoints: 25 | sample_idx = np.argsort(dist)[:npoints] 26 | scan = scan[sample_idx, :].astype('float32') 27 | dist = dist[sample_idx] 28 | scan = scan[np.logical_and(dist > 3., scan[:, 2] > -3.5)] 29 | return scan 30 | 31 | class KittiDataset(Dataset): 32 | def __init__(self, root, seqs, npoints, voxel_size, data_list, augment=0.0): 33 | super(KittiDataset, self).__init__() 34 | self.root = root 35 | self.seqs = seqs 36 | self.npoints = npoints 37 | self.voxel_size = voxel_size 38 | self.augment = augment 39 | self.data_list = data_list 40 | self.dataset = self.make_dataset() 41 | 42 | def make_dataset(self): 43 | last_row = np.zeros((1,4), dtype=np.float32) 44 | last_row[:,3] = 1.0 45 | dataset = [] 46 | 47 | fn_pair_poses = os.path.join(self.data_list, self.seqs + '.txt') 48 | metadata = np.genfromtxt(fn_pair_poses).reshape([-1, 15]) 49 | for i in range(metadata.shape[0]): 50 | folder = os.path.join(self.root, '%02d'%metadata[i][0], 'velodyne') 51 | src_fn = os.path.join(folder, '%06d.bin'%metadata[i][1]) 52 | dst_fn = os.path.join(folder, '%06d.bin'%metadata[i][2]) 53 | rela_pose = metadata[i][3:].reshape(3,4).astype(np.float32) 54 | rela_pose = np.concatenate([rela_pose, last_row], axis = 0) 55 | data_dict = {'points1': src_fn, 'points2': dst_fn, 'Tr': rela_pose} 56 | dataset.append(data_dict) 57 | 58 | return dataset 59 | 60 | def __getitem__(self, index): 61 | data_dict = self.dataset[index] 62 | src_points = read_kitti_bin_voxel(data_dict['points1'], self.npoints, self.voxel_size) 63 | dst_points = read_kitti_bin_voxel(data_dict['points2'], self.npoints, self.voxel_size) 64 | Tr = data_dict['Tr'] 65 | 66 | if np.random.rand() < self.augment: 67 | aug_T = np.eye(4, dtype=np.float32) 68 | aug_T[:3,:3] = generate_rand_rotm(1.0, 1.0) 69 | dst_points = dst_points @ aug_T[:3,:3] 70 | Tr = Tr @ aug_T 71 | 72 | src_points = torch.from_numpy(src_points) 73 | dst_points = torch.from_numpy(dst_points) 74 | Tr = torch.from_numpy(Tr) 75 | return src_points, dst_points, Tr 76 | 77 | def __len__(self): 78 | return len(self.dataset) -------------------------------------------------------------------------------- /data/nuscenes_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import MinkowskiEngine as ME 5 | from torch.utils.data import Dataset 6 | from models.utils import generate_rand_rotm 7 | 8 | def read_nuscenes_bin_voxel(filename, npoints=None, voxel_size=None) -> np.ndarray: 9 | scan = np.fromfile(filename, dtype=np.float32, count=-1).reshape([-1,5]) 10 | scan = scan[:,:3] 11 | 12 | if voxel_size is not None: 13 | _, sel = ME.utils.sparse_quantize(scan / voxel_size, return_index=True) 14 | scan = scan[sel] 15 | if npoints is None: 16 | return scan.astype('float32') 17 | 18 | dist = np.linalg.norm(scan, ord=2, axis=1) 19 | N = scan.shape[0] 20 | if N >= npoints: 21 | sample_idx = np.argsort(dist)[:npoints] 22 | scan = scan[sample_idx, :].astype('float32') 23 | dist = dist[sample_idx] 24 | 25 | scan = scan[np.logical_and(dist > 3., scan[:, 2] > -6.)] 26 | return scan 27 | 28 | 29 | class NuscenesDataset(Dataset): 30 | def __init__(self, root, seqs, npoints, voxel_size, data_list, augment=0.0): 31 | super(NuscenesDataset, self).__init__() 32 | 33 | self.root = root 34 | self.seqs = seqs 35 | self.npoints = npoints 36 | self.voxel_size = voxel_size 37 | self.data_list = data_list 38 | self.augment = augment 39 | self.dataset = self.make_dataset() 40 | 41 | def make_dataset(self): 42 | last_row = np.zeros((1,4), dtype=np.float32) 43 | last_row[:,3] = 1.0 44 | dataset = [] 45 | 46 | if True:#for seq in self.seqs: 47 | data_root = self.root 48 | '''if seq == 'test': 49 | data_root = os.path.join(self.root, 'v1.0-test') 50 | else: 51 | data_root = os.path.join(self.root, 'v1.0-trainval')''' 52 | fn_pair_poses = os.path.join(self.data_list, self.seqs + '.txt') 53 | with open(fn_pair_poses, 'r') as f: 54 | lines = f.readlines() 55 | for line in lines: 56 | data_dict = {} 57 | line = line.strip(' \n').split(' ') 58 | src_fn = os.path.join(data_root, line[0]) 59 | dst_fn = os.path.join(data_root, line[1]) 60 | values = [] 61 | for i in range(2, len(line)): 62 | values.append(float(line[i])) 63 | values = np.array(values).astype(np.float32) 64 | rela_pose = values.reshape(3,4) 65 | rela_pose = np.concatenate([rela_pose, last_row], axis = 0) 66 | data_dict['points1'] = src_fn 67 | data_dict['points2'] = dst_fn 68 | data_dict['Tr'] = rela_pose 69 | dataset.append(data_dict) 70 | 71 | return dataset 72 | 73 | def __getitem__(self, index): 74 | data_dict = self.dataset[index] 75 | src_points = read_nuscenes_bin_voxel(data_dict['points1'], self.npoints, self.voxel_size) 76 | dst_points = read_nuscenes_bin_voxel(data_dict['points2'], self.npoints, self.voxel_size) 77 | Tr = np.linalg.inv(data_dict['Tr']) 78 | 79 | if np.random.rand() < self.augment: 80 | aug_T = np.eye(4, dtype=np.float32) 81 | aug_T[:3,:3] = generate_rand_rotm(1.0, 1.0) 82 | dst_points = dst_points @ aug_T[:3,:3] 83 | Tr = Tr @ aug_T # np.linalg.inv(aug_T) 84 | 85 | src_points = torch.from_numpy(src_points) 86 | dst_points = torch.from_numpy(dst_points) 87 | Tr = torch.from_numpy(Tr) 88 | return src_points, dst_points, Tr 89 | 90 | def __len__(self): 91 | return len(self.dataset) -------------------------------------------------------------------------------- /demo_3dmatch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import open3d as o3d 7 | from munch import munchify 8 | from engine.trainer import EpochBasedTrainer 9 | from data.indoor_data import IndoorDataset, IndoorTestDataset 10 | from models.models.cast import CAST 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--split", default='train', choices=['train', 'val', 'test']) 15 | parser.add_argument("--benchmark", default='3DMatch', choices=['3DMatch', '3DLoMatch']) 16 | parser.add_argument("--load_pretrained", default='cast-epoch-05', type=str) 17 | parser.add_argument("--id", default=0, type=int) 18 | 19 | _args = parser.parse_args() 20 | 21 | 22 | class Engine(EpochBasedTrainer): 23 | def __init__(self, cfg): 24 | super().__init__(cfg) 25 | self.train_dataset = IndoorDataset(cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, 0.0) 26 | self.val_dataset = IndoorDataset(cfg.data.root, 'val', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, 0.0) 27 | self.test_dataset = IndoorTestDataset(cfg.data.root, _args.benchmark, cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 28 | 29 | self.model = CAST(cfg.model).cuda() 30 | 31 | 32 | if __name__ == "__main__": 33 | with open('./config/3dmatch.json', 'r') as cfg: 34 | args = json.load(cfg) 35 | args = munchify(args) 36 | 37 | tester = Engine(args) 38 | tester.set_eval_mode() 39 | tester.load_snapshot(_args.load_pretrained) 40 | 41 | if _args.split == 'train': 42 | data = tester.train_dataset[_args.id] 43 | data_dict = tester.train_dataset.dataset[_args.id] 44 | elif _args.split == 'val': 45 | data = tester.val_dataset[_args.id] 46 | data_dict = tester.val_dataset.dataset[_args.id] 47 | else: 48 | data = tester.test_dataset[_args.id] 49 | data_dict = tester.test_dataset.dataset[_args.id] 50 | 51 | gt_trans = data[2].numpy() 52 | ref_cloud = o3d.io.read_point_cloud(data_dict['points1']) 53 | src_cloud = o3d.io.read_point_cloud(data_dict['points2']) 54 | 55 | custom_yellow = np.asarray([[221., 184., 34.]]) / 255.0 56 | custom_blue = np.asarray([[9., 151., 247.]]) / 255.0 57 | ref_cloud.paint_uniform_color(custom_blue.T) 58 | src_cloud.paint_uniform_color(custom_yellow.T) 59 | 60 | data = [v.cuda().unsqueeze(0) for v in data] 61 | with torch.no_grad(): 62 | output_dict = tester.model(*data) 63 | trans = output_dict['refined_transform'].cpu().numpy() 64 | 65 | ref_cloud.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=50)) 66 | src_cloud.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.3, max_nn=50)) 67 | 68 | src_cloud.transform(trans) 69 | vis = o3d.visualization.Visualizer() 70 | vis.create_window() 71 | view_option: o3d.visualization.ViewControl = vis.get_view_control() 72 | render_option: o3d.visualization.RenderOption = vis.get_render_option() 73 | render_option.background_color = np.array([0, 0, 0]) 74 | render_option.background_color = np.array([1, 1, 1]) 75 | render_option.point_size = 3.0 76 | vis.add_geometry(ref_cloud) 77 | vis.add_geometry(src_cloud) 78 | view_option.set_front([0., -0.3, -1.]) 79 | view_option.set_up([0., 0., 1.]) 80 | vis.run() -------------------------------------------------------------------------------- /demo_outdoor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import open3d as o3d 6 | from munch import munchify 7 | 8 | from data.kitti_data import KittiDataset 9 | from data.nuscenes_data import NuscenesDataset 10 | from engine.trainer import EpochBasedTrainer 11 | 12 | from models.models.cast import CAST 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--dataset", default='kitti', choices=['kitti', 'nuscenes']) 17 | parser.add_argument("--split", default='train', choices=['train', 'val', 'test']) 18 | parser.add_argument("--mode", required=True, choices=['keypts', 'corr', 'reg']) 19 | parser.add_argument("--load_pretrained", default='cast-epoch-39', type=str) 20 | parser.add_argument("--id", default=0, type=int) 21 | 22 | _args = parser.parse_args() 23 | 24 | 25 | class Engine(EpochBasedTrainer): 26 | def __init__(self, cfg): 27 | super().__init__(cfg) 28 | if cfg.dataset == 'kitti': 29 | self.train_dataset = KittiDataset(cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 30 | self.val_dataset = KittiDataset(cfg.data.root, 'val', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 31 | self.test_dataset = KittiDataset(cfg.data.root, 'test', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 32 | elif cfg.dataset == 'nuscenes': 33 | self.train_dataset = NuscenesDataset(cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 34 | self.val_dataset = NuscenesDataset(cfg.data.root, 'val', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 35 | self.test_dataset = NuscenesDataset(cfg.data.root, 'test', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 36 | 37 | self.model = CAST(cfg.model).cuda() 38 | 39 | def keypoints_to_spheres(keypoints, radius=0.03, color=[1.0, 0.25, 0.0]): 40 | spheres = o3d.geometry.TriangleMesh() 41 | for keypoint in keypoints: 42 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius) 43 | sphere.translate(keypoint) 44 | spheres += sphere 45 | spheres.paint_uniform_color(color) 46 | return spheres 47 | 48 | 49 | if __name__ == "__main__": 50 | with open('./config/' + _args.dataset + '.json', 'r') as cfg: 51 | args = json.load(cfg) 52 | args = munchify(args) 53 | 54 | tester = Engine(args) 55 | tester.set_eval_mode() 56 | tester.load_snapshot(_args.load_pretrained) 57 | #e.g., tester.load_snapshot('cast-epoch-39') 58 | # kitti: train 2,11,730,895,1249, test 153,240,360 59 | if _args.split == 'train': 60 | data = tester.train_dataset[_args.id] 61 | elif _args.split == 'val': 62 | data = tester.val_dataset[_args.id] 63 | else: 64 | data = tester.test_dataset[_args.id] 65 | gt_trans = data[2].numpy() 66 | 67 | custom_yellow = np.asarray([[221., 184., 34.]]) / 255.0 68 | custom_blue = np.asarray([[9., 151., 247.]]) / 255.0 69 | custom_green = np.asarray([[17., 238., 194.]]) / 255.0 70 | custom_red = np.asarray([[204., 51., 51.]]) / 255.0 71 | 72 | ref_cloud = o3d.geometry.PointCloud() 73 | ref_cloud.points = o3d.utility.Vector3dVector(data[0].numpy()) 74 | 75 | src_cloud = o3d.geometry.PointCloud() 76 | src_cloud.points = o3d.utility.Vector3dVector(data[1].numpy()) 77 | 78 | 79 | ref_cloud.paint_uniform_color(custom_blue.T) 80 | src_cloud.paint_uniform_color(custom_yellow.T) 81 | 82 | data = [v.cuda().unsqueeze(0) for v in data] 83 | with torch.no_grad(): 84 | output_dict = tester.model(*data) 85 | trans = output_dict['refined_transform'].cpu().numpy() 86 | corres_xyz = output_dict['corres'].cpu().numpy() 87 | corr_weight = output_dict['corr_confidence'].cpu().numpy() 88 | corres_xyz = corres_xyz[corr_weight > np.max(corr_weight)* 0.3] 89 | 90 | 91 | if _args.mode == 'reg': 92 | src_cloud.transform(trans) 93 | elif _args.mode == 'corr': 94 | lines = list() 95 | points = np.reshape(corres_xyz + np.array([[0,0,0,0,0,-15.]]), [-1,3]) 96 | lines = np.arange(points.shape[0], dtype=np.int32).reshape([-1,2]) 97 | colors = corres_xyz[:, 3:] @ gt_trans[:3, :3].T + gt_trans[:3, 3:].T - corres_xyz[:, :3] 98 | colors = np.asarray(np.linalg.norm(colors, axis=-1) < 0.6, dtype=np.float64).reshape([-1, 1]) 99 | colors = colors * custom_green + (1. - colors) * custom_red 100 | line_set = o3d.geometry.LineSet( 101 | points=o3d.utility.Vector3dVector(points), 102 | lines=o3d.utility.Vector2iVector(lines), 103 | ) 104 | line_set.colors = o3d.utility.Vector3dVector(colors) 105 | src_cloud.translate(np.array([[0., 0., -15.]]).T) 106 | else: 107 | ref_keypts = output_dict['ref_kpts'].squeeze(0).cpu().numpy() 108 | src_keypts = output_dict['src_kpts'].squeeze(0).cpu().numpy() 109 | ref_scores = 1. / np.clip(output_dict['ref_sigma'].cpu().numpy(), 1e-8, None) 110 | src_scores = 1. / np.clip(output_dict['src_sigma'].cpu().numpy(), 1e-8, None) 111 | ref_scores = (ref_scores - np.min(ref_scores)) / (np.max(ref_scores) - np.min(ref_scores)) 112 | src_scores = (src_scores - np.min(src_scores)) / (np.max(src_scores) - np.min(src_scores)) 113 | 114 | if _args.dataset == 'kitti': 115 | threshold, radius = 0.2, 0.3 116 | else: threshold, radius = 0.25, 0.5 117 | 118 | ref_keypts = keypoints_to_spheres(ref_keypts[ref_scores > threshold], radius, custom_red[0]) 119 | src_keypts = keypoints_to_spheres(src_keypts[src_scores > threshold], radius, custom_red[0]) 120 | 121 | src_cloud.translate(np.array([[160., 0., 0.]]).T) 122 | src_keypts.translate(np.array([[160., 0., 0.]]).T) 123 | 124 | 125 | vis = o3d.visualization.Visualizer() 126 | vis.create_window() 127 | view_option: o3d.visualization.ViewControl = vis.get_view_control() 128 | render_option: o3d.visualization.RenderOption = vis.get_render_option() 129 | render_option.background_color = np.array([0, 0, 0]) 130 | render_option.background_color = np.array([1, 1, 1]) 131 | render_option.point_size = 3.0 132 | vis.add_geometry(ref_cloud) 133 | vis.add_geometry(src_cloud) 134 | 135 | if _args.mode == 'corr': 136 | vis.add_geometry(line_set) 137 | if _args.dataset == 'kitti': 138 | view_option.set_front([-0.1, -1., 0.7]) 139 | view_option.set_up([0., 0., 1.]) 140 | else: 141 | view_option.set_front([-1., -0.1, 0.6]) 142 | view_option.set_up([1., 0., 0.]) 143 | elif _args.mode == 'keypts': 144 | vis.add_geometry(ref_keypts) 145 | vis.add_geometry(src_keypts) 146 | 147 | view_option.set_zoom(0.4) 148 | vis.run() -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/engine/__init__.py -------------------------------------------------------------------------------- /engine/evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from models.utils import apply_transform 5 | from scipy.spatial.transform import Rotation as R 6 | 7 | 8 | class Evaluator(nn.Module): 9 | def __init__(self, cfg): 10 | super(Evaluator, self).__init__() 11 | self.hit_threshold = cfg.hit_threshold 12 | self.acceptance_overlap = cfg.acceptance_overlap 13 | self.inlier_ratio_threshold = 0.05 14 | if 'rre_threshold' in cfg.keys(): 15 | self.rre_threshold = cfg.rre_threshold 16 | self.rte_threshold = cfg.rte_threshold 17 | self.scene = 'outdoor' 18 | else: 19 | self.inlier_distance_threshold = 0.1 20 | self.rmse_threshold = cfg.rmse_threshold 21 | self.scene = 'indoor' 22 | 23 | @torch.no_grad() 24 | def keypoint_repeatability(self, keypoints1: torch.Tensor, keypoints2: torch.Tensor, transform: torch.Tensor): 25 | keypoints2 = apply_transform(keypoints2, transform) # (M, 3) 26 | dist = torch.norm(keypoints1.unsqueeze(1) - keypoints2.unsqueeze(0), dim=-1) # (M, N) 27 | forward_KR = torch.min(dist, dim=-1)[0].lt(self.hit_threshold).float().mean() 28 | backward_KR = torch.min(dist, dim=-2)[0].lt(self.hit_threshold).float().mean() 29 | return (forward_KR + backward_KR) / 2. 30 | 31 | @torch.no_grad() 32 | def evaluate_inlier_ratio(self, corr_points: torch.Tensor, transform: torch.Tensor): 33 | src_corr_points = apply_transform(corr_points[..., 3:], transform) 34 | dist = torch.norm(src_corr_points - corr_points[..., :3], dim=-1) 35 | return torch.lt(dist, self.inlier_distance_threshold).float().mean() 36 | 37 | @torch.no_grad() 38 | def compute_rmse(self, transform, covariance, estimated_transform): 39 | relative_transform = torch.matmul(torch.linalg.inv(transform), estimated_transform) 40 | q = R.from_matrix(relative_transform[:3, :3].cpu().numpy()).as_quat() 41 | q = torch.from_numpy(q[:3]).float().to(transform.device) 42 | er = torch.cat([relative_transform[:3, 3], q], dim=-1) 43 | er = er.view(1, 6) @ covariance @ er.view(6, 1) / covariance[0, 0] 44 | return torch.sqrt(er) 45 | 46 | @torch.no_grad() 47 | def transform_error(self, gt_transforms: torch.Tensor, transforms: torch.Tensor): 48 | rre = 0.5 * ((transforms[:3, :3].T @ gt_transforms[:3, :3]).trace() - 1.0) 49 | rre = 180.0 * torch.arccos(rre.clamp(-1., 1.)) / np.pi 50 | rte = torch.norm(gt_transforms[:3, 3] - transforms[:3, 3], dim=-1) 51 | return rte, rre 52 | 53 | @torch.no_grad() 54 | def evaluate_coarse_inlier_ratio(self, output_dict): 55 | ref_length_c = output_dict['ref_feats_c'].shape[0] 56 | src_length_c = output_dict['src_feats_c'].shape[0] 57 | masks = torch.gt(output_dict['gt_patch_corr_overlaps'], self.acceptance_overlap) 58 | gt_node_corr_indices = output_dict['gt_patch_corr_indices'][masks] 59 | gt_node_corr_map = torch.zeros([ref_length_c, src_length_c], device=masks.device) 60 | gt_node_corr_map[gt_node_corr_indices[:, 0], gt_node_corr_indices[:, 1]] = 1.0 61 | return gt_node_corr_map[output_dict['ref_patch_corr_indices'], output_dict['src_patch_corr_indices']].mean() 62 | 63 | @torch.no_grad() 64 | def evaluate_registration(self, output_dict): 65 | if self.scene == 'outdoor': 66 | rte, rre = self.transform_error(output_dict['gt_transform'], output_dict['transform']) 67 | recall = torch.logical_and(torch.lt(rre, self.rre_threshold), torch.lt(rte, self.rte_threshold)).float() 68 | return rre, rte, recall 69 | else: 70 | rmse = self.compute_rmse(output_dict['gt_transform'], output_dict['covariance'], output_dict['transform']) 71 | return rmse, rmse < self.rmse_threshold 72 | 73 | def forward(self, output_dict): 74 | PIR = self.evaluate_coarse_inlier_ratio(output_dict) 75 | te, re = self.transform_error(output_dict['gt_transform'], output_dict['transform']) 76 | rte, rre = self.transform_error(output_dict['gt_transform'], output_dict['refined_transform']) 77 | #KR = self.keypoint_repeatability(output_dict['ref_kpts'], output_dict['src_kpts'], output_dict['gt_transform']) 78 | results = {'PIR': PIR} #{'KR': KR, 'PIR': PIR} 79 | 80 | if not self.training: 81 | results['PMR'] = PIR.gt(0.2).float() 82 | 83 | if self.scene == 'indoor': 84 | indices = output_dict['corr_confidence'] > 0.1 * torch.max(output_dict['corr_confidence']) 85 | results['IR'] = self.evaluate_inlier_ratio(output_dict['corres'][indices], output_dict['gt_transform']) 86 | FMR = results['IR'].gt(self.inlier_ratio_threshold).float() 87 | results['RE'] = re; results['TE'] = te 88 | results['RRE'] = rre; results['RTE'] = rte 89 | if not self.training: 90 | results['FMR'] = FMR 91 | if 'covariance' in output_dict.keys(): 92 | covariance = output_dict['covariance'] 93 | gt_transform = output_dict['gt_transform'] 94 | #pred_transform = output_dict['transform'] 95 | pred_transform = output_dict['refined_transform'] 96 | results['rmse'] = self.compute_rmse(gt_transform, covariance, pred_transform) 97 | results['RR'] = results['rmse'].lt(self.rmse_threshold).float() 98 | if results['RR'] < 0.5: 99 | results.pop('RTE'); results.pop('RRE') 100 | results.pop('TE'); results.pop('RE') 101 | 102 | else: 103 | registration_recall = torch.lt(rre, self.rre_threshold) & torch.lt(rte, self.rte_threshold) 104 | if self.training or registration_recall.item(): 105 | results['RE'] = re; results['TE'] = te 106 | results['RRE'] = rre; results['RTE'] = rte 107 | if not self.training: 108 | results['RR'] = registration_recall.float() 109 | 110 | return results 111 | -------------------------------------------------------------------------------- /engine/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict 4 | import torch.nn.functional as F 5 | from models.utils import pairwise_distance, apply_transform 6 | 7 | 8 | class ProbChamferLoss(nn.Module): 9 | def __init__(self): 10 | super(ProbChamferLoss, self).__init__() 11 | 12 | def forward(self, output_dict: Dict[str, torch.Tensor]): 13 | keypoints1 = output_dict['ref_kpts'] 14 | keypoints2 = apply_transform(output_dict['src_kpts'], output_dict['gt_transform']) 15 | diff = torch.norm(keypoints1.unsqueeze(1) - keypoints2.unsqueeze(0), dim=-1) 16 | 17 | if output_dict['ref_sigma'] is None or output_dict['src_sigma'] is None: 18 | min_dist_forward, _ = torch.min(diff, dim=-1) 19 | forward_loss = min_dist_forward.mean() 20 | min_dist_backward, _ = torch.min(diff, dim=-2) 21 | backward_loss = min_dist_backward.mean() 22 | else: 23 | min_dist_forward, min_dist_forward_I = torch.min(diff, dim=-1) 24 | selected_sigma_2 = output_dict['src_sigma'].index_select(0, min_dist_forward_I) 25 | sigma_forward = (output_dict['ref_sigma'] + selected_sigma_2) / 2. 26 | forward_loss = (sigma_forward.log() + min_dist_forward / sigma_forward).mean() 27 | 28 | min_dist_backward, min_dist_backward_I = torch.min(diff, dim=-2) 29 | selected_sigma_1 = output_dict['ref_sigma'].index_select(0, min_dist_backward_I) 30 | sigma_backward = (output_dict['src_sigma'] + selected_sigma_1) / 2. 31 | backward_loss = (sigma_backward.log() + min_dist_backward / sigma_backward).mean() 32 | 33 | return forward_loss + backward_loss 34 | 35 | 36 | class WeightedCircleLoss(nn.Module): 37 | def __init__(self, pos_margin, neg_margin, pos_optimal, neg_optimal, log_scale, bilateral=True): 38 | super(WeightedCircleLoss, self).__init__() 39 | self.pos_margin = pos_margin 40 | self.neg_margin = neg_margin 41 | self.pos_optimal = pos_optimal 42 | self.neg_optimal = neg_optimal 43 | self.log_scale = log_scale 44 | self.bilateral = bilateral 45 | 46 | def forward(self, pos_masks:torch.Tensor, neg_masks:torch.Tensor, feat_dists, pos_scales=None, neg_scales=None): 47 | with torch.no_grad(): 48 | row_masks = (torch.gt(pos_masks.sum(-1), 0) & torch.gt(neg_masks.sum(-1), 0)).nonzero().squeeze() 49 | if self.bilateral: 50 | col_masks = (torch.gt(pos_masks.sum(-2), 0) & torch.gt(neg_masks.sum(-2), 0)).nonzero().squeeze() 51 | 52 | pos_weights = torch.relu(feat_dists - 1e5 * (~pos_masks).float() - self.pos_optimal) 53 | neg_weights = torch.relu(self.neg_optimal - feat_dists - 1e5 * (~neg_masks).float()) 54 | if pos_scales is not None: pos_weights = pos_weights * pos_scales 55 | if neg_scales is not None: neg_weights = neg_weights * neg_scales 56 | 57 | loss_pos_row = torch.logsumexp(self.log_scale * (feat_dists - self.pos_margin) * pos_weights, dim=-1) 58 | loss_neg_row = torch.logsumexp(self.log_scale * (self.neg_margin - feat_dists) * neg_weights, dim=-1) 59 | loss_row = F.softplus(loss_pos_row + loss_neg_row) / self.log_scale 60 | loss_row = torch.index_select(loss_row, 0, row_masks) 61 | 62 | if not self.bilateral: 63 | return loss_row.mean() 64 | 65 | loss_pos_col = torch.logsumexp(self.log_scale * (feat_dists - self.pos_margin) * pos_weights, dim=-2) 66 | loss_neg_col = torch.logsumexp(self.log_scale * (self.neg_margin - feat_dists) * neg_weights, dim=-2) 67 | loss_col = F.softplus(loss_pos_col + loss_neg_col) / self.log_scale 68 | loss_col = torch.index_select(loss_col, 0, col_masks) 69 | 70 | return (loss_row.mean() + loss_col.mean()) / 2. 71 | 72 | 73 | class SpotMatchingLoss(nn.Module): 74 | def __init__(self, cfg): 75 | super(SpotMatchingLoss, self).__init__() 76 | self.positive_overlap = cfg.positive_overlap 77 | 78 | def forward(self, output_dict): 79 | coarse_matching_scores = output_dict['coarse_matching_scores'] 80 | gt_node_corr_indices = output_dict['gt_patch_corr_indices'] 81 | gt_node_corr_overlaps = output_dict['gt_patch_corr_overlaps'] 82 | 83 | with torch.no_grad(): 84 | overlaps = torch.zeros_like(coarse_matching_scores) 85 | overlaps[gt_node_corr_indices[:, 0], gt_node_corr_indices[:, 1]] = gt_node_corr_overlaps 86 | pos_masks = torch.gt(overlaps, self.positive_overlap) 87 | 88 | row_mask = torch.zeros_like(overlaps, dtype=torch.bool) 89 | idx = overlaps.max(dim=1, keepdim=True)[1] 90 | row_mask.scatter_(1, idx, True) 91 | col_mask = torch.zeros_like(overlaps, dtype=torch.bool) 92 | idx = overlaps.max(dim=0, keepdim=True)[1] 93 | col_mask.scatter_(0, idx, True) 94 | pos_masks = overlaps * (pos_masks & row_mask & col_mask).float() 95 | 96 | if 'spot_matching_scores' in output_dict.keys(): 97 | matching_scores = output_dict['spot_matching_scores'] 98 | loss = -torch.log(matching_scores + 1e-8) * pos_masks.unsqueeze(0) 99 | loss = torch.sum(loss) / pos_masks.sum() / matching_scores.shape[0] 100 | 101 | coarse_loss = -torch.log(coarse_matching_scores + 1e-8) * pos_masks 102 | coarse_loss = torch.sum(coarse_loss) / pos_masks.sum() 103 | 104 | if 'ref_patch_overlap' in output_dict.keys(): 105 | gt_ref_patch_overlap = 1. - pos_masks.sum(-1).gt(0).float() 106 | gt_src_patch_overlap = 1. - pos_masks.sum(-2).gt(0).float() 107 | gt_ref_patch_overlap = gt_ref_patch_overlap / (gt_ref_patch_overlap.sum() + 1e-8) 108 | gt_src_patch_overlap = gt_src_patch_overlap / (gt_src_patch_overlap.sum() + 1e-8) 109 | loss_ref_ov = -torch.log(1. - output_dict['ref_patch_overlap'] + 1e-8) * gt_ref_patch_overlap 110 | loss_src_ov = -torch.log(1. - output_dict['src_patch_overlap'] + 1e-8) * gt_src_patch_overlap 111 | #coarse_loss = coarse_loss + loss_ref_ov.mean() + loss_src_ov.mean() 112 | coarse_loss = coarse_loss + loss_ref_ov.sum() + loss_src_ov.sum() 113 | #loss = loss + loss_ref_ov.mean() + loss_src_ov.mean() 114 | 115 | if 'spot_matching_scores' in output_dict.keys(): 116 | return loss, coarse_loss 117 | else: return coarse_loss 118 | 119 | 120 | class CoarseMatchingLoss(nn.Module): 121 | def __init__(self, cfg): 122 | super(CoarseMatchingLoss, self).__init__() 123 | self.weighted_circle_loss = WeightedCircleLoss( 124 | cfg.positive_margin, 125 | cfg.negative_margin, 126 | cfg.positive_optimal, 127 | cfg.negative_optimal, 128 | cfg.log_scale, 129 | ) 130 | self.positive_overlap = cfg.positive_overlap 131 | 132 | def forward(self, output_dict): 133 | ref_feats = output_dict['ref_feats_c'] 134 | src_feats = output_dict['src_feats_c'] 135 | gt_node_corr_indices = output_dict['gt_patch_corr_indices'] 136 | gt_node_corr_overlaps = output_dict['gt_patch_corr_overlaps'] 137 | 138 | feat_dists = torch.sqrt(pairwise_distance(ref_feats, src_feats, normalized=True)) 139 | overlaps = torch.zeros_like(feat_dists) 140 | overlaps[gt_node_corr_indices[:, 0], gt_node_corr_indices[:, 1]] = gt_node_corr_overlaps 141 | pos_masks = torch.gt(overlaps, self.positive_overlap) 142 | neg_masks = torch.eq(overlaps, 0) 143 | pos_scales = torch.sqrt(overlaps * pos_masks.float()) 144 | 145 | return self.weighted_circle_loss(pos_masks, neg_masks, feat_dists, pos_scales) 146 | 147 | 148 | class CorrespondenceLoss(nn.Module): 149 | def __init__(self, point_to_patch_threshold): 150 | super(CorrespondenceLoss, self).__init__() 151 | self.point_to_patch_threshold = point_to_patch_threshold 152 | 153 | def forward(self, output_dict: Dict[str, torch.Tensor]): 154 | gt_transform = output_dict['gt_transform'] 155 | with torch.no_grad(): 156 | ref_kpts = apply_transform(output_dict['ref_kpts'], torch.linalg.inv(gt_transform)) 157 | dist = torch.norm(output_dict['src_patch_corr_kpts'] - ref_kpts.unsqueeze(1), dim=-1) 158 | ref_mask = torch.lt(dist.min(dim=-1)[0], self.point_to_patch_threshold).nonzero().squeeze() 159 | 160 | src_kpts = apply_transform(output_dict['src_kpts'], gt_transform) 161 | dist = torch.norm(output_dict['ref_patch_corr_kpts'] - src_kpts.unsqueeze(1), dim=-1) 162 | src_mask = torch.lt(dist.min(dim=-1)[0], self.point_to_patch_threshold).nonzero().squeeze() 163 | 164 | loss_corr_ref = torch.norm(output_dict['ref_corres'] - ref_kpts, dim=-1) 165 | loss_corr_src = torch.norm(output_dict['src_corres'] - src_kpts, dim=-1) 166 | 167 | loss_corr_ref = torch.index_select(loss_corr_ref, 0, ref_mask) 168 | loss_corr_src = torch.index_select(loss_corr_src, 0, src_mask) 169 | return (loss_corr_ref.mean() + loss_corr_src.mean()) / 2. 170 | 171 | 172 | class KeypointMatchingLoss(nn.Module): 173 | """ 174 | Modified from source codes of: 175 | - REGTR https://github.com/yewzijian/RegTR. 176 | """ 177 | def __init__(self, positive_threshold, negative_threshold): 178 | super(KeypointMatchingLoss, self).__init__() 179 | self.r_p = positive_threshold 180 | self.r_n = negative_threshold 181 | 182 | def cal_loss(self, src_xyz, tgt_grouped_xyz, tgt_corres, match_logits, match_score, transform): 183 | tgt_grouped_xyz = apply_transform(tgt_grouped_xyz, transform) 184 | tgt_corres = apply_transform(tgt_corres, transform) 185 | 186 | with torch.no_grad(): 187 | dist_keypts:torch.Tensor = torch.norm(src_xyz.unsqueeze(1) - tgt_grouped_xyz, dim=-1) 188 | dist1, idx1 = torch.topk(dist_keypts, k=1, dim=-1, largest=False) 189 | mask = dist1[..., 0] < self.r_p # Only consider points with correspondences 190 | ignore = dist_keypts < self.r_n # Ignore all the points within a certain boundary 191 | ignore.scatter_(-1, idx1, 0) # except the positive 192 | mask_id = mask.nonzero().squeeze() 193 | 194 | match_logits:torch.Tensor = match_logits - 1e4 * ignore.float() 195 | loss_feat = match_logits.logsumexp(dim=-1) - match_logits.gather(-1, idx1).squeeze(-1) 196 | loss_feat = loss_feat.index_select(0, mask_id).mean() 197 | if loss_feat.isnan(): loss_feat = 0. 198 | 199 | dist_keypts:torch.Tensor = torch.norm(src_xyz - tgt_corres, dim=-1) 200 | loss_corr = dist_keypts.index_select(0, mask_id).mean() 201 | if loss_corr.isnan(): loss_corr = 0. 202 | 203 | label = dist_keypts.lt(self.r_p) 204 | weight = torch.logical_not(label.logical_xor(dist_keypts.lt(self.r_n))) 205 | loss_ov = F.binary_cross_entropy(match_score, label.float(), weight.float()) 206 | 207 | return loss_feat, loss_ov, loss_corr 208 | 209 | def forward(self, output_dict: Dict[str, torch.Tensor]): 210 | return self.cal_loss( 211 | output_dict['corres'][:, :3], 212 | output_dict['src_patch_corr_kpts'], 213 | output_dict['corres'][:, 3:], 214 | output_dict['match_logits'], 215 | output_dict['corr_confidence'], 216 | output_dict['gt_transform'] 217 | ) -------------------------------------------------------------------------------- /engine/summary_board.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Optional, List, Dict 3 | 4 | 5 | class AverageMeter: 6 | def __init__(self, last_n=None): 7 | self._records = [] 8 | self.last_n = last_n 9 | 10 | def update(self, result): 11 | if isinstance(result, (list, tuple)): 12 | self._records += result 13 | else: 14 | self._records.append(result) 15 | 16 | def reset(self): 17 | self._records.clear() 18 | 19 | @property 20 | def records(self): 21 | if self.last_n is not None: 22 | return self._records[-self.last_n :] 23 | else: 24 | return self._records 25 | 26 | def sum(self): 27 | return np.sum(self.records) 28 | 29 | def mean(self): 30 | return np.mean(self.records) 31 | 32 | def std(self): 33 | return np.std(self.records) 34 | 35 | def median(self): 36 | return np.median(self.records) 37 | 38 | 39 | class SummaryBoard: 40 | def __init__(self, names: Optional[List[str]] = None, last_n: Optional[int] = None, adaptive=False): 41 | r"""Instantiate a SummaryBoard. 42 | 43 | Args: 44 | names (List[str]=None): create AverageMeter with the names. 45 | last_n (int=None): only the last n records are used. 46 | adaptive (bool=False): whether register basic meters automatically on the fly. 47 | """ 48 | self.meter_names: List[str] = [] 49 | self.meter_dict: Dict[str, AverageMeter] = {} 50 | self.last_n = last_n 51 | self.adaptive = adaptive 52 | 53 | if names is not None: 54 | self.register_all(names) 55 | 56 | def register_meter(self, name): 57 | self.meter_dict[name] = AverageMeter(last_n=self.last_n) 58 | self.meter_names.append(name) 59 | 60 | def register_all(self, names): 61 | for name in names: 62 | self.register_meter(name) 63 | 64 | def reset_meter(self, name): 65 | self.meter_dict[name].reset() 66 | 67 | def reset_all(self): 68 | for name in self.meter_names: 69 | self.reset_meter(name) 70 | 71 | def check_name(self, name): 72 | if name not in self.meter_names: 73 | if self.adaptive: 74 | self.register_meter(name) 75 | else: 76 | raise KeyError('No meter for key "{}".'.format(name)) 77 | 78 | def update(self, name, value): 79 | self.check_name(name) 80 | self.meter_dict[name].update(value) 81 | 82 | def update_from_dict(self, result_dict): 83 | if not isinstance(result_dict, dict): 84 | raise TypeError('`result_dict` must be a dict: {}.'.format(type(result_dict))) 85 | for key, value in result_dict.items(): 86 | if key not in self.meter_names and self.adaptive: 87 | self.register_meter(key) 88 | if key in self.meter_names: 89 | self.meter_dict[key].update(value) 90 | 91 | def sum(self, name): 92 | self.check_name(name) 93 | return self.meter_dict[name].sum() 94 | 95 | def mean(self, name): 96 | self.check_name(name) 97 | return self.meter_dict[name].mean() 98 | 99 | def std(self, name): 100 | self.check_name(name) 101 | return self.meter_dict[name].std() 102 | 103 | def median(self, name): 104 | self.check_name(name) 105 | return self.meter_dict[name].median() 106 | 107 | def summary(self, names=None): 108 | if names is None: 109 | names = self.meter_names 110 | summary_dict = {name: self.meter_dict[name].mean() for name in names} 111 | return summary_dict -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pickle 4 | import random 5 | import numpy as np 6 | from typing import Dict, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.utils.data import DataLoader 12 | from .summary_board import SummaryBoard 13 | 14 | 15 | class EpochBasedTrainer: 16 | def __init__(self, cfg): 17 | self.snapshot_prefix = cfg.runname 18 | self.max_epoch = cfg.optim.max_epoch 19 | self.log_steps = cfg.log_steps 20 | self.save_steps = cfg.save_steps 21 | self.log_dir = cfg.log_dir + '/' 22 | self.snapshot_dir = cfg.snapshot_dir + '/' 23 | 24 | random.seed(cfg.seed) 25 | np.random.seed(cfg.seed) 26 | torch.manual_seed(cfg.seed) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed(cfg.seed) 29 | torch.cuda.manual_seed_all(cfg.seed) 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | #torch.autograd.set_detect_anomaly(True) 33 | 34 | if not os.path.exists(cfg.snapshot_dir): 35 | os.makedirs(cfg.snapshot_dir) 36 | if not os.path.exists(cfg.log_dir): 37 | os.makedirs(cfg.log_dir) 38 | 39 | self.log_file:str = cfg.log_dir + '/' + cfg.runname + '_{}.pkl' 40 | self.log_file = self.log_file.format(time.strftime('%Y-%m-%d-%H-%M-%S')) 41 | self.summary_board = SummaryBoard(last_n=self.log_steps, adaptive=True) 42 | 43 | # deep learning entities 44 | self.model: Optional[nn.Module] = None 45 | self.evaluator: Optional[nn.Module] = None 46 | self.optimizer: Optional[optim.Optimizer] = None 47 | self.scheduler: Optional[optim.lr_scheduler.StepLR] = None 48 | self.loss_func: Optional[nn.Module] = None 49 | self.clip_grad = cfg.optim.clip_grad_norm 50 | 51 | self.train_loader: Optional[DataLoader] = None 52 | self.val_loader: Optional[DataLoader] = None 53 | 54 | self.epoch = 0 55 | self.iteration = 0 56 | 57 | 58 | @classmethod 59 | def release_cuda(clf, x): 60 | if isinstance(x, list): 61 | x = [clf.release_cuda(item) for item in x] 62 | elif isinstance(x, tuple): 63 | x = (clf.release_cuda(item) for item in x) 64 | elif isinstance(x, dict): 65 | x = {key: clf.release_cuda(value) for key, value in x.items()} 66 | elif isinstance(x, torch.Tensor): 67 | if x.numel() == 1: 68 | x = x.item() 69 | else: 70 | x = x.detach().cpu().numpy() 71 | return x 72 | 73 | @classmethod 74 | def to_cuda(clf, x): 75 | if isinstance(x, list): 76 | x = [clf.to_cuda(item) for item in x] 77 | elif isinstance(x, tuple): 78 | x = (clf.to_cuda(item) for item in x) 79 | elif isinstance(x, dict): 80 | x = {key: clf.to_cuda(value) for key, value in x.items()} 81 | elif isinstance(x, torch.Tensor): 82 | x = x.cuda() 83 | return x 84 | 85 | def save_model(self, filename): 86 | filename = self.snapshot_prefix + filename + ".pth" 87 | torch.save(self.model.state_dict(), self.snapshot_dir + filename) 88 | print('Model saved to "{}"'.format(filename)) 89 | 90 | def save_snapshot(self, filename): 91 | state_dict = { 92 | 'epoch': self.epoch, 93 | 'iteration': self.iteration, 94 | 'model': self.model.state_dict(), 95 | 'optimizer': self.optimizer.state_dict(), 96 | } 97 | if self.scheduler is not None: 98 | state_dict['scheduler'] = self.scheduler.state_dict() 99 | filename = self.snapshot_prefix + filename + '.pth.tar' 100 | torch.save(state_dict, self.snapshot_dir + filename) 101 | print('Snapshot saved to "{}"'.format(filename)) 102 | 103 | def load_snapshot(self, filename): 104 | print('Loading from "{}".'.format(filename + '.pth.tar')) 105 | state_dict = torch.load(self.snapshot_dir + filename + '.pth.tar') 106 | self.model.load_state_dict(state_dict['model'], strict=False) 107 | print('Model has been loaded.') 108 | 109 | if 'epoch' in state_dict: 110 | self.epoch = state_dict['epoch'] 111 | if 'iteration' in state_dict: 112 | self.iteration = state_dict['iteration'] 113 | if 'scheduler' in state_dict and self.scheduler is not None: 114 | self.scheduler.load_state_dict(state_dict['scheduler']) 115 | if 'optimizer' in state_dict and self.optimizer is not None: 116 | try: 117 | self.optimizer.load_state_dict(state_dict['optimizer']) 118 | except ValueError: 119 | pass 120 | 121 | 122 | def set_train_mode(self): 123 | self.model.train() 124 | if self.evaluator is not None: 125 | self.evaluator.train() 126 | torch.set_grad_enabled(True) 127 | 128 | def set_eval_mode(self): 129 | self.model.eval() 130 | if self.evaluator is not None: 131 | self.evaluator.eval() 132 | torch.set_grad_enabled(False) 133 | 134 | 135 | def step(self, data_dict) -> Dict[str,torch.Tensor]: 136 | if isinstance(data_dict, tuple) or isinstance(data_dict, list): 137 | output_dict = self.model(*data_dict) 138 | else: output_dict = self.model(data_dict) 139 | loss_dict: Dict = self.loss_func(output_dict, data_dict) 140 | if self.evaluator is not None: 141 | with torch.no_grad(): 142 | result_dict = self.evaluator(output_dict, data_dict) 143 | loss_dict.update(result_dict) 144 | return loss_dict 145 | 146 | def train_epoch(self): 147 | self.optimizer.zero_grad() 148 | steps = len(self.train_loader) 149 | for iteration, data_dict in enumerate(self.train_loader): 150 | self.iteration += 1 151 | data_dict = self.to_cuda(data_dict) 152 | result_dict = self.step(data_dict) 153 | result_dict['loss'].backward() 154 | 155 | if self.clip_grad is not None: 156 | nn.utils.clip_grad_norm_( 157 | self.model.parameters(), max_norm=self.clip_grad) 158 | 159 | self.optimizer.step() 160 | self.optimizer.zero_grad() 161 | 162 | result_dict = self.release_cuda(result_dict) 163 | self.summary_board.update_from_dict(result_dict) 164 | #torch.cuda.empty_cache() 165 | 166 | print("Epoch %d [%d/%d]"%(self.epoch, iteration+1, steps), end=' ') 167 | for key, value in result_dict.items(): 168 | print(key, "%.4f"%float(value), end='; ') 169 | print() 170 | 171 | if (iteration + 1) % self.log_steps == 0: 172 | logs = dict() 173 | for k,v in self.summary_board.meter_dict.items(): 174 | logs[k] = np.array(v._records) 175 | print("Logging into ", self.log_file) 176 | flog = open(self.log_file, 'wb') 177 | flog.write(pickle.dumps(logs)) 178 | flog.close() 179 | 180 | if self.save_steps > 0 and (iteration + 1) % self.save_steps == 0: 181 | self.save_snapshot("-epoch-%02d-%d"%(self.epoch, iteration + 1)) 182 | if self.scheduler is not None: self.scheduler.step() 183 | 184 | if self.scheduler is not None: 185 | self.scheduler.step() 186 | 187 | 188 | def validate_epoch(self): 189 | self.set_eval_mode() 190 | summary_board = SummaryBoard(adaptive=True) 191 | print("---------Start validation---------") 192 | torch.cuda.synchronize() 193 | start = time.time() 194 | 195 | for iteration, data_dict in enumerate(self.val_loader): 196 | data_dict = self.to_cuda(data_dict) 197 | result_dict = self.step(data_dict) 198 | result_dict = self.release_cuda(result_dict) 199 | 200 | summary_board.update_from_dict(result_dict) 201 | print("[%d/%d]"%(iteration, len(self.val_loader)), end=' ') 202 | for key, value in result_dict.items(): 203 | print(key, "%.4f"%float(value), end='; ') 204 | torch.cuda.empty_cache() 205 | torch.cuda.synchronize() 206 | print('%.4fs'%(time.time() - start)) 207 | 208 | self.set_train_mode() 209 | summary = summary_board.summary() 210 | summary_dict = {"val_" + k : v for k,v in summary.items()} 211 | self.summary_board.update_from_dict(summary_dict) 212 | 213 | print("Validate Epoch %02d:"%self.epoch, end=' ') 214 | for key, value in summary_dict.items(): 215 | print(key, "%.4f"%float(value), end='; ') 216 | print() 217 | 218 | logs = dict() 219 | for k,v in self.summary_board.meter_dict.items(): 220 | logs[k] = np.array(v._records) 221 | flog = open(self.log_file, 'wb') 222 | flog.write(pickle.dumps(logs)) 223 | flog.close() 224 | 225 | return summary_board 226 | 227 | 228 | def fit(self, resume_epoch=0, resume_log=None): 229 | assert self.train_loader is not None 230 | if resume_log is not None and resume_epoch > 0: 231 | self.load_snapshot(self.snapshot_prefix + "-epoch-%02d"%resume_epoch) 232 | print('Continue training from epoch %02d.'%(self.epoch + 1)) 233 | f = open(self.log_dir + resume_log, 'rb') 234 | log: Dict[str, np.ndarray] = pickle.load(f) 235 | f.close(); data = dict() 236 | for k, v in log.items(): 237 | print(k, v.shape) 238 | if k[:3] == 'val': data[k] = v[:self.epoch].tolist() 239 | else: data[k] = v[:self.iteration+1].tolist() 240 | self.summary_board.update_from_dict(data) 241 | self.set_train_mode() 242 | 243 | while self.epoch < self.max_epoch: 244 | self.epoch += 1 245 | self.train_epoch() 246 | if self.val_loader is not None: 247 | self.validate_epoch() 248 | self.save_snapshot("-epoch-%02d"%self.epoch) -------------------------------------------------------------------------------- /evaluate_IR_FMR.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 3 | os.environ['CUDA_VISIBLE_DEVICES']='0' 4 | 5 | import json 6 | import torch 7 | import argparse 8 | from munch import munchify 9 | from torch.utils.data import DataLoader 10 | 11 | from data.indoor_data import IndoorTestDataset 12 | from engine.trainer import EpochBasedTrainer 13 | 14 | from models.models.cast import CAST 15 | from engine.evaluator import Evaluator 16 | 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--benchmark", default='3DMatch', choices=['3DMatch', '3DLoMatch']) 20 | parser.add_argument("--config", default='./config/3dmatch.json', type=str) 21 | parser.add_argument("--load_pretrained", default='cast-epoch-05', type=str) 22 | parser.add_argument("--ransac", default=False, action="store_true") 23 | 24 | _args = parser.parse_args() 25 | 26 | 27 | class Tester(EpochBasedTrainer): 28 | def __init__(self, cfg): 29 | super().__init__(cfg) 30 | val_dataset = IndoorTestDataset(cfg.data.root, _args.benchmark, cfg.data.npoints, cfg.data.voxel_size, cfg.data_list) 31 | self.val_loader = DataLoader(val_dataset, 1, num_workers=cfg.data.num_workers, shuffle=False, pin_memory=True) 32 | 33 | self.model = CAST(cfg.model).cuda() 34 | self.evaluator = Evaluator(cfg.eval).cuda() 35 | 36 | def step(self, data_dict): 37 | output_dict = self.model(*data_dict[:3]) 38 | trans = output_dict['gt_transform'] 39 | PIR = self.evaluator.evaluate_coarse_inlier_ratio(output_dict) 40 | results = {'PIR': PIR, 'PMR': PIR.gt(0.2).float()} 41 | 42 | indices = torch.argsort(output_dict['corr_confidence'], descending=True) 43 | corr_confidence = output_dict['corr_confidence'][indices] 44 | corres = output_dict['corres'][indices] 45 | indices = indices[corr_confidence.gt(0.1 * corr_confidence[0])] 46 | corres_ = output_dict['corres'][indices] 47 | for num in [250, 500, 1000, 2500, 5000]: 48 | results['IR@%d'%num] = self.evaluator.evaluate_inlier_ratio(corres_[:num], trans) 49 | results['FMR@%d'%num] = self.evaluator.evaluate_inlier_ratio(corres[:num], trans).gt(0.05).float() 50 | return results 51 | 52 | 53 | if __name__ == "__main__": 54 | with open(_args.config, 'r') as cfg: 55 | args = json.load(cfg) 56 | args = munchify(args) 57 | args.model.ransac = False 58 | 59 | tester = Tester(args) 60 | tester.set_eval_mode() 61 | tester.load_snapshot(_args.load_pretrained) 62 | # e.g. tester.load_snapshot("cast-epoch-05") 63 | tester.validate_epoch() -------------------------------------------------------------------------------- /evaluate_RR.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import numpy as np 4 | from munch import munchify 5 | from torch.utils.data import DataLoader 6 | 7 | from data.indoor_data import IndoorTestDataset 8 | from engine.trainer import EpochBasedTrainer 9 | 10 | from models.models.cast import CAST 11 | from engine.evaluator import Evaluator 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--benchmark", default='3DMatch', choices=['3DMatch', '3DLoMatch']) 16 | parser.add_argument("--config", default='./config/3dmatch.json', type=str) 17 | parser.add_argument("--load_pretrained", default='cast-epoch-05', type=str) 18 | parser.add_argument("--ransac", default=False, action="store_true") 19 | 20 | _args = parser.parse_args() 21 | 22 | 23 | class Tester(EpochBasedTrainer): 24 | def __init__(self, cfg): 25 | super().__init__(cfg) 26 | val_dataset = IndoorTestDataset(cfg.data.root, _args.benchmark, cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, True) 27 | self.val_loader = DataLoader(val_dataset, 1, num_workers=cfg.data.num_workers, shuffle=False, pin_memory=True) 28 | self.val_dataset = val_dataset 29 | 30 | self.model = CAST(cfg.model).cuda() 31 | self.evaluator = Evaluator(cfg.eval).cuda() 32 | 33 | def step(self, data_dict): 34 | output_dict = self.model(*data_dict[:3]) 35 | output_dict['covariance'] = data_dict[-1][0] 36 | return self.evaluator(output_dict) 37 | 38 | 39 | if __name__ == "__main__": 40 | with open(_args.config, 'r') as cfg: 41 | args = json.load(cfg) 42 | args = munchify(args) 43 | args.model.ransac = _args.ransac 44 | 45 | tester = Tester(args) 46 | tester.set_eval_mode() 47 | tester.load_snapshot(_args.load_pretrained) 48 | # e.g. tester.load_snapshot("cast-epoch-05") 49 | result_list = tester.validate_epoch() 50 | RRs = result_list.meter_dict['RR'].records 51 | splits = {} 52 | 53 | for data_dict, recall in zip(tester.val_dataset.dataset, RRs): 54 | scene = data_dict['points1'].split('/')[-3] 55 | if scene not in splits.keys(): 56 | splits[scene] = [] 57 | splits[scene].append(recall) 58 | 59 | print("Registration Recalls:") 60 | splits = {k:np.array(v).mean() for k,v in splits.items()} 61 | for k, v in splits.items(): print(k, v) 62 | print("Average Registration Recall:", np.array([v for v in splits.values()]).mean()) 63 | 64 | -------------------------------------------------------------------------------- /evaluate_eth.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import numpy as np 4 | from munch import munchify 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from data.eth_data import ETHDataset 10 | from engine.trainer import EpochBasedTrainer 11 | 12 | from models.models.cast_eth import CAST 13 | from engine.evaluator import Evaluator 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--load_pretrained", default='kitti/cast-epoch-39.pth.tar', type=str) 18 | parser.add_argument("--filter", default='DSC', choices=['DSC', 'SM']) 19 | 20 | _args = parser.parse_args() 21 | 22 | 23 | class Tester(EpochBasedTrainer): 24 | def __init__(self, cfg): 25 | super().__init__(cfg) 26 | val_dataset = ETHDataset(cfg.data.root, cfg.data.npoints, cfg.data.voxel_size) 27 | self.val_loader = DataLoader(val_dataset, 1, num_workers=cfg.data.num_workers, shuffle=False, pin_memory=True) 28 | self.val_dataset = val_dataset 29 | 30 | cfg.model.filter = _args.filter 31 | self.model = CAST(cfg.model).cuda() 32 | self.evaluator = Evaluator(cfg.eval).cuda() 33 | 34 | def step(self, data_dict): 35 | output_dict = self.model(*data_dict[:3]) 36 | gt_trans = output_dict['gt_transform'] 37 | pred_trans = output_dict['refined_transform'] 38 | rte, rre = self.evaluator.transform_error(gt_trans, pred_trans) 39 | return {"RTE": rte, "RRE": rre} 40 | 41 | 42 | if __name__ == "__main__": 43 | with open('./config/eth.json', 'r') as cfg: 44 | args = json.load(cfg) 45 | args = munchify(args) 46 | 47 | tester = Tester(args) 48 | tester.set_eval_mode() 49 | state_dict = torch.load('./ckpt/' + _args.load_pretrained) 50 | tester.model.load_state_dict(state_dict['model'], strict=False) 51 | result_list = tester.validate_epoch() 52 | 53 | RTE = np.array(result_list.meter_dict['RTE'].records) 54 | RRE = np.array(result_list.meter_dict['RRE'].records) 55 | 56 | RR = np.logical_and(RTE < 0.3, RRE < 2.0).astype(dtype=RTE.dtype) 57 | print("Threshold TE@0.3m,RE@2°", end=", ") 58 | print('RR: %.4f, RRE: %.4f, RTE: %.4f'%(RR.mean(), (RRE * RR).sum() / RR.sum(), (RTE * RR).sum() / RR.sum())) 59 | RR = np.logical_and(RTE < 0.3, RRE < 5.0).astype(dtype=RTE.dtype) 60 | print("Threshold TE@0.3m,RE@5°", end=", ") 61 | print('RR: %.4f, RRE: %.4f, RTE: %.4f'%(RR.mean(), (RRE * RR).sum() / RR.sum(), (RTE * RR).sum() / RR.sum())) 62 | RR = np.logical_and(RTE < 2.0, RRE < 5.0).astype(dtype=RTE.dtype) 63 | print("Threshold TE@2.0m,RE@5°", end=", ") 64 | print('RR: %.4f, RRE: %.4f, RTE: %.4f'%(RR.mean(), (RRE * RR).sum() / RR.sum(), (RTE * RR).sum() / RR.sum())) 65 | 66 | 67 | # RANSAC 68 | #Threshold TE@0.3m,RE@2°, RR: 0.9158, RRE: 0.5176, RTE: 0.0640 69 | #Threshold TE@0.3m,RE@5°, RR: 0.9705, RRE: 0.6422, RTE: 0.0699 70 | #Threshold TE@2.0m,RE@5°, RR: 0.9846, RRE: 0.6652, RTE: 0.0743 71 | # RANSAC + DSC 72 | #Threshold TE@0.3m,RE@2°, RR: 0.9285, RRE: 0.5409, RTE: 0.0626 73 | #Threshold TE@0.3m,RE@5°, RR: 0.9776, RRE: 0.6499, RTE: 0.0685 74 | #Threshold TE@2.0m,RE@5°, RR: 0.9832, RRE: 0.6593, RTE: 0.0704 75 | # RANSAC + SM 76 | #Threshold TE@0.3m,RE@2°, RR: 0.9467, RRE: 0.5428, RTE: 0.0630 77 | #Threshold TE@0.3m,RE@5°, RR: 0.9804, RRE: 0.6095, RTE: 0.0666 78 | #Threshold TE@2.0m,RE@5°, RR: 0.9874, RRE: 0.6279, RTE: 0.0690 -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 3 | os.environ['CUDA_VISIBLE_DEVICES']='0' 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.optim.lr_scheduler import StepLR 10 | from torch.utils.data import DataLoader 11 | 12 | import json 13 | import pickle 14 | import numpy as np 15 | from typing import Dict 16 | from munch import munchify 17 | 18 | from data.kitti_data import KittiDataset 19 | from data.eth_data import ETHDataset 20 | 21 | from models.models.cast import CAST 22 | from engine.evaluator import Evaluator 23 | from engine.trainer import EpochBasedTrainer 24 | from engine.losses import SpotMatchingLoss, CoarseMatchingLoss, KeypointMatchingLoss, ProbChamferLoss 25 | 26 | 27 | class OverallLoss(torch.nn.Module): 28 | def __init__(self, cfg): 29 | super(OverallLoss, self).__init__() 30 | self.weight_det_loss = cfg.weight_det_loss 31 | self.weight_spot_loss = cfg.weight_spot_loss 32 | self.weight_feat_loss = cfg.weight_feat_loss 33 | self.weight_desc_loss = cfg.weight_desc_loss 34 | self.weight_overlap_loss = cfg.weight_overlap_loss 35 | self.weight_corr_loss = cfg.weight_corr_loss 36 | self.weight_trans_loss = cfg.weight_trans_loss 37 | self.weight_rot_loss = cfg.weight_rot_loss 38 | self.pretrain_feat_epochs = cfg.pretrain_feat_epochs 39 | 40 | self.prob_chamfer_loss = ProbChamferLoss() 41 | self.spot_matching_loss = SpotMatchingLoss(cfg) 42 | self.coarse_matching_loss = CoarseMatchingLoss(cfg) 43 | self.kpt_matching_loss = KeypointMatchingLoss(cfg.r_p, cfg.r_n) 44 | self.register_buffer('I3x3', torch.eye(3)) 45 | 46 | def forward(self, output_dict: Dict[str, torch.Tensor], epoch) -> Dict[str, torch.Tensor]: 47 | l_det = self.prob_chamfer_loss(output_dict) 48 | l_spot,l_feat = self.spot_matching_loss(output_dict) 49 | #l_feat = self.coarse_matching_loss(output_dict) 50 | loss = l_feat * self.weight_feat_loss + l_det * self.weight_det_loss + l_spot * self.weight_spot_loss 51 | 52 | loss_dict = {'l_det':l_det, 'l_spot':l_spot, 'l_feat':l_feat} 53 | 54 | l_desc, l_ov, l_corr = self.kpt_matching_loss(output_dict) 55 | l_trans = torch.norm(output_dict['transform'][:3, 3] - output_dict['gt_transform'][:3, 3]) 56 | l_rot = torch.norm(output_dict['transform'][:3, :3].T @ output_dict['gt_transform'][:3, :3] - self.I3x3) 57 | loss = loss + l_desc * self.weight_desc_loss + l_ov * self.weight_overlap_loss + l_corr * self.weight_corr_loss 58 | 59 | loss_dict.update({'l_corr':l_corr, 'l_desc':l_desc, 'l_ov':l_ov}) 60 | 61 | l_trans2 = torch.norm(output_dict['refined_transform'][:3, 3] - output_dict['gt_transform'][:3, 3]) 62 | l_rot2 = torch.norm(output_dict['refined_transform'][:3, :3].T @ output_dict['gt_transform'][:3, :3] - self.I3x3) 63 | 64 | if epoch > self.pretrain_feat_epochs: 65 | loss = loss + l_trans.clamp_max(2.) * self.weight_trans_loss + l_rot.clamp_max(1.) * self.weight_rot_loss 66 | loss = loss + l_trans2.clamp_max(2.) * self.weight_trans_loss + l_rot2.clamp_max(1.) * self.weight_rot_loss 67 | 68 | ret_dict = {'loss':loss, 'l_rot':l_rot} 69 | ret_dict.update(loss_dict) 70 | return ret_dict 71 | 72 | 73 | class Trainer(EpochBasedTrainer): 74 | def __init__(self, cfg, uda_cfg): 75 | super().__init__(cfg) 76 | self.cfg, self.uda_cfg = cfg, uda_cfg 77 | if not os.path.exists(uda_cfg.snapshot_dir): 78 | os.makedirs(uda_cfg.snapshot_dir) 79 | 80 | train_dataset = KittiDataset( 81 | cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, cfg.data.augment) 82 | val_dataset = KittiDataset( 83 | cfg.data.root, 'test', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, 0.0) 84 | self.tune_dataset = ETHDataset( 85 | uda_cfg.data.root, uda_cfg.data.npoints, uda_cfg.data.voxel_size, uda_cfg.data.augment) 86 | test_dataset = ETHDataset( 87 | uda_cfg.data.root, uda_cfg.data.npoints, uda_cfg.data.voxel_size, 0) 88 | 89 | self.train_loader = DataLoader(train_dataset, 1, num_workers=cfg.data.num_workers, shuffle=True, pin_memory=True) 90 | self.val_loader = DataLoader(val_dataset, 1, num_workers=cfg.data.num_workers, shuffle=False, pin_memory=True) 91 | self.test_loader = DataLoader(test_dataset, 1, num_workers=cfg.data.num_workers, shuffle=False, pin_memory=True) 92 | 93 | self.model = CAST(cfg.model).cuda() 94 | self.optimizer = optim.AdamW(self.model.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay) 95 | self.scheduler = StepLR(self.optimizer, step_size=cfg.optim.step_size, gamma=cfg.optim.gamma) 96 | self.loss_func = OverallLoss(cfg.loss).cuda() 97 | self.evaluator = Evaluator(cfg.eval).cuda() 98 | 99 | def step(self, data_dict) -> Dict[str,torch.Tensor]: 100 | output_dict = self.model(*data_dict) 101 | loss_dict: Dict = self.loss_func(output_dict, self.epoch) 102 | with torch.no_grad(): 103 | result_dict = self.evaluator(output_dict) 104 | loss_dict.update(result_dict) 105 | return loss_dict 106 | 107 | def train_epoch(self): 108 | self.optimizer.zero_grad() 109 | steps = len(self.train_loader) 110 | for iteration, data_dict in enumerate(self.train_loader): 111 | self.iteration += 1 112 | self.model.backbone.cfg = self.cfg.model 113 | data_dict = self.to_cuda(data_dict) 114 | result_dict = self.step(data_dict) 115 | result_dict['loss'].backward() 116 | 117 | if self.clip_grad is not None: 118 | nn.utils.clip_grad_norm_( 119 | self.model.parameters(), max_norm=self.clip_grad) 120 | 121 | self.optimizer.step() 122 | self.optimizer.zero_grad() 123 | 124 | print("Epoch %d [%d/%d]"%(self.epoch, iteration+1, steps), end=' ') 125 | for key, value in result_dict.items(): 126 | print(key, "%.4f"%float(value), end='; ') 127 | print() 128 | 129 | self.model.backbone.cfg = self.uda_cfg.model 130 | index = self.iteration % len(self.tune_dataset) 131 | data_dict = self.tune_dataset.__getitem__(index) 132 | data_dict = [x.cuda().unsqueeze(0) for x in data_dict] 133 | result_dict = self.step(data_dict) 134 | result_dict['loss'].backward() 135 | 136 | if self.clip_grad is not None: 137 | nn.utils.clip_grad_norm_( 138 | self.model.parameters(), max_norm=self.clip_grad) 139 | 140 | self.optimizer.step() 141 | self.optimizer.zero_grad() 142 | 143 | result_dict = self.release_cuda(result_dict) 144 | self.summary_board.update_from_dict(result_dict) 145 | #torch.cuda.empty_cache() 146 | 147 | print(" %d [%d/%d]"%(self.epoch, iteration+1, steps), end=' ') 148 | for key, value in result_dict.items(): 149 | print(key, "%.4f"%float(value), end='; ') 150 | print() 151 | 152 | if (iteration + 1) % self.log_steps == 0: 153 | logs = dict() 154 | for k,v in self.summary_board.meter_dict.items(): 155 | logs[k] = np.array(v._records) 156 | print("Logging into ", self.log_file) 157 | flog = open(self.log_file, 'wb') 158 | flog.write(pickle.dumps(logs)) 159 | flog.close() 160 | 161 | if self.save_steps > 0 and (iteration + 1) % self.save_steps == 0: 162 | self.save_snapshot("-epoch-%02d-%d"%(self.epoch, iteration + 1)) 163 | if self.scheduler is not None: self.scheduler.step() 164 | 165 | if self.scheduler is not None: 166 | self.scheduler.step() 167 | 168 | 169 | if __name__ == "__main__": 170 | with open('./config/kitti.json', 'r') as cfg: 171 | args = json.load(cfg) 172 | args = munchify(args) 173 | with open('./config/eth.json', 'r') as cfg: 174 | uda_args = json.load(cfg) 175 | uda_args = munchify(uda_args) 176 | 177 | engine = Trainer(args, uda_args) 178 | engine.load_snapshot('cast-epoch-39') 179 | engine.snapshot_dir = uda_args.snapshot_dir 180 | engine.max_epoch = engine.epoch + uda_args.optim.max_epoch 181 | for i in range(uda_args.optim.max_epoch): 182 | engine.train_epoch() 183 | engine.save_snapshot("-epoch-%02d"%(i+1)) -------------------------------------------------------------------------------- /models/cast/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/models/cast/__init__.py -------------------------------------------------------------------------------- /models/cast/consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import torch 4 | 5 | from pytorch3d.ops import knn_points 6 | from models.utils import apply_transform, weighted_svd 7 | 8 | 9 | def registration_ransac_based_on_correspondence( 10 | ref_corres_xyz: torch.Tensor, 11 | src_corres_xyz: torch.Tensor, 12 | corr_weight: torch.Tensor = None, 13 | verified_ref_points: torch.Tensor = None, 14 | verified_src_points: torch.Tensor = None, 15 | inlier_threshold = 0.05, 16 | topk = 250, 17 | ransac_iters = 50000, 18 | ransac_n = 4 19 | ): 20 | if corr_weight is None: 21 | indices = torch.arange(ref_corres_xyz.shape[0], device=ref_corres_xyz.device)[:topk] 22 | else: 23 | indices = torch.argsort(corr_weight, descending=True)[:topk] 24 | if verified_ref_points is None: 25 | ref_points = ref_corres_xyz[indices] 26 | else: 27 | ref_points = torch.cat([ref_corres_xyz[indices], verified_ref_points], dim=0) 28 | if verified_src_points is None: 29 | src_points = src_corres_xyz[indices] 30 | else: 31 | src_points = torch.cat([src_corres_xyz[indices], verified_src_points], dim=0) 32 | indices = np.arange(indices.shape[0]) 33 | correspondences = np.stack([indices, indices], axis=1) 34 | correspondences = o3d.utility.Vector2iVector(correspondences) 35 | 36 | ref_pcd = o3d.geometry.PointCloud() 37 | ref_pcd.points = o3d.utility.Vector3dVector(ref_points.detach().cpu().numpy()) 38 | src_pcd = o3d.geometry.PointCloud() 39 | src_pcd.points = o3d.utility.Vector3dVector(src_points.detach().cpu().numpy()) 40 | 41 | transform = o3d.pipelines.registration.registration_ransac_based_on_correspondence( 42 | src_pcd, ref_pcd, correspondences, inlier_threshold, 43 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), ransac_n, 44 | [o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 45 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(inlier_threshold)], 46 | o3d.pipelines.registration.RANSACConvergenceCriteria(ransac_iters, 0.999) 47 | ).transformation 48 | 49 | return torch.FloatTensor(np.array(transform)).to(ref_corres_xyz.device) 50 | 51 | -------------------------------------------------------------------------------- /models/cast/correspondence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch3d.ops import knn_points, knn_gather 5 | 6 | from models.kpconv import UnaryBlock 7 | from models.transformer.vanilla_transformer import AttentionLayer 8 | 9 | 10 | class KeypointMatching(nn.Module): 11 | def __init__(self, d_embed, num_neighbors, learnable=True): 12 | super(KeypointMatching, self).__init__() 13 | self.num_neighbors = num_neighbors 14 | self.learnable = learnable 15 | if self.learnable: 16 | self.proj_q = nn.Linear(d_embed, d_embed, bias=False) 17 | self.proj_k = nn.Linear(d_embed, d_embed, bias=False) 18 | 19 | self.W = torch.nn.Parameter(torch.zeros(d_embed, d_embed), requires_grad=True) 20 | torch.nn.init.normal_(self.W, std=0.1) 21 | else: 22 | self.proj_q, self.proj_k = nn.Identity(), nn.Identity() 23 | self.W = torch.nn.Parameter(torch.eye(d_embed) * 0.5, requires_grad=False) 24 | 25 | def forward(self, feat, knn_xyz, knn_feat:torch.Tensor, weights=None, knn_weights=None, knn_mask:torch.Tensor=None): 26 | q = self.proj_q(feat).unsqueeze(1) # (N, 1, C) 27 | k = self.proj_k(knn_feat).transpose(1, 2) # (N, C, K) 28 | attention_scores = torch.matmul(q, k).squeeze(1) # (N, K) 29 | if knn_mask is not None: 30 | attention_scores = attention_scores - (~knn_mask).float() * 1e12 31 | 32 | if self.num_neighbors > 0 and self.num_neighbors < k.shape[-1]: 33 | neighbor_mask = torch.full_like(attention_scores, fill_value=float('-inf')) 34 | neighbor_mask[:, torch.topk(attention_scores, k=self.num_neighbors, dim=-1)[1]] = 0 35 | attention_scores = attention_scores + neighbor_mask 36 | 37 | attention_scores = torch.softmax(attention_scores, dim=-1) # (N, K) 38 | corres_xyz = torch.einsum('nk,nkc->nc', attention_scores, knn_xyz) # (N, 3) 39 | corres_feat = torch.einsum('nk,nkc->nc', attention_scores, knn_feat) # (N, C) 40 | 41 | W_triu = torch.triu(self.W) 42 | W_symmetrical = W_triu + W_triu.T 43 | match_logits = torch.einsum('nc,cd,nkd->nk', feat, W_symmetrical, knn_feat) # (N, K) 44 | logit = torch.einsum('nc,cd,nd->n', feat, W_symmetrical, corres_feat).unsqueeze(-1) # (N, 1) 45 | if knn_weights is None: 46 | attentive_feats = torch.cat([feat, corres_feat, logit], dim=-1) # (N, 2C+1) 47 | else: 48 | corres_weight = torch.sum(attention_scores * knn_weights.squeeze(-1), dim=-1, keepdim=True) # (N, 1) 49 | attentive_feats = torch.cat([feat, corres_feat, weights.unsqueeze(-1), corres_weight, logit], dim=-1) # (N, 2C+3) 50 | return corres_xyz, attentive_feats, match_logits 51 | 52 | 53 | class FineMatching(nn.Module): 54 | def __init__(self, d_embed, num_neighbors, max_distance, learnable=True): 55 | super(FineMatching, self).__init__() 56 | self.k = num_neighbors 57 | self.max_dist = max_distance 58 | if learnable: 59 | self.W = torch.nn.Parameter(torch.zeros(d_embed, d_embed), requires_grad=True) 60 | torch.nn.init.normal_(self.W, std=0.1) 61 | else: 62 | self.W = torch.nn.Parameter(torch.eye(d_embed) * 0.5, requires_grad=False) 63 | 64 | def forward(self, ref_points:torch.Tensor, src_points:torch.Tensor, ref_feats:torch.Tensor, src_feats:torch.Tensor): 65 | dist, knn_indices, knn_xyz = knn_points(src_points.unsqueeze(0), ref_points.unsqueeze(0), K=self.k, return_nn=True) 66 | weight = torch.relu(1. - dist.squeeze() / self.max_dist) # (N, K) or (N,) (k=1) 67 | if self.k == 1: # not learnable 68 | return knn_xyz.squeeze(), weight # (N, 3), (N,) 69 | 70 | W_triu = torch.triu(self.W) 71 | W_symmetrical = W_triu + W_triu.T 72 | knn_feats = knn_gather(ref_feats.unsqueeze(0), knn_indices).squeeze(0) # (N, K, C) 73 | attention_scores = torch.einsum('nc,cd,nkd->nk', src_feats, W_symmetrical, knn_feats) # (N, K) 74 | attention_scores = torch.softmax(attention_scores * weight, dim=-1) # (N, K) 75 | 76 | corres = torch.einsum('nk,nkc->nc', attention_scores, knn_xyz.squeeze()) # (N, 3) 77 | dist = torch.norm(src_points - corres, dim=-1) 78 | weight = torch.relu(1. - dist / self.max_dist) 79 | return corres, weight 80 | 81 | 82 | class CompatibilityGraphEmbedding(nn.Module): 83 | def __init__(self, in_channels, out_channels, num_layers, sigma_d): 84 | super(CompatibilityGraphEmbedding, self).__init__() 85 | self.num_layers = num_layers 86 | self.layer = nn.Linear(in_channels+6, out_channels) 87 | self.mlps = nn.ModuleList([UnaryBlock(out_channels, out_channels) for _ in range(num_layers)]) 88 | self.attns = nn.ModuleList([AttentionLayer(out_channels, 1) for _ in range(num_layers)]) 89 | self.classifier = nn.Sequential( 90 | nn.Linear(out_channels, 32), nn.ReLU(), 91 | nn.Linear(32, 32), nn.ReLU(), 92 | nn.Linear(32, 1), nn.Sigmoid() 93 | ) 94 | self.sigma_spat = nn.Parameter(torch.tensor(sigma_d).float(), requires_grad=False) 95 | 96 | def forward(self, ref_keypts:torch.Tensor, src_keypts:torch.Tensor, corr_feat): 97 | feat = torch.cat([ref_keypts, src_keypts], dim=-1) # (N, 6) 98 | feat = feat - feat.mean(-1, keepdim=True) # (N, 6) 99 | feat = torch.cat([corr_feat, feat], dim=-1) # (N, C+6) 100 | feat = self.layer(feat.unsqueeze(0)) # (1, N, C) 101 | 102 | with torch.no_grad(): 103 | geo_compatibility = torch.cdist(ref_keypts, ref_keypts) - torch.cdist(src_keypts, src_keypts) 104 | geo_compatibility = torch.clamp(1.0 - geo_compatibility ** 2 / self.sigma_spat ** 2, min=0) 105 | geo_compatibility = geo_compatibility.unsqueeze(0) 106 | 107 | for i in range(self.num_layers): 108 | feat = self.mlps[i](feat) 109 | feat = self.attns[i](feat, feat, attention_factors=geo_compatibility)[0] 110 | 111 | feat = F.normalize(feat, p=2, dim=-1).squeeze(0) 112 | return feat, self.classifier(feat).squeeze() 113 | -------------------------------------------------------------------------------- /models/cast/spot_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch3d.ops import knn_gather 5 | from einops import rearrange 6 | 7 | from models.transformer.output_layer import AttentionOutput 8 | from models.transformer.positional_encoding import RotaryPositionalEmbedding 9 | from models.kpconv import UnaryBlock 10 | 11 | 12 | class Upsampling(nn.Module): 13 | def __init__(self, in_channels, out_channels): 14 | super(Upsampling, self).__init__() 15 | self.unary = nn.Sequential( 16 | UnaryBlock(in_channels, out_channels), 17 | UnaryBlock(out_channels, out_channels) 18 | ) 19 | self.output = UnaryBlock(out_channels, out_channels) 20 | 21 | def forward(self, query, support, upsample_indices): 22 | """ 23 | Args: 24 | query (Tensor): (B, N, C) 25 | support (Tensor): (B, M, C') 26 | upsample_indices (Tensor): (B, N, 1) 27 | return: 28 | latent (Tensor): (B, N, C) 29 | """ 30 | latent = knn_gather(support, upsample_indices).squeeze(2) 31 | return self.output(self.unary(latent) + query) 32 | 33 | 34 | class Downsampling(nn.Module): 35 | def __init__(self, in_channels, out_channels): 36 | super(Downsampling, self).__init__() 37 | self.unary = nn.Sequential( 38 | UnaryBlock(in_channels, out_channels), 39 | UnaryBlock(out_channels, out_channels) 40 | ) 41 | self.output = UnaryBlock(out_channels, out_channels) 42 | 43 | def forward(self, q_feats, s_feats, q_points:torch.Tensor, s_points:torch.Tensor, downsample_indices): 44 | """ 45 | Args: 46 | q_feats (Tensor): (B, N, C) 47 | s_feats (Tensor): (B, M, C') 48 | q_points (Tensor): (B, N, 3) 49 | s_points (Tensor): (B, N, K, 3) 50 | downsample_indices (Tensor): (B, N, K) 51 | return: 52 | latent (Tensor): (B, M, C) 53 | """ 54 | grouped_feats = knn_gather(s_feats, downsample_indices) # (B, N, K, C') 55 | knn_weights = 1. / ((s_points - q_points.unsqueeze(2)).pow(2).sum(-1) + 1e-8) # (B, N, K) 56 | knn_weights = knn_weights / knn_weights.sum(dim=-1, keepdim=True) # (B, N, K) 57 | latent = torch.sum(grouped_feats * knn_weights.unsqueeze(-1), dim=2) # (B, N, C) 58 | return self.output(self.unary(latent) + q_feats) 59 | 60 | 61 | class SparseTransformerLayer(nn.Module): 62 | def __init__(self, d_model, num_heads, pe=True, dropout=None, activation_fn='relu'): 63 | super(SparseTransformerLayer, self).__init__() 64 | self.d_model = d_model 65 | self.num_heads = num_heads 66 | self.d_model_per_head = d_model // num_heads 67 | self.pe = pe 68 | 69 | self.proj_q = nn.Linear(self.d_model, self.d_model) 70 | self.proj_k = nn.Linear(self.d_model, self.d_model) 71 | self.proj_v = nn.Linear(self.d_model, self.d_model) 72 | 73 | self.linear = nn.Linear(d_model, d_model) 74 | if dropout is None or dropout <= 0: 75 | self.dropout = nn.Identity() 76 | else: self.dropout = nn.Dropout(dropout) 77 | self.norm = nn.LayerNorm(d_model) 78 | self.output = AttentionOutput(d_model, dropout, activation_fn) 79 | if pe: self.rpe = RotaryPositionalEmbedding(self.d_model) 80 | 81 | @torch.no_grad() 82 | def select_spots(self, input_knn, memory_knn, confidence_scores, matching_indices, num_spots): 83 | """ 84 | Args: 85 | input_knn (Tensor): (B, N, k+1) 86 | memory_knn (Tensor): (B, M, K) 87 | confidence_scores (Tensor): (B, N, 1) 88 | matching_indices (Tensor): (B, N, 1) 89 | 90 | Returns: 91 | output_states: torch.Tensor (B, N, C) 92 | """ 93 | knn_scores = knn_gather(confidence_scores, input_knn[...,1:]).squeeze(-1) # (B, N, k) 94 | confidence_scores, confident_knn = knn_scores.topk(k=num_spots) # (B, N, S) 95 | confident_knn = torch.gather(input_knn[...,1:], -1, confident_knn) # (B, N, S) 96 | confident_knn = torch.cat([input_knn[...,:1], confident_knn], dim=-1) # (B, N, S+1) 97 | 98 | spot_indices = knn_gather(matching_indices, confident_knn).squeeze(-1) # (B, N, S+1) 99 | spot_indices = knn_gather(memory_knn, spot_indices) # (B, N, S+1, K) 100 | spot_indices = rearrange(spot_indices, 'b n s k -> b n (s k)') # (B, N, (S+1)*K) 101 | 102 | # avoid redundant indices from spot areas 103 | B, N, M = input_knn.shape[0], input_knn.shape[1], memory_knn.shape[1] 104 | attention_mask = torch.zeros((B, N, M), device=input_knn.device) 105 | attention_mask.scatter_(-1, spot_indices, 1.) # (B, N, M) 106 | spot_mask, spot_indices = attention_mask.topk(spot_indices.shape[-1]) # (B, N, (S+1)*K) 107 | return spot_mask, spot_indices 108 | 109 | def forward(self, input_states, memory_states, indices, input_coord=None, memory_coord=None, attention_mask=None): 110 | """Sparse Transformer Layer 111 | 112 | Args: 113 | input_states (Tensor): (B, N, C) 114 | memory_states (Tensor): (B, M, C) 115 | indices (Tensor): (B, N, K) 116 | input_coord (Tensor): (B, N, 3) 117 | memory_coord (Tensor): (B, M, 3) 118 | attention_mask (Tensor): (B, N, K) 119 | 120 | Returns: 121 | output_states: torch.Tensor (B, N, C) 122 | """ 123 | q = self.proj_q(input_states) # (B, N, H*C) 124 | k = knn_gather(self.proj_k(memory_states), indices) # (B, N, K, H*C) 125 | v = knn_gather(self.proj_v(memory_states), indices) # (B, N, K, H*C) 126 | if self.pe and memory_coord is not None and input_coord is not None: 127 | k = self.rpe(knn_gather(memory_coord, indices) - input_coord.unsqueeze(2), k) 128 | 129 | q = rearrange(q, 'b n (h c) -> b h n c', h=self.num_heads) 130 | k = rearrange(k, 'b n m (h c) -> b h n m c', h=self.num_heads) 131 | v = rearrange(v, 'b n m (h c) -> b h n m c', h=self.num_heads) 132 | 133 | attention_scores = torch.einsum('bhnc,bhnmc->bhnm', q, k) / self.d_model_per_head ** 0.5 134 | if attention_mask is not None: 135 | attention_scores = attention_scores - 1e6 * (1. - attention_mask.unsqueeze(1)) 136 | attention_scores = F.softmax(attention_scores, dim=-1) 137 | hidden_states = torch.sum(attention_scores.unsqueeze(-1) * v, dim=-2) 138 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 139 | 140 | hidden_states = self.linear(hidden_states) 141 | hidden_states = self.dropout(hidden_states) 142 | output_states = self.norm(hidden_states + input_states) 143 | output_states = self.output(output_states) 144 | return output_states 145 | -------------------------------------------------------------------------------- /models/kpconv/__init__.py: -------------------------------------------------------------------------------- 1 | from models.kpconv.backbone import KPConvFPN 2 | from models.kpconv.kpconv import KPConv 3 | from models.kpconv.modules import ( 4 | ConvBlock, 5 | ResidualBlock, 6 | NearestUpsampleBlock, 7 | KeypointDetector, 8 | DescExtractor, 9 | UnaryBlock, 10 | GroupNorm, 11 | nearest_upsample, 12 | maxpool, 13 | ) -------------------------------------------------------------------------------- /models/kpconv/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.kpconv.modules import ConvBlock, ResidualBlock, NearestUpsampleBlock, KeypointDetector, DescExtractor 5 | 6 | 7 | class KPConvFPN(nn.Module): 8 | def __init__(self, cfg): 9 | super(KPConvFPN, self).__init__() 10 | self.cfg = cfg 11 | init_dim = cfg.init_dim * 2 12 | init_sigma = cfg.init_sigma 13 | init_radius = cfg.init_radius 14 | 15 | self.encoder1_1 = ConvBlock(cfg.input_dim, cfg.init_dim, cfg.kernel_size, init_radius, init_sigma) 16 | self.encoder1_2 = ResidualBlock(cfg.init_dim, cfg.init_dim * 2, cfg.kernel_size, init_radius, init_sigma) 17 | 18 | self.encoder = nn.ModuleList() 19 | for _ in range(1, cfg.kpconv_layers): 20 | self.encoder.append(nn.ModuleList([ 21 | ResidualBlock(init_dim, init_dim, cfg.kernel_size, init_radius, init_sigma, strided=True), 22 | ResidualBlock(init_dim, init_dim * 2, cfg.kernel_size, init_radius * 2, init_sigma * 2), 23 | ResidualBlock(init_dim * 2, init_dim * 2, cfg.kernel_size, init_radius * 2, init_sigma * 2), 24 | ])) 25 | init_dim = init_dim * 2 26 | init_sigma = init_sigma * 2 27 | init_radius = init_radius * 2 28 | 29 | self.decoder = nn.ModuleList() 30 | for _ in range(2, cfg.kpconv_layers): 31 | init_dim = init_dim // 2 32 | self.decoder.append(NearestUpsampleBlock(init_dim * 3, init_dim)) 33 | 34 | self.detector = KeypointDetector(32, cfg.init_dim * 4, cfg.init_dim) 35 | self.desc_extractor = DescExtractor(cfg.init_dim * 4, cfg.init_dim) 36 | 37 | def forward(self, points_list, neighbors_list, subsampling_list, upsampling_list): 38 | feats = torch.ones_like(points_list[0][:, :1]) 39 | feats = self.encoder1_1(feats, points_list[0], points_list[0], neighbors_list[0]) 40 | feats = self.encoder1_2(feats, points_list[0], points_list[0], neighbors_list[0]) 41 | 42 | feats_list = [] 43 | for i in range(self.cfg.kpconv_layers - 1): 44 | feats = self.encoder[i][0](feats, points_list[i + 1], points_list[i], subsampling_list[i]) 45 | feats = self.encoder[i][1](feats, points_list[i + 1], points_list[i + 1], neighbors_list[i + 1]) 46 | feats = self.encoder[i][2](feats, points_list[i + 1], points_list[i + 1], neighbors_list[i + 1]) 47 | feats_list.append(feats) 48 | 49 | for i in range(1, self.cfg.kpconv_layers - 1): 50 | feats_list[-i - 1] = self.decoder[i - 1](feats_list[-i - 1], feats_list[-i], upsampling_list[-i]) 51 | 52 | xyz, sigma, grouped_feat, attentive_feat = self.detector(points_list[2], points_list[1], feats_list[0]) 53 | desc = self.desc_extractor(grouped_feat, attentive_feat) 54 | 55 | return {'feats':feats_list, 'keypoints':xyz, 'sigma':sigma, 'desc':desc} 56 | -------------------------------------------------------------------------------- /models/kpconv/dispositions/k_015_center_3D.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/models/kpconv/dispositions/k_015_center_3D.ply -------------------------------------------------------------------------------- /models/kpconv/kpconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.utils import index_select 7 | from models.kpconv.kernel_points import load_kernels 8 | 9 | 10 | class KPConv(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size, radius, sigma, bias=False, dimension=3): 12 | """Initialize parameters for KPConv. 13 | 14 | Modified from [KPConv-PyTorch](https://github.com/HuguesTHOMAS/KPConv-PyTorch). 15 | 16 | Args: 17 | in_channels: dimension of input features. 18 | out_channels: dimension of output features. 19 | kernel_size: Number of kernel points. 20 | radius: radius used for kernel point init. 21 | sigma: influence radius of each kernel point. 22 | bias: use bias or not (default: False) 23 | dimension: dimension of the point space. 24 | inf: value of infinity to generate the padding point 25 | eps: epsilon for gaussian influence 26 | """ 27 | super(KPConv, self).__init__() 28 | self.kernel_size = kernel_size 29 | self.in_channels = in_channels 30 | self.out_channels = out_channels 31 | self.radius = radius 32 | self.sigma = sigma 33 | 34 | self.weights = nn.Parameter(torch.zeros(self.kernel_size, in_channels, out_channels)) 35 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) 36 | if bias: 37 | self.bias = nn.Parameter(torch.zeros(self.out_channels), requires_grad=True) 38 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights) 39 | bound = 1 / math.sqrt(fan_in) 40 | nn.init.uniform_(self.bias, -bound, bound) 41 | else: 42 | self.register_parameter('bias', None) 43 | 44 | kernel_points = load_kernels(self.radius, self.kernel_size, dimension, fixed='center') # (N, 3) 45 | self.kernel_points = nn.Parameter(torch.tensor(kernel_points, dtype=torch.float32), requires_grad=False) 46 | 47 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 48 | """KPConv forward. 49 | 50 | Args: 51 | s_feats (Tensor): (N, C_in) 52 | q_points (Tensor): (M, 3) 53 | s_points (Tensor): (N, 3) 54 | neighbor_indices (LongTensor): (M, H) 55 | Returns: 56 | q_feats (Tensor): (M, C_out) 57 | """ 58 | s_points = torch.cat([s_points, torch.zeros_like(s_points[:1, :]) + 1e8], 0) # (N, 3) -> (N+1, 3) 59 | neighbors = index_select(s_points, neighbor_indices, dim=0) # (N+1, 3) -> (M, H, 3) 60 | neighbors = neighbors - q_points.unsqueeze(1) # (M, H, 3) 61 | 62 | # Get Kernel point influences 63 | neighbors = neighbors.unsqueeze(2) # (M, H, 3) -> (M, H, 1, 3) 64 | differences = neighbors - self.kernel_points # (M, H, 1, 3) x (K, 3) -> (M, H, K, 3) 65 | sq_distances = torch.sum(differences ** 2, dim=3) # (M, H, K) 66 | neighbor_weights = torch.clamp(1 - torch.sqrt(sq_distances) / self.sigma, min=0.0) # (M, H, K) 67 | neighbor_weights = torch.transpose(neighbor_weights, 1, 2) # (M, H, K) -> (M, K, H) 68 | 69 | # apply neighbor weights 70 | s_feats = torch.cat((s_feats, torch.zeros_like(s_feats[:1, :])), 0) # (N, C) -> (N+1, C) 71 | neighbor_feats = index_select(s_feats, neighbor_indices, dim=0) # (N+1, C) -> (M, H, C) 72 | weighted_feats = torch.matmul(neighbor_weights, neighbor_feats) # (M, K, H) x (M, H, C) -> (M, K, C) 73 | 74 | # apply convolutional weights 75 | weighted_feats = weighted_feats.permute(1, 0, 2) # (M, K, C) -> (K, M, C) 76 | kernel_outputs = torch.matmul(weighted_feats, self.weights) # (K, M, C) x (K, C, C_out) -> (K, M, C_out) 77 | output_feats = torch.sum(kernel_outputs, dim=0, keepdim=False) # (K, M, C_out) -> (M, C_out) 78 | 79 | # normalization 80 | neighbor_feats_sum = torch.sum(neighbor_feats, dim=-1) 81 | neighbor_num = torch.sum(torch.gt(neighbor_feats_sum, 0.0), dim=-1) 82 | neighbor_num = torch.max(neighbor_num, torch.ones_like(neighbor_num)) 83 | output_feats = output_feats / neighbor_num.unsqueeze(1) 84 | 85 | # add bias 86 | if self.bias is not None: 87 | output_feats = output_feats + self.bias 88 | 89 | return output_feats 90 | 91 | def __repr__(self): 92 | format_string = self.__class__.__name__ + '(' 93 | format_string += 'kernel_size: {}'.format(self.kernel_size) 94 | format_string += ', in_channels: {}'.format(self.in_channels) 95 | format_string += ', out_channels: {}'.format(self.out_channels) 96 | format_string += ', radius: {:g}'.format(self.radius) 97 | format_string += ', sigma: {:g}'.format(self.sigma) 98 | format_string += ', bias: {}'.format(self.bias is not None) 99 | format_string += ')' 100 | return format_string 101 | -------------------------------------------------------------------------------- /models/kpconv/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch3d.ops import knn_points 5 | from models.utils import index_select 6 | from models.kpconv.kpconv import KPConv 7 | 8 | 9 | def nearest_upsample(x, upsample_indices): 10 | """Pools features from the closest neighbors. 11 | 12 | Args: 13 | x: [M, C] features matrix 14 | upsample_indices: [N, K] Only the first column is used for pooling 15 | 16 | Returns: 17 | x: [N, C] pooled features matrix 18 | """ 19 | x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) 20 | x = index_select(x, upsample_indices[:, 0], dim=0) 21 | return x 22 | 23 | 24 | def maxpool(x, neighbor_indices): 25 | """Max pooling from neighbors. 26 | 27 | Args: 28 | x: [M, C] features matrix 29 | neighbor_indices: [N, K] pooling indices 30 | 31 | Returns: 32 | pooled_feats: [N, C] pooled features matrix 33 | """ 34 | x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) 35 | neighbor_feats = index_select(x, neighbor_indices, dim=0) 36 | pooled_feats = neighbor_feats.max(1)[0] 37 | return pooled_feats 38 | 39 | 40 | class GroupNorm(nn.Module): 41 | def __init__(self, num_groups, num_channels): 42 | super(GroupNorm, self).__init__() 43 | self.num_groups = num_groups 44 | self.num_channels = num_channels 45 | self.norm = nn.GroupNorm(self.num_groups, self.num_channels) 46 | 47 | def forward(self, x: torch.Tensor): 48 | ndim = x.ndim 49 | if ndim == 2: 50 | x = x.unsqueeze(0) # (N, C) -> (1, N, C) 51 | x = x.transpose(1, 2) # (B, N, C) -> (B, C, N) 52 | x = self.norm(x) 53 | x = x.transpose(1, 2) # (B, C, N) -> (B, N, C) 54 | if ndim == 2: 55 | x = x.squeeze(0) 56 | return x 57 | 58 | def __repr__(self): 59 | return self.norm.__repr__() 60 | 61 | 62 | ACT_LAYERS = { 63 | 'relu': nn.ReLU(), 64 | 'leaky_relu': nn.LeakyReLU(0.1), 65 | 'sigmoid': nn.Sigmoid(), 66 | 'softplus': nn.Softplus(), 67 | 'tanh': nn.Tanh(), 68 | 'elu': nn.ELU(), 69 | 'gelu': nn.GELU(), 70 | None: nn.Identity(), 71 | } 72 | 73 | class UnaryBlock(nn.Module): 74 | def __init__(self, in_channels, out_channels, group_norm=32, activation_fn='leaky_relu', bias=True, layer_norm=False): 75 | super(UnaryBlock, self).__init__() 76 | self.in_channels = in_channels 77 | self.out_channels = out_channels 78 | self.activation_fn = activation_fn 79 | self.mlp = nn.Linear(in_channels, out_channels, bias=bias) 80 | if layer_norm: 81 | self.norm = nn.LayerNorm(out_channels) 82 | else: 83 | self.norm = GroupNorm(group_norm, out_channels) 84 | self.activation = ACT_LAYERS[activation_fn] 85 | 86 | def forward(self, x): 87 | x = self.mlp(x) 88 | x = self.norm(x) 89 | if self.activation is not None: 90 | x = self.activation(x) 91 | return x 92 | 93 | def __repr__(self): 94 | format_string = self.__class__.__name__ + '(' 95 | format_string += '{}, {}'.format(self.in_channels, self.out_channels) 96 | format_string += ', ' + self.norm.__repr__() 97 | format_string += ', ' + self.activation.__repr__() 98 | format_string += ')' 99 | return format_string 100 | 101 | 102 | class ConvBlock(nn.Module): 103 | def __init__(self, in_channels, out_channels, kernel_size, radius, 104 | sigma, group_norm=32, bias=True, layer_norm=False): 105 | super(ConvBlock, self).__init__() 106 | 107 | self.in_channels = in_channels 108 | self.out_channels = out_channels 109 | 110 | self.KPConv = KPConv(in_channels, out_channels, kernel_size, radius, sigma, bias) 111 | if layer_norm: 112 | self.norm = nn.LayerNorm(out_channels) 113 | else: 114 | self.norm = GroupNorm(group_norm, out_channels) 115 | self.leaky_relu = nn.LeakyReLU(0.1) 116 | 117 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 118 | x = self.KPConv(s_feats, q_points, s_points, neighbor_indices) 119 | return self.leaky_relu(self.norm(x)) 120 | 121 | 122 | class ResidualBlock(nn.Module): 123 | def __init__(self, in_channels, out_channels, kernel_size, radius, 124 | sigma, group_norm=32, strided=False, bias=True, layer_norm=False, 125 | ): 126 | super(ResidualBlock, self).__init__() 127 | self.in_channels = in_channels 128 | self.out_channels = out_channels 129 | self.strided = strided 130 | mid_channels = out_channels // 4 131 | 132 | if in_channels != mid_channels: 133 | self.unary1 = UnaryBlock(in_channels, mid_channels, group_norm, bias=bias, layer_norm=layer_norm) 134 | else: 135 | self.unary1 = nn.Identity() 136 | 137 | self.KPConv = KPConv(mid_channels, mid_channels, kernel_size, radius, sigma, bias=bias) 138 | if layer_norm: 139 | self.norm = nn.LayerNorm(mid_channels) 140 | else: 141 | self.norm = GroupNorm(group_norm, mid_channels) 142 | 143 | self.unary2 = UnaryBlock( 144 | mid_channels, out_channels, group_norm, activation_fn=None, bias=bias, layer_norm=layer_norm 145 | ) 146 | 147 | if in_channels != out_channels: 148 | self.unary_shortcut = UnaryBlock( 149 | in_channels, out_channels, group_norm, activation_fn=None, bias=bias, layer_norm=layer_norm 150 | ) 151 | else: 152 | self.unary_shortcut = nn.Identity() 153 | 154 | self.leaky_relu = nn.LeakyReLU(0.1) 155 | 156 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 157 | x = self.unary1(s_feats) 158 | x = self.KPConv(x, q_points, s_points, neighbor_indices) 159 | x = self.leaky_relu(self.norm(x)) 160 | x = self.unary2(x) 161 | if self.strided: 162 | shortcut = maxpool(s_feats, neighbor_indices) 163 | else: 164 | shortcut = s_feats 165 | shortcut = self.unary_shortcut(shortcut) 166 | return self.leaky_relu(x + shortcut) 167 | 168 | 169 | class NearestUpsampleBlock(nn.Module): 170 | def __init__(self, in_channels, out_channels, group_norm=32): 171 | super(NearestUpsampleBlock, self).__init__() 172 | 173 | self.in_channels = in_channels 174 | self.out_channels = out_channels 175 | self.group_norm = group_norm 176 | if isinstance(self.group_norm, int): 177 | self.unary = UnaryBlock(in_channels, out_channels, group_norm) 178 | else: 179 | self.unary = nn.Linear(in_channels, out_channels) 180 | 181 | def forward(self, query, support, upsample_indices): 182 | latent = nearest_upsample(support, upsample_indices) 183 | latent = torch.cat([latent, query], dim=1) 184 | return self.unary(latent) 185 | 186 | 187 | def knn_group(xyz1: torch.Tensor, xyz2: torch.Tensor, features2: torch.Tensor, k): 188 | _, knn_idx, knn_xyz = knn_points(xyz1.unsqueeze(0), xyz2.unsqueeze(0), K=k, return_nn=True) 189 | knn_idx, knn_xyz = torch.squeeze(knn_idx, dim=0), torch.squeeze(knn_xyz, dim=0) 190 | rela_xyz = knn_xyz - xyz1.unsqueeze(1) # (M, k, 3) 191 | rela_dist = torch.norm(rela_xyz, dim=-1, keepdim=True) # (M, k, 1) 192 | grouped_features = torch.cat([rela_xyz, rela_dist], dim=-1) # (M, k, 4) 193 | if features2 is not None: 194 | knn_features = index_select(features2, knn_idx, dim=0) # (M, k, C) 195 | grouped_features = torch.cat([rela_xyz, rela_dist, knn_features], dim=-1) 196 | return grouped_features, knn_xyz 197 | 198 | 199 | class KeypointDetector(nn.Module): 200 | def __init__(self, k, in_channels, out_channels): 201 | super(KeypointDetector, self).__init__() 202 | self.k = k 203 | self.convs = nn.Sequential( 204 | UnaryBlock(in_channels+4, out_channels, bias=False), 205 | UnaryBlock(out_channels, out_channels, bias=False) 206 | ) 207 | self.mlp = nn.Sequential( 208 | UnaryBlock(out_channels, out_channels), 209 | nn.Linear(out_channels, 1), 210 | nn.Softplus() 211 | ) 212 | 213 | def forward(self, sampled_xyz, xyz, features): 214 | grouped_features, knn_xyz = knn_group(sampled_xyz, xyz, features, self.k) 215 | embedding: torch.Tensor = self.convs(grouped_features) # (M, k, C) 216 | attentive_weights = F.softmax(embedding.max(dim=-1)[0], dim=-1).unsqueeze(-1) # (M, k, 1) 217 | keypoints = torch.sum(attentive_weights * knn_xyz, dim=-2) # (M, k, 3) 218 | 219 | attentive_feature_map = embedding * attentive_weights # (M, k, C) 220 | attentive_feature = torch.sum(attentive_feature_map, dim=-2) # (M, k, C) 221 | sigmas = torch.squeeze(self.mlp(attentive_feature) + 0.001, dim=-1) # (M,) 222 | return keypoints, sigmas, grouped_features[..., 4:], attentive_feature_map 223 | 224 | 225 | class DescExtractor(nn.Module): 226 | def __init__(self, in_channels, out_channels): 227 | super(DescExtractor, self).__init__() 228 | self.mlp = nn.Sequential( 229 | UnaryBlock(in_channels*2+out_channels, out_channels, bias=False), 230 | UnaryBlock(out_channels, out_channels, bias=False), 231 | ) 232 | 233 | def forward(self, x1, attentive_feature_map): 234 | #x1 = self.convs(grouped_features), # (B, N, k, C) 235 | x2 = torch.max(x1, dim=1, keepdim=True)[0] # (N, 1, C_in) 236 | x2 = x2.repeat(1, x1.shape[1], 1) # (N, k, C_in) 237 | x2 = torch.cat((x2, x1), dim=-1) # (N, k, 2C_in) 238 | x2 = torch.cat((x2, attentive_feature_map), dim=-1) # (N, k, 2C_in+C_det) 239 | desc = torch.max(self.mlp(x2), dim=1, keepdim=False)[0] # (N, C_out) 240 | return torch.relu(desc) -------------------------------------------------------------------------------- /models/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RenlangHuang/CAST/56ed94e4109809fc92e28b6c10d726473ace1321/models/transformer/__init__.py -------------------------------------------------------------------------------- /models/transformer/conditional_transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from models.transformer.pe_transformer import PETransformerLayer 3 | from models.transformer.rpe_transformer import RPETransformerLayer 4 | from models.transformer.vanilla_transformer import TransformerLayer 5 | 6 | 7 | def _check_block_type(block): 8 | if block not in ['self', 'cross']: 9 | raise ValueError('Unsupported block type "{}".'.format(block)) 10 | 11 | 12 | class VanillaConditionalTransformer(nn.Module): 13 | def __init__(self, blocks, d_model, num_heads, dropout=None, activation_fn='relu', return_attention_scores=False): 14 | super(VanillaConditionalTransformer, self).__init__() 15 | self.blocks = blocks 16 | layers = [] 17 | for block in self.blocks: 18 | _check_block_type(block) 19 | layers.append(TransformerLayer(d_model, num_heads, dropout, activation_fn)) 20 | self.layers = nn.ModuleList(layers) 21 | self.return_attention_scores = return_attention_scores 22 | 23 | def forward(self, feats0, feats1, masks0=None, masks1=None): 24 | attention_scores = [] 25 | for i, block in enumerate(self.blocks): 26 | if block == 'self': 27 | feats0, scores0 = self.layers[i](feats0, feats0, memory_masks=masks0) 28 | feats1, scores1 = self.layers[i](feats1, feats1, memory_masks=masks1) 29 | else: 30 | feats0, scores0 = self.layers[i](feats0, feats1, memory_masks=masks1) 31 | feats1, scores1 = self.layers[i](feats1, feats0, memory_masks=masks0) 32 | if self.return_attention_scores: 33 | attention_scores.append([scores0, scores1]) 34 | if self.return_attention_scores: 35 | return feats0, feats1, attention_scores 36 | else: 37 | return feats0, feats1 38 | 39 | 40 | class PEConditionalTransformer(nn.Module): 41 | def __init__(self, blocks, d_model, num_heads, dropout=None, activation_fn='relu', return_attention_scores=False): 42 | super(PEConditionalTransformer, self).__init__() 43 | self.blocks = blocks 44 | layers = [] 45 | for block in self.blocks: 46 | _check_block_type(block) 47 | if block == 'self': 48 | layers.append(PETransformerLayer(d_model, num_heads, dropout, activation_fn)) 49 | else: 50 | layers.append(TransformerLayer(d_model, num_heads, dropout, activation_fn)) 51 | self.layers = nn.ModuleList(layers) 52 | self.return_attention_scores = return_attention_scores 53 | 54 | def forward(self, feats0, feats1, embeddings0, embeddings1, masks0=None, masks1=None): 55 | attention_scores = [] 56 | for i, block in enumerate(self.blocks): 57 | if block == 'self': 58 | feats0, scores0 = self.layers[i](feats0, feats0, embeddings0, embeddings0, memory_masks=masks0) 59 | feats1, scores1 = self.layers[i](feats1, feats1, embeddings1, embeddings1, memory_masks=masks1) 60 | else: 61 | feats0, scores0 = self.layers[i](feats0, feats1, memory_masks=masks1) 62 | feats1, scores1 = self.layers[i](feats1, feats0, memory_masks=masks0) 63 | if self.return_attention_scores: 64 | attention_scores.append([scores0, scores1]) 65 | if self.return_attention_scores: 66 | return feats0, feats1, attention_scores 67 | else: 68 | return feats0, feats1 69 | 70 | 71 | class RPEConditionalTransformer(nn.Module): 72 | def __init__( 73 | self, blocks, d_model, num_heads, dropout=None, 74 | activation_fn='relu', return_attention_scores=False, parallel=False, 75 | ): 76 | super(RPEConditionalTransformer, self).__init__() 77 | self.blocks = blocks 78 | layers = [] 79 | for block in self.blocks: 80 | _check_block_type(block) 81 | if block == 'self': 82 | layers.append(RPETransformerLayer(d_model, num_heads, dropout, activation_fn)) 83 | else: 84 | layers.append(TransformerLayer(d_model, num_heads, dropout, activation_fn)) 85 | self.layers = nn.ModuleList(layers) 86 | self.return_attention_scores = return_attention_scores 87 | self.parallel = parallel 88 | 89 | def forward(self, feats0, feats1, embeddings0, embeddings1, masks0=None, masks1=None): 90 | attention_scores = [] 91 | for i, block in enumerate(self.blocks): 92 | if block == 'self': 93 | feats0, scores0 = self.layers[i](feats0, feats0, embeddings0, memory_masks=masks0) 94 | feats1, scores1 = self.layers[i](feats1, feats1, embeddings1, memory_masks=masks1) 95 | else: 96 | if self.parallel: 97 | new_feats0, scores0 = self.layers[i](feats0, feats1, memory_masks=masks1) 98 | new_feats1, scores1 = self.layers[i](feats1, feats0, memory_masks=masks0) 99 | feats0 = new_feats0 100 | feats1 = new_feats1 101 | else: 102 | feats0, scores0 = self.layers[i](feats0, feats1, memory_masks=masks1) 103 | feats1, scores1 = self.layers[i](feats1, feats0, memory_masks=masks0) 104 | if self.return_attention_scores: 105 | attention_scores.append([scores0, scores1]) 106 | if self.return_attention_scores: 107 | return feats0, feats1, attention_scores 108 | else: 109 | return feats0, feats1 110 | -------------------------------------------------------------------------------- /models/transformer/linear_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from models.transformer.output_layer import AttentionOutput 7 | 8 | 9 | class LinearMultiHeadAttention(nn.Module): 10 | def __init__(self, d_model, num_heads, normalize=True): 11 | super(LinearMultiHeadAttention, self).__init__() 12 | self.d_model = d_model 13 | self.normalize = normalize 14 | self.num_heads = num_heads 15 | self.d_model_per_head = d_model // num_heads 16 | 17 | self.proj_q = nn.Linear(self.d_model, self.d_model) 18 | self.proj_k = nn.Linear(self.d_model, self.d_model) 19 | self.proj_v = nn.Linear(self.d_model, self.d_model) 20 | 21 | def forward(self, input_q, input_k, input_v): 22 | q = rearrange(self.proj_q(input_q), 'b n (h c) -> b h n c', h=self.num_heads) 23 | k = rearrange(self.proj_k(input_k), 'b m (h c) -> b h m c', h=self.num_heads) 24 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 25 | 26 | q, k = F.elu(q,1.) + 1., F.elu(k,1.) + 1. 27 | hidden_states = torch.matmul(q, torch.einsum('bhmc,bhmd->bhcd', k, v)) 28 | if self.normalize: 29 | hidden_states = hidden_states / (torch.matmul(q, k.sum(2).unsqueeze(-1)) + 1e-4) 30 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 31 | return hidden_states 32 | 33 | 34 | class LinearAttentionLayer(nn.Module): 35 | def __init__(self, d_model, num_heads, dropout=None, normalize=True): 36 | super(LinearAttentionLayer, self).__init__() 37 | self.attention = LinearMultiHeadAttention(d_model, num_heads, normalize) 38 | self.linear = nn.Linear(d_model, d_model) 39 | if dropout is None or dropout <= 0: 40 | self.dropout = nn.Identity() 41 | else: self.dropout = nn.Dropout(dropout) 42 | self.norm = nn.LayerNorm(d_model) 43 | 44 | def forward(self, input_states, memory_states): 45 | hidden_states = self.attention(input_states, memory_states, memory_states) 46 | hidden_states = self.linear(hidden_states) 47 | hidden_states = self.dropout(hidden_states) 48 | output_states = self.norm(hidden_states + input_states) 49 | return output_states 50 | 51 | 52 | class LinearTransformerLayer(nn.Module): 53 | def __init__(self, d_model, num_heads, dropout=None, activation_fn='relu', normalize=True): 54 | super(LinearTransformerLayer, self).__init__() 55 | self.attention = LinearAttentionLayer(d_model, num_heads, dropout, normalize) 56 | self.output = AttentionOutput(d_model, dropout, activation_fn) 57 | 58 | def forward(self, input_states, memory_states): 59 | hidden_states = self.attention(input_states, memory_states) 60 | output_states = self.output(hidden_states) 61 | return output_states 62 | -------------------------------------------------------------------------------- /models/transformer/output_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | ACT_LAYERS = { 5 | 'relu': nn.ReLU(), 6 | 'leaky_relu': nn.LeakyReLU(0.1), 7 | 'sigmoid': nn.Sigmoid(), 8 | 'softplus': nn.Softplus(), 9 | 'tanh': nn.Tanh(), 10 | 'elu': nn.ELU(), 11 | 'gelu': nn.GELU(), 12 | None: nn.Identity(), 13 | } 14 | 15 | class AttentionOutput(nn.Module): 16 | def __init__(self, d_model, dropout=None, activation_fn='relu'): 17 | super(AttentionOutput, self).__init__() 18 | self.expand = nn.Linear(d_model, d_model * 2) 19 | self.activation = ACT_LAYERS[activation_fn] 20 | self.squeeze = nn.Linear(d_model * 2, d_model) 21 | if dropout is None or dropout <= 0: 22 | self.dropout = nn.Identity() 23 | else: 24 | self.dropout = nn.Dropout(dropout) 25 | self.norm = nn.LayerNorm(d_model) 26 | 27 | def forward(self, input_states): 28 | hidden_states = self.expand(input_states) 29 | hidden_states = self.activation(hidden_states) 30 | hidden_states = self.squeeze(hidden_states) 31 | hidden_states = self.dropout(hidden_states) 32 | output_states = self.norm(input_states + hidden_states) 33 | return output_states 34 | -------------------------------------------------------------------------------- /models/transformer/pe_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from models.transformer.positional_encoding import RotaryPositionalEmbedding 7 | from models.transformer.output_layer import AttentionOutput 8 | 9 | 10 | class PEMultiHeadAttention(nn.Module): 11 | def __init__(self, d_model, num_heads, dropout=None, rotary=True): 12 | super(PEMultiHeadAttention, self).__init__() 13 | self.d_model = d_model 14 | self.num_heads = num_heads 15 | self.d_model_per_head = d_model // num_heads 16 | self.rotary_encoding = rotary 17 | 18 | self.proj_q = nn.Linear(self.d_model, self.d_model) 19 | self.proj_k = nn.Linear(self.d_model, self.d_model) 20 | self.proj_v = nn.Linear(self.d_model, self.d_model) 21 | if self.rotary_encoding: 22 | self.proj_p = RotaryPositionalEmbedding(self.d_model) 23 | else: 24 | self.proj_p = nn.Linear(self.d_model, self.d_model) 25 | 26 | if dropout is None or dropout <= 0: 27 | self.dropout = nn.Identity() 28 | else: self.dropout = nn.Dropout(dropout) 29 | 30 | def forward(self, input_q, input_k, input_v, embed_q, embed_k, key_masks=None, attention_factors=None): 31 | """Self-attention with positional embedding forward propagation. 32 | 33 | Args: 34 | input_q: torch.Tensor (B, N, C) 35 | input_k: torch.Tensor (B, M, C) 36 | input_v: torch.Tensor (B, M, C) 37 | embed_q: torch.Tensor (B, N, C) | (B, N, 3) 38 | embed_k: torch.Tensor (B, M, C) | (B, M, 3) 39 | key_masks: torch.Tensor (B, M), True if ignored, False if preserved 40 | attention_factors: torch.Tensor (B, N, M) 41 | 42 | Returns: 43 | hidden_states: torch.Tensor (B, C, N) 44 | attention_scores: torch.Tensor (B, H, N, M) 45 | """ 46 | if self.rotary_encoding: 47 | q = rearrange(self.proj_p(embed_q, self.proj_q(input_q)), 'b n (h c) -> b h n c', h=self.num_heads) 48 | k = rearrange(self.proj_p(embed_k, self.proj_k(input_k)), 'b m (h c) -> b h m c', h=self.num_heads) 49 | else: 50 | q = rearrange(self.proj_q(input_q) + self.proj_p(embed_q), 'b n (h c) -> b h n c', h=self.num_heads) 51 | k = rearrange(self.proj_k(input_k) + self.proj_p(embed_k), 'b m (h c) -> b h m c', h=self.num_heads) 52 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 53 | 54 | attention_scores = torch.einsum('bhnc,bhmc->bhnm', q, k) / self.d_model_per_head ** 0.5 55 | if attention_factors is not None: 56 | attention_scores = attention_factors.unsqueeze(1) * attention_scores 57 | if key_masks is not None: 58 | attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf')) 59 | attention_scores = F.softmax(attention_scores, dim=-1) 60 | attention_scores = self.dropout(attention_scores) 61 | 62 | hidden_states = torch.matmul(attention_scores, v) 63 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 64 | return hidden_states, attention_scores 65 | 66 | 67 | class PEAttentionLayer(nn.Module): 68 | def __init__(self, d_model, num_heads, dropout=None, rotary=True): 69 | super(PEAttentionLayer, self).__init__() 70 | self.attention = PEMultiHeadAttention(d_model, num_heads, dropout=dropout, rotary=rotary) 71 | self.linear = nn.Linear(d_model, d_model) 72 | if dropout is None or dropout <= 0: 73 | self.dropout = nn.Identity() 74 | else: self.dropout = nn.Dropout(dropout) 75 | self.norm = nn.LayerNorm(d_model) 76 | 77 | def forward( 78 | self, 79 | input_states, 80 | memory_states, 81 | input_embeddings, 82 | memory_embeddings, 83 | memory_masks=None, 84 | attention_factors=None, 85 | ): 86 | hidden_states, attention_scores = self.attention( 87 | input_states, 88 | memory_states, 89 | memory_states, 90 | input_embeddings, 91 | memory_embeddings, 92 | key_masks=memory_masks, 93 | attention_factors=attention_factors, 94 | ) 95 | hidden_states = self.linear(hidden_states) 96 | hidden_states = self.dropout(hidden_states) 97 | output_states = self.norm(hidden_states + input_states) 98 | return output_states, attention_scores 99 | 100 | 101 | class PETransformerLayer(nn.Module): 102 | def __init__(self, d_model, num_heads, dropout=None, activation_fn='relu', rotary=True): 103 | super(PETransformerLayer, self).__init__() 104 | self.attention = PEAttentionLayer(d_model, num_heads, dropout=dropout, rotary=rotary) 105 | self.output = AttentionOutput(d_model, dropout=dropout, activation_fn=activation_fn) 106 | 107 | def forward( 108 | self, 109 | input_states, 110 | memory_states, 111 | input_embeddings, 112 | memory_embeddings, 113 | memory_masks=None, 114 | attention_factors=None, 115 | ): 116 | hidden_states, attention_scores = self.attention( 117 | input_states, 118 | memory_states, 119 | input_embeddings, 120 | memory_embeddings, 121 | memory_masks=memory_masks, 122 | attention_factors=attention_factors, 123 | ) 124 | output_states = self.output(hidden_states) 125 | return output_states, attention_scores 126 | -------------------------------------------------------------------------------- /models/transformer/positional_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from models.utils import pairwise_distance 5 | 6 | 7 | class SinusoidalPositionalEmbedding(nn.Module): 8 | def __init__(self, d_model): 9 | super(SinusoidalPositionalEmbedding, self).__init__() 10 | self.d_model = d_model 11 | div_indices = torch.arange(0, d_model, 2).float() 12 | div_term = torch.exp(div_indices * (-np.log(10000.0) / d_model)) 13 | self.register_buffer('div_term', div_term) 14 | 15 | @torch.no_grad() 16 | def forward(self, emb_indices): 17 | input_shape = emb_indices.shape 18 | omegas = emb_indices.view(-1, 1, 1) * self.div_term.view(1, -1, 1) # (-1, d_model/2, 1) 19 | embeddings = torch.cat([torch.sin(omegas), torch.cos(omegas)], dim=2) # (-1, d_model/2, 2) 20 | embeddings = embeddings.view(*input_shape, self.d_model) # (*, d_model) 21 | return embeddings.detach() 22 | 23 | 24 | class GeometricStructureEmbedding(nn.Module): 25 | def __init__(self, hidden_dim, sigma_d, sigma_a, angle_k, reduction_a='max'): 26 | super(GeometricStructureEmbedding, self).__init__() 27 | self.sigma_d = sigma_d 28 | self.sigma_a = sigma_a 29 | self.factor_a = 180.0 / (self.sigma_a * np.pi) 30 | self.angle_k = angle_k 31 | 32 | self.embedding = SinusoidalPositionalEmbedding(hidden_dim) 33 | self.proj_d = nn.Linear(hidden_dim, hidden_dim) 34 | self.proj_a = nn.Linear(hidden_dim, hidden_dim) 35 | 36 | self.reduction_a = reduction_a 37 | if self.reduction_a not in ['max', 'mean']: 38 | raise ValueError(f'Unsupported reduction mode: {self.reduction_a}.') 39 | 40 | @torch.no_grad() 41 | def get_embedding_indices(self, points: torch.Tensor): 42 | """Compute the indices of pair-wise distance embedding and triplet-wise angular embedding. 43 | 44 | Args: 45 | points: torch.Tensor (B, N, 3), input point cloud 46 | 47 | Returns: 48 | d_indices: torch.FloatTensor (B, N, N), distance embedding indices 49 | a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices 50 | """ 51 | batch_size, num_point, _ = points.shape 52 | dist_map = torch.sqrt(pairwise_distance(points, points)) # (B, N, N) 53 | d_indices = dist_map / self.sigma_d 54 | 55 | k = self.angle_k 56 | knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:] # (B, N, k) 57 | knn_indices = knn_indices.unsqueeze(3).expand(batch_size, num_point, k, 3) # (B, N, k, 3) 58 | expanded_points = points.unsqueeze(1).expand(batch_size, num_point, num_point, 3) # (B, N, N, 3) 59 | knn_points = torch.gather(expanded_points, dim=2, index=knn_indices) # (B, N, k, 3) 60 | ref_vectors = knn_points - points.unsqueeze(2) # (B, N, k, 3) 61 | anc_vectors = points.unsqueeze(1) - points.unsqueeze(2) # (B, N, N, 3) 62 | ref_vectors = ref_vectors.unsqueeze(2).expand(batch_size, num_point, num_point, k, 3) # (B, N, N, k, 3) 63 | anc_vectors = anc_vectors.unsqueeze(3).expand(batch_size, num_point, num_point, k, 3) # (B, N, N, k, 3) 64 | sin_values = torch.norm(torch.cross(ref_vectors, anc_vectors, dim=-1), dim=-1) # (B, N, N, k) 65 | cos_values = torch.sum(ref_vectors * anc_vectors, dim=-1) # (B, N, N, k) 66 | angles = torch.atan2(sin_values, cos_values) # (B, N, N, k) 67 | a_indices = angles * self.factor_a 68 | 69 | return d_indices, a_indices 70 | 71 | def forward(self, points: torch.Tensor): 72 | d_indices, a_indices = self.get_embedding_indices(points) 73 | d_embeddings = self.embedding(d_indices) 74 | d_embeddings = self.proj_d(d_embeddings) 75 | a_embeddings = self.embedding(a_indices) 76 | a_embeddings = self.proj_a(a_embeddings) 77 | if self.reduction_a == 'max': 78 | a_embeddings = a_embeddings.max(dim=3)[0] 79 | else: 80 | a_embeddings = a_embeddings.mean(dim=3) 81 | 82 | return d_embeddings + a_embeddings 83 | 84 | 85 | class RotaryPositionalEmbedding(nn.Module): 86 | def __init__(self, d_model): 87 | super(RotaryPositionalEmbedding, self).__init__() 88 | self.linear = nn.Linear(3, d_model // 2) 89 | 90 | def embed(self, emb_coordinates): 91 | x = self.linear(emb_coordinates) 92 | return torch.sin(x), torch.cos(x) 93 | 94 | def encode(self, sin_embeddings, cos_embeddings, features): 95 | feats1 = features[...,0::2] * cos_embeddings - features[...,1::2] * sin_embeddings 96 | feats2 = features[...,0::2] * sin_embeddings + features[...,1::2] * cos_embeddings 97 | return torch.stack([feats1, feats2], dim=-1).view(features.shape) 98 | 99 | def forward(self, emb_coordinates, features): 100 | sin_embeddings, cos_embeddings = self.embed(emb_coordinates) 101 | return self.encode(sin_embeddings, cos_embeddings, features) -------------------------------------------------------------------------------- /models/transformer/rpe_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from models.transformer.output_layer import AttentionOutput 6 | 7 | 8 | class RPEMultiHeadAttention(nn.Module): 9 | def __init__(self, d_model, num_heads, dropout=None): 10 | super(RPEMultiHeadAttention, self).__init__() 11 | if d_model % num_heads != 0: 12 | raise ValueError('`d_model` ({}) must be a multiple of `num_heads` ({}).'.format(d_model, num_heads)) 13 | 14 | self.d_model = d_model 15 | self.num_heads = num_heads 16 | self.d_model_per_head = d_model // num_heads 17 | 18 | self.proj_q = nn.Linear(self.d_model, self.d_model) 19 | self.proj_k = nn.Linear(self.d_model, self.d_model) 20 | self.proj_v = nn.Linear(self.d_model, self.d_model) 21 | self.proj_p = nn.Linear(self.d_model, self.d_model) 22 | 23 | if dropout is None or dropout <= 0: 24 | self.dropout = nn.Identity() 25 | else: self.dropout = nn.Dropout(dropout) 26 | 27 | def forward(self, input_q, input_k, input_v, embed_qk, key_weights=None, key_masks=None, attention_factors=None): 28 | """Scaled Dot-Product Attention with Pre-computed Relative Positional Embedding (forward) 29 | 30 | Args: 31 | input_q: torch.Tensor (B, N, C) 32 | input_k: torch.Tensor (B, M, C) 33 | input_v: torch.Tensor (B, M, C) 34 | embed_qk: torch.Tensor (B, N, M, C), relative positional embedding 35 | key_weights: torch.Tensor (B, M), soft masks for the keys 36 | key_masks: torch.Tensor (B, M), True if ignored, False if preserved 37 | attention_factors: torch.Tensor (B, N, M) 38 | 39 | Returns: 40 | hidden_states: torch.Tensor (B, C, N) 41 | attention_scores: torch.Tensor (B, H, N, M) 42 | """ 43 | q = rearrange(self.proj_q(input_q), 'b n (h c) -> b h n c', h=self.num_heads) 44 | k = rearrange(self.proj_k(input_k), 'b m (h c) -> b h m c', h=self.num_heads) 45 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 46 | p = rearrange(self.proj_p(embed_qk), 'b n m (h c) -> b h n m c', h=self.num_heads) 47 | 48 | attention_scores_p = torch.einsum('bhnc,bhnmc->bhnm', q, p) 49 | attention_scores_e = torch.einsum('bhnc,bhmc->bhnm', q, k) 50 | attention_scores = (attention_scores_e + attention_scores_p) / self.d_model_per_head ** 0.5 51 | if attention_factors is not None: 52 | attention_scores = attention_factors.unsqueeze(1) * attention_scores 53 | if key_weights is not None: 54 | attention_scores = attention_scores * key_weights.unsqueeze(1).unsqueeze(1) 55 | if key_masks is not None: 56 | attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf')) 57 | attention_scores = F.softmax(attention_scores, dim=-1) 58 | attention_scores = self.dropout(attention_scores) 59 | 60 | hidden_states = torch.matmul(attention_scores, v) 61 | 62 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 63 | 64 | return hidden_states, attention_scores 65 | 66 | 67 | class RPEAttentionLayer(nn.Module): 68 | def __init__(self, d_model, num_heads, dropout=None): 69 | super(RPEAttentionLayer, self).__init__() 70 | self.attention = RPEMultiHeadAttention(d_model, num_heads, dropout=dropout) 71 | self.linear = nn.Linear(d_model, d_model) 72 | if dropout is None or dropout <= 0: 73 | self.dropout = nn.Identity() 74 | else: self.dropout = nn.Dropout(dropout) 75 | self.norm = nn.LayerNorm(d_model) 76 | 77 | def forward( 78 | self, 79 | input_states, 80 | memory_states, 81 | position_states, 82 | memory_weights=None, 83 | memory_masks=None, 84 | attention_factors=None, 85 | ): 86 | hidden_states, attention_scores = self.attention( 87 | input_states, 88 | memory_states, 89 | memory_states, 90 | position_states, 91 | key_weights=memory_weights, 92 | key_masks=memory_masks, 93 | attention_factors=attention_factors, 94 | ) 95 | hidden_states = self.linear(hidden_states) 96 | hidden_states = self.dropout(hidden_states) 97 | output_states = self.norm(hidden_states + input_states) 98 | return output_states, attention_scores 99 | 100 | 101 | class RPETransformerLayer(nn.Module): 102 | def __init__(self, d_model, num_heads, dropout=None, activation_fn='relu'): 103 | super(RPETransformerLayer, self).__init__() 104 | self.attention = RPEAttentionLayer(d_model, num_heads, dropout=dropout) 105 | self.output = AttentionOutput(d_model, dropout=dropout, activation_fn=activation_fn) 106 | 107 | def forward( 108 | self, 109 | input_states, 110 | memory_states, 111 | position_states, 112 | memory_weights=None, 113 | memory_masks=None, 114 | attention_factors=None, 115 | ): 116 | hidden_states, attention_scores = self.attention( 117 | input_states, 118 | memory_states, 119 | position_states, 120 | memory_weights=memory_weights, 121 | memory_masks=memory_masks, 122 | attention_factors=attention_factors, 123 | ) 124 | output_states = self.output(hidden_states) 125 | return output_states, attention_scores 126 | -------------------------------------------------------------------------------- /models/transformer/vanilla_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | from models.transformer.output_layer import AttentionOutput 7 | 8 | 9 | class MultiHeadAttention(nn.Module): 10 | def __init__(self, d_model, num_heads, dropout=None): 11 | super(MultiHeadAttention, self).__init__() 12 | self.d_model = d_model 13 | self.num_heads = num_heads 14 | self.d_model_per_head = d_model // num_heads 15 | 16 | self.proj_q = nn.Linear(self.d_model, self.d_model) 17 | self.proj_k = nn.Linear(self.d_model, self.d_model) 18 | self.proj_v = nn.Linear(self.d_model, self.d_model) 19 | 20 | if dropout is None or dropout <= 0: 21 | self.dropout = nn.Identity() 22 | else: self.dropout = nn.Dropout(dropout) 23 | 24 | def forward(self, input_q, input_k, input_v, 25 | key_weights=None, key_masks=None, attention_factors=None, attention_masks=None): 26 | """Vanilla attention forward propagation. 27 | 28 | Args: 29 | input_q (Tensor): input tensor for query (B, N, C) 30 | input_k (Tensor): input tensor for key (B, M, C) 31 | input_v (Tensor): input tensor for value (B, M, C) 32 | key_weights (Tensor): soft masks for the keys (B, M) 33 | key_masks (BoolTensor): True if ignored, False if preserved (B, M) 34 | attention_factors (Tensor): factors for attention matrix (B, N, M) 35 | attention_masks (BoolTensor): True if ignored, False if preserved (B, N, M) 36 | 37 | Returns: 38 | hidden_states: torch.Tensor (B, C, N) 39 | attention_scores: intermediate values 40 | 'attention_scores': torch.Tensor (B, H, N, M), attention scores before dropout 41 | """ 42 | q = rearrange(self.proj_q(input_q), 'b n (h c) -> b h n c', h=self.num_heads) 43 | k = rearrange(self.proj_k(input_k), 'b m (h c) -> b h m c', h=self.num_heads) 44 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 45 | 46 | attention_scores = torch.einsum('bhnc,bhmc->bhnm', q, k) / self.d_model_per_head ** 0.5 47 | if attention_factors is not None: 48 | attention_scores = attention_factors.unsqueeze(1) * attention_scores 49 | if key_weights is not None: 50 | attention_scores = attention_scores * key_weights.unsqueeze(1).unsqueeze(1) 51 | if key_masks is not None: 52 | attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf')) 53 | if attention_masks is not None: 54 | attention_scores = attention_scores.masked_fill(attention_masks, float('-inf')) 55 | attention_scores = F.softmax(attention_scores, dim=-1) 56 | attention_scores = self.dropout(attention_scores) 57 | 58 | hidden_states = torch.matmul(attention_scores, v) 59 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 60 | return hidden_states, attention_scores 61 | 62 | 63 | class AttentionLayer(nn.Module): 64 | def __init__(self, d_model, num_heads, dropout=None): 65 | super(AttentionLayer, self).__init__() 66 | self.attention = MultiHeadAttention(d_model, num_heads, dropout) 67 | self.linear = nn.Linear(d_model, d_model) 68 | if dropout is None or dropout <= 0: 69 | self.dropout = nn.Identity() 70 | else: self.dropout = nn.Dropout(dropout) 71 | self.norm = nn.LayerNorm(d_model) 72 | 73 | def forward(self, 74 | input_states, 75 | memory_states, 76 | memory_weights=None, 77 | memory_masks=None, 78 | attention_factors=None, 79 | attention_masks=None, 80 | ): 81 | hidden_states, attention_scores = self.attention( 82 | input_states, 83 | memory_states, 84 | memory_states, 85 | key_weights=memory_weights, 86 | key_masks=memory_masks, 87 | attention_factors=attention_factors, 88 | attention_masks=attention_masks, 89 | ) 90 | hidden_states = self.linear(hidden_states) 91 | hidden_states = self.dropout(hidden_states) 92 | output_states = self.norm(hidden_states + input_states) 93 | return output_states, attention_scores 94 | 95 | 96 | class TransformerLayer(nn.Module): 97 | def __init__(self, d_model, num_heads, dropout=None, activation_fn='relu'): 98 | super(TransformerLayer, self).__init__() 99 | self.attention = AttentionLayer(d_model, num_heads, dropout) 100 | self.output = AttentionOutput(d_model, dropout, activation_fn) 101 | 102 | def forward( 103 | self, 104 | input_states, 105 | memory_states, 106 | memory_weights=None, 107 | memory_masks=None, 108 | attention_factors=None, 109 | attention_masks=None, 110 | ): 111 | hidden_states, attention_scores = self.attention( 112 | input_states, 113 | memory_states, 114 | memory_weights=memory_weights, 115 | memory_masks=memory_masks, 116 | attention_factors=attention_factors, 117 | attention_masks=attention_masks, 118 | ) 119 | output_states = self.output(hidden_states) 120 | return output_states, attention_scores 121 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import Optional 4 | import MinkowskiEngine as ME 5 | import torch.nn.functional as F 6 | from scipy.spatial.transform import Rotation 7 | from pytorch3d.ops import ball_query, packed_to_padded 8 | 9 | 10 | def weighted_svd(src_points: torch.Tensor, ref_points: torch.Tensor, weights: Optional[torch.Tensor]=None, orthogonalization=True): 11 | """Compute rigid transformation from `src_points` to `ref_points` using weighted SVD (Kabsch). 12 | 13 | Args: 14 | src_points: torch.Tensor (B, N, 3) or (N, 3) 15 | ref_points: torch.Tensor (B, N, 3) or (N, 3) 16 | weights: torch.Tensor (B, N) or (N,) (default: None) 17 | 18 | Returns: 19 | transform: torch.Tensor (B, 4, 4) or (4, 4) 20 | """ 21 | if src_points.ndim == 2: 22 | src_points = src_points.unsqueeze(0) 23 | ref_points = ref_points.unsqueeze(0) 24 | if weights is not None: 25 | weights = weights.unsqueeze(0) 26 | squeeze_first = True 27 | else: 28 | squeeze_first = False 29 | 30 | batch_size = src_points.shape[0] 31 | if weights is None: 32 | weights = torch.ones_like(src_points[:, :, 0]) 33 | else: weights = torch.clamp(weights, 0.) 34 | weights = weights / (torch.sum(weights, dim=1, keepdim=True) + 1e-5) 35 | weights = weights.unsqueeze(2) # (B, N, 1) 36 | 37 | src_centroid = torch.sum(src_points * weights, dim=1, keepdim=True) # (B, 1, 3) 38 | ref_centroid = torch.sum(ref_points * weights, dim=1, keepdim=True) # (B, 1, 3) 39 | src_points_centered = src_points - src_centroid # (B, N, 3) 40 | ref_points_centered = ref_points - ref_centroid # (B, N, 3) 41 | 42 | H = src_points_centered.permute(0, 2, 1) @ (weights * ref_points_centered) 43 | U, _, V = torch.svd(H.cpu()) # H = USV^T, SVD operates faster on CPU than on GPU 44 | Ut, V = U.transpose(1, 2).to(H.device), V.to(H.device) 45 | eye = torch.eye(3, device=H.device).unsqueeze(0).repeat(batch_size, 1, 1) 46 | eye[:, -1, -1] = torch.sign(torch.det(V @ Ut)) 47 | R = V @ eye @ Ut 48 | 49 | if orthogonalization: 50 | rot_0 = R[..., 0] / torch.norm(R[...,0], dim=-1, keepdim=True) 51 | rot_1 = R[..., 1] - torch.sum(R[..., 1] * rot_0, dim=-1, keepdim=True) * rot_0 52 | rot_1 = rot_1 / torch.norm(rot_1, dim=-1, keepdim=True) 53 | rot_2 = R[..., 2] - torch.sum(R[..., 2] * rot_0, dim=-1, keepdim=True) * rot_0 \ 54 | - torch.sum(R[..., 2] * rot_1, dim=-1, keepdim=True) * rot_1 55 | rot_2 = rot_2 / torch.norm(rot_2, dim=-1, keepdim=True) 56 | R = torch.stack([rot_0, rot_1, rot_2], dim=-1) 57 | 58 | t = ref_centroid.permute(0, 2, 1) - R @ src_centroid.permute(0, 2, 1) 59 | transform = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 60 | transform[:, :3, :3], transform[:, :3, 3] = R, t.squeeze(2) 61 | if squeeze_first: transform = transform.squeeze(0) 62 | return transform 63 | 64 | 65 | def grid_subsample_gpu(points:torch.Tensor, batches_len:torch.Tensor, voxel_size): 66 | """ 67 | Same as `grid_subsample`, but implemented in GPU using Minkowski engine's sparse quantization. 68 | Note: This function is not deterministic and may return subsampled points in a different order. 69 | """ 70 | B = len(batches_len) 71 | batch_start_end = F.pad(torch.cumsum(batches_len, 0), (1, 0)) 72 | device = points.device 73 | 74 | coord_batched = ME.utils.batched_coordinates( 75 | [points[batch_start_end[b]:batch_start_end[b + 1]] / voxel_size for b in range(B)], device=device) 76 | sparse_tensor = ME.SparseTensor( 77 | features=points, 78 | coordinates=coord_batched, 79 | quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE 80 | ) 81 | s_points = sparse_tensor.features 82 | s_len = torch.tensor([f.shape[0] for f in sparse_tensor.decomposed_features], device=device) 83 | return s_points, s_len 84 | 85 | 86 | def radius_search_gpu(queries, supports, q_batches, s_batches, radius, max_neighbors): 87 | """ 88 | Same as `radius_search`, but implemented by GPU using PyTorch3D's ball_query functions. 89 | Computes neighbors for a batch of queries and supports, apply radius search 90 | :param queries: (N1, 3) the query points 91 | :param supports: (N2, 3) the support points 92 | :param q_batches: (B) the list of lengths of batch elements in queries 93 | :param s_batches: (B) the list of lengths of batch elements in supports 94 | :param radius: float32 95 | :return: neighbors indices 96 | """ 97 | B = len(q_batches) 98 | N_spts_total = supports.shape[0] 99 | q_first_idx = F.pad(torch.cumsum(q_batches, dim=0)[:-1], (1, 0)) 100 | queries_padded = packed_to_padded(queries, q_first_idx, q_batches.max().item()) # (B, N_max, 3) 101 | s_first_idx = F.pad(torch.cumsum(s_batches, dim=0)[:-1], (1, 0)) 102 | supports_padded = packed_to_padded(supports, s_first_idx, s_batches.max().item()) # (B, N_max, 3) 103 | 104 | idx = ball_query(queries_padded, supports_padded, 105 | q_batches, s_batches, 106 | K=max_neighbors, radius=radius).idx # (N_clouds, N_pts, K) 107 | idx[idx < 0] = torch.iinfo(idx.dtype).min 108 | 109 | idx_packed = torch.cat([idx[b][:q_batches[b]] + s_first_idx[b] for b in range(B)], dim=0) 110 | idx_packed[idx_packed < 0] = N_spts_total 111 | 112 | return idx_packed 113 | 114 | 115 | 116 | def index_select(data: torch.Tensor, index: torch.LongTensor, dim: int) -> torch.Tensor: 117 | """Advanced index select. 118 | 119 | Returns a tensor `output` which indexes the `data` tensor along dimension `dim` 120 | using the entries in `index` which is a `LongTensor`. 121 | 122 | Different from `torch.index_select`, `index` does not has to be 1-D. The `dim`-th 123 | dimension of `data` will be expanded to the number of dimensions in `index`. 124 | 125 | For example, suppose the shape `data` is $(a_0, a_1, ..., a_{n-1})$, the shape of `index` is 126 | $(b_0, b_1, ..., b_{m-1})$, and `dim` is $i$, then `output` is $(n+m-1)$-d tensor, whose shape is 127 | $(a_0, ..., a_{i-1}, b_0, b_1, ..., b_{m-1}, a_{i+1}, ..., a_{n-1})$. 128 | 129 | Args: 130 | data (Tensor): (a_0, a_1, ..., a_{n-1}) 131 | index (LongTensor): (b_0, b_1, ..., b_{m-1}) 132 | dim: int 133 | 134 | Returns: 135 | output (Tensor): (a_0, ..., a_{dim-1}, b_0, ..., b_{m-1}, a_{dim+1}, ..., a_{n-1}) 136 | """ 137 | output = data.index_select(dim, index.view(-1)) 138 | if index.ndim > 1: 139 | output_shape = data.shape[:dim] + index.shape + data.shape[dim:][1:] 140 | output = output.view(*output_shape) 141 | return output 142 | 143 | 144 | def pairwise_distance(x: torch.Tensor, y: torch.Tensor, normalized = False) -> torch.Tensor: 145 | """Pairwise distance of two (batched) point clouds. 146 | 147 | Args: 148 | x (Tensor): (*, N, C) or (*, C, N) 149 | y (Tensor): (*, M, C) or (*, C, M) 150 | normalized (bool=False): if the points are normalized, "x2 + y2 = 1", so "d2 = 2 - 2xy". 151 | 152 | Returns: 153 | dist: torch.Tensor (*, N, M) 154 | """ 155 | xy = torch.matmul(x, y.transpose(-1, -2)) # (*, N, C) x [(*, M, C) -> (*, C, M)] 156 | if normalized: 157 | sq_distances = 2.0 - 2.0 * xy 158 | else: 159 | x2 = torch.sum(x ** 2, dim=-1).unsqueeze(-1) # (*, N, C) or (*, C, N) -> (*, N) -> (*, N, 1) 160 | y2 = torch.sum(y ** 2, dim=-1).unsqueeze(-2) # (*, M, C) or (*, C, M) -> (*, M) -> (*, 1, M) 161 | sq_distances = x2 - 2.0 * xy + y2 162 | return sq_distances.clamp(min=0.0) 163 | 164 | 165 | def apply_transform(points: torch.Tensor, transform: torch.Tensor, normals: Optional[torch.Tensor] = None): 166 | """Rigid transform to points and normals (optional). There are two cases supported: 167 | 1. points and normals are (*, 3), transform is (4, 4), the output points are (*, 3). 168 | In this case, the transform is applied to all points. 169 | 2. points and normals are (B, N, 3), transform is (B, 4, 4), the output points are (B, N, 3). 170 | In this case, the transform is applied batch-wise. The points can be broadcast if B=1. 171 | 172 | Args: 173 | points (Tensor): (*, 3) or (B, N, 3) 174 | normals (optional[Tensor]=None): same shape as points. 175 | transform (Tensor): (4, 4) or (B, 4, 4) 176 | 177 | Returns: 178 | points (Tensor): same shape as points. 179 | normals (Tensor): same shape as points. 180 | """ 181 | if normals is not None: 182 | assert points.shape == normals.shape 183 | if transform.ndim == 2: 184 | rotation = transform[:3, :3] 185 | translation = transform[:3, 3] 186 | points_shape = points.shape 187 | points = points.reshape(-1, 3) 188 | points = torch.matmul(points, rotation.transpose(-1, -2)) + translation 189 | points = points.reshape(*points_shape) 190 | if normals is not None: 191 | normals = normals.reshape(-1, 3) 192 | normals = torch.matmul(normals, rotation.transpose(-1, -2)) 193 | normals = normals.reshape(*points_shape) 194 | elif transform.ndim == 3 and points.ndim == 3: 195 | rotation = transform[:, :3, :3] # (B, 3, 3) 196 | translation = transform[:, None, :3, 3] # (B, 1, 3) 197 | points = torch.matmul(points, rotation.transpose(-1, -2)) + translation 198 | if normals is not None: 199 | normals = torch.matmul(normals, rotation.transpose(-1, -2)) 200 | else: raise ValueError('Incompatible tensor shapes.') 201 | if normals is not None: 202 | return points, normals 203 | else: 204 | return points 205 | 206 | 207 | @torch.no_grad() 208 | def point_to_node_partition(points: torch.Tensor, nodes: torch.Tensor, point_limit): 209 | """Point-to-Node partition to the point cloud. 210 | 211 | Args: 212 | points (Tensor): (N, 3) 213 | nodes (Tensor): (M, 3) 214 | point_limit (int): max number of points to each node 215 | 216 | Returns: 217 | node_masks (BoolTensor): (M,) 218 | node_knn_indices (LongTensor): (M, K) 219 | node_knn_masks (BoolTensor) (M, K) 220 | """ 221 | sq_dist_mat = pairwise_distance(nodes, points) # (M, N) 222 | point_to_node = sq_dist_mat.min(dim=0)[1] # (N,) 223 | node_masks = torch.zeros_like(nodes[:,0], dtype=torch.bool) # (M,) 224 | node_masks.index_fill_(0, point_to_node, True) 225 | 226 | matching_masks = torch.zeros_like(sq_dist_mat, dtype=torch.bool) # (M, N) 227 | point_indices = torch.arange(points.shape[0], device=points.device) # (N,) 228 | matching_masks[point_to_node, point_indices] = True # (M, N) 229 | sq_dist_mat.masked_fill_(~matching_masks, 1e12) # (M, N) 230 | 231 | node_knn_indices = sq_dist_mat.topk(k=point_limit, dim=1, largest=False)[1] # (M, K) 232 | node_knn_node_indices = index_select(point_to_node, node_knn_indices, dim=0) # (M, K) 233 | node_indices = torch.arange(nodes.shape[0], device=nodes.device).unsqueeze(1) # (M, 1) 234 | node_knn_masks = node_knn_node_indices.eq(node_indices) # (M, K) 235 | node_knn_indices.masked_fill_(~node_knn_masks, points.shape[0]) 236 | 237 | return node_masks, node_knn_indices, node_knn_masks 238 | 239 | 240 | def generate_rand_rotm(x_lim=5.0, y_lim=5.0, z_lim=180.0) -> np.ndarray: 241 | rand_z = np.random.uniform(low=-z_lim, high=z_lim) 242 | rand_y = np.random.uniform(low=-y_lim, high=y_lim) 243 | rand_x = np.random.uniform(low=-x_lim, high=x_lim) 244 | 245 | rand_eul = np.array([rand_z, rand_y, rand_x]) 246 | r = Rotation.from_euler('zyx', rand_eul, degrees=True) 247 | rotm = r.as_matrix() 248 | return rotm 249 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pickle 3 | matplotlib 4 | scipy 5 | munch 6 | open3d==0.16.0 7 | einops 8 | ninja -------------------------------------------------------------------------------- /trainval.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' 3 | os.environ['CUDA_VISIBLE_DEVICES']='0' 4 | 5 | 6 | import torch 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import StepLR 9 | from torch.utils.data import DataLoader 10 | 11 | import json 12 | from typing import Dict 13 | from munch import munchify 14 | 15 | from data.kitti_data import KittiDataset 16 | from data.nuscenes_data import NuscenesDataset 17 | from data.indoor_data import IndoorDataset 18 | 19 | from models.models.cast import CAST 20 | from engine.evaluator import Evaluator 21 | from engine.trainer import EpochBasedTrainer 22 | from engine.losses import SpotMatchingLoss, CoarseMatchingLoss, KeypointMatchingLoss, ProbChamferLoss 23 | 24 | 25 | class OverallLoss(torch.nn.Module): 26 | def __init__(self, cfg): 27 | super(OverallLoss, self).__init__() 28 | self.weight_det_loss = cfg.weight_det_loss 29 | self.weight_spot_loss = cfg.weight_spot_loss 30 | self.weight_feat_loss = cfg.weight_feat_loss 31 | self.weight_desc_loss = cfg.weight_desc_loss 32 | self.weight_overlap_loss = cfg.weight_overlap_loss 33 | self.weight_corr_loss = cfg.weight_corr_loss 34 | self.weight_trans_loss = cfg.weight_trans_loss 35 | self.weight_rot_loss = cfg.weight_rot_loss 36 | self.pretrain_feat_epochs = cfg.pretrain_feat_epochs 37 | 38 | self.prob_chamfer_loss = ProbChamferLoss() 39 | self.spot_matching_loss = SpotMatchingLoss(cfg) 40 | self.coarse_matching_loss = CoarseMatchingLoss(cfg) 41 | self.kpt_matching_loss = KeypointMatchingLoss(cfg.r_p, cfg.r_n) 42 | self.register_buffer('I3x3', torch.eye(3)) 43 | 44 | def forward(self, output_dict: Dict[str, torch.Tensor], epoch) -> Dict[str, torch.Tensor]: 45 | l_det = self.prob_chamfer_loss(output_dict) 46 | l_spot,l_feat = self.spot_matching_loss(output_dict) 47 | #l_feat = self.coarse_matching_loss(output_dict) 48 | loss = l_feat * self.weight_feat_loss + l_det * self.weight_det_loss + l_spot * self.weight_spot_loss 49 | 50 | loss_dict = {'l_det':l_det, 'l_spot':l_spot, 'l_feat':l_feat} 51 | 52 | l_desc, l_ov, l_corr = self.kpt_matching_loss(output_dict) 53 | l_trans = torch.norm(output_dict['transform'][:3, 3] - output_dict['gt_transform'][:3, 3]) 54 | l_rot = torch.norm(output_dict['transform'][:3, :3].T @ output_dict['gt_transform'][:3, :3] - self.I3x3) 55 | loss = loss + l_desc * self.weight_desc_loss + l_ov * self.weight_overlap_loss + l_corr * self.weight_corr_loss 56 | 57 | loss_dict.update({'l_corr':l_corr, 'l_desc':l_desc, 'l_ov':l_ov}) 58 | 59 | l_trans2 = torch.norm(output_dict['refined_transform'][:3, 3] - output_dict['gt_transform'][:3, 3]) 60 | l_rot2 = torch.norm(output_dict['refined_transform'][:3, :3].T @ output_dict['gt_transform'][:3, :3] - self.I3x3) 61 | 62 | if epoch > self.pretrain_feat_epochs: 63 | loss = loss + l_trans.clamp_max(2.) * self.weight_trans_loss + l_rot.clamp_max(1.) * self.weight_rot_loss 64 | loss = loss + l_trans2.clamp_max(2.) * self.weight_trans_loss + l_rot2.clamp_max(1.) * self.weight_rot_loss 65 | 66 | ret_dict = {'loss':loss, 'l_rot':l_rot} 67 | ret_dict.update(loss_dict) 68 | return ret_dict 69 | 70 | 71 | class Trainer(EpochBasedTrainer): 72 | def __init__(self, cfg): 73 | super().__init__(cfg) 74 | if cfg.dataset == 'kitti': 75 | train_dataset = KittiDataset( 76 | cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, cfg.data.augment) 77 | val_dataset = KittiDataset( 78 | cfg.data.root, 'test', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, 0.0) 79 | elif cfg.dataset == 'nuscenes': 80 | train_dataset = NuscenesDataset( 81 | cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, cfg.data.augment) 82 | val_dataset = NuscenesDataset( 83 | cfg.data.root, 'test', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, 0.0) 84 | elif cfg.dataset == '3dmatch': 85 | train_dataset = IndoorDataset( 86 | cfg.data.root, 'train', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, cfg.data.augment) 87 | val_dataset = IndoorDataset( 88 | cfg.data.root, 'val', cfg.data.npoints, cfg.data.voxel_size, cfg.data_list, 0.0) 89 | else: 90 | raise('Not implemented') 91 | 92 | self.train_loader = DataLoader(train_dataset, 1, num_workers=cfg.data.num_workers, shuffle=True, pin_memory=True) 93 | self.val_loader = DataLoader(val_dataset, 1, num_workers=cfg.data.num_workers, shuffle=False, pin_memory=True) 94 | 95 | self.model = CAST(cfg.model).cuda() 96 | self.optimizer = optim.AdamW(self.model.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay) 97 | self.scheduler = StepLR(self.optimizer, step_size=cfg.optim.step_size, gamma=cfg.optim.gamma) 98 | self.loss_func = OverallLoss(cfg.loss).cuda() 99 | self.evaluator = Evaluator(cfg.eval).cuda() 100 | 101 | self.pretrain_feat_epochs = cfg.loss.pretrain_feat_epochs 102 | 103 | def step(self, data_dict) -> Dict[str,torch.Tensor]: 104 | output_dict = self.model(*data_dict, self.epoch <= self.pretrain_feat_epochs) 105 | loss_dict: Dict = self.loss_func(output_dict, self.epoch) 106 | with torch.no_grad(): 107 | result_dict = self.evaluator(output_dict) 108 | loss_dict.update(result_dict) 109 | return loss_dict 110 | 111 | 112 | 113 | if __name__ == "__main__": 114 | import argparse 115 | 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--mode", required=True, default="train", choices=["train", "test"]) 118 | parser.add_argument("--config", required=True, type=str) 119 | parser.add_argument("--resume_epoch", default=0, type=int) 120 | parser.add_argument("--resume_log", default=None, type=str) 121 | parser.add_argument("--load_pretrained", default=None, type=str) 122 | 123 | _args = parser.parse_args() 124 | 125 | with open(_args.config, 'r') as cfg: 126 | args = json.load(cfg) 127 | args = munchify(args) 128 | 129 | if _args.mode == "train": 130 | Trainer(args).fit(_args.resume_epoch, _args.resume_log) 131 | elif _args.mode == "test": 132 | tester = Trainer(args) 133 | tester.load_snapshot(_args.load_pretrained) 134 | # e.g. tester.load_snapshot("cast-epoch-39") 135 | tester.validate_epoch() 136 | else: assert "Unspecified mode." --------------------------------------------------------------------------------