├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.png ├── data ├── 3DMatch │ └── metadata │ │ ├── 3DLoMatch.pkl │ │ ├── 3DMatch.pkl │ │ ├── benchmarks │ │ ├── 3DLoMatch │ │ │ ├── 7-scenes-redkitchen │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_uc-scan3 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel1 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel3 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2 │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika │ │ │ │ ├── gt.info │ │ │ │ ├── gt.log │ │ │ │ └── gt_overlap.log │ │ └── 3DMatch │ │ │ ├── 7-scenes-redkitchen │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_uc-scan3 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel1 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-hotel_umd-maryland_hotel3 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2 │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika │ │ │ ├── gt.info │ │ │ ├── gt.log │ │ │ └── gt_overlap.log │ │ ├── split │ │ ├── train_3dmatch.txt │ │ └── val_3dmatch.txt │ │ ├── train.pkl │ │ └── val.pkl ├── Kitti │ ├── downsample_pcd.py │ └── metadata │ │ ├── test.pkl │ │ ├── train.pkl │ │ └── val.pkl ├── ModelNet │ └── split_data.py └── demo │ ├── gt.npy │ ├── ref.npy │ └── src.npy ├── experiments ├── geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn │ ├── backbone.py │ ├── config.py │ ├── dataset.py │ ├── demo.py │ ├── eval.py │ ├── eval.sh │ ├── eval_all.sh │ ├── eval_dgr.py │ ├── loss.py │ ├── model.py │ ├── test.py │ └── trainval.py ├── geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn │ ├── backbone.py │ ├── config.py │ ├── dataset.py │ ├── eval.py │ ├── eval.sh │ ├── loss.py │ ├── model.py │ ├── test.py │ └── trainval.py └── geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn │ ├── backbone.py │ ├── config.py │ ├── dataset.py │ ├── loss.py │ ├── model.py │ ├── test.py │ └── trainval.py ├── geotransformer ├── __init__.py ├── datasets │ ├── __init__.py │ └── registration │ │ ├── __init__.py │ │ ├── kitti │ │ ├── __init__.py │ │ └── dataset.py │ │ ├── modelnet │ │ ├── __init__.py │ │ └── dataset.py │ │ └── threedmatch │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── utils.py ├── engine │ ├── __init__.py │ ├── base_tester.py │ ├── base_trainer.py │ ├── epoch_based_trainer.py │ ├── iter_based_trainer.py │ ├── logger.py │ └── single_tester.py ├── extensions │ ├── common │ │ └── torch_helper.h │ ├── cpu │ │ ├── grid_subsampling │ │ │ ├── grid_subsampling.cpp │ │ │ ├── grid_subsampling.h │ │ │ ├── grid_subsampling_cpu.cpp │ │ │ └── grid_subsampling_cpu.h │ │ └── radius_neighbors │ │ │ ├── radius_neighbors.cpp │ │ │ ├── radius_neighbors.h │ │ │ ├── radius_neighbors_cpu.cpp │ │ │ └── radius_neighbors_cpu.h │ ├── extra │ │ ├── cloud │ │ │ ├── cloud.cpp │ │ │ └── cloud.h │ │ └── nanoflann │ │ │ └── nanoflann.hpp │ └── pybind.cpp ├── modules │ ├── __init__.py │ ├── geotransformer │ │ ├── __init__.py │ │ ├── geotransformer.py │ │ ├── local_global_registration.py │ │ ├── point_matching.py │ │ ├── superpoint_matching.py │ │ └── superpoint_target.py │ ├── kpconv │ │ ├── __init__.py │ │ ├── dispositions │ │ │ └── k_015_center_3D.ply │ │ ├── functional.py │ │ ├── kernel_points.py │ │ ├── kpconv.py │ │ └── modules.py │ ├── layers │ │ ├── __init__.py │ │ ├── conv_block.py │ │ └── factory.py │ ├── loss │ │ ├── __init__.py │ │ └── circle_loss.py │ ├── ops │ │ ├── __init__.py │ │ ├── grid_subsample.py │ │ ├── index_select.py │ │ ├── pairwise_distance.py │ │ ├── pointcloud_partition.py │ │ ├── radius_search.py │ │ ├── transformation.py │ │ └── vector_angle.py │ ├── registration │ │ ├── __init__.py │ │ ├── matching.py │ │ ├── metrics.py │ │ └── procrustes.py │ ├── sinkhorn │ │ ├── __init__.py │ │ └── learnable_sinkhorn.py │ └── transformer │ │ ├── __init__.py │ │ ├── conditional_transformer.py │ │ ├── lrpe_transformer.py │ │ ├── output_layer.py │ │ ├── pe_transformer.py │ │ ├── positional_embedding.py │ │ ├── rpe_transformer.py │ │ └── vanilla_transformer.py ├── transforms │ ├── __init__.py │ └── functional.py └── utils │ ├── __init__.py │ ├── average_meter.py │ ├── common.py │ ├── data.py │ ├── open3d.py │ ├── pointcloud.py │ ├── registration.py │ ├── summary_board.py │ ├── timer.py │ ├── torch.py │ └── visualization.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | build 4 | output 5 | *.egg-info 6 | *.so 7 | **/__pycache__ 8 | weights/* 9 | pretrained 10 | data/Kitti/downsampled 11 | data/ModelNet/*.pkl 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zheng Qin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/assets/teaser.png -------------------------------------------------------------------------------- /data/3DMatch/metadata/3DLoMatch.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/3DMatch/metadata/3DLoMatch.pkl -------------------------------------------------------------------------------- /data/3DMatch/metadata/3DMatch.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/3DMatch/metadata/3DMatch.pkl -------------------------------------------------------------------------------- /data/3DMatch/metadata/split/train_3dmatch.txt: -------------------------------------------------------------------------------- 1 | 7-scenes-chess 2 | 7-scenes-fire 3 | 7-scenes-office 4 | 7-scenes-pumpkin 5 | 7-scenes-stairs 6 | analysis-by-synthesis-apt1-kitchen 7 | analysis-by-synthesis-apt1-living 8 | analysis-by-synthesis-apt2-bed 9 | analysis-by-synthesis-apt2-kitchen 10 | analysis-by-synthesis-apt2-living 11 | analysis-by-synthesis-apt2-luke 12 | analysis-by-synthesis-office2-5a 13 | analysis-by-synthesis-office2-5b 14 | bundlefusion-apt0_1 15 | bundlefusion-apt0_2 16 | bundlefusion-apt0_3 17 | bundlefusion-apt0_4 18 | bundlefusion-apt1_1 19 | bundlefusion-apt1_2 20 | bundlefusion-apt1_3 21 | bundlefusion-apt1_4 22 | bundlefusion-apt2_1 23 | bundlefusion-apt2_2 24 | bundlefusion-copyroom_1 25 | bundlefusion-copyroom_2 26 | bundlefusion-office1_1 27 | bundlefusion-office1_2 28 | bundlefusion-office2 29 | bundlefusion-office3 30 | rgbd-scenes-v2-scene_01 31 | rgbd-scenes-v2-scene_02 32 | rgbd-scenes-v2-scene_03 33 | rgbd-scenes-v2-scene_04 34 | rgbd-scenes-v2-scene_05 35 | rgbd-scenes-v2-scene_06 36 | rgbd-scenes-v2-scene_07 37 | rgbd-scenes-v2-scene_08 38 | rgbd-scenes-v2-scene_09 39 | rgbd-scenes-v2-scene_11 40 | rgbd-scenes-v2-scene_12 41 | rgbd-scenes-v2-scene_13 42 | rgbd-scenes-v2-scene_14 43 | sun3d-brown_bm_1-brown_bm_1_1 44 | sun3d-brown_bm_1-brown_bm_1_2 45 | sun3d-brown_bm_1-brown_bm_1_3 46 | sun3d-brown_cogsci_1-brown_cogsci_1 47 | sun3d-brown_cs_2-brown_cs2_1 48 | sun3d-brown_cs_2-brown_cs2_2 49 | sun3d-brown_cs_3-brown_cs3 50 | sun3d-harvard_c3-hv_c3_1 51 | sun3d-harvard_c5-hv_c5_1 52 | sun3d-harvard_c6-hv_c6_1 53 | sun3d-harvard_c8-hv_c8_3 54 | sun3d-hotel_nips2012-nips_4_1 55 | sun3d-hotel_nips2012-nips_4_2 56 | sun3d-hotel_sf-scan1_1 57 | sun3d-hotel_sf-scan1_2 58 | sun3d-hotel_sf-scan1_3 59 | sun3d-hotel_sf-scan1_4 60 | sun3d-mit_32_d507-d507_2_1 61 | sun3d-mit_32_d507-d507_2_2 62 | sun3d-mit_46_ted_lab1-ted_lab_2_1 63 | sun3d-mit_46_ted_lab1-ted_lab_2_2 64 | sun3d-mit_46_ted_lab1-ted_lab_2_3 65 | sun3d-mit_46_ted_lab1-ted_lab_2_4 66 | sun3d-mit_76_417-76-417b_1 67 | sun3d-mit_76_417-76-417b_2_1 68 | sun3d-mit_76_417-76-417b_3 69 | sun3d-mit_76_417-76-417b_4 70 | sun3d-mit_76_417-76-417b_5 71 | sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika 72 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_1 73 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_2 74 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_3 75 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika_4 76 | -------------------------------------------------------------------------------- /data/3DMatch/metadata/split/val_3dmatch.txt: -------------------------------------------------------------------------------- 1 | sun3d-brown_bm_4-brown_bm_4 2 | sun3d-harvard_c11-hv_c11_2 3 | 7-scenes-heads 4 | rgbd-scenes-v2-scene_10 5 | bundlefusion-office0_1 6 | bundlefusion-office0_2 7 | bundlefusion-office0_3 8 | analysis-by-synthesis-apt2-kitchen 9 | -------------------------------------------------------------------------------- /data/3DMatch/metadata/train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/3DMatch/metadata/train.pkl -------------------------------------------------------------------------------- /data/3DMatch/metadata/val.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/3DMatch/metadata/val.pkl -------------------------------------------------------------------------------- /data/Kitti/downsample_pcd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import open3d as o3d 4 | import numpy as np 5 | import glob 6 | from tqdm import tqdm 7 | 8 | 9 | def main(): 10 | for i in range(11): 11 | seq_id = '{:02d}'.format(i) 12 | file_names = glob.glob(osp.join('sequences', seq_id, 'velodyne', '*.bin')) 13 | for file_name in tqdm(file_names): 14 | frame = file_name.split('/')[-1][:-4] 15 | new_file_name = osp.join('downsampled', seq_id, frame + '.npy') 16 | points = np.fromfile(file_name, dtype=np.float32).reshape(-1, 4) 17 | points = points[:, :3] 18 | pcd = o3d.geometry.PointCloud() 19 | pcd.points = o3d.utility.Vector3dVector(points) 20 | pcd = pcd.voxel_down_sample(0.3) 21 | points = np.array(pcd.points).astype(np.float32) 22 | np.save(new_file_name, points) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /data/Kitti/metadata/test.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/Kitti/metadata/test.pkl -------------------------------------------------------------------------------- /data/Kitti/metadata/train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/Kitti/metadata/train.pkl -------------------------------------------------------------------------------- /data/Kitti/metadata/val.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/Kitti/metadata/val.pkl -------------------------------------------------------------------------------- /data/ModelNet/split_data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import pickle 4 | 5 | 6 | def dump_pickle(data, filename): 7 | with open(filename, 'wb') as f: 8 | pickle.dump(data, f) 9 | 10 | 11 | def process(subset): 12 | with open(f'modelnet40_ply_hdf5_2048/{subset}_files.txt') as f: 13 | lines = f.readlines() 14 | all_points = [] 15 | all_normals = [] 16 | all_labels = [] 17 | for line in lines: 18 | filename = line.strip() 19 | h5file = h5py.File(f'modelnet40_ply_hdf5_2048/{filename}', 'r') 20 | all_points.append(h5file['data'][:]) 21 | all_normals.append(h5file['normal'][:]) 22 | all_labels.append(h5file['label'][:].flatten().astype(np.int)) 23 | points = np.concatenate(all_points, axis=0) 24 | normals = np.concatenate(all_normals, axis=0) 25 | labels = np.concatenate(all_labels, axis=0) 26 | print(f'{subset} data loaded.') 27 | all_data = [] 28 | num_data = points.shape[0] 29 | for i in range(num_data): 30 | all_data.append(dict(points=points[i], normals=normals[i], label=labels[i])) 31 | if subset == 'train': 32 | indices = np.random.permutation(num_data) 33 | num_train = int(num_data * 0.8) 34 | num_val = num_data - num_train 35 | train_indices = indices[:num_train] 36 | val_indices = indices[num_train:] 37 | train_data = [all_data[i] for i in train_indices.tolist()] 38 | dump_pickle(train_data, 'train.pkl') 39 | val_data = [all_data[i] for i in val_indices.tolist()] 40 | dump_pickle(val_data, 'val.pkl') 41 | else: 42 | dump_pickle(all_data, 'test.pkl') 43 | 44 | 45 | 46 | for subset in ['train', 'test']: 47 | process(subset) 48 | -------------------------------------------------------------------------------- /data/demo/gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/demo/gt.npy -------------------------------------------------------------------------------- /data/demo/ref.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/demo/ref.npy -------------------------------------------------------------------------------- /data/demo/src.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/data/demo/src.npy -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from IPython import embed 4 | 5 | from geotransformer.modules.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample 6 | 7 | 8 | class KPConvFPN(nn.Module): 9 | def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm): 10 | super(KPConvFPN, self).__init__() 11 | 12 | self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm) 13 | self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm) 14 | 15 | self.encoder2_1 = ResidualBlock( 16 | init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True 17 | ) 18 | self.encoder2_2 = ResidualBlock( 19 | init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 20 | ) 21 | self.encoder2_3 = ResidualBlock( 22 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 23 | ) 24 | 25 | self.encoder3_1 = ResidualBlock( 26 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True 27 | ) 28 | self.encoder3_2 = ResidualBlock( 29 | init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 30 | ) 31 | self.encoder3_3 = ResidualBlock( 32 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 33 | ) 34 | 35 | self.encoder4_1 = ResidualBlock( 36 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm, strided=True 37 | ) 38 | self.encoder4_2 = ResidualBlock( 39 | init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 40 | ) 41 | self.encoder4_3 = ResidualBlock( 42 | init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 43 | ) 44 | 45 | self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm) 46 | self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim) 47 | 48 | def forward(self, feats, data_dict): 49 | feats_list = [] 50 | 51 | points_list = data_dict['points'] 52 | neighbors_list = data_dict['neighbors'] 53 | subsampling_list = data_dict['subsampling'] 54 | upsampling_list = data_dict['upsampling'] 55 | 56 | feats_s1 = feats 57 | feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 58 | feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 59 | 60 | feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0]) 61 | feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 62 | feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 63 | 64 | feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1]) 65 | feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 66 | feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 67 | 68 | feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2]) 69 | feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 70 | feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 71 | 72 | latent_s4 = feats_s4 73 | feats_list.append(feats_s4) 74 | 75 | latent_s3 = nearest_upsample(latent_s4, upsampling_list[2]) 76 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1) 77 | latent_s3 = self.decoder3(latent_s3) 78 | feats_list.append(latent_s3) 79 | 80 | latent_s2 = nearest_upsample(latent_s3, upsampling_list[1]) 81 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) 82 | latent_s2 = self.decoder2(latent_s2) 83 | feats_list.append(latent_s2) 84 | 85 | feats_list.reverse() 86 | 87 | return feats_list 88 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | 5 | from easydict import EasyDict as edict 6 | 7 | from geotransformer.utils.common import ensure_dir 8 | 9 | 10 | _C = edict() 11 | 12 | # common 13 | _C.seed = 7351 14 | 15 | # dirs 16 | _C.working_dir = osp.dirname(osp.realpath(__file__)) 17 | _C.root_dir = osp.dirname(osp.dirname(_C.working_dir)) 18 | _C.exp_name = osp.basename(_C.working_dir) 19 | _C.output_dir = osp.join(_C.root_dir, 'output', _C.exp_name) 20 | _C.snapshot_dir = osp.join(_C.output_dir, 'snapshots') 21 | _C.log_dir = osp.join(_C.output_dir, 'logs') 22 | _C.event_dir = osp.join(_C.output_dir, 'events') 23 | _C.feature_dir = osp.join(_C.output_dir, 'features') 24 | _C.registration_dir = osp.join(_C.output_dir, 'registration') 25 | 26 | ensure_dir(_C.output_dir) 27 | ensure_dir(_C.snapshot_dir) 28 | ensure_dir(_C.log_dir) 29 | ensure_dir(_C.event_dir) 30 | ensure_dir(_C.feature_dir) 31 | ensure_dir(_C.registration_dir) 32 | 33 | # data 34 | _C.data = edict() 35 | _C.data.dataset_root = osp.join(_C.root_dir, 'data', '3DMatch') 36 | 37 | # train data 38 | _C.train = edict() 39 | _C.train.batch_size = 1 40 | _C.train.num_workers = 8 41 | _C.train.point_limit = 30000 42 | _C.train.use_augmentation = True 43 | _C.train.augmentation_noise = 0.005 44 | _C.train.augmentation_rotation = 1.0 45 | 46 | # test data 47 | _C.test = edict() 48 | _C.test.batch_size = 1 49 | _C.test.num_workers = 8 50 | _C.test.point_limit = None 51 | 52 | # evaluation 53 | _C.eval = edict() 54 | _C.eval.acceptance_overlap = 0.0 55 | _C.eval.acceptance_radius = 0.1 56 | _C.eval.inlier_ratio_threshold = 0.05 57 | _C.eval.rmse_threshold = 0.2 58 | _C.eval.rre_threshold = 15.0 59 | _C.eval.rte_threshold = 0.3 60 | 61 | # ransac 62 | _C.ransac = edict() 63 | _C.ransac.distance_threshold = 0.05 64 | _C.ransac.num_points = 3 65 | _C.ransac.num_iterations = 1000 66 | 67 | # optim 68 | _C.optim = edict() 69 | _C.optim.lr = 1e-4 70 | _C.optim.lr_decay = 0.95 71 | _C.optim.lr_decay_steps = 1 72 | _C.optim.weight_decay = 1e-6 73 | _C.optim.max_epoch = 40 74 | _C.optim.grad_acc_steps = 1 75 | 76 | # model - backbone 77 | _C.backbone = edict() 78 | _C.backbone.num_stages = 4 79 | _C.backbone.init_voxel_size = 0.025 80 | _C.backbone.kernel_size = 15 81 | _C.backbone.base_radius = 2.5 82 | _C.backbone.base_sigma = 2.0 83 | _C.backbone.init_radius = _C.backbone.base_radius * _C.backbone.init_voxel_size 84 | _C.backbone.init_sigma = _C.backbone.base_sigma * _C.backbone.init_voxel_size 85 | _C.backbone.group_norm = 32 86 | _C.backbone.input_dim = 1 87 | _C.backbone.init_dim = 64 88 | _C.backbone.output_dim = 256 89 | 90 | # model - Global 91 | _C.model = edict() 92 | _C.model.ground_truth_matching_radius = 0.05 93 | _C.model.num_points_in_patch = 64 94 | _C.model.num_sinkhorn_iterations = 100 95 | 96 | # model - Coarse Matching 97 | _C.coarse_matching = edict() 98 | _C.coarse_matching.num_targets = 128 99 | _C.coarse_matching.overlap_threshold = 0.1 100 | _C.coarse_matching.num_correspondences = 256 101 | _C.coarse_matching.dual_normalization = True 102 | 103 | # model - GeoTransformer 104 | _C.geotransformer = edict() 105 | _C.geotransformer.input_dim = 1024 106 | _C.geotransformer.hidden_dim = 256 107 | _C.geotransformer.output_dim = 256 108 | _C.geotransformer.num_heads = 4 109 | _C.geotransformer.blocks = ['self', 'cross', 'self', 'cross', 'self', 'cross'] 110 | _C.geotransformer.sigma_d = 0.2 111 | _C.geotransformer.sigma_a = 15 112 | _C.geotransformer.angle_k = 3 113 | _C.geotransformer.reduction_a = 'max' 114 | 115 | # model - Fine Matching 116 | _C.fine_matching = edict() 117 | _C.fine_matching.topk = 3 118 | _C.fine_matching.acceptance_radius = 0.1 119 | _C.fine_matching.mutual = True 120 | _C.fine_matching.confidence_threshold = 0.05 121 | _C.fine_matching.use_dustbin = False 122 | _C.fine_matching.use_global_score = False 123 | _C.fine_matching.correspondence_threshold = 3 124 | _C.fine_matching.correspondence_limit = None 125 | _C.fine_matching.num_refinement_steps = 5 126 | 127 | # loss - Coarse level 128 | _C.coarse_loss = edict() 129 | _C.coarse_loss.positive_margin = 0.1 130 | _C.coarse_loss.negative_margin = 1.4 131 | _C.coarse_loss.positive_optimal = 0.1 132 | _C.coarse_loss.negative_optimal = 1.4 133 | _C.coarse_loss.log_scale = 24 134 | _C.coarse_loss.positive_overlap = 0.1 135 | 136 | # loss - Fine level 137 | _C.fine_loss = edict() 138 | _C.fine_loss.positive_radius = 0.05 139 | 140 | # loss - Overall 141 | _C.loss = edict() 142 | _C.loss.weight_coarse_loss = 1.0 143 | _C.loss.weight_fine_loss = 1.0 144 | 145 | 146 | def make_cfg(): 147 | return _C 148 | 149 | 150 | def parse_args(): 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument('--link_output', dest='link_output', action='store_true', help='link output dir') 153 | args = parser.parse_args() 154 | return args 155 | 156 | 157 | def main(): 158 | cfg = make_cfg() 159 | args = parse_args() 160 | if args.link_output: 161 | os.symlink(cfg.output_dir, 'output') 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/dataset.py: -------------------------------------------------------------------------------- 1 | from geotransformer.datasets.registration.threedmatch.dataset import ThreeDMatchPairDataset 2 | from geotransformer.utils.data import ( 3 | registration_collate_fn_stack_mode, 4 | calibrate_neighbors_stack_mode, 5 | build_dataloader_stack_mode, 6 | ) 7 | 8 | 9 | def train_valid_data_loader(cfg, distributed): 10 | train_dataset = ThreeDMatchPairDataset( 11 | cfg.data.dataset_root, 12 | 'train', 13 | point_limit=cfg.train.point_limit, 14 | use_augmentation=cfg.train.use_augmentation, 15 | augmentation_noise=cfg.train.augmentation_noise, 16 | augmentation_rotation=cfg.train.augmentation_rotation, 17 | ) 18 | neighbor_limits = calibrate_neighbors_stack_mode( 19 | train_dataset, 20 | registration_collate_fn_stack_mode, 21 | cfg.backbone.num_stages, 22 | cfg.backbone.init_voxel_size, 23 | cfg.backbone.init_radius, 24 | ) 25 | train_loader = build_dataloader_stack_mode( 26 | train_dataset, 27 | registration_collate_fn_stack_mode, 28 | cfg.backbone.num_stages, 29 | cfg.backbone.init_voxel_size, 30 | cfg.backbone.init_radius, 31 | neighbor_limits, 32 | batch_size=cfg.train.batch_size, 33 | num_workers=cfg.train.num_workers, 34 | shuffle=True, 35 | distributed=distributed, 36 | ) 37 | 38 | valid_dataset = ThreeDMatchPairDataset( 39 | cfg.data.dataset_root, 40 | 'val', 41 | point_limit=cfg.test.point_limit, 42 | use_augmentation=False, 43 | ) 44 | valid_loader = build_dataloader_stack_mode( 45 | valid_dataset, 46 | registration_collate_fn_stack_mode, 47 | cfg.backbone.num_stages, 48 | cfg.backbone.init_voxel_size, 49 | cfg.backbone.init_radius, 50 | neighbor_limits, 51 | batch_size=cfg.test.batch_size, 52 | num_workers=cfg.test.num_workers, 53 | shuffle=False, 54 | distributed=distributed, 55 | ) 56 | 57 | return train_loader, valid_loader, neighbor_limits 58 | 59 | 60 | def test_data_loader(cfg, benchmark): 61 | train_dataset = ThreeDMatchPairDataset( 62 | cfg.data.dataset_root, 63 | 'train', 64 | point_limit=cfg.train.point_limit, 65 | use_augmentation=cfg.train.use_augmentation, 66 | augmentation_noise=cfg.train.augmentation_noise, 67 | augmentation_rotation=cfg.train.augmentation_rotation, 68 | ) 69 | neighbor_limits = calibrate_neighbors_stack_mode( 70 | train_dataset, 71 | registration_collate_fn_stack_mode, 72 | cfg.backbone.num_stages, 73 | cfg.backbone.init_voxel_size, 74 | cfg.backbone.init_radius, 75 | ) 76 | 77 | test_dataset = ThreeDMatchPairDataset( 78 | cfg.data.dataset_root, 79 | benchmark, 80 | point_limit=cfg.test.point_limit, 81 | use_augmentation=False, 82 | ) 83 | test_loader = build_dataloader_stack_mode( 84 | test_dataset, 85 | registration_collate_fn_stack_mode, 86 | cfg.backbone.num_stages, 87 | cfg.backbone.init_voxel_size, 88 | cfg.backbone.init_radius, 89 | neighbor_limits, 90 | batch_size=cfg.test.batch_size, 91 | num_workers=cfg.test.num_workers, 92 | shuffle=False, 93 | ) 94 | 95 | return test_loader, neighbor_limits 96 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from geotransformer.utils.data import registration_collate_fn_stack_mode 7 | from geotransformer.utils.torch import to_cuda, release_cuda 8 | from geotransformer.utils.open3d import make_open3d_point_cloud, get_color, draw_geometries 9 | from geotransformer.utils.registration import compute_registration_error 10 | 11 | from config import make_cfg 12 | from model import create_model 13 | 14 | 15 | def make_parser(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--src_file", required=True, help="src point cloud numpy file") 18 | parser.add_argument("--ref_file", required=True, help="src point cloud numpy file") 19 | parser.add_argument("--gt_file", required=True, help="ground-truth transformation file") 20 | parser.add_argument("--weights", required=True, help="model weights file") 21 | return parser 22 | 23 | 24 | def load_data(args): 25 | src_points = np.load(args.src_file) 26 | ref_points = np.load(args.ref_file) 27 | src_feats = np.ones_like(src_points[:, :1]) 28 | ref_feats = np.ones_like(ref_points[:, :1]) 29 | 30 | data_dict = { 31 | "ref_points": ref_points.astype(np.float32), 32 | "src_points": src_points.astype(np.float32), 33 | "ref_feats": ref_feats.astype(np.float32), 34 | "src_feats": src_feats.astype(np.float32), 35 | } 36 | 37 | if args.gt_file is not None: 38 | transform = np.load(args.gt_file) 39 | data_dict["transform"] = transform.astype(np.float32) 40 | 41 | return data_dict 42 | 43 | 44 | def main(): 45 | parser = make_parser() 46 | args = parser.parse_args() 47 | 48 | cfg = make_cfg() 49 | 50 | # prepare data 51 | data_dict = load_data(args) 52 | neighbor_limits = [38, 36, 36, 38] # default setting in 3DMatch 53 | data_dict = registration_collate_fn_stack_mode( 54 | [data_dict], cfg.backbone.num_stages, cfg.backbone.init_voxel_size, cfg.backbone.init_radius, neighbor_limits 55 | ) 56 | 57 | # prepare model 58 | model = create_model(cfg).cuda() 59 | state_dict = torch.load(args.weights) 60 | model.load_state_dict(state_dict["model"]) 61 | 62 | # prediction 63 | data_dict = to_cuda(data_dict) 64 | output_dict = model(data_dict) 65 | data_dict = release_cuda(data_dict) 66 | output_dict = release_cuda(output_dict) 67 | 68 | # get results 69 | ref_points = output_dict["ref_points"] 70 | src_points = output_dict["src_points"] 71 | estimated_transform = output_dict["estimated_transform"] 72 | transform = data_dict["transform"] 73 | 74 | # visualization 75 | ref_pcd = make_open3d_point_cloud(ref_points) 76 | ref_pcd.estimate_normals() 77 | ref_pcd.paint_uniform_color(get_color("custom_yellow")) 78 | src_pcd = make_open3d_point_cloud(src_points) 79 | src_pcd.estimate_normals() 80 | src_pcd.paint_uniform_color(get_color("custom_blue")) 81 | draw_geometries(ref_pcd, src_pcd) 82 | src_pcd = src_pcd.transform(estimated_transform) 83 | draw_geometries(ref_pcd, src_pcd) 84 | 85 | # compute error 86 | rre, rte = compute_registration_error(transform, estimated_transform) 87 | print(f"RRE(deg): {rre:.3f}, RTE(m): {rte:.3f}") 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/eval.sh: -------------------------------------------------------------------------------- 1 | if [ "$3" = "test" ]; then 2 | python test.py --test_epoch=$1 --benchmark=$2 3 | fi 4 | python eval.py --test_epoch=$1 --benchmark=$2 --method=lgr 5 | # for n in 250 500 1000 2500; do 6 | # python eval.py --test_epoch=$1 --num_corr=$n --run_matching --run_registration --benchmark=$2 7 | # done 8 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/eval_all.sh: -------------------------------------------------------------------------------- 1 | for n in $(seq 20 40); do 2 | python test.py --test_epoch=$n --benchmark=$1 --verbose 3 | python eval.py --test_epoch=$n --benchmark=$1 --method=lgr 4 | done 5 | # for n in 250 500 1000 2500; do 6 | # python eval.py --test_epoch=$1 --num_corr=$n --run_matching --run_registration --benchmark=$2 7 | # done 8 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import time 4 | 5 | import numpy as np 6 | 7 | from geotransformer.engine import SingleTester 8 | from geotransformer.utils.torch import release_cuda 9 | from geotransformer.utils.common import ensure_dir, get_log_string 10 | 11 | from dataset import test_data_loader 12 | from config import make_cfg 13 | from model import create_model 14 | from loss import Evaluator 15 | 16 | 17 | def make_parser(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--benchmark', choices=['3DMatch', '3DLoMatch', 'val'], help='test benchmark') 20 | return parser 21 | 22 | 23 | class Tester(SingleTester): 24 | def __init__(self, cfg): 25 | super().__init__(cfg, parser=make_parser()) 26 | 27 | # dataloader 28 | start_time = time.time() 29 | data_loader, neighbor_limits = test_data_loader(cfg, self.args.benchmark) 30 | loading_time = time.time() - start_time 31 | message = f'Data loader created: {loading_time:.3f}s collapsed.' 32 | self.logger.info(message) 33 | message = f'Calibrate neighbors: {neighbor_limits}.' 34 | self.logger.info(message) 35 | self.register_loader(data_loader) 36 | 37 | # model 38 | model = create_model(cfg).cuda() 39 | self.register_model(model) 40 | 41 | # evaluator 42 | self.evaluator = Evaluator(cfg).cuda() 43 | 44 | # preparation 45 | self.output_dir = osp.join(cfg.feature_dir, self.args.benchmark) 46 | ensure_dir(self.output_dir) 47 | 48 | def test_step(self, iteration, data_dict): 49 | output_dict = self.model(data_dict) 50 | return output_dict 51 | 52 | def eval_step(self, iteration, data_dict, output_dict): 53 | result_dict = self.evaluator(output_dict, data_dict) 54 | return result_dict 55 | 56 | def summary_string(self, iteration, data_dict, output_dict, result_dict): 57 | scene_name = data_dict['scene_name'] 58 | ref_frame = data_dict['ref_frame'] 59 | src_frame = data_dict['src_frame'] 60 | message = f'{scene_name}, id0: {ref_frame}, id1: {src_frame}' 61 | message += ', ' + get_log_string(result_dict=result_dict) 62 | message += ', nCorr: {}'.format(output_dict['corr_scores'].shape[0]) 63 | return message 64 | 65 | def after_test_step(self, iteration, data_dict, output_dict, result_dict): 66 | scene_name = data_dict['scene_name'] 67 | ref_id = data_dict['ref_frame'] 68 | src_id = data_dict['src_frame'] 69 | 70 | ensure_dir(osp.join(self.output_dir, scene_name)) 71 | file_name = osp.join(self.output_dir, scene_name, f'{ref_id}_{src_id}.npz') 72 | np.savez_compressed( 73 | file_name, 74 | ref_points=release_cuda(output_dict['ref_points']), 75 | src_points=release_cuda(output_dict['src_points']), 76 | ref_points_f=release_cuda(output_dict['ref_points_f']), 77 | src_points_f=release_cuda(output_dict['src_points_f']), 78 | ref_points_c=release_cuda(output_dict['ref_points_c']), 79 | src_points_c=release_cuda(output_dict['src_points_c']), 80 | ref_feats_c=release_cuda(output_dict['ref_feats_c']), 81 | src_feats_c=release_cuda(output_dict['src_feats_c']), 82 | ref_node_corr_indices=release_cuda(output_dict['ref_node_corr_indices']), 83 | src_node_corr_indices=release_cuda(output_dict['src_node_corr_indices']), 84 | ref_corr_points=release_cuda(output_dict['ref_corr_points']), 85 | src_corr_points=release_cuda(output_dict['src_corr_points']), 86 | corr_scores=release_cuda(output_dict['corr_scores']), 87 | gt_node_corr_indices=release_cuda(output_dict['gt_node_corr_indices']), 88 | gt_node_corr_overlaps=release_cuda(output_dict['gt_node_corr_overlaps']), 89 | estimated_transform=release_cuda(output_dict['estimated_transform']), 90 | transform=release_cuda(data_dict['transform']), 91 | overlap=data_dict['overlap'], 92 | ) 93 | 94 | 95 | def main(): 96 | cfg = make_cfg() 97 | tester = Tester(cfg) 98 | tester.run() 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /experiments/geotransformer.3dmatch.stage4.gse.k3.max.oacl.stage2.sinkhorn/trainval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch.optim as optim 5 | 6 | from geotransformer.engine import EpochBasedTrainer 7 | 8 | from config import make_cfg 9 | from dataset import train_valid_data_loader 10 | from model import create_model 11 | from loss import OverallLoss, Evaluator 12 | 13 | 14 | class Trainer(EpochBasedTrainer): 15 | def __init__(self, cfg): 16 | super().__init__(cfg, max_epoch=cfg.optim.max_epoch) 17 | 18 | # dataloader 19 | start_time = time.time() 20 | train_loader, val_loader, neighbor_limits = train_valid_data_loader(cfg, self.distributed) 21 | loading_time = time.time() - start_time 22 | message = 'Data loader created: {:.3f}s collapsed.'.format(loading_time) 23 | self.logger.info(message) 24 | message = 'Calibrate neighbors: {}.'.format(neighbor_limits) 25 | self.logger.info(message) 26 | self.register_loader(train_loader, val_loader) 27 | 28 | # model, optimizer, scheduler 29 | model = create_model(cfg).cuda() 30 | model = self.register_model(model) 31 | optimizer = optim.Adam(model.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay) 32 | self.register_optimizer(optimizer) 33 | scheduler = optim.lr_scheduler.StepLR(optimizer, cfg.optim.lr_decay_steps, gamma=cfg.optim.lr_decay) 34 | self.register_scheduler(scheduler) 35 | 36 | # loss function, evaluator 37 | self.loss_func = OverallLoss(cfg).cuda() 38 | self.evaluator = Evaluator(cfg).cuda() 39 | 40 | def train_step(self, epoch, iteration, data_dict): 41 | output_dict = self.model(data_dict) 42 | loss_dict = self.loss_func(output_dict, data_dict) 43 | result_dict = self.evaluator(output_dict, data_dict) 44 | loss_dict.update(result_dict) 45 | return output_dict, loss_dict 46 | 47 | def val_step(self, epoch, iteration, data_dict): 48 | output_dict = self.model(data_dict) 49 | loss_dict = self.loss_func(output_dict, data_dict) 50 | result_dict = self.evaluator(output_dict, data_dict) 51 | loss_dict.update(result_dict) 52 | return output_dict, loss_dict 53 | 54 | 55 | def main(): 56 | cfg = make_cfg() 57 | trainer = Trainer(cfg) 58 | trainer.run() 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from geotransformer.modules.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample 5 | 6 | 7 | class KPConvFPN(nn.Module): 8 | def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm): 9 | super(KPConvFPN, self).__init__() 10 | 11 | self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm) 12 | self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm) 13 | 14 | self.encoder2_1 = ResidualBlock( 15 | init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True 16 | ) 17 | self.encoder2_2 = ResidualBlock( 18 | init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 19 | ) 20 | self.encoder2_3 = ResidualBlock( 21 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 22 | ) 23 | 24 | self.encoder3_1 = ResidualBlock( 25 | init_dim * 4, 26 | init_dim * 4, 27 | kernel_size, 28 | init_radius * 2, 29 | init_sigma * 2, 30 | group_norm, 31 | strided=True, 32 | ) 33 | self.encoder3_2 = ResidualBlock( 34 | init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 35 | ) 36 | self.encoder3_3 = ResidualBlock( 37 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 38 | ) 39 | 40 | self.encoder4_1 = ResidualBlock( 41 | init_dim * 8, 42 | init_dim * 8, 43 | kernel_size, 44 | init_radius * 4, 45 | init_sigma * 4, 46 | group_norm, 47 | strided=True, 48 | ) 49 | self.encoder4_2 = ResidualBlock( 50 | init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 51 | ) 52 | self.encoder4_3 = ResidualBlock( 53 | init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 54 | ) 55 | 56 | self.encoder5_1 = ResidualBlock( 57 | init_dim * 16, 58 | init_dim * 16, 59 | kernel_size, 60 | init_radius * 8, 61 | init_sigma * 8, 62 | group_norm, 63 | strided=True, 64 | ) 65 | self.encoder5_2 = ResidualBlock( 66 | init_dim * 16, init_dim * 32, kernel_size, init_radius * 16, init_sigma * 16, group_norm 67 | ) 68 | self.encoder5_3 = ResidualBlock( 69 | init_dim * 32, init_dim * 32, kernel_size, init_radius * 16, init_sigma * 16, group_norm 70 | ) 71 | 72 | self.decoder4 = UnaryBlock(init_dim * 48, init_dim * 16, group_norm) 73 | self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm) 74 | self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim) 75 | 76 | def forward(self, feats, data_dict): 77 | feats_list = [] 78 | 79 | points_list = data_dict['points'] 80 | neighbors_list = data_dict['neighbors'] 81 | subsampling_list = data_dict['subsampling'] 82 | upsampling_list = data_dict['upsampling'] 83 | 84 | feats_s1 = feats 85 | feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 86 | feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 87 | 88 | feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0]) 89 | feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 90 | feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 91 | 92 | feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1]) 93 | feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 94 | feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 95 | 96 | feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2]) 97 | feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 98 | feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 99 | 100 | feats_s5 = self.encoder5_1(feats_s4, points_list[4], points_list[3], subsampling_list[3]) 101 | feats_s5 = self.encoder5_2(feats_s5, points_list[4], points_list[4], neighbors_list[4]) 102 | feats_s5 = self.encoder5_3(feats_s5, points_list[4], points_list[4], neighbors_list[4]) 103 | 104 | latent_s5 = feats_s5 105 | feats_list.append(feats_s5) 106 | 107 | latent_s4 = nearest_upsample(latent_s5, upsampling_list[3]) 108 | latent_s4 = torch.cat([latent_s4, feats_s4], dim=1) 109 | latent_s4 = self.decoder4(latent_s4) 110 | feats_list.append(latent_s4) 111 | 112 | latent_s3 = nearest_upsample(latent_s4, upsampling_list[2]) 113 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1) 114 | latent_s3 = self.decoder3(latent_s3) 115 | feats_list.append(latent_s3) 116 | 117 | latent_s2 = nearest_upsample(latent_s3, upsampling_list[1]) 118 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) 119 | latent_s2 = self.decoder2(latent_s2) 120 | feats_list.append(latent_s2) 121 | 122 | feats_list.reverse() 123 | 124 | return feats_list 125 | -------------------------------------------------------------------------------- /experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | 5 | from easydict import EasyDict as edict 6 | 7 | from geotransformer.utils.common import ensure_dir 8 | 9 | 10 | _C = edict() 11 | 12 | # random seed 13 | _C.seed = 7351 14 | 15 | # dirs 16 | _C.working_dir = osp.dirname(osp.realpath(__file__)) 17 | _C.root_dir = osp.dirname(osp.dirname(_C.working_dir)) 18 | _C.exp_name = osp.basename(_C.working_dir) 19 | _C.output_dir = osp.join(_C.root_dir, 'output', _C.exp_name) 20 | _C.snapshot_dir = osp.join(_C.output_dir, 'snapshots') 21 | _C.log_dir = osp.join(_C.output_dir, 'logs') 22 | _C.event_dir = osp.join(_C.output_dir, 'events') 23 | _C.feature_dir = osp.join(_C.output_dir, 'features') 24 | 25 | ensure_dir(_C.output_dir) 26 | ensure_dir(_C.snapshot_dir) 27 | ensure_dir(_C.log_dir) 28 | ensure_dir(_C.event_dir) 29 | ensure_dir(_C.feature_dir) 30 | 31 | # data 32 | _C.data = edict() 33 | _C.data.dataset_root = osp.join(_C.root_dir, 'data', 'Kitti') 34 | 35 | # train data 36 | _C.train = edict() 37 | _C.train.batch_size = 1 38 | _C.train.num_workers = 8 39 | _C.train.point_limit = 30000 40 | _C.train.use_augmentation = True 41 | _C.train.augmentation_noise = 0.01 42 | _C.train.augmentation_min_scale = 0.8 43 | _C.train.augmentation_max_scale = 1.2 44 | _C.train.augmentation_shift = 2.0 45 | _C.train.augmentation_rotation = 1.0 46 | 47 | # test config 48 | _C.test = edict() 49 | _C.test.batch_size = 1 50 | _C.test.num_workers = 8 51 | _C.test.point_limit = None 52 | 53 | # eval config 54 | _C.eval = edict() 55 | _C.eval.acceptance_overlap = 0.0 56 | _C.eval.acceptance_radius = 1.0 57 | _C.eval.inlier_ratio_threshold = 0.05 58 | _C.eval.rre_threshold = 5.0 59 | _C.eval.rte_threshold = 2.0 60 | 61 | # ransac 62 | _C.ransac = edict() 63 | _C.ransac.distance_threshold = 0.3 64 | _C.ransac.num_points = 4 65 | _C.ransac.num_iterations = 50000 66 | 67 | # optim config 68 | _C.optim = edict() 69 | _C.optim.lr = 1e-4 70 | _C.optim.lr_decay = 0.95 71 | _C.optim.lr_decay_steps = 4 72 | _C.optim.weight_decay = 1e-6 73 | _C.optim.max_epoch = 160 74 | _C.optim.grad_acc_steps = 1 75 | 76 | # model - backbone 77 | _C.backbone = edict() 78 | _C.backbone.num_stages = 5 79 | _C.backbone.init_voxel_size = 0.3 80 | _C.backbone.kernel_size = 15 81 | _C.backbone.base_radius = 4.25 82 | _C.backbone.base_sigma = 2.0 83 | _C.backbone.init_radius = _C.backbone.base_radius * _C.backbone.init_voxel_size 84 | _C.backbone.init_sigma = _C.backbone.base_sigma * _C.backbone.init_voxel_size 85 | _C.backbone.group_norm = 32 86 | _C.backbone.input_dim = 1 87 | _C.backbone.init_dim = 64 88 | _C.backbone.output_dim = 256 89 | 90 | # model - Global 91 | _C.model = edict() 92 | _C.model.ground_truth_matching_radius = 0.6 93 | _C.model.num_points_in_patch = 128 94 | _C.model.num_sinkhorn_iterations = 100 95 | 96 | # model - Coarse Matching 97 | _C.coarse_matching = edict() 98 | _C.coarse_matching.num_targets = 128 99 | _C.coarse_matching.overlap_threshold = 0.1 100 | _C.coarse_matching.num_correspondences = 256 101 | _C.coarse_matching.dual_normalization = True 102 | 103 | # model - GeoTransformer 104 | _C.geotransformer = edict() 105 | _C.geotransformer.input_dim = 2048 106 | _C.geotransformer.hidden_dim = 128 107 | _C.geotransformer.output_dim = 256 108 | _C.geotransformer.num_heads = 4 109 | _C.geotransformer.blocks = ['self', 'cross', 'self', 'cross', 'self', 'cross'] 110 | _C.geotransformer.sigma_d = 4.8 111 | _C.geotransformer.sigma_a = 15 112 | _C.geotransformer.angle_k = 3 113 | _C.geotransformer.reduction_a = 'max' 114 | 115 | # model - Fine Matching 116 | _C.fine_matching = edict() 117 | _C.fine_matching.topk = 2 118 | _C.fine_matching.acceptance_radius = 0.6 119 | _C.fine_matching.mutual = True 120 | _C.fine_matching.confidence_threshold = 0.05 121 | _C.fine_matching.use_dustbin = False 122 | _C.fine_matching.use_global_score = False 123 | _C.fine_matching.correspondence_threshold = 3 124 | _C.fine_matching.correspondence_limit = None 125 | _C.fine_matching.num_refinement_steps = 5 126 | 127 | # loss - Coarse level 128 | _C.coarse_loss = edict() 129 | _C.coarse_loss.positive_margin = 0.1 130 | _C.coarse_loss.negative_margin = 1.4 131 | _C.coarse_loss.positive_optimal = 0.1 132 | _C.coarse_loss.negative_optimal = 1.4 133 | _C.coarse_loss.log_scale = 40 134 | _C.coarse_loss.positive_overlap = 0.1 135 | 136 | # loss - Fine level 137 | _C.fine_loss = edict() 138 | _C.fine_loss.positive_radius = 0.6 139 | 140 | # loss - Overall 141 | _C.loss = edict() 142 | _C.loss.weight_coarse_loss = 1.0 143 | _C.loss.weight_fine_loss = 1.0 144 | 145 | 146 | def make_cfg(): 147 | return _C 148 | 149 | 150 | def parse_args(): 151 | parser = argparse.ArgumentParser() 152 | parser.add_argument('--link_output', dest='link_output', action='store_true', help='link output dir') 153 | args = parser.parse_args() 154 | return args 155 | 156 | 157 | def main(): 158 | cfg = make_cfg() 159 | args = parse_args() 160 | if args.link_output: 161 | os.symlink(cfg.output_dir, 'output') 162 | 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/dataset.py: -------------------------------------------------------------------------------- 1 | from geotransformer.datasets.registration.kitti.dataset import OdometryKittiPairDataset 2 | from geotransformer.utils.data import ( 3 | registration_collate_fn_stack_mode, 4 | calibrate_neighbors_stack_mode, 5 | build_dataloader_stack_mode, 6 | ) 7 | 8 | 9 | def train_valid_data_loader(cfg, distributed): 10 | train_dataset = OdometryKittiPairDataset( 11 | cfg.data.dataset_root, 12 | 'train', 13 | point_limit=cfg.train.point_limit, 14 | use_augmentation=cfg.train.use_augmentation, 15 | augmentation_noise=cfg.train.augmentation_noise, 16 | augmentation_min_scale=cfg.train.augmentation_min_scale, 17 | augmentation_max_scale=cfg.train.augmentation_max_scale, 18 | augmentation_shift=cfg.train.augmentation_shift, 19 | augmentation_rotation=cfg.train.augmentation_rotation, 20 | ) 21 | neighbor_limits = calibrate_neighbors_stack_mode( 22 | train_dataset, 23 | registration_collate_fn_stack_mode, 24 | cfg.backbone.num_stages, 25 | cfg.backbone.init_voxel_size, 26 | cfg.backbone.init_radius, 27 | ) 28 | train_loader = build_dataloader_stack_mode( 29 | train_dataset, 30 | registration_collate_fn_stack_mode, 31 | cfg.backbone.num_stages, 32 | cfg.backbone.init_voxel_size, 33 | cfg.backbone.init_radius, 34 | neighbor_limits, 35 | batch_size=cfg.train.batch_size, 36 | num_workers=cfg.train.num_workers, 37 | shuffle=True, 38 | distributed=distributed, 39 | ) 40 | 41 | valid_dataset = OdometryKittiPairDataset( 42 | cfg.data.dataset_root, 43 | 'val', 44 | point_limit=cfg.test.point_limit, 45 | use_augmentation=False, 46 | ) 47 | valid_loader = build_dataloader_stack_mode( 48 | valid_dataset, 49 | registration_collate_fn_stack_mode, 50 | cfg.backbone.num_stages, 51 | cfg.backbone.init_voxel_size, 52 | cfg.backbone.init_radius, 53 | neighbor_limits, 54 | batch_size=cfg.test.batch_size, 55 | num_workers=cfg.test.num_workers, 56 | shuffle=False, 57 | distributed=distributed, 58 | ) 59 | 60 | return train_loader, valid_loader, neighbor_limits 61 | 62 | 63 | def test_data_loader(cfg): 64 | train_dataset = OdometryKittiPairDataset( 65 | cfg.data.dataset_root, 66 | 'train', 67 | point_limit=cfg.train.point_limit, 68 | use_augmentation=cfg.train.use_augmentation, 69 | augmentation_noise=cfg.train.augmentation_noise, 70 | augmentation_min_scale=cfg.train.augmentation_min_scale, 71 | augmentation_max_scale=cfg.train.augmentation_max_scale, 72 | augmentation_shift=cfg.train.augmentation_shift, 73 | augmentation_rotation=cfg.train.augmentation_rotation, 74 | ) 75 | neighbor_limits = calibrate_neighbors_stack_mode( 76 | train_dataset, 77 | registration_collate_fn_stack_mode, 78 | cfg.backbone.num_stages, 79 | cfg.backbone.init_voxel_size, 80 | cfg.backbone.init_radius, 81 | ) 82 | 83 | test_dataset = OdometryKittiPairDataset( 84 | cfg.data.dataset_root, 85 | 'test', 86 | point_limit=cfg.test.point_limit, 87 | use_augmentation=False, 88 | ) 89 | test_loader = build_dataloader_stack_mode( 90 | test_dataset, 91 | registration_collate_fn_stack_mode, 92 | cfg.backbone.num_stages, 93 | cfg.backbone.init_voxel_size, 94 | cfg.backbone.init_radius, 95 | neighbor_limits, 96 | batch_size=cfg.test.batch_size, 97 | num_workers=cfg.test.num_workers, 98 | shuffle=False, 99 | ) 100 | 101 | return test_loader, neighbor_limits 102 | -------------------------------------------------------------------------------- /experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/eval.sh: -------------------------------------------------------------------------------- 1 | if [ "$2" = "test" ]; then 2 | python test.py --test_epoch=$1 3 | fi 4 | python eval.py --test_epoch=$1 --method=lgr 5 | -------------------------------------------------------------------------------- /experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import time 4 | 5 | import numpy as np 6 | 7 | from geotransformer.engine import SingleTester 8 | from geotransformer.utils.common import ensure_dir, get_log_string 9 | from geotransformer.utils.torch import release_cuda 10 | 11 | from config import make_cfg 12 | from dataset import test_data_loader 13 | from loss import Evaluator 14 | from model import create_model 15 | 16 | 17 | class Tester(SingleTester): 18 | def __init__(self, cfg): 19 | super().__init__(cfg) 20 | 21 | # dataloader 22 | start_time = time.time() 23 | data_loader, neighbor_limits = test_data_loader(cfg) 24 | loading_time = time.time() - start_time 25 | message = f'Data loader created: {loading_time:.3f}s collapsed.' 26 | self.logger.info(message) 27 | message = f'Calibrate neighbors: {neighbor_limits}.' 28 | self.logger.info(message) 29 | self.register_loader(data_loader) 30 | 31 | # model 32 | model = create_model(cfg).cuda() 33 | self.register_model(model) 34 | 35 | # evaluator 36 | self.evaluator = Evaluator(cfg).cuda() 37 | 38 | # preparation 39 | self.output_dir = osp.join(cfg.feature_dir) 40 | ensure_dir(self.output_dir) 41 | 42 | def test_step(self, iteration, data_dict): 43 | output_dict = self.model(data_dict) 44 | return output_dict 45 | 46 | def eval_step(self, iteration, data_dict, output_dict): 47 | result_dict = self.evaluator(output_dict, data_dict) 48 | return result_dict 49 | 50 | def summary_string(self, iteration, data_dict, output_dict, result_dict): 51 | seq_id = data_dict['seq_id'] 52 | ref_frame = data_dict['ref_frame'] 53 | src_frame = data_dict['src_frame'] 54 | message = f'seq_id: {seq_id}, id0: {ref_frame}, id1: {src_frame}' 55 | message += ', ' + get_log_string(result_dict=result_dict) 56 | message += ', nCorr: {}'.format(output_dict['corr_scores'].shape[0]) 57 | return message 58 | 59 | def after_test_step(self, iteration, data_dict, output_dict, result_dict): 60 | seq_id = data_dict['seq_id'] 61 | ref_frame = data_dict['ref_frame'] 62 | src_frame = data_dict['src_frame'] 63 | 64 | file_name = osp.join(self.output_dir, f'{seq_id}_{src_frame}_{ref_frame}.npz') 65 | np.savez_compressed( 66 | file_name, 67 | ref_points=release_cuda(output_dict['ref_points']), 68 | src_points=release_cuda(output_dict['src_points']), 69 | ref_points_f=release_cuda(output_dict['ref_points_f']), 70 | src_points_f=release_cuda(output_dict['src_points_f']), 71 | ref_points_c=release_cuda(output_dict['ref_points_c']), 72 | src_points_c=release_cuda(output_dict['src_points_c']), 73 | ref_feats_c=release_cuda(output_dict['ref_feats_c']), 74 | src_feats_c=release_cuda(output_dict['src_feats_c']), 75 | ref_node_corr_indices=release_cuda(output_dict['ref_node_corr_indices']), 76 | src_node_corr_indices=release_cuda(output_dict['src_node_corr_indices']), 77 | ref_corr_points=release_cuda(output_dict['ref_corr_points']), 78 | src_corr_points=release_cuda(output_dict['src_corr_points']), 79 | corr_scores=release_cuda(output_dict['corr_scores']), 80 | gt_node_corr_indices=release_cuda(output_dict['gt_node_corr_indices']), 81 | gt_node_corr_overlaps=release_cuda(output_dict['gt_node_corr_overlaps']), 82 | estimated_transform=release_cuda(output_dict['estimated_transform']), 83 | transform=release_cuda(data_dict['transform']), 84 | ) 85 | 86 | 87 | def main(): 88 | cfg = make_cfg() 89 | tester = Tester(cfg) 90 | tester.run() 91 | 92 | 93 | if __name__ == '__main__': 94 | main() 95 | -------------------------------------------------------------------------------- /experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/trainval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import torch.optim as optim 5 | 6 | from geotransformer.engine import EpochBasedTrainer 7 | 8 | from config import make_cfg 9 | from dataset import train_valid_data_loader 10 | from model import create_model 11 | from loss import OverallLoss, Evaluator 12 | 13 | 14 | class Trainer(EpochBasedTrainer): 15 | def __init__(self, cfg): 16 | super().__init__(cfg, max_epoch=cfg.optim.max_epoch) 17 | 18 | # dataloader 19 | start_time = time.time() 20 | train_loader, val_loader, neighbor_limits = train_valid_data_loader(cfg, self.distributed) 21 | loading_time = time.time() - start_time 22 | message = 'Data loader created: {:.3f}s collapsed.'.format(loading_time) 23 | self.logger.info(message) 24 | message = 'Calibrate neighbors: {}.'.format(neighbor_limits) 25 | self.logger.info(message) 26 | self.register_loader(train_loader, val_loader) 27 | 28 | # model, optimizer, scheduler 29 | model = create_model(cfg).cuda() 30 | model = self.register_model(model) 31 | optimizer = optim.Adam(model.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay) 32 | self.register_optimizer(optimizer) 33 | scheduler = optim.lr_scheduler.StepLR(optimizer, cfg.optim.lr_decay_steps, gamma=cfg.optim.lr_decay) 34 | self.register_scheduler(scheduler) 35 | 36 | # loss function, evaluator 37 | self.loss_func = OverallLoss(cfg).cuda() 38 | self.evaluator = Evaluator(cfg).cuda() 39 | 40 | def train_step(self, epoch, iteration, data_dict): 41 | output_dict = self.model(data_dict) 42 | loss_dict = self.loss_func(output_dict, data_dict) 43 | result_dict = self.evaluator(output_dict, data_dict) 44 | loss_dict.update(result_dict) 45 | return output_dict, loss_dict 46 | 47 | def val_step(self, epoch, iteration, data_dict): 48 | output_dict = self.model(data_dict) 49 | loss_dict = self.loss_func(output_dict, data_dict) 50 | result_dict = self.evaluator(output_dict, data_dict) 51 | loss_dict.update(result_dict) 52 | return output_dict, loss_dict 53 | 54 | 55 | def main(): 56 | cfg = make_cfg() 57 | trainer = Trainer(cfg) 58 | trainer.run() 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from IPython import embed 4 | 5 | from geotransformer.modules.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample 6 | 7 | 8 | class KPConvFPN(nn.Module): 9 | def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm): 10 | super(KPConvFPN, self).__init__() 11 | 12 | self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm) 13 | self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm) 14 | 15 | self.encoder2_1 = ResidualBlock( 16 | init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True 17 | ) 18 | self.encoder2_2 = ResidualBlock( 19 | init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 20 | ) 21 | self.encoder2_3 = ResidualBlock( 22 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 23 | ) 24 | 25 | self.encoder3_1 = ResidualBlock( 26 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm, strided=True 27 | ) 28 | self.encoder3_2 = ResidualBlock( 29 | init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 30 | ) 31 | self.encoder3_3 = ResidualBlock( 32 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 33 | ) 34 | 35 | self.decoder2 = UnaryBlock(init_dim * 12, init_dim * 4, group_norm) 36 | self.decoder1 = LastUnaryBlock(init_dim * 6, output_dim) 37 | 38 | def forward(self, feats, data_dict): 39 | feats_list = [] 40 | 41 | points_list = data_dict['points'] 42 | neighbors_list = data_dict['neighbors'] 43 | subsampling_list = data_dict['subsampling'] 44 | upsampling_list = data_dict['upsampling'] 45 | 46 | feats_s1 = feats 47 | feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 48 | feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 49 | 50 | feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0]) 51 | feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 52 | feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 53 | 54 | feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1]) 55 | feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 56 | feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 57 | 58 | latent_s3 = feats_s3 59 | feats_list.append(feats_s3) 60 | 61 | latent_s2 = nearest_upsample(latent_s3, upsampling_list[1]) 62 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) 63 | latent_s2 = self.decoder2(latent_s2) 64 | feats_list.append(latent_s2) 65 | 66 | latent_s1 = nearest_upsample(latent_s2, upsampling_list[0]) 67 | latent_s1 = torch.cat([latent_s1, feats_s1], dim=1) 68 | latent_s1 = self.decoder1(latent_s1) 69 | feats_list.append(latent_s1) 70 | 71 | feats_list.reverse() 72 | 73 | return feats_list 74 | -------------------------------------------------------------------------------- /experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | 5 | from easydict import EasyDict as edict 6 | 7 | from geotransformer.utils.common import ensure_dir 8 | 9 | 10 | _C = edict() 11 | 12 | # common 13 | _C.seed = 7351 14 | 15 | # dirs 16 | _C.working_dir = osp.dirname(osp.realpath(__file__)) 17 | _C.root_dir = osp.dirname(osp.dirname(_C.working_dir)) 18 | _C.exp_name = osp.basename(_C.working_dir) 19 | _C.output_dir = osp.join(_C.root_dir, "output", _C.exp_name) 20 | _C.snapshot_dir = osp.join(_C.output_dir, "snapshots") 21 | _C.log_dir = osp.join(_C.output_dir, "logs") 22 | _C.event_dir = osp.join(_C.output_dir, "events") 23 | 24 | ensure_dir(_C.output_dir) 25 | ensure_dir(_C.snapshot_dir) 26 | ensure_dir(_C.log_dir) 27 | ensure_dir(_C.event_dir) 28 | 29 | # data 30 | _C.data = edict() 31 | _C.data.dataset_root = osp.join(_C.root_dir, "data", "ModelNet") 32 | _C.data.num_points = 717 33 | _C.data.voxel_size = None 34 | _C.data.rotation_magnitude = 45.0 35 | _C.data.translation_magnitude = 0.5 36 | _C.data.keep_ratio = 0.7 37 | _C.data.crop_method = "plane" 38 | _C.data.asymmetric = True 39 | _C.data.twice_sample = True 40 | _C.data.twice_transform = False 41 | 42 | # train data 43 | _C.train = edict() 44 | _C.train.batch_size = 1 45 | _C.train.num_workers = 8 46 | _C.train.noise_magnitude = 0.05 47 | _C.train.class_indices = "all" 48 | 49 | # test data 50 | _C.test = edict() 51 | _C.test.batch_size = 1 52 | _C.test.num_workers = 8 53 | _C.test.noise_magnitude = 0.05 54 | _C.test.class_indices = "all" 55 | 56 | # evaluation 57 | _C.eval = edict() 58 | _C.eval.acceptance_overlap = 0.0 59 | _C.eval.acceptance_radius = 0.1 60 | _C.eval.inlier_ratio_threshold = 0.05 61 | _C.eval.rre_threshold = 1.0 62 | _C.eval.rte_threshold = 0.1 63 | 64 | # ransac 65 | _C.ransac = edict() 66 | _C.ransac.distance_threshold = 0.05 67 | _C.ransac.num_points = 3 68 | _C.ransac.num_iterations = 1000 69 | 70 | # optim 71 | _C.optim = edict() 72 | _C.optim.lr = 1e-4 73 | _C.optim.weight_decay = 1e-6 74 | _C.optim.warmup_steps = 10000 75 | _C.optim.eta_init = 0.1 76 | _C.optim.eta_min = 0.1 77 | _C.optim.max_iteration = 400000 78 | _C.optim.snapshot_steps = 10000 79 | _C.optim.grad_acc_steps = 1 80 | 81 | # model - backbone 82 | _C.backbone = edict() 83 | _C.backbone.num_stages = 3 84 | _C.backbone.init_voxel_size = 0.05 85 | _C.backbone.kernel_size = 15 86 | _C.backbone.base_radius = 2.5 87 | _C.backbone.base_sigma = 2.0 88 | _C.backbone.init_radius = _C.backbone.base_radius * _C.backbone.init_voxel_size 89 | _C.backbone.init_sigma = _C.backbone.base_sigma * _C.backbone.init_voxel_size 90 | _C.backbone.group_norm = 32 91 | _C.backbone.input_dim = 1 92 | _C.backbone.init_dim = 64 93 | _C.backbone.output_dim = 256 94 | 95 | # model - Global 96 | _C.model = edict() 97 | _C.model.ground_truth_matching_radius = 0.05 98 | _C.model.num_points_in_patch = 128 99 | _C.model.num_sinkhorn_iterations = 100 100 | 101 | # model - Coarse Matching 102 | _C.coarse_matching = edict() 103 | _C.coarse_matching.num_targets = 128 104 | _C.coarse_matching.overlap_threshold = 0.1 105 | _C.coarse_matching.num_correspondences = 128 106 | _C.coarse_matching.dual_normalization = True 107 | 108 | # model - GeoTransformer 109 | _C.geotransformer = edict() 110 | _C.geotransformer.input_dim = 512 111 | _C.geotransformer.hidden_dim = 256 112 | _C.geotransformer.output_dim = 256 113 | _C.geotransformer.num_heads = 4 114 | _C.geotransformer.blocks = ["self", "cross", "self", "cross", "self", "cross"] 115 | _C.geotransformer.sigma_d = 0.2 116 | _C.geotransformer.sigma_a = 15 117 | _C.geotransformer.angle_k = 3 118 | _C.geotransformer.reduction_a = "max" 119 | 120 | # model - Fine Matching 121 | _C.fine_matching = edict() 122 | _C.fine_matching.topk = 3 123 | _C.fine_matching.acceptance_radius = 0.1 124 | _C.fine_matching.mutual = True 125 | _C.fine_matching.confidence_threshold = 0.05 126 | _C.fine_matching.use_dustbin = False 127 | _C.fine_matching.use_global_score = False 128 | _C.fine_matching.correspondence_threshold = 3 129 | _C.fine_matching.correspondence_limit = None 130 | _C.fine_matching.num_refinement_steps = 5 131 | 132 | # loss - Coarse level 133 | _C.coarse_loss = edict() 134 | _C.coarse_loss.positive_margin = 0.1 135 | _C.coarse_loss.negative_margin = 1.4 136 | _C.coarse_loss.positive_optimal = 0.1 137 | _C.coarse_loss.negative_optimal = 1.4 138 | _C.coarse_loss.log_scale = 24 139 | _C.coarse_loss.positive_overlap = 0.1 140 | 141 | # loss - Fine level 142 | _C.fine_loss = edict() 143 | _C.fine_loss.positive_radius = 0.05 144 | 145 | # loss - Overall 146 | _C.loss = edict() 147 | _C.loss.weight_coarse_loss = 1.0 148 | _C.loss.weight_fine_loss = 1.0 149 | 150 | 151 | def make_cfg(): 152 | return _C 153 | 154 | 155 | def parse_args(): 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument("--link_output", dest="link_output", action="store_true", help="link output dir") 158 | args = parser.parse_args() 159 | return args 160 | 161 | 162 | def main(): 163 | args = parse_args() 164 | cfg = make_cfg() 165 | if args.link_output: 166 | os.symlink(cfg.output_dir, "output") 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | 4 | from geotransformer.engine import SingleTester 5 | from geotransformer.utils.common import get_log_string 6 | 7 | from dataset import test_data_loader 8 | from config import make_cfg 9 | from model import create_model 10 | from loss import Evaluator 11 | 12 | 13 | class Tester(SingleTester): 14 | def __init__(self, cfg): 15 | super().__init__(cfg) 16 | 17 | # dataloader 18 | start_time = time.time() 19 | data_loader, neighbor_limits = test_data_loader(cfg) 20 | loading_time = time.time() - start_time 21 | message = f'Data loader created: {loading_time:.3f}s collapsed.' 22 | self.logger.info(message) 23 | message = f'Calibrate neighbors: {neighbor_limits}.' 24 | self.logger.info(message) 25 | self.register_loader(data_loader) 26 | 27 | # model 28 | model = create_model(cfg).cuda() 29 | self.register_model(model) 30 | 31 | # evaluator 32 | self.evaluator = Evaluator(cfg).cuda() 33 | 34 | def test_step(self, iteration, data_dict): 35 | output_dict = self.model(data_dict) 36 | return output_dict 37 | 38 | def eval_step(self, iteration, data_dict, output_dict): 39 | result_dict = self.evaluator(output_dict, data_dict) 40 | return result_dict 41 | 42 | def summary_string(self, iteration, data_dict, output_dict, result_dict): 43 | message = get_log_string(result_dict=result_dict) 44 | message += ', nCorr: {}'.format(output_dict['corr_scores'].shape[0]) 45 | return message 46 | 47 | 48 | def main(): 49 | cfg = make_cfg() 50 | tester = Tester(cfg) 51 | tester.run() 52 | 53 | 54 | if __name__ == '__main__': 55 | main() 56 | -------------------------------------------------------------------------------- /experiments/geotransformer.modelnet.rpmnet.stage4.gse.k3.max.oacl.stage2.sinkhorn/trainval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from IPython import embed 10 | 11 | from geotransformer.engine.iter_based_trainer import IterBasedTrainer 12 | from geotransformer.utils.torch import build_warmup_cosine_lr_scheduler 13 | 14 | from config import make_cfg 15 | from dataset import train_valid_data_loader 16 | from model import create_model 17 | from loss import OverallLoss, Evaluator 18 | 19 | 20 | class Trainer(IterBasedTrainer): 21 | def __init__(self, cfg): 22 | super().__init__(cfg, max_iteration=cfg.optim.max_iteration, snapshot_steps=cfg.optim.snapshot_steps) 23 | 24 | # dataloader 25 | start_time = time.time() 26 | train_loader, val_loader, neighbor_limits = train_valid_data_loader(cfg, self.distributed) 27 | loading_time = time.time() - start_time 28 | message = 'Data loader created: {:.3f}s collapsed.'.format(loading_time) 29 | self.logger.info(message) 30 | message = 'Calibrate neighbors: {}.'.format(neighbor_limits) 31 | self.logger.info(message) 32 | self.register_loader(train_loader, val_loader) 33 | 34 | # model, optimizer, scheduler 35 | model = create_model(cfg).cuda() 36 | model = self.register_model(model) 37 | optimizer = optim.Adam(model.parameters(), lr=cfg.optim.lr, weight_decay=cfg.optim.weight_decay) 38 | self.register_optimizer(optimizer) 39 | scheduler = build_warmup_cosine_lr_scheduler( 40 | optimizer, 41 | total_steps=cfg.optim.max_iteration, 42 | warmup_steps=cfg.optim.warmup_steps, 43 | eta_init=cfg.optim.eta_init, 44 | eta_min=cfg.optim.eta_min, 45 | grad_acc_steps=cfg.optim.grad_acc_steps, 46 | ) 47 | self.register_scheduler(scheduler) 48 | 49 | # loss function, evaluator 50 | self.loss_func = OverallLoss(cfg).cuda() 51 | self.evaluator = Evaluator(cfg).cuda() 52 | 53 | def train_step(self, iteration, data_dict): 54 | output_dict = self.model(data_dict) 55 | loss_dict = self.loss_func(output_dict, data_dict) 56 | result_dict = self.evaluator(output_dict, data_dict) 57 | loss_dict.update(result_dict) 58 | return output_dict, loss_dict 59 | 60 | def val_step(self, iteration, data_dict): 61 | output_dict = self.model(data_dict) 62 | loss_dict = self.loss_func(output_dict, data_dict) 63 | result_dict = self.evaluator(output_dict, data_dict) 64 | loss_dict.update(result_dict) 65 | return output_dict, loss_dict 66 | 67 | 68 | def main(): 69 | cfg = make_cfg() 70 | trainer = Trainer(cfg) 71 | trainer.run() 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /geotransformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/__init__.py -------------------------------------------------------------------------------- /geotransformer/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/datasets/__init__.py -------------------------------------------------------------------------------- /geotransformer/datasets/registration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/datasets/registration/__init__.py -------------------------------------------------------------------------------- /geotransformer/datasets/registration/kitti/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.datasets.registration.kitti.dataset import OdometryKittiPairDataset 2 | 3 | 4 | __all__ = [ 5 | 'OdometryKittiPairDataset', 6 | ] 7 | -------------------------------------------------------------------------------- /geotransformer/datasets/registration/kitti/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | 4 | import numpy as np 5 | import torch.utils.data 6 | 7 | from geotransformer.utils.common import load_pickle 8 | from geotransformer.utils.pointcloud import ( 9 | random_sample_rotation, 10 | get_transform_from_rotation_translation, 11 | get_rotation_translation_from_transform, 12 | ) 13 | from geotransformer.utils.registration import get_correspondences 14 | 15 | 16 | class OdometryKittiPairDataset(torch.utils.data.Dataset): 17 | ODOMETRY_KITTI_DATA_SPLIT = { 18 | 'train': ['00', '01', '02', '03', '04', '05'], 19 | 'val': ['06', '07'], 20 | 'test': ['08', '09', '10'], 21 | } 22 | 23 | def __init__( 24 | self, 25 | dataset_root, 26 | subset, 27 | point_limit=None, 28 | use_augmentation=False, 29 | augmentation_noise=0.005, 30 | augmentation_min_scale=0.8, 31 | augmentation_max_scale=1.2, 32 | augmentation_shift=2.0, 33 | augmentation_rotation=1.0, 34 | return_corr_indices=False, 35 | matching_radius=None, 36 | ): 37 | super(OdometryKittiPairDataset, self).__init__() 38 | 39 | self.dataset_root = dataset_root 40 | self.subset = subset 41 | self.point_limit = point_limit 42 | 43 | self.use_augmentation = use_augmentation 44 | self.augmentation_noise = augmentation_noise 45 | self.augmentation_min_scale = augmentation_min_scale 46 | self.augmentation_max_scale = augmentation_max_scale 47 | self.augmentation_shift = augmentation_shift 48 | self.augmentation_rotation = augmentation_rotation 49 | 50 | self.return_corr_indices = return_corr_indices 51 | self.matching_radius = matching_radius 52 | if self.return_corr_indices and self.matching_radius is None: 53 | raise ValueError('"matching_radius" is None but "return_corr_indices" is set.') 54 | 55 | self.metadata = load_pickle(osp.join(self.dataset_root, 'metadata', f'{subset}.pkl')) 56 | 57 | def _augment_point_cloud(self, ref_points, src_points, transform): 58 | rotation, translation = get_rotation_translation_from_transform(transform) 59 | # add gaussian noise 60 | ref_points = ref_points + (np.random.rand(ref_points.shape[0], 3) - 0.5) * self.augmentation_noise 61 | src_points = src_points + (np.random.rand(src_points.shape[0], 3) - 0.5) * self.augmentation_noise 62 | # random rotation 63 | aug_rotation = random_sample_rotation(self.augmentation_rotation) 64 | if random.random() > 0.5: 65 | ref_points = np.matmul(ref_points, aug_rotation.T) 66 | rotation = np.matmul(aug_rotation, rotation) 67 | translation = np.matmul(aug_rotation, translation) 68 | else: 69 | src_points = np.matmul(src_points, aug_rotation.T) 70 | rotation = np.matmul(rotation, aug_rotation.T) 71 | # random scaling 72 | scale = random.random() 73 | scale = self.augmentation_min_scale + (self.augmentation_max_scale - self.augmentation_min_scale) * scale 74 | ref_points = ref_points * scale 75 | src_points = src_points * scale 76 | translation = translation * scale 77 | # random shift 78 | ref_shift = np.random.uniform(-self.augmentation_shift, self.augmentation_shift, 3) 79 | src_shift = np.random.uniform(-self.augmentation_shift, self.augmentation_shift, 3) 80 | ref_points = ref_points + ref_shift 81 | src_points = src_points + src_shift 82 | translation = -np.matmul(src_shift[None, :], rotation.T) + translation + ref_shift 83 | # compose transform from rotation and translation 84 | transform = get_transform_from_rotation_translation(rotation, translation) 85 | return ref_points, src_points, transform 86 | 87 | def _load_point_cloud(self, file_name): 88 | points = np.load(file_name) 89 | if self.point_limit is not None and points.shape[0] > self.point_limit: 90 | indices = np.random.permutation(points.shape[0])[: self.point_limit] 91 | points = points[indices] 92 | return points 93 | 94 | def __getitem__(self, index): 95 | data_dict = {} 96 | 97 | metadata = self.metadata[index] 98 | data_dict['seq_id'] = metadata['seq_id'] 99 | data_dict['ref_frame'] = metadata['frame0'] 100 | data_dict['src_frame'] = metadata['frame1'] 101 | 102 | ref_points = self._load_point_cloud(osp.join(self.dataset_root, metadata['pcd0'])) 103 | src_points = self._load_point_cloud(osp.join(self.dataset_root, metadata['pcd1'])) 104 | transform = metadata['transform'] 105 | 106 | if self.use_augmentation: 107 | ref_points, src_points, transform = self._augment_point_cloud(ref_points, src_points, transform) 108 | 109 | if self.return_corr_indices: 110 | corr_indices = get_correspondences(ref_points, src_points, transform, self.matching_radius) 111 | data_dict['corr_indices'] = corr_indices 112 | 113 | data_dict['ref_points'] = ref_points.astype(np.float32) 114 | data_dict['src_points'] = src_points.astype(np.float32) 115 | data_dict['ref_feats'] = np.ones((ref_points.shape[0], 1), dtype=np.float32) 116 | data_dict['src_feats'] = np.ones((src_points.shape[0], 1), dtype=np.float32) 117 | data_dict['transform'] = transform.astype(np.float32) 118 | 119 | return data_dict 120 | 121 | def __len__(self): 122 | return len(self.metadata) 123 | -------------------------------------------------------------------------------- /geotransformer/datasets/registration/modelnet/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.datasets.registration.modelnet.dataset import ModelNetPairDataset 2 | -------------------------------------------------------------------------------- /geotransformer/datasets/registration/threedmatch/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.datasets.registration.threedmatch.dataset import ThreeDMatchPairDataset 2 | # from geotransformer.datasets.registration.threedmatch.dataset_minkowski import ThreeDMatchPairMinkowskiDataset 3 | 4 | 5 | __all__ = [ 6 | 'ThreeDMatchPairDataset', 7 | # 'ThreeDMatchPairMinkowskiDataset', 8 | ] 9 | -------------------------------------------------------------------------------- /geotransformer/datasets/registration/threedmatch/dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import random 4 | from typing import Dict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | 10 | from geotransformer.utils.pointcloud import ( 11 | random_sample_rotation, 12 | random_sample_rotation_v2, 13 | get_transform_from_rotation_translation, 14 | ) 15 | from geotransformer.utils.registration import get_correspondences 16 | 17 | 18 | class ThreeDMatchPairDataset(torch.utils.data.Dataset): 19 | def __init__( 20 | self, 21 | dataset_root, 22 | subset, 23 | point_limit=None, 24 | use_augmentation=False, 25 | augmentation_noise=0.005, 26 | augmentation_rotation=1, 27 | overlap_threshold=None, 28 | return_corr_indices=False, 29 | matching_radius=None, 30 | rotated=False, 31 | ): 32 | super(ThreeDMatchPairDataset, self).__init__() 33 | 34 | self.dataset_root = dataset_root 35 | self.metadata_root = osp.join(self.dataset_root, 'metadata') 36 | self.data_root = osp.join(self.dataset_root, 'data') 37 | 38 | self.subset = subset 39 | self.point_limit = point_limit 40 | self.overlap_threshold = overlap_threshold 41 | self.rotated = rotated 42 | 43 | self.return_corr_indices = return_corr_indices 44 | self.matching_radius = matching_radius 45 | if self.return_corr_indices and self.matching_radius is None: 46 | raise ValueError('"matching_radius" is None but "return_corr_indices" is set.') 47 | 48 | self.use_augmentation = use_augmentation 49 | self.aug_noise = augmentation_noise 50 | self.aug_rotation = augmentation_rotation 51 | 52 | with open(osp.join(self.metadata_root, f'{subset}.pkl'), 'rb') as f: 53 | self.metadata_list = pickle.load(f) 54 | if self.overlap_threshold is not None: 55 | self.metadata_list = [x for x in self.metadata_list if x['overlap'] > self.overlap_threshold] 56 | 57 | def __len__(self): 58 | return len(self.metadata_list) 59 | 60 | def _load_point_cloud(self, file_name): 61 | points = torch.load(osp.join(self.data_root, file_name)) 62 | # NOTE: setting "point_limit" with "num_workers" > 1 will cause nondeterminism. 63 | if self.point_limit is not None and points.shape[0] > self.point_limit: 64 | indices = np.random.permutation(points.shape[0])[: self.point_limit] 65 | points = points[indices] 66 | return points 67 | 68 | def _augment_point_cloud(self, ref_points, src_points, rotation, translation): 69 | r"""Augment point clouds. 70 | 71 | ref_points = src_points @ rotation.T + translation 72 | 73 | 1. Random rotation to one point cloud. 74 | 2. Random noise. 75 | """ 76 | aug_rotation = random_sample_rotation(self.aug_rotation) 77 | if random.random() > 0.5: 78 | ref_points = np.matmul(ref_points, aug_rotation.T) 79 | rotation = np.matmul(aug_rotation, rotation) 80 | translation = np.matmul(aug_rotation, translation) 81 | else: 82 | src_points = np.matmul(src_points, aug_rotation.T) 83 | rotation = np.matmul(rotation, aug_rotation.T) 84 | 85 | ref_points += (np.random.rand(ref_points.shape[0], 3) - 0.5) * self.aug_noise 86 | src_points += (np.random.rand(src_points.shape[0], 3) - 0.5) * self.aug_noise 87 | 88 | return ref_points, src_points, rotation, translation 89 | 90 | def __getitem__(self, index): 91 | data_dict = {} 92 | 93 | # metadata 94 | metadata: Dict = self.metadata_list[index] 95 | data_dict['scene_name'] = metadata['scene_name'] 96 | data_dict['ref_frame'] = metadata['frag_id0'] 97 | data_dict['src_frame'] = metadata['frag_id1'] 98 | data_dict['overlap'] = metadata['overlap'] 99 | 100 | # get transformation 101 | rotation = metadata['rotation'] 102 | translation = metadata['translation'] 103 | 104 | # get point cloud 105 | ref_points = self._load_point_cloud(metadata['pcd0']) 106 | src_points = self._load_point_cloud(metadata['pcd1']) 107 | 108 | # augmentation 109 | if self.use_augmentation: 110 | ref_points, src_points, rotation, translation = self._augment_point_cloud( 111 | ref_points, src_points, rotation, translation 112 | ) 113 | 114 | if self.rotated: 115 | ref_rotation = random_sample_rotation_v2() 116 | ref_points = np.matmul(ref_points, ref_rotation.T) 117 | rotation = np.matmul(ref_rotation, rotation) 118 | translation = np.matmul(ref_rotation, translation) 119 | 120 | src_rotation = random_sample_rotation_v2() 121 | src_points = np.matmul(src_points, src_rotation.T) 122 | rotation = np.matmul(rotation, src_rotation.T) 123 | 124 | transform = get_transform_from_rotation_translation(rotation, translation) 125 | 126 | # get correspondences 127 | if self.return_corr_indices: 128 | corr_indices = get_correspondences(ref_points, src_points, transform, self.matching_radius) 129 | data_dict['corr_indices'] = corr_indices 130 | 131 | data_dict['ref_points'] = ref_points.astype(np.float32) 132 | data_dict['src_points'] = src_points.astype(np.float32) 133 | data_dict['ref_feats'] = np.ones((ref_points.shape[0], 1), dtype=np.float32) 134 | data_dict['src_feats'] = np.ones((src_points.shape[0], 1), dtype=np.float32) 135 | data_dict['transform'] = transform.astype(np.float32) 136 | 137 | return data_dict 138 | -------------------------------------------------------------------------------- /geotransformer/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.engine.epoch_based_trainer import EpochBasedTrainer 2 | from geotransformer.engine.iter_based_trainer import IterBasedTrainer 3 | from geotransformer.engine.single_tester import SingleTester 4 | from geotransformer.engine.logger import Logger 5 | -------------------------------------------------------------------------------- /geotransformer/engine/base_tester.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os.path as osp 4 | import time 5 | import json 6 | import abc 7 | 8 | import torch 9 | import ipdb 10 | 11 | from geotransformer.utils.torch import initialize 12 | from geotransformer.engine.logger import Logger 13 | 14 | 15 | def inject_default_parser(parser=None): 16 | if parser is None: 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--snapshot', default=None, help='load from snapshot') 19 | parser.add_argument('--test_epoch', type=int, default=None, help='test epoch') 20 | parser.add_argument('--test_iter', type=int, default=None, help='test iteration') 21 | return parser 22 | 23 | 24 | class BaseTester(abc.ABC): 25 | def __init__(self, cfg, parser=None, cudnn_deterministic=True): 26 | # parser 27 | parser = inject_default_parser(parser) 28 | self.args = parser.parse_args() 29 | 30 | # logger 31 | log_file = osp.join(cfg.log_dir, 'test-{}.log'.format(time.strftime('%Y%m%d-%H%M%S'))) 32 | self.logger = Logger(log_file=log_file) 33 | 34 | # command executed 35 | message = 'Command executed: ' + ' '.join(sys.argv) 36 | self.logger.info(message) 37 | 38 | # find snapshot 39 | if self.args.snapshot is None: 40 | if self.args.test_epoch is not None: 41 | self.args.snapshot = osp.join(cfg.snapshot_dir, 'epoch-{}.pth.tar'.format(self.args.test_epoch)) 42 | elif self.args.test_iter is not None: 43 | self.args.snapshot = osp.join(cfg.snapshot_dir, 'iter-{}.pth.tar'.format(self.args.test_iter)) 44 | if self.args.snapshot is None: 45 | raise RuntimeError('Snapshot is not specified.') 46 | 47 | # print config 48 | message = 'Configs:\n' + json.dumps(cfg, indent=4) 49 | self.logger.info(message) 50 | 51 | # cuda and distributed 52 | if not torch.cuda.is_available(): 53 | raise RuntimeError('No CUDA devices available.') 54 | self.cudnn_deterministic = cudnn_deterministic 55 | self.seed = cfg.seed 56 | initialize(seed=self.seed, cudnn_deterministic=self.cudnn_deterministic) 57 | 58 | # state 59 | self.model = None 60 | self.iteration = None 61 | 62 | self.test_loader = None 63 | self.saved_states = {} 64 | 65 | def load_snapshot(self, snapshot): 66 | self.logger.info('Loading from "{}".'.format(snapshot)) 67 | state_dict = torch.load(snapshot, map_location=torch.device('cpu')) 68 | assert 'model' in state_dict, 'No model can be loaded.' 69 | self.model.load_state_dict(state_dict['model'], strict=True) 70 | self.logger.info('Model has been loaded.') 71 | 72 | def register_model(self, model): 73 | r"""Register model. DDP is automatically used.""" 74 | self.model = model 75 | message = 'Model description:\n' + str(model) 76 | self.logger.info(message) 77 | return model 78 | 79 | def register_loader(self, test_loader): 80 | r"""Register data loader.""" 81 | self.test_loader = test_loader 82 | 83 | @abc.abstractmethod 84 | def run(self): 85 | raise NotImplemented 86 | -------------------------------------------------------------------------------- /geotransformer/engine/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import coloredlogs 4 | 5 | 6 | def create_logger(log_file=None): 7 | logger = logging.getLogger() 8 | logger.handlers.clear() 9 | logger.setLevel(level=logging.DEBUG) 10 | logger.propagate = False 11 | 12 | format_str = '[%(asctime)s] [%(levelname).4s] %(message)s' 13 | 14 | stream_handler = logging.StreamHandler() 15 | colored_formatter = coloredlogs.ColoredFormatter(format_str) 16 | stream_handler.setFormatter(colored_formatter) 17 | logger.addHandler(stream_handler) 18 | 19 | if log_file is not None: 20 | file_handler = logging.FileHandler(log_file) 21 | formatter = logging.Formatter(format_str, datefmt='%Y-%m-%d %H:%M:%S') 22 | file_handler.setFormatter(formatter) 23 | logger.addHandler(file_handler) 24 | 25 | return logger 26 | 27 | 28 | class Logger: 29 | def __init__(self, log_file=None, local_rank=-1): 30 | if local_rank == 0 or local_rank == -1: 31 | self.logger = create_logger(log_file=log_file) 32 | else: 33 | self.logger = None 34 | 35 | def debug(self, message): 36 | if self.logger is not None: 37 | self.logger.debug(message) 38 | 39 | def info(self, message): 40 | if self.logger is not None: 41 | self.logger.info(message) 42 | 43 | def warning(self, message): 44 | if self.logger is not None: 45 | self.logger.warning(message) 46 | 47 | def error(self, message): 48 | if self.logger is not None: 49 | self.logger.error(message) 50 | 51 | def critical(self, message): 52 | if self.logger is not None: 53 | self.logger.critical(message) 54 | -------------------------------------------------------------------------------- /geotransformer/engine/single_tester.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import ipdb 5 | from tqdm import tqdm 6 | 7 | from geotransformer.engine.base_tester import BaseTester 8 | from geotransformer.utils.summary_board import SummaryBoard 9 | from geotransformer.utils.timer import Timer 10 | from geotransformer.utils.common import get_log_string 11 | from geotransformer.utils.torch import release_cuda, to_cuda 12 | 13 | 14 | class SingleTester(BaseTester): 15 | def __init__(self, cfg, parser=None, cudnn_deterministic=True): 16 | super().__init__(cfg, parser=parser, cudnn_deterministic=cudnn_deterministic) 17 | 18 | def before_test_epoch(self): 19 | pass 20 | 21 | def before_test_step(self, iteration, data_dict): 22 | pass 23 | 24 | def test_step(self, iteration, data_dict) -> Dict: 25 | pass 26 | 27 | def eval_step(self, iteration, data_dict, output_dict) -> Dict: 28 | pass 29 | 30 | def after_test_step(self, iteration, data_dict, output_dict, result_dict): 31 | pass 32 | 33 | def after_test_epoch(self): 34 | pass 35 | 36 | def summary_string(self, iteration, data_dict, output_dict, result_dict): 37 | return get_log_string(result_dict) 38 | 39 | def run(self): 40 | assert self.test_loader is not None 41 | self.load_snapshot(self.args.snapshot) 42 | self.model.eval() 43 | torch.set_grad_enabled(False) 44 | self.before_test_epoch() 45 | summary_board = SummaryBoard(adaptive=True) 46 | timer = Timer() 47 | total_iterations = len(self.test_loader) 48 | pbar = tqdm(enumerate(self.test_loader), total=total_iterations) 49 | for iteration, data_dict in pbar: 50 | # on start 51 | self.iteration = iteration + 1 52 | data_dict = to_cuda(data_dict) 53 | self.before_test_step(self.iteration, data_dict) 54 | # test step 55 | torch.cuda.synchronize() 56 | timer.add_prepare_time() 57 | output_dict = self.test_step(self.iteration, data_dict) 58 | torch.cuda.synchronize() 59 | timer.add_process_time() 60 | # eval step 61 | result_dict = self.eval_step(self.iteration, data_dict, output_dict) 62 | # after step 63 | self.after_test_step(self.iteration, data_dict, output_dict, result_dict) 64 | # logging 65 | result_dict = release_cuda(result_dict) 66 | summary_board.update_from_result_dict(result_dict) 67 | message = self.summary_string(self.iteration, data_dict, output_dict, result_dict) 68 | message += f', {timer.tostring()}' 69 | pbar.set_description(message) 70 | torch.cuda.empty_cache() 71 | self.after_test_epoch() 72 | summary_dict = summary_board.summary() 73 | message = get_log_string(result_dict=summary_dict, timer=timer) 74 | self.logger.critical(message) 75 | -------------------------------------------------------------------------------- /geotransformer/extensions/common/torch_helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #define CHECK_CUDA(x) \ 7 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 8 | 9 | #define CHECK_CPU(x) \ 10 | TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor") 11 | 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | 15 | #define CHECK_INPUT(x) \ 16 | CHECK_CUDA(x); \ 17 | CHECK_CONTIGUOUS(x) 18 | 19 | #define CHECK_IS_INT(x) \ 20 | do { \ 21 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, \ 22 | #x " must be an int tensor"); \ 23 | } while (0) 24 | 25 | #define CHECK_IS_LONG(x) \ 26 | do { \ 27 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Long, \ 28 | #x " must be an long tensor"); \ 29 | } while (0) 30 | 31 | #define CHECK_IS_FLOAT(x) \ 32 | do { \ 33 | TORCH_CHECK(x.scalar_type() == at::ScalarType::Float, \ 34 | #x " must be a float tensor"); \ 35 | } while (0) 36 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/grid_subsampling/grid_subsampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "grid_subsampling.h" 3 | #include "grid_subsampling_cpu.h" 4 | 5 | std::vector grid_subsampling( 6 | at::Tensor points, 7 | at::Tensor lengths, 8 | float voxel_size 9 | ) { 10 | CHECK_CPU(points); 11 | CHECK_CPU(lengths); 12 | CHECK_IS_FLOAT(points); 13 | CHECK_IS_LONG(lengths); 14 | CHECK_CONTIGUOUS(points); 15 | CHECK_CONTIGUOUS(lengths); 16 | 17 | std::size_t batch_size = lengths.size(0); 18 | std::size_t total_points = points.size(0); 19 | 20 | std::vector vec_points = std::vector( 21 | reinterpret_cast(points.data_ptr()), 22 | reinterpret_cast(points.data_ptr()) + total_points 23 | ); 24 | std::vector vec_s_points; 25 | 26 | std::vector vec_lengths = std::vector( 27 | lengths.data_ptr(), 28 | lengths.data_ptr() + batch_size 29 | ); 30 | std::vector vec_s_lengths; 31 | 32 | grid_subsampling_cpu( 33 | vec_points, 34 | vec_s_points, 35 | vec_lengths, 36 | vec_s_lengths, 37 | voxel_size 38 | ); 39 | 40 | std::size_t total_s_points = vec_s_points.size(); 41 | at::Tensor s_points = torch::zeros( 42 | {total_s_points, 3}, 43 | at::device(points.device()).dtype(at::ScalarType::Float) 44 | ); 45 | at::Tensor s_lengths = torch::zeros( 46 | {batch_size}, 47 | at::device(lengths.device()).dtype(at::ScalarType::Long) 48 | ); 49 | 50 | std::memcpy( 51 | s_points.data_ptr(), 52 | reinterpret_cast(vec_s_points.data()), 53 | sizeof(float) * total_s_points * 3 54 | ); 55 | std::memcpy( 56 | s_lengths.data_ptr(), 57 | vec_s_lengths.data(), 58 | sizeof(long) * batch_size 59 | ); 60 | 61 | return {s_points, s_lengths}; 62 | } 63 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/grid_subsampling/grid_subsampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "../../common/torch_helper.h" 5 | 6 | std::vector grid_subsampling( 7 | at::Tensor points, 8 | at::Tensor lengths, 9 | float voxel_size 10 | ); 11 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/grid_subsampling/grid_subsampling_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "grid_subsampling_cpu.h" 2 | 3 | void single_grid_subsampling_cpu( 4 | std::vector& points, 5 | std::vector& s_points, 6 | float voxel_size 7 | ) { 8 | // float sub_scale = 1. / voxel_size; 9 | PointXYZ minCorner = min_point(points); 10 | PointXYZ maxCorner = max_point(points); 11 | PointXYZ originCorner = floor(minCorner * (1. / voxel_size)) * voxel_size; 12 | 13 | std::size_t sampleNX = static_cast( 14 | // floor((maxCorner.x - originCorner.x) * sub_scale) + 1 15 | floor((maxCorner.x - originCorner.x) / voxel_size) + 1 16 | ); 17 | std::size_t sampleNY = static_cast( 18 | // floor((maxCorner.y - originCorner.y) * sub_scale) + 1 19 | floor((maxCorner.y - originCorner.y) / voxel_size) + 1 20 | ); 21 | 22 | std::size_t iX = 0; 23 | std::size_t iY = 0; 24 | std::size_t iZ = 0; 25 | std::size_t mapIdx = 0; 26 | std::unordered_map data; 27 | 28 | for (auto& p : points) { 29 | // iX = static_cast(floor((p.x - originCorner.x) * sub_scale)); 30 | // iY = static_cast(floor((p.y - originCorner.y) * sub_scale)); 31 | // iZ = static_cast(floor((p.z - originCorner.z) * sub_scale)); 32 | iX = static_cast(floor((p.x - originCorner.x) / voxel_size)); 33 | iY = static_cast(floor((p.y - originCorner.y) / voxel_size)); 34 | iZ = static_cast(floor((p.z - originCorner.z) / voxel_size)); 35 | mapIdx = iX + sampleNX * iY + sampleNX * sampleNY * iZ; 36 | 37 | if (!data.count(mapIdx)) { 38 | data.emplace(mapIdx, SampledData()); 39 | } 40 | 41 | data[mapIdx].update(p); 42 | } 43 | 44 | s_points.reserve(data.size()); 45 | for (auto& v : data) { 46 | s_points.push_back(v.second.point * (1.0 / v.second.count)); 47 | } 48 | } 49 | 50 | void grid_subsampling_cpu( 51 | std::vector& points, 52 | std::vector& s_points, 53 | std::vector& lengths, 54 | std::vector& s_lengths, 55 | float voxel_size 56 | ) { 57 | std::size_t start_index = 0; 58 | std::size_t batch_size = lengths.size(); 59 | for (std::size_t b = 0; b < batch_size; b++) { 60 | std::vector cur_points = std::vector( 61 | points.begin() + start_index, 62 | points.begin() + start_index + lengths[b] 63 | ); 64 | std::vector cur_s_points; 65 | 66 | single_grid_subsampling_cpu(cur_points, cur_s_points, voxel_size); 67 | 68 | s_points.insert(s_points.end(), cur_s_points.begin(), cur_s_points.end()); 69 | s_lengths.push_back(cur_s_points.size()); 70 | 71 | start_index += lengths[b]; 72 | } 73 | 74 | return; 75 | } 76 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/grid_subsampling/grid_subsampling_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "../../extra/cloud/cloud.h" 6 | 7 | class SampledData { 8 | public: 9 | int count; 10 | PointXYZ point; 11 | 12 | SampledData() { 13 | count = 0; 14 | point = PointXYZ(); 15 | } 16 | 17 | void update(const PointXYZ& p) { 18 | count += 1; 19 | point += p; 20 | } 21 | }; 22 | 23 | void single_grid_subsampling_cpu( 24 | std::vector& o_points, 25 | std::vector& s_points, 26 | float voxel_size 27 | ); 28 | 29 | void grid_subsampling_cpu( 30 | std::vector& o_points, 31 | std::vector& s_points, 32 | std::vector& o_lengths, 33 | std::vector& s_lengths, 34 | float voxel_size 35 | ); 36 | 37 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/radius_neighbors/radius_neighbors.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "radius_neighbors.h" 3 | #include "radius_neighbors_cpu.h" 4 | 5 | at::Tensor radius_neighbors( 6 | at::Tensor q_points, 7 | at::Tensor s_points, 8 | at::Tensor q_lengths, 9 | at::Tensor s_lengths, 10 | float radius 11 | ) { 12 | CHECK_CPU(q_points); 13 | CHECK_CPU(s_points); 14 | CHECK_CPU(q_lengths); 15 | CHECK_CPU(s_lengths); 16 | CHECK_IS_FLOAT(q_points); 17 | CHECK_IS_FLOAT(s_points); 18 | CHECK_IS_LONG(q_lengths); 19 | CHECK_IS_LONG(s_lengths); 20 | CHECK_CONTIGUOUS(q_points); 21 | CHECK_CONTIGUOUS(s_points); 22 | CHECK_CONTIGUOUS(q_lengths); 23 | CHECK_CONTIGUOUS(s_lengths); 24 | 25 | std::size_t total_q_points = q_points.size(0); 26 | std::size_t total_s_points = s_points.size(0); 27 | std::size_t batch_size = q_lengths.size(0); 28 | 29 | std::vector vec_q_points = std::vector( 30 | reinterpret_cast(q_points.data_ptr()), 31 | reinterpret_cast(q_points.data_ptr()) + total_q_points 32 | ); 33 | std::vector vec_s_points = std::vector( 34 | reinterpret_cast(s_points.data_ptr()), 35 | reinterpret_cast(s_points.data_ptr()) + total_s_points 36 | ); 37 | std::vector vec_q_lengths = std::vector( 38 | q_lengths.data_ptr(), q_lengths.data_ptr() + batch_size 39 | ); 40 | std::vector vec_s_lengths = std::vector( 41 | s_lengths.data_ptr(), s_lengths.data_ptr() + batch_size 42 | ); 43 | std::vector vec_neighbor_indices; 44 | 45 | radius_neighbors_cpu( 46 | vec_q_points, 47 | vec_s_points, 48 | vec_q_lengths, 49 | vec_s_lengths, 50 | vec_neighbor_indices, 51 | radius 52 | ); 53 | 54 | std::size_t max_neighbors = vec_neighbor_indices.size() / total_q_points; 55 | 56 | at::Tensor neighbor_indices = torch::zeros( 57 | {total_q_points, max_neighbors}, 58 | at::device(q_points.device()).dtype(at::ScalarType::Long) 59 | ); 60 | 61 | std::memcpy( 62 | neighbor_indices.data_ptr(), 63 | vec_neighbor_indices.data(), 64 | sizeof(long) * total_q_points * max_neighbors 65 | ); 66 | 67 | return neighbor_indices; 68 | } 69 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/radius_neighbors/radius_neighbors.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../../common/torch_helper.h" 4 | 5 | at::Tensor radius_neighbors( 6 | at::Tensor q_points, 7 | at::Tensor s_points, 8 | at::Tensor q_lengths, 9 | at::Tensor s_lengths, 10 | float radius 11 | ); 12 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/radius_neighbors/radius_neighbors_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "radius_neighbors_cpu.h" 2 | 3 | void radius_neighbors_cpu( 4 | std::vector& q_points, 5 | std::vector& s_points, 6 | std::vector& q_lengths, 7 | std::vector& s_lengths, 8 | std::vector& neighbor_indices, 9 | float radius 10 | ) { 11 | std::size_t i0 = 0; 12 | float r2 = radius * radius; 13 | 14 | std::size_t max_count = 0; 15 | std::vector>> all_inds_dists( 16 | q_points.size() 17 | ); 18 | 19 | std::size_t b = 0; 20 | std::size_t q_start_index = 0; 21 | std::size_t s_start_index = 0; 22 | 23 | PointCloud current_cloud; 24 | current_cloud.pts = std::vector( 25 | s_points.begin() + s_start_index, 26 | s_points.begin() + s_start_index + s_lengths[b] 27 | ); 28 | 29 | nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10); 30 | my_kd_tree_t* index = new my_kd_tree_t(3, current_cloud, tree_params);; 31 | index->buildIndex(); 32 | 33 | nanoflann::SearchParams search_params; 34 | search_params.sorted = true; 35 | 36 | for (auto& p0 : q_points) { 37 | if (i0 == q_start_index + q_lengths[b]) { 38 | q_start_index += q_lengths[b]; 39 | s_start_index += s_lengths[b]; 40 | b++; 41 | 42 | current_cloud.pts.clear(); 43 | current_cloud.pts = std::vector( 44 | s_points.begin() + s_start_index, 45 | s_points.begin() + s_start_index + s_lengths[b] 46 | ); 47 | 48 | delete index; 49 | index = new my_kd_tree_t(3, current_cloud, tree_params); 50 | index->buildIndex(); 51 | } 52 | 53 | all_inds_dists[i0].reserve(max_count); 54 | float query_pt[3] = {p0.x, p0.y, p0.z}; 55 | std::size_t nMatches = index->radiusSearch( 56 | query_pt, r2, all_inds_dists[i0], search_params 57 | ); 58 | 59 | if (nMatches > max_count) { 60 | max_count = nMatches; 61 | } 62 | 63 | i0++; 64 | } 65 | 66 | delete index; 67 | 68 | neighbor_indices.resize(q_points.size() * max_count); 69 | i0 = 0; 70 | s_start_index = 0; 71 | q_start_index = 0; 72 | b = 0; 73 | for (auto& inds_dists : all_inds_dists) { 74 | if (i0 == q_start_index + q_lengths[b]) { 75 | q_start_index += q_lengths[b]; 76 | s_start_index += s_lengths[b]; 77 | b++; 78 | } 79 | 80 | for (std::size_t j = 0; j < max_count; j++) { 81 | std::size_t i = i0 * max_count + j; 82 | if (j < inds_dists.size()) { 83 | neighbor_indices[i] = inds_dists[j].first + s_start_index; 84 | } else { 85 | neighbor_indices[i] = s_points.size(); 86 | } 87 | } 88 | 89 | i0++; 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /geotransformer/extensions/cpu/radius_neighbors/radius_neighbors_cpu.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include "../../extra/cloud/cloud.h" 3 | #include "../../extra/nanoflann/nanoflann.hpp" 4 | 5 | typedef nanoflann::KDTreeSingleIndexAdaptor< 6 | nanoflann::L2_Simple_Adaptor, PointCloud, 3 7 | > my_kd_tree_t; 8 | 9 | void radius_neighbors_cpu( 10 | std::vector& q_points, 11 | std::vector& s_points, 12 | std::vector& q_lengths, 13 | std::vector& s_lengths, 14 | std::vector& neighbor_indices, 15 | float radius 16 | ); 17 | -------------------------------------------------------------------------------- /geotransformer/extensions/extra/cloud/cloud.cpp: -------------------------------------------------------------------------------- 1 | // Modified from https://github.com/HuguesTHOMAS/KPConv-PyTorch 2 | #include "cloud.h" 3 | 4 | PointXYZ max_point(std::vector points) { 5 | PointXYZ maxP(points[0]); 6 | 7 | for (auto p : points) { 8 | if (p.x > maxP.x) { 9 | maxP.x = p.x; 10 | } 11 | if (p.y > maxP.y) { 12 | maxP.y = p.y; 13 | } 14 | if (p.z > maxP.z) { 15 | maxP.z = p.z; 16 | } 17 | } 18 | 19 | return maxP; 20 | } 21 | 22 | PointXYZ min_point(std::vector points) { 23 | PointXYZ minP(points[0]); 24 | 25 | for (auto p : points) { 26 | if (p.x < minP.x) { 27 | minP.x = p.x; 28 | } 29 | if (p.y < minP.y) { 30 | minP.y = p.y; 31 | } 32 | if (p.z < minP.z) { 33 | minP.z = p.z; 34 | } 35 | } 36 | 37 | return minP; 38 | } -------------------------------------------------------------------------------- /geotransformer/extensions/extra/cloud/cloud.h: -------------------------------------------------------------------------------- 1 | // Modified from https://github.com/HuguesTHOMAS/KPConv-PyTorch 2 | #pragma once 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | class PointXYZ { 15 | public: 16 | float x, y, z; 17 | 18 | PointXYZ() { 19 | x = 0; 20 | y = 0; 21 | z = 0; 22 | } 23 | 24 | PointXYZ(float x0, float y0, float z0) { 25 | x = x0; 26 | y = y0; 27 | z = z0; 28 | } 29 | 30 | float operator [] (int i) const { 31 | if (i == 0) { 32 | return x; 33 | } 34 | else if (i == 1) { 35 | return y; 36 | } 37 | else { 38 | return z; 39 | } 40 | } 41 | 42 | float dot(const PointXYZ P) const { 43 | return x * P.x + y * P.y + z * P.z; 44 | } 45 | 46 | float sq_norm() { 47 | return x * x + y * y + z * z; 48 | } 49 | 50 | PointXYZ cross(const PointXYZ P) const { 51 | return PointXYZ(y * P.z - z * P.y, z * P.x - x * P.z, x * P.y - y * P.x); 52 | } 53 | 54 | PointXYZ& operator+=(const PointXYZ& P) { 55 | x += P.x; 56 | y += P.y; 57 | z += P.z; 58 | return *this; 59 | } 60 | 61 | PointXYZ& operator-=(const PointXYZ& P) { 62 | x -= P.x; 63 | y -= P.y; 64 | z -= P.z; 65 | return *this; 66 | } 67 | 68 | PointXYZ& operator*=(const float& a) { 69 | x *= a; 70 | y *= a; 71 | z *= a; 72 | return *this; 73 | } 74 | }; 75 | 76 | inline PointXYZ operator + (const PointXYZ A, const PointXYZ B) { 77 | return PointXYZ(A.x + B.x, A.y + B.y, A.z + B.z); 78 | } 79 | 80 | inline PointXYZ operator - (const PointXYZ A, const PointXYZ B) { 81 | return PointXYZ(A.x - B.x, A.y - B.y, A.z - B.z); 82 | } 83 | 84 | inline PointXYZ operator * (const PointXYZ P, const float a) { 85 | return PointXYZ(P.x * a, P.y * a, P.z * a); 86 | } 87 | 88 | inline PointXYZ operator * (const float a, const PointXYZ P) { 89 | return PointXYZ(P.x * a, P.y * a, P.z * a); 90 | } 91 | 92 | inline std::ostream& operator << (std::ostream& os, const PointXYZ P) { 93 | return os << "[" << P.x << ", " << P.y << ", " << P.z << "]"; 94 | } 95 | 96 | inline bool operator == (const PointXYZ A, const PointXYZ B) { 97 | return A.x == B.x && A.y == B.y && A.z == B.z; 98 | } 99 | 100 | inline PointXYZ floor(const PointXYZ P) { 101 | return PointXYZ(std::floor(P.x), std::floor(P.y), std::floor(P.z)); 102 | } 103 | 104 | PointXYZ max_point(std::vector points); 105 | 106 | PointXYZ min_point(std::vector points); 107 | 108 | struct PointCloud { 109 | std::vector pts; 110 | 111 | inline size_t kdtree_get_point_count() const { 112 | return pts.size(); 113 | } 114 | 115 | // Returns the dim'th component of the idx'th point in the class: 116 | // Since this is inlined and the "dim" argument is typically an immediate value, the 117 | // "if/else's" are actually solved at compile time. 118 | inline float kdtree_get_pt(const size_t idx, const size_t dim) const { 119 | if (dim == 0) { 120 | return pts[idx].x; 121 | } 122 | else if (dim == 1) { 123 | return pts[idx].y; 124 | } 125 | else { 126 | return pts[idx].z; 127 | } 128 | } 129 | 130 | // Optional bounding-box computation: return false to default to a standard bbox computation loop. 131 | // Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again. 132 | // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds) 133 | template 134 | bool kdtree_get_bbox(BBOX& /* bb */) const { 135 | return false; 136 | } 137 | }; 138 | -------------------------------------------------------------------------------- /geotransformer/extensions/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cpu/radius_neighbors/radius_neighbors.h" 4 | #include "cpu/grid_subsampling/grid_subsampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | // CPU extensions 8 | m.def( 9 | "radius_neighbors", 10 | &radius_neighbors, 11 | "Radius neighbors (CPU)" 12 | ); 13 | m.def( 14 | "grid_subsampling", 15 | &grid_subsampling, 16 | "Grid subsampling (CPU)" 17 | ); 18 | } 19 | -------------------------------------------------------------------------------- /geotransformer/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/modules/__init__.py -------------------------------------------------------------------------------- /geotransformer/modules/geotransformer/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.geotransformer.geotransformer import GeometricStructureEmbedding, GeometricTransformer 2 | from geotransformer.modules.geotransformer.superpoint_matching import SuperPointMatching 3 | from geotransformer.modules.geotransformer.superpoint_target import SuperPointTargetGenerator 4 | from geotransformer.modules.geotransformer.point_matching import PointMatching 5 | from geotransformer.modules.geotransformer.local_global_registration import LocalGlobalRegistration 6 | -------------------------------------------------------------------------------- /geotransformer/modules/geotransformer/geotransformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from geotransformer.modules.ops import pairwise_distance 6 | from geotransformer.modules.transformer import SinusoidalPositionalEmbedding, RPEConditionalTransformer 7 | 8 | 9 | class GeometricStructureEmbedding(nn.Module): 10 | def __init__(self, hidden_dim, sigma_d, sigma_a, angle_k, reduction_a='max'): 11 | super(GeometricStructureEmbedding, self).__init__() 12 | self.sigma_d = sigma_d 13 | self.sigma_a = sigma_a 14 | self.factor_a = 180.0 / (self.sigma_a * np.pi) 15 | self.angle_k = angle_k 16 | 17 | self.embedding = SinusoidalPositionalEmbedding(hidden_dim) 18 | self.proj_d = nn.Linear(hidden_dim, hidden_dim) 19 | self.proj_a = nn.Linear(hidden_dim, hidden_dim) 20 | 21 | self.reduction_a = reduction_a 22 | if self.reduction_a not in ['max', 'mean']: 23 | raise ValueError(f'Unsupported reduction mode: {self.reduction_a}.') 24 | 25 | @torch.no_grad() 26 | def get_embedding_indices(self, points): 27 | r"""Compute the indices of pair-wise distance embedding and triplet-wise angular embedding. 28 | 29 | Args: 30 | points: torch.Tensor (B, N, 3), input point cloud 31 | 32 | Returns: 33 | d_indices: torch.FloatTensor (B, N, N), distance embedding indices 34 | a_indices: torch.FloatTensor (B, N, N, k), angular embedding indices 35 | """ 36 | batch_size, num_point, _ = points.shape 37 | 38 | dist_map = torch.sqrt(pairwise_distance(points, points)) # (B, N, N) 39 | d_indices = dist_map / self.sigma_d 40 | 41 | k = self.angle_k 42 | knn_indices = dist_map.topk(k=k + 1, dim=2, largest=False)[1][:, :, 1:] # (B, N, k) 43 | knn_indices = knn_indices.unsqueeze(3).expand(batch_size, num_point, k, 3) # (B, N, k, 3) 44 | expanded_points = points.unsqueeze(1).expand(batch_size, num_point, num_point, 3) # (B, N, N, 3) 45 | knn_points = torch.gather(expanded_points, dim=2, index=knn_indices) # (B, N, k, 3) 46 | ref_vectors = knn_points - points.unsqueeze(2) # (B, N, k, 3) 47 | anc_vectors = points.unsqueeze(1) - points.unsqueeze(2) # (B, N, N, 3) 48 | ref_vectors = ref_vectors.unsqueeze(2).expand(batch_size, num_point, num_point, k, 3) # (B, N, N, k, 3) 49 | anc_vectors = anc_vectors.unsqueeze(3).expand(batch_size, num_point, num_point, k, 3) # (B, N, N, k, 3) 50 | sin_values = torch.linalg.norm(torch.cross(ref_vectors, anc_vectors, dim=-1), dim=-1) # (B, N, N, k) 51 | cos_values = torch.sum(ref_vectors * anc_vectors, dim=-1) # (B, N, N, k) 52 | angles = torch.atan2(sin_values, cos_values) # (B, N, N, k) 53 | a_indices = angles * self.factor_a 54 | 55 | return d_indices, a_indices 56 | 57 | def forward(self, points): 58 | d_indices, a_indices = self.get_embedding_indices(points) 59 | 60 | d_embeddings = self.embedding(d_indices) 61 | d_embeddings = self.proj_d(d_embeddings) 62 | 63 | a_embeddings = self.embedding(a_indices) 64 | a_embeddings = self.proj_a(a_embeddings) 65 | if self.reduction_a == 'max': 66 | a_embeddings = a_embeddings.max(dim=3)[0] 67 | else: 68 | a_embeddings = a_embeddings.mean(dim=3) 69 | 70 | embeddings = d_embeddings + a_embeddings 71 | 72 | return embeddings 73 | 74 | 75 | class GeometricTransformer(nn.Module): 76 | def __init__( 77 | self, 78 | input_dim, 79 | output_dim, 80 | hidden_dim, 81 | num_heads, 82 | blocks, 83 | sigma_d, 84 | sigma_a, 85 | angle_k, 86 | dropout=None, 87 | activation_fn='ReLU', 88 | reduction_a='max', 89 | ): 90 | r"""Geometric Transformer (GeoTransformer). 91 | 92 | Args: 93 | input_dim: input feature dimension 94 | output_dim: output feature dimension 95 | hidden_dim: hidden feature dimension 96 | num_heads: number of head in transformer 97 | blocks: list of 'self' or 'cross' 98 | sigma_d: temperature of distance 99 | sigma_a: temperature of angles 100 | angle_k: number of nearest neighbors for angular embedding 101 | activation_fn: activation function 102 | reduction_a: reduction mode of angular embedding ['max', 'mean'] 103 | """ 104 | super(GeometricTransformer, self).__init__() 105 | 106 | self.embedding = GeometricStructureEmbedding(hidden_dim, sigma_d, sigma_a, angle_k, reduction_a=reduction_a) 107 | 108 | self.in_proj = nn.Linear(input_dim, hidden_dim) 109 | self.transformer = RPEConditionalTransformer( 110 | blocks, hidden_dim, num_heads, dropout=dropout, activation_fn=activation_fn 111 | ) 112 | self.out_proj = nn.Linear(hidden_dim, output_dim) 113 | 114 | def forward( 115 | self, 116 | ref_points, 117 | src_points, 118 | ref_feats, 119 | src_feats, 120 | ref_masks=None, 121 | src_masks=None, 122 | ): 123 | r"""Geometric Transformer 124 | 125 | Args: 126 | ref_points (Tensor): (B, N, 3) 127 | src_points (Tensor): (B, M, 3) 128 | ref_feats (Tensor): (B, N, C) 129 | src_feats (Tensor): (B, M, C) 130 | ref_masks (Optional[BoolTensor]): (B, N) 131 | src_masks (Optional[BoolTensor]): (B, M) 132 | 133 | Returns: 134 | ref_feats: torch.Tensor (B, N, C) 135 | src_feats: torch.Tensor (B, M, C) 136 | """ 137 | ref_embeddings = self.embedding(ref_points) 138 | src_embeddings = self.embedding(src_points) 139 | 140 | ref_feats = self.in_proj(ref_feats) 141 | src_feats = self.in_proj(src_feats) 142 | 143 | ref_feats, src_feats = self.transformer( 144 | ref_feats, 145 | src_feats, 146 | ref_embeddings, 147 | src_embeddings, 148 | masks0=ref_masks, 149 | masks1=src_masks, 150 | ) 151 | 152 | ref_feats = self.out_proj(ref_feats) 153 | src_feats = self.out_proj(src_feats) 154 | 155 | return ref_feats, src_feats 156 | -------------------------------------------------------------------------------- /geotransformer/modules/geotransformer/point_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PointMatching(nn.Module): 6 | def __init__( 7 | self, 8 | k: int, 9 | mutual: bool = True, 10 | confidence_threshold: float = 0.05, 11 | use_dustbin: bool = False, 12 | use_global_score: bool = False, 13 | remove_duplicate: bool = False, 14 | ): 15 | r"""Point Matching with Local-to-Global Registration. 16 | 17 | Args: 18 | k (int): top-k selection for matching. 19 | mutual (bool=True): mutual or non-mutual matching. 20 | confidence_threshold (float=0.05): ignore matches whose scores are below this threshold. 21 | use_dustbin (bool=False): whether dustbin row/column is used in the score matrix. 22 | use_global_score (bool=False): whether use patch correspondence scores. 23 | """ 24 | super(PointMatching, self).__init__() 25 | self.k = k 26 | self.mutual = mutual 27 | self.confidence_threshold = confidence_threshold 28 | self.use_dustbin = use_dustbin 29 | self.use_global_score = use_global_score 30 | self.remove_duplicate = remove_duplicate 31 | 32 | def compute_correspondence_matrix(self, score_mat, ref_knn_masks, src_knn_masks): 33 | r"""Compute matching matrix and score matrix for each patch correspondence.""" 34 | mask_mat = torch.logical_and(ref_knn_masks.unsqueeze(2), src_knn_masks.unsqueeze(1)) 35 | 36 | batch_size, ref_length, src_length = score_mat.shape 37 | batch_indices = torch.arange(batch_size).cuda() 38 | 39 | # correspondences from reference side 40 | ref_topk_scores, ref_topk_indices = score_mat.topk(k=self.k, dim=2) # (B, N, K) 41 | ref_batch_indices = batch_indices.view(batch_size, 1, 1).expand(-1, ref_length, self.k) # (B, N, K) 42 | ref_indices = torch.arange(ref_length).cuda().view(1, ref_length, 1).expand(batch_size, -1, self.k) # (B, N, K) 43 | ref_score_mat = torch.zeros_like(score_mat) 44 | ref_score_mat[ref_batch_indices, ref_indices, ref_topk_indices] = ref_topk_scores 45 | ref_corr_mat = torch.gt(ref_score_mat, self.confidence_threshold) 46 | 47 | # correspondences from source side 48 | src_topk_scores, src_topk_indices = score_mat.topk(k=self.k, dim=1) # (B, K, N) 49 | src_batch_indices = batch_indices.view(batch_size, 1, 1).expand(-1, self.k, src_length) # (B, K, N) 50 | src_indices = torch.arange(src_length).cuda().view(1, 1, src_length).expand(batch_size, self.k, -1) # (B, K, N) 51 | src_score_mat = torch.zeros_like(score_mat) 52 | src_score_mat[src_batch_indices, src_topk_indices, src_indices] = src_topk_scores 53 | src_corr_mat = torch.gt(src_score_mat, self.confidence_threshold) 54 | 55 | # merge results from two sides 56 | if self.mutual: 57 | corr_mat = torch.logical_and(ref_corr_mat, src_corr_mat) 58 | else: 59 | corr_mat = torch.logical_or(ref_corr_mat, src_corr_mat) 60 | 61 | if self.use_dustbin: 62 | corr_mat = corr_mat[:, -1:, -1] 63 | 64 | corr_mat = torch.logical_and(corr_mat, mask_mat) 65 | 66 | return corr_mat 67 | 68 | def forward( 69 | self, 70 | ref_knn_points, 71 | src_knn_points, 72 | ref_knn_masks, 73 | src_knn_masks, 74 | ref_knn_indices, 75 | src_knn_indices, 76 | score_mat, 77 | global_scores, 78 | ): 79 | r"""Point Matching Module forward propagation with Local-to-Global registration. 80 | 81 | Args: 82 | ref_knn_points (Tensor): (B, K, 3) 83 | src_knn_points (Tensor): (B, K, 3) 84 | ref_knn_masks (BoolTensor): (B, K) 85 | src_knn_masks (BoolTensor): (B, K) 86 | ref_knn_indices (LongTensor): (B, K) 87 | src_knn_indices (LongTensor): (B, K) 88 | score_mat (Tensor): (B, K, K) or (B, K + 1, K + 1), log likelihood 89 | global_scores (Tensor): (B,) 90 | 91 | Returns: 92 | ref_corr_points (Tensor): (C, 3) 93 | src_corr_points (Tensor): (C, 3) 94 | ref_corr_indices (LongTensor): (C,) 95 | src_corr_indices (LongTensor): (C,) 96 | corr_scores (Tensor): (C,) 97 | """ 98 | score_mat = torch.exp(score_mat) 99 | 100 | corr_mat = self.compute_correspondence_matrix(score_mat, ref_knn_masks, src_knn_masks) # (B, K, K) 101 | 102 | if self.use_dustbin: 103 | score_mat = score_mat[:, :-1, :-1] 104 | if self.use_global_score: 105 | score_mat = score_mat * global_scores.view(-1, 1, 1) 106 | score_mat = score_mat * corr_mat.float() 107 | 108 | batch_indices, ref_indices, src_indices = torch.nonzero(corr_mat, as_tuple=True) 109 | ref_corr_indices = ref_knn_indices[batch_indices, ref_indices] 110 | src_corr_indices = src_knn_indices[batch_indices, src_indices] 111 | ref_corr_points = ref_knn_points[batch_indices, ref_indices] 112 | src_corr_points = src_knn_points[batch_indices, src_indices] 113 | corr_scores = score_mat[batch_indices, ref_indices, src_indices] 114 | 115 | return ref_corr_points, src_corr_points, ref_corr_indices, src_corr_indices, corr_scores 116 | -------------------------------------------------------------------------------- /geotransformer/modules/geotransformer/superpoint_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from geotransformer.modules.ops import pairwise_distance 5 | 6 | 7 | class SuperPointMatching(nn.Module): 8 | def __init__(self, num_correspondences, dual_normalization=True): 9 | super(SuperPointMatching, self).__init__() 10 | self.num_correspondences = num_correspondences 11 | self.dual_normalization = dual_normalization 12 | 13 | def forward(self, ref_feats, src_feats, ref_masks=None, src_masks=None): 14 | r"""Extract superpoint correspondences. 15 | 16 | Args: 17 | ref_feats (Tensor): features of the superpoints in reference point cloud. 18 | src_feats (Tensor): features of the superpoints in source point cloud. 19 | ref_masks (BoolTensor=None): masks of the superpoints in reference point cloud (False if empty). 20 | src_masks (BoolTensor=None): masks of the superpoints in source point cloud (False if empty). 21 | 22 | Returns: 23 | ref_corr_indices (LongTensor): indices of the corresponding superpoints in reference point cloud. 24 | src_corr_indices (LongTensor): indices of the corresponding superpoints in source point cloud. 25 | corr_scores (Tensor): scores of the correspondences. 26 | """ 27 | if ref_masks is None: 28 | ref_masks = torch.ones(size=(ref_feats.shape[0],), dtype=torch.bool).cuda() 29 | if src_masks is None: 30 | src_masks = torch.ones(size=(src_feats.shape[0],), dtype=torch.bool).cuda() 31 | # remove empty patch 32 | ref_indices = torch.nonzero(ref_masks, as_tuple=True)[0] 33 | src_indices = torch.nonzero(src_masks, as_tuple=True)[0] 34 | ref_feats = ref_feats[ref_indices] 35 | src_feats = src_feats[src_indices] 36 | # select top-k proposals 37 | matching_scores = torch.exp(-pairwise_distance(ref_feats, src_feats, normalized=True)) 38 | if self.dual_normalization: 39 | ref_matching_scores = matching_scores / matching_scores.sum(dim=1, keepdim=True) 40 | src_matching_scores = matching_scores / matching_scores.sum(dim=0, keepdim=True) 41 | matching_scores = ref_matching_scores * src_matching_scores 42 | num_correspondences = min(self.num_correspondences, matching_scores.numel()) 43 | corr_scores, corr_indices = matching_scores.view(-1).topk(k=num_correspondences, largest=True) 44 | ref_sel_indices = corr_indices // matching_scores.shape[1] 45 | src_sel_indices = corr_indices % matching_scores.shape[1] 46 | # recover original indices 47 | ref_corr_indices = ref_indices[ref_sel_indices] 48 | src_corr_indices = src_indices[src_sel_indices] 49 | 50 | return ref_corr_indices, src_corr_indices, corr_scores 51 | -------------------------------------------------------------------------------- /geotransformer/modules/geotransformer/superpoint_target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SuperPointTargetGenerator(nn.Module): 7 | def __init__(self, num_targets, overlap_threshold): 8 | super(SuperPointTargetGenerator, self).__init__() 9 | self.num_targets = num_targets 10 | self.overlap_threshold = overlap_threshold 11 | 12 | @torch.no_grad() 13 | def forward(self, gt_corr_indices, gt_corr_overlaps): 14 | r"""Generate ground truth superpoint (patch) correspondences. 15 | 16 | Randomly select "num_targets" correspondences whose overlap is above "overlap_threshold". 17 | 18 | Args: 19 | gt_corr_indices (LongTensor): ground truth superpoint correspondences (N, 2) 20 | gt_corr_overlaps (Tensor): ground truth superpoint correspondences overlap (N,) 21 | 22 | Returns: 23 | gt_ref_corr_indices (LongTensor): selected superpoints in reference point cloud. 24 | gt_src_corr_indices (LongTensor): selected superpoints in source point cloud. 25 | gt_corr_overlaps (LongTensor): overlaps of the selected superpoint correspondences. 26 | """ 27 | gt_corr_masks = torch.gt(gt_corr_overlaps, self.overlap_threshold) 28 | gt_corr_overlaps = gt_corr_overlaps[gt_corr_masks] 29 | gt_corr_indices = gt_corr_indices[gt_corr_masks] 30 | 31 | if gt_corr_indices.shape[0] > self.num_targets: 32 | indices = np.arange(gt_corr_indices.shape[0]) 33 | sel_indices = np.random.choice(indices, self.num_targets, replace=False) 34 | sel_indices = torch.from_numpy(sel_indices).cuda() 35 | gt_corr_indices = gt_corr_indices[sel_indices] 36 | gt_corr_overlaps = gt_corr_overlaps[sel_indices] 37 | 38 | gt_ref_corr_indices = gt_corr_indices[:, 0] 39 | gt_src_corr_indices = gt_corr_indices[:, 1] 40 | 41 | return gt_ref_corr_indices, gt_src_corr_indices, gt_corr_overlaps 42 | -------------------------------------------------------------------------------- /geotransformer/modules/kpconv/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.kpconv.kpconv import KPConv 2 | from geotransformer.modules.kpconv.modules import ( 3 | ConvBlock, 4 | ResidualBlock, 5 | UnaryBlock, 6 | LastUnaryBlock, 7 | GroupNorm, 8 | KNNInterpolate, 9 | GlobalAvgPool, 10 | MaxPool, 11 | ) 12 | from geotransformer.modules.kpconv.functional import nearest_upsample, global_avgpool, maxpool 13 | -------------------------------------------------------------------------------- /geotransformer/modules/kpconv/dispositions/k_015_center_3D.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/modules/kpconv/dispositions/k_015_center_3D.ply -------------------------------------------------------------------------------- /geotransformer/modules/kpconv/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from geotransformer.modules.ops import index_select 4 | 5 | 6 | def nearest_upsample(x, upsample_indices): 7 | """Pools features from the closest neighbors. 8 | 9 | WARNING: this function assumes the neighbors are ordered. 10 | 11 | Args: 12 | x: [n1, d] features matrix 13 | upsample_indices: [n2, max_num] Only the first column is used for pooling 14 | 15 | Returns: 16 | x: [n2, d] pooled features matrix 17 | """ 18 | # Add a last row with minimum features for shadow pools 19 | x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) 20 | # Get features for each pooling location [n2, d] 21 | x = index_select(x, upsample_indices[:, 0], dim=0) 22 | return x 23 | 24 | 25 | def knn_interpolate(s_feats, q_points, s_points, neighbor_indices, k, eps=1e-8): 26 | r"""K-NN interpolate. 27 | 28 | WARNING: this function assumes the neighbors are ordered. 29 | 30 | Args: 31 | s_feats (Tensor): (M, C) 32 | q_points (Tensor): (N, 3) 33 | s_points (Tensor): (M, 3) 34 | neighbor_indices (LongTensor): (N, X) 35 | k (int) 36 | eps (float) 37 | 38 | Returns: 39 | q_feats (Tensor): (N, C) 40 | """ 41 | s_points = torch.cat((s_points, torch.zeros_like(s_points[:1, :])), 0) # (M + 1, 3) 42 | s_feats = torch.cat((s_feats, torch.zeros_like(s_feats[:1, :])), 0) # (M + 1, C) 43 | knn_indices = neighbor_indices[:, :k].contiguous() 44 | knn_points = index_select(s_points, knn_indices, dim=0) # (N, k, 3) 45 | knn_feats = index_select(s_feats, knn_indices, dim=0) # (N, k, C) 46 | knn_sq_distances = (q_points.unsqueeze(1) - knn_points).pow(2).sum(dim=-1) # (N, k) 47 | knn_masks = torch.ne(knn_indices, s_points.shape[0] - 1).float() # (N, k) 48 | knn_weights = knn_masks / (knn_sq_distances + eps) # (N, k) 49 | knn_weights = knn_weights / (knn_weights.sum(dim=1, keepdim=True) + eps) # (N, k) 50 | q_feats = (knn_feats * knn_weights.unsqueeze(-1)).sum(dim=1) # (N, C) 51 | return q_feats 52 | 53 | 54 | def maxpool(x, neighbor_indices): 55 | """Max pooling from neighbors. 56 | 57 | Args: 58 | x: [n1, d] features matrix 59 | neighbor_indices: [n2, max_num] pooling indices 60 | 61 | Returns: 62 | pooled_feats: [n2, d] pooled features matrix 63 | """ 64 | x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) 65 | neighbor_feats = index_select(x, neighbor_indices, dim=0) 66 | pooled_feats = neighbor_feats.max(1)[0] 67 | return pooled_feats 68 | 69 | 70 | def global_avgpool(x, batch_lengths): 71 | """Global average pooling over batch. 72 | 73 | Args: 74 | x: [N, D] input features 75 | batch_lengths: [B] list of batch lengths 76 | 77 | Returns: 78 | x: [B, D] averaged features 79 | """ 80 | # Loop over the clouds of the batch 81 | averaged_features = [] 82 | i0 = 0 83 | for b_i, length in enumerate(batch_lengths): 84 | # Average features for each batch cloud 85 | averaged_features.append(torch.mean(x[i0 : i0 + length], dim=0)) 86 | # Increment for next cloud 87 | i0 += length 88 | # Average features in each batch 89 | x = torch.stack(averaged_features) 90 | return x 91 | -------------------------------------------------------------------------------- /geotransformer/modules/kpconv/kpconv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from geotransformer.modules.ops import index_select 7 | from geotransformer.modules.kpconv.kernel_points import load_kernels 8 | 9 | 10 | class KPConv(nn.Module): 11 | def __init__( 12 | self, 13 | in_channels, 14 | out_channels, 15 | kernel_size, 16 | radius, 17 | sigma, 18 | bias=False, 19 | dimension=3, 20 | inf=1e6, 21 | eps=1e-9, 22 | ): 23 | """Initialize parameters for KPConv. 24 | 25 | Modified from [KPConv-PyTorch](https://github.com/HuguesTHOMAS/KPConv-PyTorch). 26 | 27 | Deformable KPConv is not supported. 28 | 29 | Args: 30 | in_channels: dimension of input features. 31 | out_channels: dimension of output features. 32 | kernel_size: Number of kernel points. 33 | radius: radius used for kernel point init. 34 | sigma: influence radius of each kernel point. 35 | bias: use bias or not (default: False) 36 | dimension: dimension of the point space. 37 | inf: value of infinity to generate the padding point 38 | eps: epsilon for gaussian influence 39 | """ 40 | super(KPConv, self).__init__() 41 | 42 | # Save parameters 43 | self.kernel_size = kernel_size 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.radius = radius 47 | self.sigma = sigma 48 | self.dimension = dimension 49 | 50 | self.inf = inf 51 | self.eps = eps 52 | 53 | # Initialize weights 54 | self.weights = nn.Parameter(torch.zeros(self.kernel_size, in_channels, out_channels)) 55 | if bias: 56 | self.bias = nn.Parameter(torch.zeros(self.out_channels)) 57 | else: 58 | self.register_parameter('bias', None) 59 | 60 | # Reset parameters 61 | self.reset_parameters() 62 | 63 | # Initialize kernel points 64 | kernel_points = self.initialize_kernel_points() # (N, 3) 65 | self.register_buffer('kernel_points', kernel_points) 66 | 67 | def reset_parameters(self): 68 | nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) 69 | if self.bias is not None: 70 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights) 71 | bound = 1 / math.sqrt(fan_in) 72 | nn.init.uniform_(self.bias, -bound, bound) 73 | 74 | def initialize_kernel_points(self): 75 | """Initialize the kernel point positions in a sphere.""" 76 | kernel_points = load_kernels(self.radius, self.kernel_size, dimension=self.dimension, fixed='center') 77 | return torch.from_numpy(kernel_points).float() 78 | 79 | def forward(self, s_feats, q_points, s_points, neighbor_indices): 80 | r"""KPConv forward. 81 | 82 | Args: 83 | s_feats (Tensor): (N, C_in) 84 | q_points (Tensor): (M, 3) 85 | s_points (Tensor): (N, 3) 86 | neighbor_indices (LongTensor): (M, H) 87 | 88 | Returns: 89 | q_feats (Tensor): (M, C_out) 90 | """ 91 | s_points = torch.cat([s_points, torch.zeros_like(s_points[:1, :]) + self.inf], 0) # (N, 3) -> (N+1, 3) 92 | neighbors = index_select(s_points, neighbor_indices, dim=0) # (N+1, 3) -> (M, H, 3) 93 | neighbors = neighbors - q_points.unsqueeze(1) # (M, H, 3) 94 | 95 | # Get Kernel point influences 96 | neighbors = neighbors.unsqueeze(2) # (M, H, 3) -> (M, H, 1, 3) 97 | differences = neighbors - self.kernel_points # (M, H, 1, 3) x (K, 3) -> (M, H, K, 3) 98 | sq_distances = torch.sum(differences ** 2, dim=3) # (M, H, K) 99 | neighbor_weights = torch.clamp(1 - torch.sqrt(sq_distances) / self.sigma, min=0.0) # (M, H, K) 100 | neighbor_weights = torch.transpose(neighbor_weights, 1, 2) # (M, H, K) -> (M, K, H) 101 | 102 | # apply neighbor weights 103 | s_feats = torch.cat((s_feats, torch.zeros_like(s_feats[:1, :])), 0) # (N, C) -> (N+1, C) 104 | neighbor_feats = index_select(s_feats, neighbor_indices, dim=0) # (N+1, C) -> (M, H, C) 105 | weighted_feats = torch.matmul(neighbor_weights, neighbor_feats) # (M, K, H) x (M, H, C) -> (M, K, C) 106 | 107 | # apply convolutional weights 108 | weighted_feats = weighted_feats.permute(1, 0, 2) # (M, K, C) -> (K, M, C) 109 | kernel_outputs = torch.matmul(weighted_feats, self.weights) # (K, M, C) x (K, C, C_out) -> (K, M, C_out) 110 | output_feats = torch.sum(kernel_outputs, dim=0, keepdim=False) # (K, M, C_out) -> (M, C_out) 111 | 112 | # normalization 113 | neighbor_feats_sum = torch.sum(neighbor_feats, dim=-1) 114 | neighbor_num = torch.sum(torch.gt(neighbor_feats_sum, 0.0), dim=-1) 115 | neighbor_num = torch.max(neighbor_num, torch.ones_like(neighbor_num)) 116 | output_feats = output_feats / neighbor_num.unsqueeze(1) 117 | 118 | # add bias 119 | if self.bias is not None: 120 | output_feats = output_feats + self.bias 121 | 122 | return output_feats 123 | 124 | def __repr__(self): 125 | format_string = self.__class__.__name__ + '(' 126 | format_string += 'kernel_size: {}'.format(self.kernel_size) 127 | format_string += ', in_channels: {}'.format(self.in_channels) 128 | format_string += ', out_channels: {}'.format(self.out_channels) 129 | format_string += ', radius: {:g}'.format(self.radius) 130 | format_string += ', sigma: {:g}'.format(self.sigma) 131 | format_string += ', bias: {}'.format(self.bias is not None) 132 | format_string += ')' 133 | return format_string 134 | -------------------------------------------------------------------------------- /geotransformer/modules/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.layers.conv_block import ConvBlock 2 | from geotransformer.modules.layers.factory import build_dropout_layer, build_conv_layer, build_norm_layer, build_act_layer 3 | -------------------------------------------------------------------------------- /geotransformer/modules/layers/conv_block.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from geotransformer.modules.layers.factory import build_conv_layer, build_norm_layer, build_act_layer 7 | 8 | 9 | class ConvBlock(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | kernel_size=None, 15 | stride=1, 16 | padding=0, 17 | dilation=1, 18 | groups=1, 19 | padding_mode='zeros', 20 | depth_multiplier=None, 21 | conv_cfg=None, 22 | norm_cfg=None, 23 | act_cfg=None, 24 | act_before_norm=False, 25 | ): 26 | r"""Conv-Norm-Act Block. 27 | 28 | Args: 29 | act_before_norm (bool=False): If True, conv-act-norm. If False, conv-norm-act. 30 | """ 31 | super().__init__() 32 | 33 | assert conv_cfg is not None 34 | 35 | if isinstance(norm_cfg, str): 36 | norm_cfg = {'type': norm_cfg} 37 | if isinstance(act_cfg, str): 38 | act_cfg = {'type': act_cfg} 39 | 40 | norm_type = norm_cfg['type'] 41 | if norm_type in ['BatchNorm', 'InstanceNorm']: 42 | norm_cfg['type'] = norm_type + conv_cfg[-2:] 43 | 44 | self.act_before_norm = act_before_norm 45 | 46 | bias = True 47 | if not self.act_before_norm: 48 | # conv-norm-act 49 | norm_type = norm_cfg['type'] 50 | if norm_type.startswith('BatchNorm') or norm_type.startswith('InstanceNorm'): 51 | bias = False 52 | if conv_cfg == 'Linear': 53 | layer_cfg = { 54 | 'type': conv_cfg, 55 | 'in_features': in_channels, 56 | 'out_features': out_channels, 57 | 'bias': bias, 58 | } 59 | elif conv_cfg.startswith('SeparableConv'): 60 | if groups != 1: 61 | warnings.warn(f'`groups={groups}` is ignored when building {conv_cfg} layer.') 62 | layer_cfg = { 63 | 'type': conv_cfg, 64 | 'in_channels': in_channels, 65 | 'out_channels': out_channels, 66 | 'kernel_size': kernel_size, 67 | 'stride': stride, 68 | 'padding': padding, 69 | 'dilation': dilation, 70 | 'depth_multiplier': depth_multiplier, 71 | 'bias': bias, 72 | 'padding_mode': padding_mode, 73 | } 74 | else: 75 | if depth_multiplier is not None: 76 | warnings.warn(f'`depth_multiplier={depth_multiplier}` is ignored when building {conv_cfg} layer.') 77 | layer_cfg = { 78 | 'type': conv_cfg, 79 | 'in_channels': in_channels, 80 | 'out_channels': out_channels, 81 | 'kernel_size': kernel_size, 82 | 'stride': stride, 83 | 'padding': padding, 84 | 'dilation': dilation, 85 | 'groups': groups, 86 | 'bias': bias, 87 | 'padding_mode': padding_mode, 88 | } 89 | 90 | self.conv = build_conv_layer(layer_cfg) 91 | 92 | norm_layer = build_norm_layer(out_channels, norm_cfg) 93 | act_layer = build_act_layer(act_cfg) 94 | if self.act_before_norm: 95 | self.act = act_layer 96 | self.norm = norm_layer 97 | else: 98 | self.norm = norm_layer 99 | self.act = act_layer 100 | 101 | def forward(self, x): 102 | x = self.conv(x) 103 | if self.act_before_norm: 104 | x = self.norm(self.act(x)) 105 | else: 106 | x = self.act(self.norm(x)) 107 | return x 108 | -------------------------------------------------------------------------------- /geotransformer/modules/layers/factory.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Dict, Optional, Tuple 2 | 3 | import torch.nn as nn 4 | 5 | 6 | NORM_LAYERS = { 7 | 'BatchNorm1d': nn.BatchNorm1d, 8 | 'BatchNorm2d': nn.BatchNorm2d, 9 | 'BatchNorm3d': nn.BatchNorm3d, 10 | 'InstanceNorm1d': nn.InstanceNorm1d, 11 | 'InstanceNorm2d': nn.InstanceNorm2d, 12 | 'InstanceNorm3d': nn.InstanceNorm3d, 13 | 'GroupNorm': nn.GroupNorm, 14 | 'LayerNorm': nn.LayerNorm, 15 | } 16 | 17 | 18 | ACT_LAYERS = { 19 | 'ReLU': nn.ReLU, 20 | 'LeakyReLU': nn.LeakyReLU, 21 | 'ELU': nn.ELU, 22 | 'GELU': nn.GELU, 23 | 'Sigmoid': nn.Sigmoid, 24 | 'Softplus': nn.Softplus, 25 | 'Tanh': nn.Tanh, 26 | 'Identity': nn.Identity, 27 | } 28 | 29 | 30 | CONV_LAYERS = { 31 | 'Linear': nn.Linear, 32 | 'Conv1d': nn.Conv1d, 33 | 'Conv2d': nn.Conv2d, 34 | 'Conv3d': nn.Conv3d, 35 | } 36 | 37 | 38 | def parse_cfg(cfg: Union[str, Dict]) -> Tuple[str, Dict]: 39 | assert isinstance(cfg, (str, Dict)), 'Illegal cfg type: {}.'.format(type(cfg)) 40 | if isinstance(cfg, str): 41 | cfg = {'type': cfg} 42 | else: 43 | cfg = cfg.copy() 44 | layer = cfg.pop('type') 45 | return layer, cfg 46 | 47 | 48 | def build_dropout_layer(p: Optional[float], **kwargs) -> nn.Module: 49 | r"""Factory function for dropout layer.""" 50 | if p is None or p == 0: 51 | return nn.Identity() 52 | else: 53 | return nn.Dropout(p=p, **kwargs) 54 | 55 | 56 | def build_norm_layer(num_features, norm_cfg: Optional[Union[str, Dict]]) -> nn.Module: 57 | r"""Factory function for normalization layers.""" 58 | if norm_cfg is None: 59 | return nn.Identity() 60 | layer, kwargs = parse_cfg(norm_cfg) 61 | assert layer in NORM_LAYERS, f'Illegal normalization: {layer}.' 62 | if layer == 'GroupNorm': 63 | kwargs['num_channels'] = num_features 64 | elif layer == 'LayerNorm': 65 | kwargs['normalized_shape'] = num_features 66 | else: 67 | kwargs['num_features'] = num_features 68 | return NORM_LAYERS[layer](**kwargs) 69 | 70 | 71 | def build_act_layer(act_cfg: Optional[Union[str, Dict]]) -> nn.Module: 72 | r"""Factory function for activation functions.""" 73 | if act_cfg is None: 74 | return nn.Identity() 75 | layer, kwargs = parse_cfg(act_cfg) 76 | assert layer in ACT_LAYERS, f'Illegal activation: {layer}.' 77 | if layer == 'LeakyReLU': 78 | if 'negative_slope' not in kwargs: 79 | kwargs['negative_slope'] = 0.2 80 | return ACT_LAYERS[layer](**kwargs) 81 | 82 | 83 | def build_conv_layer(conv_cfg: Union[str, Dict]) -> nn.Module: 84 | r"""Factory function for convolution or linear layers.""" 85 | layer, kwargs = parse_cfg(conv_cfg) 86 | assert layer in CONV_LAYERS, f'Illegal layer: {layer}.' 87 | return CONV_LAYERS[layer](**kwargs) 88 | -------------------------------------------------------------------------------- /geotransformer/modules/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.loss.circle_loss import CircleLoss, WeightedCircleLoss 2 | 3 | -------------------------------------------------------------------------------- /geotransformer/modules/loss/circle_loss.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | 6 | 7 | def circle_loss( 8 | pos_masks, 9 | neg_masks, 10 | feat_dists, 11 | pos_margin, 12 | neg_margin, 13 | pos_optimal, 14 | neg_optimal, 15 | log_scale, 16 | ): 17 | # get anchors that have both positive and negative pairs 18 | row_masks = (torch.gt(pos_masks.sum(-1), 0) & torch.gt(neg_masks.sum(-1), 0)).detach() 19 | col_masks = (torch.gt(pos_masks.sum(-2), 0) & torch.gt(neg_masks.sum(-2), 0)).detach() 20 | 21 | # get alpha for both positive and negative pairs 22 | pos_weights = feat_dists - 1e5 * (~pos_masks).float() # mask the non-positive 23 | pos_weights = pos_weights - pos_optimal # mask the uninformative positive 24 | pos_weights = torch.maximum(torch.zeros_like(pos_weights), pos_weights).detach() 25 | 26 | neg_weights = feat_dists + 1e5 * (~neg_masks).float() # mask the non-negative 27 | neg_weights = neg_optimal - neg_weights # mask the uninformative negative 28 | neg_weights = torch.maximum(torch.zeros_like(neg_weights), neg_weights).detach() 29 | 30 | loss_pos_row = torch.logsumexp(log_scale * (feat_dists - pos_margin) * pos_weights, dim=-1) 31 | loss_pos_col = torch.logsumexp(log_scale * (feat_dists - pos_margin) * pos_weights, dim=-2) 32 | 33 | loss_neg_row = torch.logsumexp(log_scale * (neg_margin - feat_dists) * neg_weights, dim=-1) 34 | loss_neg_col = torch.logsumexp(log_scale * (neg_margin - feat_dists) * neg_weights, dim=-2) 35 | 36 | loss_row = F.softplus(loss_pos_row + loss_neg_row) / log_scale 37 | loss_col = F.softplus(loss_pos_col + loss_neg_col) / log_scale 38 | 39 | loss = (loss_row[row_masks].mean() + loss_col[col_masks].mean()) / 2 40 | 41 | return loss 42 | 43 | 44 | def weighted_circle_loss( 45 | pos_masks, 46 | neg_masks, 47 | feat_dists, 48 | pos_margin, 49 | neg_margin, 50 | pos_optimal, 51 | neg_optimal, 52 | log_scale, 53 | pos_scales=None, 54 | neg_scales=None, 55 | ): 56 | # get anchors that have both positive and negative pairs 57 | row_masks = (torch.gt(pos_masks.sum(-1), 0) & torch.gt(neg_masks.sum(-1), 0)).detach() 58 | col_masks = (torch.gt(pos_masks.sum(-2), 0) & torch.gt(neg_masks.sum(-2), 0)).detach() 59 | 60 | # get alpha for both positive and negative pairs 61 | pos_weights = feat_dists - 1e5 * (~pos_masks).float() # mask the non-positive 62 | pos_weights = pos_weights - pos_optimal # mask the uninformative positive 63 | pos_weights = torch.maximum(torch.zeros_like(pos_weights), pos_weights) 64 | if pos_scales is not None: 65 | pos_weights = pos_weights * pos_scales 66 | pos_weights = pos_weights.detach() 67 | 68 | neg_weights = feat_dists + 1e5 * (~neg_masks).float() # mask the non-negative 69 | neg_weights = neg_optimal - neg_weights # mask the uninformative negative 70 | neg_weights = torch.maximum(torch.zeros_like(neg_weights), neg_weights) 71 | if neg_scales is not None: 72 | neg_weights = neg_weights * neg_scales 73 | neg_weights = neg_weights.detach() 74 | 75 | loss_pos_row = torch.logsumexp(log_scale * (feat_dists - pos_margin) * pos_weights, dim=-1) 76 | loss_pos_col = torch.logsumexp(log_scale * (feat_dists - pos_margin) * pos_weights, dim=-2) 77 | 78 | loss_neg_row = torch.logsumexp(log_scale * (neg_margin - feat_dists) * neg_weights, dim=-1) 79 | loss_neg_col = torch.logsumexp(log_scale * (neg_margin - feat_dists) * neg_weights, dim=-2) 80 | 81 | loss_row = F.softplus(loss_pos_row + loss_neg_row) / log_scale 82 | loss_col = F.softplus(loss_pos_col + loss_neg_col) / log_scale 83 | 84 | loss = (loss_row[row_masks].mean() + loss_col[col_masks].mean()) / 2 85 | 86 | return loss 87 | 88 | 89 | class CircleLoss(nn.Module): 90 | def __init__(self, pos_margin, neg_margin, pos_optimal, neg_optimal, log_scale): 91 | super(CircleLoss, self).__init__() 92 | self.pos_margin = pos_margin 93 | self.neg_margin = neg_margin 94 | self.pos_optimal = pos_optimal 95 | self.neg_optimal = neg_optimal 96 | self.log_scale = log_scale 97 | 98 | def forward(self, pos_masks, neg_masks, feat_dists): 99 | return circle_loss( 100 | pos_masks, 101 | neg_masks, 102 | feat_dists, 103 | self.pos_margin, 104 | self.neg_margin, 105 | self.pos_optimal, 106 | self.neg_optimal, 107 | self.log_scale, 108 | ) 109 | 110 | 111 | class WeightedCircleLoss(nn.Module): 112 | def __init__(self, pos_margin, neg_margin, pos_optimal, neg_optimal, log_scale): 113 | super(WeightedCircleLoss, self).__init__() 114 | self.pos_margin = pos_margin 115 | self.neg_margin = neg_margin 116 | self.pos_optimal = pos_optimal 117 | self.neg_optimal = neg_optimal 118 | self.log_scale = log_scale 119 | 120 | def forward(self, pos_masks, neg_masks, feat_dists, pos_scales=None, neg_scales=None): 121 | return weighted_circle_loss( 122 | pos_masks, 123 | neg_masks, 124 | feat_dists, 125 | self.pos_margin, 126 | self.neg_margin, 127 | self.pos_optimal, 128 | self.neg_optimal, 129 | self.log_scale, 130 | pos_scales=pos_scales, 131 | neg_scales=neg_scales, 132 | ) 133 | -------------------------------------------------------------------------------- /geotransformer/modules/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.ops.grid_subsample import grid_subsample 2 | from geotransformer.modules.ops.index_select import index_select 3 | from geotransformer.modules.ops.pairwise_distance import pairwise_distance 4 | from geotransformer.modules.ops.pointcloud_partition import ( 5 | get_point_to_node_indices, 6 | point_to_node_partition, 7 | knn_partition, 8 | ball_query_partition, 9 | ) 10 | from geotransformer.modules.ops.radius_search import radius_search 11 | from geotransformer.modules.ops.transformation import ( 12 | apply_transform, 13 | apply_rotation, 14 | inverse_transform, 15 | skew_symmetric_matrix, 16 | rodrigues_rotation_matrix, 17 | rodrigues_alignment_matrix, 18 | get_transform_from_rotation_translation, 19 | get_rotation_translation_from_transform, 20 | ) 21 | from geotransformer.modules.ops.vector_angle import vector_angle, rad2deg, deg2rad 22 | -------------------------------------------------------------------------------- /geotransformer/modules/ops/grid_subsample.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | ext_module = importlib.import_module('geotransformer.ext') 5 | 6 | 7 | def grid_subsample(points, lengths, voxel_size): 8 | """Grid subsampling in stack mode. 9 | 10 | This function is implemented on CPU. 11 | 12 | Args: 13 | points (Tensor): stacked points. (N, 3) 14 | lengths (Tensor): number of points in the stacked batch. (B,) 15 | voxel_size (float): voxel size. 16 | 17 | Returns: 18 | s_points (Tensor): stacked subsampled points (M, 3) 19 | s_lengths (Tensor): numbers of subsampled points in the batch. (B,) 20 | """ 21 | s_points, s_lengths = ext_module.grid_subsampling(points, lengths, voxel_size) 22 | return s_points, s_lengths 23 | -------------------------------------------------------------------------------- /geotransformer/modules/ops/index_select.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def index_select(data: torch.Tensor, index: torch.LongTensor, dim: int) -> torch.Tensor: 5 | r"""Advanced index select. 6 | 7 | Returns a tensor `output` which indexes the `data` tensor along dimension `dim` 8 | using the entries in `index` which is a `LongTensor`. 9 | 10 | Different from `torch.index_select`, `index` does not has to be 1-D. The `dim`-th 11 | dimension of `data` will be expanded to the number of dimensions in `index`. 12 | 13 | For example, suppose the shape `data` is $(a_0, a_1, ..., a_{n-1})$, the shape of `index` is 14 | $(b_0, b_1, ..., b_{m-1})$, and `dim` is $i$, then `output` is $(n+m-1)$-d tensor, whose shape is 15 | $(a_0, ..., a_{i-1}, b_0, b_1, ..., b_{m-1}, a_{i+1}, ..., a_{n-1})$. 16 | 17 | Args: 18 | data (Tensor): (a_0, a_1, ..., a_{n-1}) 19 | index (LongTensor): (b_0, b_1, ..., b_{m-1}) 20 | dim: int 21 | 22 | Returns: 23 | output (Tensor): (a_0, ..., a_{dim-1}, b_0, ..., b_{m-1}, a_{dim+1}, ..., a_{n-1}) 24 | """ 25 | output = data.index_select(dim, index.view(-1)) 26 | 27 | if index.ndim > 1: 28 | output_shape = data.shape[:dim] + index.shape + data.shape[dim:][1:] 29 | output = output.view(*output_shape) 30 | 31 | return output 32 | -------------------------------------------------------------------------------- /geotransformer/modules/ops/pairwise_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pairwise_distance( 5 | x: torch.Tensor, y: torch.Tensor, normalized: bool = False, channel_first: bool = False 6 | ) -> torch.Tensor: 7 | r"""Pairwise distance of two (batched) point clouds. 8 | 9 | Args: 10 | x (Tensor): (*, N, C) or (*, C, N) 11 | y (Tensor): (*, M, C) or (*, C, M) 12 | normalized (bool=False): if the points are normalized, we have "x2 + y2 = 1", so "d2 = 2 - 2xy". 13 | channel_first (bool=False): if True, the points shape is (*, C, N). 14 | 15 | Returns: 16 | dist: torch.Tensor (*, N, M) 17 | """ 18 | if channel_first: 19 | channel_dim = -2 20 | xy = torch.matmul(x.transpose(-1, -2), y) # [(*, C, N) -> (*, N, C)] x (*, C, M) 21 | else: 22 | channel_dim = -1 23 | xy = torch.matmul(x, y.transpose(-1, -2)) # (*, N, C) x [(*, M, C) -> (*, C, M)] 24 | if normalized: 25 | sq_distances = 2.0 - 2.0 * xy 26 | else: 27 | x2 = torch.sum(x ** 2, dim=channel_dim).unsqueeze(-1) # (*, N, C) or (*, C, N) -> (*, N) -> (*, N, 1) 28 | y2 = torch.sum(y ** 2, dim=channel_dim).unsqueeze(-2) # (*, M, C) or (*, C, M) -> (*, M) -> (*, 1, M) 29 | sq_distances = x2 - 2 * xy + y2 30 | sq_distances = sq_distances.clamp(min=0.0) 31 | return sq_distances 32 | -------------------------------------------------------------------------------- /geotransformer/modules/ops/radius_search.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | ext_module = importlib.import_module('geotransformer.ext') 5 | 6 | 7 | def radius_search(q_points, s_points, q_lengths, s_lengths, radius, neighbor_limit): 8 | r"""Computes neighbors for a batch of q_points and s_points, apply radius search (in stack mode). 9 | 10 | This function is implemented on CPU. 11 | 12 | Args: 13 | q_points (Tensor): the query points (N, 3) 14 | s_points (Tensor): the support points (M, 3) 15 | q_lengths (Tensor): the list of lengths of batch elements in q_points 16 | s_lengths (Tensor): the list of lengths of batch elements in s_points 17 | radius (float): maximum distance of neighbors 18 | neighbor_limit (int): maximum number of neighbors 19 | 20 | Returns: 21 | neighbors (Tensor): the k nearest neighbors of q_points in s_points (N, k). 22 | Filled with M if there are less than k neighbors. 23 | """ 24 | neighbor_indices = ext_module.radius_neighbors(q_points, s_points, q_lengths, s_lengths, radius) 25 | if neighbor_limit > 0: 26 | neighbor_indices = neighbor_indices[:, :neighbor_limit] 27 | return neighbor_indices 28 | -------------------------------------------------------------------------------- /geotransformer/modules/ops/vector_angle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def rad2deg(rad: torch.Tensor) -> torch.Tensor: 6 | factor = 180.0 / np.pi 7 | deg = rad * factor 8 | return deg 9 | 10 | 11 | def deg2rad(deg: torch.Tensor) -> torch.Tensor: 12 | factor = np.pi / 180.0 13 | rad = deg * factor 14 | return rad 15 | 16 | 17 | def vector_angle(x: torch.Tensor, y: torch.Tensor, dim: int, use_degree: bool = False): 18 | r"""Compute the angles between two set of 3D vectors. 19 | 20 | Args: 21 | x (Tensor): set of vectors (*, 3, *) 22 | y (Tensor): set of vectors (*, 3, *). 23 | dim (int): dimension index of the coordinates. 24 | use_degree (bool=False): If True, return angles in degree instead of rad. 25 | 26 | Returns: 27 | angles (Tensor): (*) 28 | """ 29 | cross = torch.linalg.norm(torch.cross(x, y, dim=dim), dim=dim) # (*, 3 *) x (*, 3, *) -> (*, 3, *) -> (*) 30 | dot = torch.sum(x * y, dim=dim) # (*, 3 *) x (*, 3, *) -> (*) 31 | angles = torch.atan2(cross, dot) # (*) 32 | if use_degree: 33 | angles = rad2deg(angles) 34 | return angles 35 | -------------------------------------------------------------------------------- /geotransformer/modules/registration/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.registration.matching import ( 2 | extract_correspondences_from_feats, 3 | extract_correspondences_from_scores, 4 | extract_correspondences_from_scores_topk, 5 | extract_correspondences_from_scores_threshold, 6 | dense_correspondences_to_node_correspondences, 7 | get_node_correspondences, 8 | node_correspondences_to_dense_correspondences, 9 | get_node_occlusion_ratios, 10 | get_node_overlap_ratios, 11 | ) 12 | from geotransformer.modules.registration.metrics import ( 13 | modified_chamfer_distance, 14 | relative_rotation_error, 15 | relative_translation_error, 16 | isotropic_transform_error, 17 | anisotropic_transform_error, 18 | ) 19 | from geotransformer.modules.registration.procrustes import weighted_procrustes, WeightedProcrustes 20 | -------------------------------------------------------------------------------- /geotransformer/modules/registration/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from geotransformer.modules.ops import apply_transform, pairwise_distance, get_rotation_translation_from_transform 5 | from geotransformer.utils.registration import compute_transform_mse_and_mae 6 | 7 | 8 | def modified_chamfer_distance(raw_points, ref_points, src_points, gt_transform, transform, reduction='mean'): 9 | r"""Compute the modified chamfer distance. 10 | 11 | Args: 12 | raw_points (Tensor): (B, N_raw, 3) 13 | ref_points (Tensor): (B, N_ref, 3) 14 | src_points (Tensor): (B, N_src, 3) 15 | gt_transform (Tensor): (B, 4, 4) 16 | transform (Tensor): (B, 4, 4) 17 | reduction (str='mean'): reduction method, 'mean', 'sum' or 'none' 18 | 19 | Returns: 20 | chamfer_distance 21 | """ 22 | assert reduction in ['mean', 'sum', 'none'] 23 | 24 | # P_t -> Q_raw 25 | aligned_src_points = apply_transform(src_points, transform) # (B, N_src, 3) 26 | sq_dist_mat_p_q = pairwise_distance(aligned_src_points, raw_points) # (B, N_src, N_raw) 27 | nn_sq_distances_p_q = sq_dist_mat_p_q.min(dim=-1)[0] # (B, N_src) 28 | chamfer_distance_p_q = torch.sqrt(nn_sq_distances_p_q).mean(dim=-1) # (B) 29 | 30 | # Q -> P_raw 31 | composed_transform = torch.matmul(transform, torch.inverse(gt_transform)) # (B, 4, 4) 32 | aligned_raw_points = apply_transform(raw_points, composed_transform) # (B, N_raw, 3) 33 | sq_dist_mat_q_p = pairwise_distance(ref_points, aligned_raw_points) # (B, N_ref, N_raw) 34 | nn_sq_distances_q_p = sq_dist_mat_q_p.min(dim=-1)[0] # (B, N_ref) 35 | chamfer_distance_q_p = torch.sqrt(nn_sq_distances_q_p).mean(dim=-1) # (B) 36 | 37 | # sum up 38 | chamfer_distance = chamfer_distance_p_q + chamfer_distance_q_p # (B) 39 | 40 | if reduction == 'mean': 41 | chamfer_distance = chamfer_distance.mean() 42 | elif reduction == 'sum': 43 | chamfer_distance = chamfer_distance.sum() 44 | return chamfer_distance 45 | 46 | 47 | def relative_rotation_error(gt_rotations, rotations): 48 | r"""Isotropic Relative Rotation Error. 49 | 50 | RRE = acos((trace(R^T \cdot \bar{R}) - 1) / 2) 51 | 52 | Args: 53 | gt_rotations (Tensor): ground truth rotation matrix (*, 3, 3) 54 | rotations (Tensor): estimated rotation matrix (*, 3, 3) 55 | 56 | Returns: 57 | rre (Tensor): relative rotation errors (*) 58 | """ 59 | mat = torch.matmul(rotations.transpose(-1, -2), gt_rotations) 60 | trace = mat[..., 0, 0] + mat[..., 1, 1] + mat[..., 2, 2] 61 | x = 0.5 * (trace - 1.0) 62 | x = x.clamp(min=-1.0, max=1.0) 63 | x = torch.arccos(x) 64 | rre = 180.0 * x / np.pi 65 | return rre 66 | 67 | 68 | def relative_translation_error(gt_translations, translations): 69 | r"""Isotropic Relative Rotation Error. 70 | 71 | RTE = \lVert t - \bar{t} \rVert_2 72 | 73 | Args: 74 | gt_translations (Tensor): ground truth translation vector (*, 3) 75 | translations (Tensor): estimated translation vector (*, 3) 76 | 77 | Returns: 78 | rre (Tensor): relative rotation errors (*) 79 | """ 80 | rte = torch.linalg.norm(gt_translations - translations, dim=-1) 81 | return rte 82 | 83 | 84 | def isotropic_transform_error(gt_transforms, transforms, reduction='mean'): 85 | r"""Compute the isotropic Relative Rotation Error and Relative Translation Error. 86 | 87 | Args: 88 | gt_transforms (Tensor): ground truth transformation matrix (*, 4, 4) 89 | transforms (Tensor): estimated transformation matrix (*, 4, 4) 90 | reduction (str='mean'): reduction method, 'mean', 'sum' or 'none' 91 | 92 | Returns: 93 | rre (Tensor): relative rotation error. 94 | rte (Tensor): relative translation error. 95 | """ 96 | assert reduction in ['mean', 'sum', 'none'] 97 | 98 | gt_rotations, gt_translations = get_rotation_translation_from_transform(gt_transforms) 99 | rotations, translations = get_rotation_translation_from_transform(transforms) 100 | 101 | rre = relative_rotation_error(gt_rotations, rotations) # (*) 102 | rte = relative_translation_error(gt_translations, translations) # (*) 103 | 104 | if reduction == 'mean': 105 | rre = rre.mean() 106 | rte = rte.mean() 107 | elif reduction == 'sum': 108 | rre = rre.sum() 109 | rte = rte.sum() 110 | 111 | return rre, rte 112 | 113 | 114 | def anisotropic_transform_error(gt_transforms, transforms, reduction='mean'): 115 | r"""Compute the anisotropic Relative Rotation Error and Relative Translation Error. 116 | 117 | This function calls numpy-based implementation to achieve batch-wise computation and thus is non-differentiable. 118 | 119 | Args: 120 | gt_transforms (Tensor): ground truth transformation matrix (B, 4, 4) 121 | transforms (Tensor): estimated transformation matrix (B, 4, 4) 122 | reduction (str='mean'): reduction method, 'mean', 'sum' or 'none' 123 | 124 | Returns: 125 | r_mse (Tensor): rotation mse. 126 | r_mae (Tensor): rotation mae. 127 | t_mse (Tensor): translation mse. 128 | t_mae (Tensor): translation mae. 129 | """ 130 | assert reduction in ['mean', 'sum', 'none'] 131 | 132 | batch_size = gt_transforms.shape[0] 133 | gt_transforms_array = gt_transforms.detach().cpu().numpy() 134 | transforms_array = transforms.detach().cpu().numpy() 135 | 136 | all_r_mse = [] 137 | all_r_mae = [] 138 | all_t_mse = [] 139 | all_t_mae = [] 140 | for i in range(batch_size): 141 | r_mse, r_mae, t_mse, t_mae = compute_transform_mse_and_mae(gt_transforms_array[i], transforms_array[i]) 142 | all_r_mse.append(r_mse) 143 | all_r_mae.append(r_mae) 144 | all_t_mse.append(t_mse) 145 | all_t_mae.append(t_mae) 146 | r_mse = torch.as_tensor(all_r_mse).to(gt_transforms) 147 | r_mae = torch.as_tensor(all_r_mae).to(gt_transforms) 148 | t_mse = torch.as_tensor(all_t_mse).to(gt_transforms) 149 | t_mae = torch.as_tensor(all_t_mae).to(gt_transforms) 150 | 151 | if reduction == 'mean': 152 | r_mse = r_mse.mean() 153 | r_mae = r_mae.mean() 154 | t_mse = t_mse.mean() 155 | t_mae = t_mae.mean() 156 | elif reduction == 'sum': 157 | r_mse = r_mse.sum() 158 | r_mae = r_mae.sum() 159 | t_mse = t_mse.sum() 160 | t_mae = t_mae.sum() 161 | 162 | return r_mse, r_mae, t_mse, t_mae 163 | -------------------------------------------------------------------------------- /geotransformer/modules/registration/procrustes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import ipdb 4 | 5 | 6 | def weighted_procrustes( 7 | src_points, 8 | ref_points, 9 | weights=None, 10 | weight_thresh=0.0, 11 | eps=1e-5, 12 | return_transform=False, 13 | ): 14 | r"""Compute rigid transformation from `src_points` to `ref_points` using weighted SVD. 15 | 16 | Modified from [PointDSC](https://github.com/XuyangBai/PointDSC/blob/master/models/common.py). 17 | 18 | Args: 19 | src_points: torch.Tensor (B, N, 3) or (N, 3) 20 | ref_points: torch.Tensor (B, N, 3) or (N, 3) 21 | weights: torch.Tensor (B, N) or (N,) (default: None) 22 | weight_thresh: float (default: 0.) 23 | eps: float (default: 1e-5) 24 | return_transform: bool (default: False) 25 | 26 | Returns: 27 | R: torch.Tensor (B, 3, 3) or (3, 3) 28 | t: torch.Tensor (B, 3) or (3,) 29 | transform: torch.Tensor (B, 4, 4) or (4, 4) 30 | """ 31 | if src_points.ndim == 2: 32 | src_points = src_points.unsqueeze(0) 33 | ref_points = ref_points.unsqueeze(0) 34 | if weights is not None: 35 | weights = weights.unsqueeze(0) 36 | squeeze_first = True 37 | else: 38 | squeeze_first = False 39 | 40 | batch_size = src_points.shape[0] 41 | if weights is None: 42 | weights = torch.ones_like(src_points[:, :, 0]) 43 | weights = torch.where(torch.lt(weights, weight_thresh), torch.zeros_like(weights), weights) 44 | weights = weights / (torch.sum(weights, dim=1, keepdim=True) + eps) 45 | weights = weights.unsqueeze(2) # (B, N, 1) 46 | 47 | src_centroid = torch.sum(src_points * weights, dim=1, keepdim=True) # (B, 1, 3) 48 | ref_centroid = torch.sum(ref_points * weights, dim=1, keepdim=True) # (B, 1, 3) 49 | src_points_centered = src_points - src_centroid # (B, N, 3) 50 | ref_points_centered = ref_points - ref_centroid # (B, N, 3) 51 | 52 | H = src_points_centered.permute(0, 2, 1) @ (weights * ref_points_centered) 53 | U, _, V = torch.svd(H.cpu()) # H = USV^T 54 | Ut, V = U.transpose(1, 2).cuda(), V.cuda() 55 | eye = torch.eye(3).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 56 | eye[:, -1, -1] = torch.sign(torch.det(V @ Ut)) 57 | R = V @ eye @ Ut 58 | 59 | t = ref_centroid.permute(0, 2, 1) - R @ src_centroid.permute(0, 2, 1) 60 | t = t.squeeze(2) 61 | 62 | if return_transform: 63 | transform = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).cuda() 64 | transform[:, :3, :3] = R 65 | transform[:, :3, 3] = t 66 | if squeeze_first: 67 | transform = transform.squeeze(0) 68 | return transform 69 | else: 70 | if squeeze_first: 71 | R = R.squeeze(0) 72 | t = t.squeeze(0) 73 | return R, t 74 | 75 | 76 | class WeightedProcrustes(nn.Module): 77 | def __init__(self, weight_thresh=0.0, eps=1e-5, return_transform=False): 78 | super(WeightedProcrustes, self).__init__() 79 | self.weight_thresh = weight_thresh 80 | self.eps = eps 81 | self.return_transform = return_transform 82 | 83 | def forward(self, src_points, tgt_points, weights=None): 84 | return weighted_procrustes( 85 | src_points, 86 | tgt_points, 87 | weights=weights, 88 | weight_thresh=self.weight_thresh, 89 | eps=self.eps, 90 | return_transform=self.return_transform, 91 | ) 92 | -------------------------------------------------------------------------------- /geotransformer/modules/sinkhorn/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.sinkhorn.learnable_sinkhorn import LearnableLogOptimalTransport 2 | -------------------------------------------------------------------------------- /geotransformer/modules/sinkhorn/learnable_sinkhorn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class LearnableLogOptimalTransport(nn.Module): 6 | def __init__(self, num_iterations, inf=1e12): 7 | r"""Sinkhorn Optimal transport with dustbin parameter (SuperGlue style).""" 8 | super(LearnableLogOptimalTransport, self).__init__() 9 | self.num_iterations = num_iterations 10 | self.register_parameter('alpha', torch.nn.Parameter(torch.tensor(1.0))) 11 | self.inf = inf 12 | 13 | def log_sinkhorn_normalization(self, scores, log_mu, log_nu): 14 | u, v = torch.zeros_like(log_mu), torch.zeros_like(log_nu) 15 | for _ in range(self.num_iterations): 16 | u = log_mu - torch.logsumexp(scores + v.unsqueeze(1), dim=2) 17 | v = log_nu - torch.logsumexp(scores + u.unsqueeze(2), dim=1) 18 | return scores + u.unsqueeze(2) + v.unsqueeze(1) 19 | 20 | def forward(self, scores, row_masks=None, col_masks=None): 21 | r"""Sinkhorn Optimal Transport (SuperGlue style) forward. 22 | 23 | Args: 24 | scores: torch.Tensor (B, M, N) 25 | row_masks: torch.Tensor (B, M) 26 | col_masks: torch.Tensor (B, N) 27 | 28 | Returns: 29 | matching_scores: torch.Tensor (B, M+1, N+1) 30 | """ 31 | batch_size, num_row, num_col = scores.shape 32 | 33 | if row_masks is None: 34 | row_masks = torch.ones(size=(batch_size, num_row), dtype=torch.bool).cuda() 35 | if col_masks is None: 36 | col_masks = torch.ones(size=(batch_size, num_col), dtype=torch.bool).cuda() 37 | 38 | padded_row_masks = torch.zeros(size=(batch_size, num_row + 1), dtype=torch.bool).cuda() 39 | padded_row_masks[:, :num_row] = ~row_masks 40 | padded_col_masks = torch.zeros(size=(batch_size, num_col + 1), dtype=torch.bool).cuda() 41 | padded_col_masks[:, :num_col] = ~col_masks 42 | padded_score_masks = torch.logical_or(padded_row_masks.unsqueeze(2), padded_col_masks.unsqueeze(1)) 43 | 44 | padded_col = self.alpha.expand(batch_size, num_row, 1) 45 | padded_row = self.alpha.expand(batch_size, 1, num_col + 1) 46 | padded_scores = torch.cat([torch.cat([scores, padded_col], dim=-1), padded_row], dim=1) 47 | padded_scores.masked_fill_(padded_score_masks, -self.inf) 48 | 49 | num_valid_row = row_masks.float().sum(1) 50 | num_valid_col = col_masks.float().sum(1) 51 | norm = -torch.log(num_valid_row + num_valid_col) # (B,) 52 | 53 | log_mu = torch.empty(size=(batch_size, num_row + 1)).cuda() 54 | log_mu[:, :num_row] = norm.unsqueeze(1) 55 | log_mu[:, num_row] = torch.log(num_valid_col) + norm 56 | log_mu[padded_row_masks] = -self.inf 57 | 58 | log_nu = torch.empty(size=(batch_size, num_col + 1)).cuda() 59 | log_nu[:, :num_col] = norm.unsqueeze(1) 60 | log_nu[:, num_col] = torch.log(num_valid_row) + norm 61 | log_nu[padded_col_masks] = -self.inf 62 | 63 | outputs = self.log_sinkhorn_normalization(padded_scores, log_mu, log_nu) 64 | outputs = outputs - norm.unsqueeze(1).unsqueeze(2) 65 | 66 | return outputs 67 | 68 | def __repr__(self): 69 | format_string = self.__class__.__name__ + '(num_iterations={})'.format(self.num_iterations) 70 | return format_string 71 | -------------------------------------------------------------------------------- /geotransformer/modules/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from geotransformer.modules.transformer.conditional_transformer import ( 2 | VanillaConditionalTransformer, 3 | PEConditionalTransformer, 4 | RPEConditionalTransformer, 5 | LRPEConditionalTransformer, 6 | ) 7 | from geotransformer.modules.transformer.lrpe_transformer import LRPETransformerLayer 8 | from geotransformer.modules.transformer.pe_transformer import PETransformerLayer 9 | from geotransformer.modules.transformer.positional_embedding import ( 10 | SinusoidalPositionalEmbedding, 11 | LearnablePositionalEmbedding, 12 | ) 13 | from geotransformer.modules.transformer.rpe_transformer import RPETransformerLayer 14 | from geotransformer.modules.transformer.vanilla_transformer import ( 15 | TransformerLayer, 16 | TransformerDecoderLayer, 17 | TransformerEncoder, 18 | TransformerDecoder, 19 | ) 20 | -------------------------------------------------------------------------------- /geotransformer/modules/transformer/lrpe_transformer.py: -------------------------------------------------------------------------------- 1 | r"""Transformer with Learnable Relative Positional Embeddings. 2 | 3 | Relative positional embedding is injected in each multi-head attention layer. 4 | 5 | The shape of input tensor should be (B, N, C). 6 | Implemented with `nn.Linear` and `nn.LayerNorm` (with affine). 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | 14 | from geotransformer.modules.layers import build_dropout_layer 15 | from geotransformer.modules.transformer.output_layer import AttentionOutput 16 | from geotransformer.modules.transformer.positional_embedding import LearnablePositionalEmbedding 17 | 18 | 19 | class LRPEMultiHeadAttention(nn.Module): 20 | def __init__(self, d_model, num_heads, num_embeddings, dropout=None): 21 | super(LRPEMultiHeadAttention, self).__init__() 22 | if d_model % num_heads != 0: 23 | raise ValueError(f'"d_model" ({d_model}) is not divisible by "num_heads" ({num_heads}).') 24 | 25 | self.d_model = d_model 26 | self.num_heads = num_heads 27 | self.d_model_per_head = d_model // num_heads 28 | self.num_embeddings = num_embeddings 29 | 30 | self.proj_q = nn.Linear(self.d_model, self.d_model) 31 | self.proj_k = nn.Linear(self.d_model, self.d_model) 32 | self.proj_v = nn.Linear(self.d_model, self.d_model) 33 | 34 | self.embedding = LearnablePositionalEmbedding(num_embeddings, d_model, dropout=dropout) 35 | 36 | self.dropout = build_dropout_layer(dropout) 37 | 38 | def transpose_for_scores(self, x): 39 | x = x.view(x.shape[0], x.shape[1], self.num_heads, self.d_model_per_head) 40 | x = x.permute(0, 2, 1, 3) 41 | return x 42 | 43 | def get_embeddings(self, q, emb_indices): 44 | emb_all_indices = torch.arange(self.num_embeddings).cuda() # (P,) 45 | emb_bank = rearrange(self.embedding(emb_all_indices), 'p (h c) -> h p c', h=self.num_heads) 46 | attention_scores = torch.einsum('bhnc,hpc->bhnp', q, emb_bank) 47 | emb_indices = emb_indices.unsqueeze(1).expand(-1, self.num_heads, -1, -1) # (B, N, M) -> (B, H, N, M) 48 | attention_scores = torch.gather(attention_scores, dim=-1, index=emb_indices) # (B, H, N, P) -> (B, H, N, M) 49 | return attention_scores 50 | 51 | def forward( 52 | self, 53 | input_q, 54 | input_k, 55 | input_v, 56 | emb_indices_qk, 57 | key_masks=None, 58 | attention_factors=None, 59 | ): 60 | r"""Scaled Dot-Product Attention with Learnable Relative Positional Embedding (forward) 61 | 62 | Args: 63 | input_q: torch.Tensor (B, N, C) 64 | input_k: torch.Tensor (B, M, C) 65 | input_v: torch.Tensor (B, M, C) 66 | emb_indices_qk: torch.Tensor (B, N, M), relative position indices 67 | key_masks: torch.Tensor (B, M), True if ignored, False if preserved 68 | attention_factors: torch.Tensor (B, N, M) 69 | 70 | Returns 71 | hidden_states: torch.Tensor (B, N, C) 72 | attention_scores: torch.Tensor (B, H, N, M) 73 | """ 74 | q = rearrange(self.proj_q(input_q), 'b n (h c) -> b h n c', h=self.num_heads) 75 | k = rearrange(self.proj_k(input_k), 'b m (h c) -> b h m c', h=self.num_heads) 76 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 77 | 78 | attention_scores_p = self.get_embedding_attention(q, emb_indices_qk) 79 | 80 | attention_scores_e = torch.einsum('bhnc,bhmc->bhnm', q, k) 81 | attention_scores = (attention_scores_e + attention_scores_p) / self.d_model_per_head ** 0.5 82 | if attention_factors is not None: 83 | attention_scores = attention_factors.unsqueeze(1) * attention_scores 84 | if key_masks is not None: 85 | attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf')) 86 | attention_scores = F.softmax(attention_scores, dim=-1) 87 | attention_scores = self.dropout(attention_scores) 88 | 89 | hidden_states = torch.matmul(attention_scores, v) 90 | 91 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 92 | 93 | return hidden_states, attention_scores 94 | 95 | 96 | class LRPEAttentionLayer(nn.Module): 97 | def __init__(self, d_model, num_heads, rpe_size, dropout=None): 98 | super(LRPEAttentionLayer, self).__init__() 99 | self.attention = LRPEMultiHeadAttention(d_model, num_heads, rpe_size, dropout=dropout) 100 | self.linear = nn.Linear(d_model, d_model) 101 | self.dropout = build_dropout_layer(dropout) 102 | self.norm = nn.LayerNorm(d_model) 103 | 104 | def forward( 105 | self, 106 | input_states, 107 | memory_states, 108 | position_states, 109 | memory_masks=None, 110 | attention_factors=None, 111 | ): 112 | hidden_states, attention_scores = self.attention( 113 | input_states, 114 | memory_states, 115 | memory_states, 116 | position_states, 117 | key_masks=memory_masks, 118 | attention_factors=attention_factors, 119 | ) 120 | hidden_states = self.linear(hidden_states) 121 | hidden_states = self.dropout(hidden_states) 122 | output_states = self.norm(hidden_states + input_states) 123 | return output_states, attention_scores 124 | 125 | 126 | class LRPETransformerLayer(nn.Module): 127 | def __init__(self, d_model, num_heads, rpe_size, dropout=None, activation_fn='ReLU'): 128 | super(LRPETransformerLayer, self).__init__() 129 | self.attention = LRPEAttentionLayer(d_model, num_heads, rpe_size, dropout=dropout) 130 | self.output = AttentionOutput(d_model, dropout=dropout, activation_fn=activation_fn) 131 | 132 | def forward( 133 | self, 134 | input_states, 135 | memory_states, 136 | position_states, 137 | memory_masks=None, 138 | attention_factors=None, 139 | ): 140 | hidden_states, attention_scores = self.attention( 141 | input_states, 142 | memory_states, 143 | position_states, 144 | memory_masks=memory_masks, 145 | attention_factors=attention_factors, 146 | ) 147 | output_states = self.output(hidden_states) 148 | return output_states, attention_scores 149 | -------------------------------------------------------------------------------- /geotransformer/modules/transformer/output_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from geotransformer.modules.layers import build_act_layer, build_dropout_layer 4 | 5 | 6 | class AttentionOutput(nn.Module): 7 | def __init__(self, d_model, dropout=None, activation_fn='ReLU'): 8 | super(AttentionOutput, self).__init__() 9 | self.expand = nn.Linear(d_model, d_model * 2) 10 | self.activation = build_act_layer(activation_fn) 11 | self.squeeze = nn.Linear(d_model * 2, d_model) 12 | self.dropout = build_dropout_layer(dropout) 13 | self.norm = nn.LayerNorm(d_model) 14 | 15 | def forward(self, input_states): 16 | hidden_states = self.expand(input_states) 17 | hidden_states = self.activation(hidden_states) 18 | hidden_states = self.squeeze(hidden_states) 19 | hidden_states = self.dropout(hidden_states) 20 | output_states = self.norm(input_states + hidden_states) 21 | return output_states 22 | -------------------------------------------------------------------------------- /geotransformer/modules/transformer/pe_transformer.py: -------------------------------------------------------------------------------- 1 | r"""Vanilla Transformer without positional embeddings. 2 | 3 | The shape of input tensor should be (B, N, C). Implemented with `nn.Linear` and `nn.LayerNorm` (with affine). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | 11 | from geotransformer.modules.layers import build_dropout_layer 12 | from geotransformer.modules.transformer.output_layer import AttentionOutput 13 | 14 | 15 | class PEMultiHeadAttention(nn.Module): 16 | def __init__(self, d_model, num_heads, dropout=None): 17 | super(PEMultiHeadAttention, self).__init__() 18 | if d_model % num_heads != 0: 19 | raise ValueError('`d_model` ({}) must be a multiple of `num_head` ({}).'.format(d_model, num_heads)) 20 | 21 | self.d_model = d_model 22 | self.num_heads = num_heads 23 | self.d_model_per_head = d_model // num_heads 24 | 25 | self.proj_q = nn.Linear(self.d_model, self.d_model) 26 | self.proj_k = nn.Linear(self.d_model, self.d_model) 27 | self.proj_v = nn.Linear(self.d_model, self.d_model) 28 | self.proj_p = nn.Linear(self.d_model, self.d_model) 29 | 30 | self.dropout = build_dropout_layer(dropout) 31 | 32 | def forward( 33 | self, 34 | input_q, 35 | input_k, 36 | input_v, 37 | embed_q, 38 | embed_k, 39 | key_masks=None, 40 | attention_factors=None, 41 | ): 42 | """Self-attention with positional embedding forward propagation. 43 | 44 | Args: 45 | input_q: torch.Tensor (B, N, C) 46 | input_k: torch.Tensor (B, M, C) 47 | input_v: torch.Tensor (B, M, C) 48 | embed_q: torch.Tensor (B, N, C) 49 | embed_k: torch.Tensor (B, M, C) 50 | key_masks: torch.Tensor (B, M), True if ignored, False if preserved 51 | attention_factors: torch.Tensor (B, N, M) 52 | 53 | Returns: 54 | hidden_states: torch.Tensor (B, C, N) 55 | attention_scores: torch.Tensor (B, H, N, M) 56 | """ 57 | q = rearrange(self.proj_q(input_q) + self.proj_p(embed_q), 'b n (h c) -> b h n c', h=self.num_heads) 58 | k = rearrange(self.proj_k(input_k) + self.proj_p(embed_k), 'b m (h c) -> b h m c', h=self.num_heads) 59 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 60 | 61 | attention_scores = torch.einsum('bhnc,bhmc->bhnm', q, k) / self.d_model_per_head ** 0.5 62 | if attention_factors is not None: 63 | attention_scores = attention_factors.unsqueeze(1) * attention_scores 64 | if key_masks is not None: 65 | attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf')) 66 | attention_scores = F.softmax(attention_scores, dim=-1) 67 | attention_scores = self.dropout(attention_scores) 68 | 69 | hidden_states = torch.matmul(attention_scores, v) 70 | 71 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 72 | 73 | return hidden_states, attention_scores 74 | 75 | 76 | class PEAttentionLayer(nn.Module): 77 | def __init__(self, d_model, num_heads, dropout=None): 78 | super(PEAttentionLayer, self).__init__() 79 | self.attention = PEMultiHeadAttention(d_model, num_heads, dropout=dropout) 80 | self.linear = nn.Linear(d_model, d_model) 81 | self.dropout = build_dropout_layer(dropout) 82 | self.norm = nn.LayerNorm(d_model) 83 | 84 | def forward( 85 | self, 86 | input_states, 87 | memory_states, 88 | input_embeddings, 89 | memory_embeddings, 90 | memory_masks=None, 91 | attention_factors=None, 92 | ): 93 | hidden_states, attention_scores = self.attention( 94 | input_states, 95 | memory_states, 96 | memory_states, 97 | input_embeddings, 98 | memory_embeddings, 99 | key_masks=memory_masks, 100 | attention_factors=attention_factors, 101 | ) 102 | hidden_states = self.linear(hidden_states) 103 | hidden_states = self.dropout(hidden_states) 104 | output_states = self.norm(hidden_states + input_states) 105 | return output_states, attention_scores 106 | 107 | 108 | class PETransformerLayer(nn.Module): 109 | def __init__(self, d_model, num_heads, dropout=None, activation_fn='ReLU'): 110 | super(PETransformerLayer, self).__init__() 111 | self.attention = PEAttentionLayer(d_model, num_heads, dropout=dropout) 112 | self.output = AttentionOutput(d_model, dropout=dropout, activation_fn=activation_fn) 113 | 114 | def forward( 115 | self, 116 | input_states, 117 | memory_states, 118 | input_embeddings, 119 | memory_embeddings, 120 | memory_masks=None, 121 | attention_factors=None, 122 | ): 123 | hidden_states, attention_scores = self.attention( 124 | input_states, 125 | memory_states, 126 | input_embeddings, 127 | memory_embeddings, 128 | memory_masks=memory_masks, 129 | attention_factors=attention_factors, 130 | ) 131 | output_states = self.output(hidden_states) 132 | return output_states, attention_scores 133 | -------------------------------------------------------------------------------- /geotransformer/modules/transformer/positional_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from geotransformer.modules.layers import build_dropout_layer 6 | 7 | 8 | class SinusoidalPositionalEmbedding(nn.Module): 9 | def __init__(self, d_model): 10 | super(SinusoidalPositionalEmbedding, self).__init__() 11 | if d_model % 2 != 0: 12 | raise ValueError(f'Sinusoidal positional encoding with odd d_model: {d_model}') 13 | self.d_model = d_model 14 | div_indices = torch.arange(0, d_model, 2).float() 15 | div_term = torch.exp(div_indices * (-np.log(10000.0) / d_model)) 16 | self.register_buffer('div_term', div_term) 17 | 18 | def forward(self, emb_indices): 19 | r"""Sinusoidal Positional Embedding. 20 | 21 | Args: 22 | emb_indices: torch.Tensor (*) 23 | 24 | Returns: 25 | embeddings: torch.Tensor (*, D) 26 | """ 27 | input_shape = emb_indices.shape 28 | omegas = emb_indices.view(-1, 1, 1) * self.div_term.view(1, -1, 1) # (-1, d_model/2, 1) 29 | sin_embeddings = torch.sin(omegas) 30 | cos_embeddings = torch.cos(omegas) 31 | embeddings = torch.cat([sin_embeddings, cos_embeddings], dim=2) # (-1, d_model/2, 2) 32 | embeddings = embeddings.view(*input_shape, self.d_model) # (*, d_model) 33 | embeddings = embeddings.detach() 34 | return embeddings 35 | 36 | 37 | class LearnablePositionalEmbedding(nn.Module): 38 | def __init__(self, num_embeddings, embedding_dim, dropout=None): 39 | super(LearnablePositionalEmbedding, self).__init__() 40 | self.num_embeddings = num_embeddings 41 | self.embedding_dim = embedding_dim 42 | self.embeddings = nn.Embedding(num_embeddings, embedding_dim) # (L, D) 43 | self.norm = nn.LayerNorm(embedding_dim) 44 | self.dropout = build_dropout_layer(dropout) 45 | 46 | def forward(self, emb_indices): 47 | r"""Learnable Positional Embedding. 48 | 49 | `emb_indices` are truncated to fit the finite embedding space. 50 | 51 | Args: 52 | emb_indices: torch.LongTensor (*) 53 | 54 | Returns: 55 | embeddings: torch.Tensor (*, D) 56 | """ 57 | input_shape = emb_indices.shape 58 | emb_indices = emb_indices.view(-1) 59 | max_emd_indices = torch.full_like(emb_indices, self.num_embeddings - 1) 60 | emb_indices = torch.minimum(emb_indices, max_emd_indices) 61 | embeddings = self.embeddings(emb_indices) # (*, D) 62 | embeddings = self.norm(embeddings) 63 | embeddings = self.dropout(embeddings) 64 | embeddings = embeddings.view(*input_shape, self.embedding_dim) 65 | return embeddings 66 | -------------------------------------------------------------------------------- /geotransformer/modules/transformer/rpe_transformer.py: -------------------------------------------------------------------------------- 1 | r"""Transformer with Relative Positional Embeddings. 2 | 3 | Relative positional embedding is further projected in each multi-head attention layer. 4 | 5 | The shape of input tensor should be (B, N, C). Implemented with `nn.Linear` and `nn.LayerNorm` (with affine). 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from einops import rearrange 12 | from IPython import embed 13 | 14 | from geotransformer.modules.layers import build_dropout_layer 15 | from geotransformer.modules.transformer.output_layer import AttentionOutput 16 | 17 | 18 | class RPEMultiHeadAttention(nn.Module): 19 | def __init__(self, d_model, num_heads, dropout=None): 20 | super(RPEMultiHeadAttention, self).__init__() 21 | if d_model % num_heads != 0: 22 | raise ValueError('`d_model` ({}) must be a multiple of `num_heads` ({}).'.format(d_model, num_heads)) 23 | 24 | self.d_model = d_model 25 | self.num_heads = num_heads 26 | self.d_model_per_head = d_model // num_heads 27 | 28 | self.proj_q = nn.Linear(self.d_model, self.d_model) 29 | self.proj_k = nn.Linear(self.d_model, self.d_model) 30 | self.proj_v = nn.Linear(self.d_model, self.d_model) 31 | self.proj_p = nn.Linear(self.d_model, self.d_model) 32 | 33 | self.dropout = build_dropout_layer(dropout) 34 | 35 | def forward(self, input_q, input_k, input_v, embed_qk, key_weights=None, key_masks=None, attention_factors=None): 36 | r"""Scaled Dot-Product Attention with Pre-computed Relative Positional Embedding (forward) 37 | 38 | Args: 39 | input_q: torch.Tensor (B, N, C) 40 | input_k: torch.Tensor (B, M, C) 41 | input_v: torch.Tensor (B, M, C) 42 | embed_qk: torch.Tensor (B, N, M, C), relative positional embedding 43 | key_weights: torch.Tensor (B, M), soft masks for the keys 44 | key_masks: torch.Tensor (B, M), True if ignored, False if preserved 45 | attention_factors: torch.Tensor (B, N, M) 46 | 47 | Returns: 48 | hidden_states: torch.Tensor (B, C, N) 49 | attention_scores: torch.Tensor (B, H, N, M) 50 | """ 51 | q = rearrange(self.proj_q(input_q), 'b n (h c) -> b h n c', h=self.num_heads) 52 | k = rearrange(self.proj_k(input_k), 'b m (h c) -> b h m c', h=self.num_heads) 53 | v = rearrange(self.proj_v(input_v), 'b m (h c) -> b h m c', h=self.num_heads) 54 | p = rearrange(self.proj_p(embed_qk), 'b n m (h c) -> b h n m c', h=self.num_heads) 55 | 56 | attention_scores_p = torch.einsum('bhnc,bhnmc->bhnm', q, p) 57 | attention_scores_e = torch.einsum('bhnc,bhmc->bhnm', q, k) 58 | attention_scores = (attention_scores_e + attention_scores_p) / self.d_model_per_head ** 0.5 59 | if attention_factors is not None: 60 | attention_scores = attention_factors.unsqueeze(1) * attention_scores 61 | if key_weights is not None: 62 | attention_scores = attention_scores * key_weights.unsqueeze(1).unsqueeze(1) 63 | if key_masks is not None: 64 | attention_scores = attention_scores.masked_fill(key_masks.unsqueeze(1).unsqueeze(1), float('-inf')) 65 | attention_scores = F.softmax(attention_scores, dim=-1) 66 | attention_scores = self.dropout(attention_scores) 67 | 68 | hidden_states = torch.matmul(attention_scores, v) 69 | 70 | hidden_states = rearrange(hidden_states, 'b h n c -> b n (h c)') 71 | 72 | return hidden_states, attention_scores 73 | 74 | 75 | class RPEAttentionLayer(nn.Module): 76 | def __init__(self, d_model, num_heads, dropout=None): 77 | super(RPEAttentionLayer, self).__init__() 78 | self.attention = RPEMultiHeadAttention(d_model, num_heads, dropout=dropout) 79 | self.linear = nn.Linear(d_model, d_model) 80 | self.dropout = build_dropout_layer(dropout) 81 | self.norm = nn.LayerNorm(d_model) 82 | 83 | def forward( 84 | self, 85 | input_states, 86 | memory_states, 87 | position_states, 88 | memory_weights=None, 89 | memory_masks=None, 90 | attention_factors=None, 91 | ): 92 | hidden_states, attention_scores = self.attention( 93 | input_states, 94 | memory_states, 95 | memory_states, 96 | position_states, 97 | key_weights=memory_weights, 98 | key_masks=memory_masks, 99 | attention_factors=attention_factors, 100 | ) 101 | hidden_states = self.linear(hidden_states) 102 | hidden_states = self.dropout(hidden_states) 103 | output_states = self.norm(hidden_states + input_states) 104 | return output_states, attention_scores 105 | 106 | 107 | class RPETransformerLayer(nn.Module): 108 | def __init__(self, d_model, num_heads, dropout=None, activation_fn='ReLU'): 109 | super(RPETransformerLayer, self).__init__() 110 | self.attention = RPEAttentionLayer(d_model, num_heads, dropout=dropout) 111 | self.output = AttentionOutput(d_model, dropout=dropout, activation_fn=activation_fn) 112 | 113 | def forward( 114 | self, 115 | input_states, 116 | memory_states, 117 | position_states, 118 | memory_weights=None, 119 | memory_masks=None, 120 | attention_factors=None, 121 | ): 122 | hidden_states, attention_scores = self.attention( 123 | input_states, 124 | memory_states, 125 | position_states, 126 | memory_weights=memory_weights, 127 | memory_masks=memory_masks, 128 | attention_factors=attention_factors, 129 | ) 130 | output_states = self.output(hidden_states) 131 | return output_states, attention_scores 132 | -------------------------------------------------------------------------------- /geotransformer/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/transforms/__init__.py -------------------------------------------------------------------------------- /geotransformer/transforms/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | 6 | 7 | def normalize_points(points): 8 | r"""Normalize point cloud to a unit sphere at origin.""" 9 | points = points - points.mean(axis=0) 10 | points = points / np.max(np.linalg.norm(points, axis=1)) 11 | return points 12 | 13 | 14 | def sample_points(points, num_samples, normals=None): 15 | r"""Sample the first K points.""" 16 | points = points[:num_samples] 17 | if normals is not None: 18 | normals = normals[:num_samples] 19 | return points, normals 20 | else: 21 | return points 22 | 23 | 24 | def random_sample_points(points, num_samples, normals=None): 25 | r"""Randomly sample points.""" 26 | num_points = points.shape[0] 27 | sel_indices = np.random.permutation(num_points) 28 | if num_points > num_samples: 29 | sel_indices = sel_indices[:num_samples] 30 | elif num_points < num_samples: 31 | num_iterations = num_samples // num_points 32 | num_paddings = num_samples % num_points 33 | all_sel_indices = [sel_indices for _ in range(num_iterations)] 34 | if num_paddings > 0: 35 | all_sel_indices.append(sel_indices[:num_paddings]) 36 | sel_indices = np.concatenate(all_sel_indices, axis=0) 37 | points = points[sel_indices] 38 | if normals is not None: 39 | normals = normals[sel_indices] 40 | return points, normals 41 | else: 42 | return points 43 | 44 | 45 | def random_scale_shift_points(points, low=2.0 / 3.0, high=3.0 / 2.0, shift=0.2, normals=None): 46 | r"""Randomly scale and shift point cloud.""" 47 | scale = np.random.uniform(low=low, high=high, size=(1, 3)) 48 | bias = np.random.uniform(low=-shift, high=shift, size=(1, 3)) 49 | points = points * scale + bias 50 | if normals is not None: 51 | normals = normals * scale 52 | normals = normals / np.linalg.norm(normals, axis=1, keepdims=True) 53 | return points, normals 54 | else: 55 | return points 56 | 57 | 58 | def random_rotate_points_along_up_axis(points, normals=None): 59 | r"""Randomly rotate point cloud along z-axis.""" 60 | theta = np.random.rand() * 2.0 * math.pi 61 | # fmt: off 62 | rotation_t = np.array([ 63 | [math.cos(theta), math.sin(theta), 0], 64 | [-math.sin(theta), math.cos(theta), 0], 65 | [0, 0, 1], 66 | ]) 67 | # fmt: on 68 | points = np.matmul(points, rotation_t) 69 | if normals is not None: 70 | normals = np.matmul(normals, rotation_t) 71 | return points, normals 72 | else: 73 | return points 74 | 75 | 76 | def random_rescale_points(points, low=0.8, high=1.2): 77 | r"""Randomly rescale point cloud.""" 78 | scale = random.uniform(low, high) 79 | points = points * scale 80 | return points 81 | 82 | 83 | def random_jitter_points(points, scale, noise_magnitude=0.05): 84 | r"""Randomly jitter point cloud.""" 85 | noises = np.clip(np.random.normal(scale=scale, size=points.shape), a_min=-noise_magnitude, a_max=noise_magnitude) 86 | points = points + noises 87 | return points 88 | 89 | 90 | def random_shuffle_points(points, normals=None): 91 | r"""Randomly permute point cloud.""" 92 | indices = np.random.permutation(points.shape[0]) 93 | points = points[indices] 94 | if normals is not None: 95 | normals = normals[indices] 96 | return points, normals 97 | else: 98 | return points 99 | 100 | 101 | def random_dropout_points(points, max_p): 102 | r"""Randomly dropout point cloud proposed in PointNet++.""" 103 | num_points = points.shape[0] 104 | p = np.random.rand(num_points) * max_p 105 | masks = np.random.rand(num_points) < p 106 | points[masks] = points[0] 107 | return points 108 | 109 | 110 | def random_jitter_features(features, mu=0, sigma=0.01): 111 | r"""Randomly jitter features in the original implementation of FCGF.""" 112 | if random.random() < 0.95: 113 | features = features + np.random.normal(mu, sigma, features.shape).astype(np.float32) 114 | return features 115 | 116 | 117 | def random_sample_plane(): 118 | r"""Random sample a plane passing the origin and return its normal.""" 119 | phi = np.random.uniform(0.0, 2 * np.pi) # longitude 120 | theta = np.random.uniform(0.0, np.pi) # latitude 121 | 122 | x = np.sin(theta) * np.cos(phi) 123 | y = np.sin(theta) * np.sin(phi) 124 | z = np.cos(theta) 125 | normal = np.asarray([x, y, z]) 126 | 127 | return normal 128 | 129 | 130 | def random_crop_point_cloud_with_plane(points, p_normal=None, keep_ratio=0.7, normals=None): 131 | r"""Random crop a point cloud with a plane and keep num_samples points.""" 132 | num_samples = int(np.floor(points.shape[0] * keep_ratio + 0.5)) 133 | if p_normal is None: 134 | p_normal = random_sample_plane() # (3,) 135 | distances = np.dot(points, p_normal) 136 | sel_indices = np.argsort(-distances)[:num_samples] # select the largest K points 137 | points = points[sel_indices] 138 | if normals is not None: 139 | normals = normals[sel_indices] 140 | return points, normals 141 | else: 142 | return points 143 | 144 | 145 | def random_sample_viewpoint(limit=500): 146 | r"""Randomly sample observing point from 8 directions.""" 147 | return np.random.rand(3) + np.array([limit, limit, limit]) * np.random.choice([1.0, -1.0], size=3) 148 | 149 | 150 | def random_crop_point_cloud_with_point(points, viewpoint=None, keep_ratio=0.7, normals=None): 151 | r"""Random crop point cloud from the observing point.""" 152 | num_samples = int(np.floor(points.shape[0] * keep_ratio + 0.5)) 153 | if viewpoint is None: 154 | viewpoint = random_sample_viewpoint() 155 | distances = np.linalg.norm(viewpoint - points, axis=1) 156 | sel_indices = np.argsort(distances)[:num_samples] 157 | points = points[sel_indices] 158 | if normals is not None: 159 | normals = normals[sel_indices] 160 | return points, normals 161 | else: 162 | return points 163 | -------------------------------------------------------------------------------- /geotransformer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinzheng93/GeoTransformer/e7a135af4c318ff3b8d7f6c963df094d7e4ea540/geotransformer/utils/__init__.py -------------------------------------------------------------------------------- /geotransformer/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AverageMeter: 5 | def __init__(self, last_n=None): 6 | self._records = [] 7 | self.last_n = last_n 8 | 9 | def update(self, result): 10 | if isinstance(result, (list, tuple)): 11 | self._records += result 12 | else: 13 | self._records.append(result) 14 | 15 | def reset(self): 16 | self._records.clear() 17 | 18 | @property 19 | def records(self): 20 | if self.last_n is not None: 21 | return self._records[-self.last_n :] 22 | else: 23 | return self._records 24 | 25 | def sum(self): 26 | return np.sum(self.records) 27 | 28 | def mean(self): 29 | return np.mean(self.records) 30 | 31 | def std(self): 32 | return np.std(self.records) 33 | 34 | def median(self): 35 | return np.median(self.records) 36 | -------------------------------------------------------------------------------- /geotransformer/utils/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pickle 4 | 5 | 6 | def ensure_dir(path): 7 | if not osp.exists(path): 8 | os.makedirs(path) 9 | 10 | 11 | def load_pickle(filename): 12 | with open(filename, 'rb') as f: 13 | data = pickle.load(f) 14 | return data 15 | 16 | 17 | def dump_pickle(data, filename): 18 | with open(filename, 'wb') as f: 19 | pickle.dump(data, f) 20 | 21 | 22 | def get_print_format(value): 23 | if isinstance(value, int): 24 | return 'd' 25 | if isinstance(value, str): 26 | return 's' 27 | if value == 0: 28 | return '.3f' 29 | if value < 1e-6: 30 | return '.3e' 31 | if value < 1e-3: 32 | return '.6f' 33 | return '.3f' 34 | 35 | 36 | def get_format_strings(kv_pairs): 37 | r"""Get format string for a list of key-value pairs.""" 38 | log_strings = [] 39 | for key, value in kv_pairs: 40 | fmt = get_print_format(value) 41 | format_string = '{}: {:' + fmt + '}' 42 | log_strings.append(format_string.format(key, value)) 43 | return log_strings 44 | 45 | 46 | def get_log_string(result_dict, epoch=None, max_epoch=None, iteration=None, max_iteration=None, lr=None, timer=None): 47 | log_strings = [] 48 | if epoch is not None: 49 | epoch_string = f'Epoch: {epoch}' 50 | if max_epoch is not None: 51 | epoch_string += f'/{max_epoch}' 52 | log_strings.append(epoch_string) 53 | if iteration is not None: 54 | iter_string = f'iter: {iteration}' 55 | if max_iteration is not None: 56 | iter_string += f'/{max_iteration}' 57 | if epoch is None: 58 | iter_string = iter_string.capitalize() 59 | log_strings.append(iter_string) 60 | if 'metadata' in result_dict: 61 | log_strings += result_dict['metadata'] 62 | for key, value in result_dict.items(): 63 | if key != 'metadata': 64 | format_string = '{}: {:' + get_print_format(value) + '}' 65 | log_strings.append(format_string.format(key, value)) 66 | if lr is not None: 67 | log_strings.append('lr: {:.3e}'.format(lr)) 68 | if timer is not None: 69 | log_strings.append(timer.tostring()) 70 | message = ', '.join(log_strings) 71 | return message 72 | -------------------------------------------------------------------------------- /geotransformer/utils/summary_board.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from geotransformer.utils.average_meter import AverageMeter 4 | from geotransformer.utils.common import get_print_format 5 | 6 | 7 | class SummaryBoard: 8 | r"""Summary board.""" 9 | 10 | def __init__(self, names: Optional[List[str]] = None, last_n: Optional[int] = None, adaptive=False): 11 | r"""Instantiate a SummaryBoard. 12 | 13 | Args: 14 | names (List[str]=None): create AverageMeter with the names. 15 | last_n (int=None): only the last n records are used. 16 | adaptive (bool=False): whether register basic meters automatically on the fly. 17 | """ 18 | self.meter_dict = {} 19 | self.meter_names = [] 20 | self.last_n = last_n 21 | self.adaptive = adaptive 22 | 23 | if names is not None: 24 | self.register_all(names) 25 | 26 | def register_meter(self, name): 27 | self.meter_dict[name] = AverageMeter(last_n=self.last_n) 28 | self.meter_names.append(name) 29 | 30 | def register_all(self, names): 31 | for name in names: 32 | self.register_meter(name) 33 | 34 | def reset_meter(self, name): 35 | self.meter_dict[name].reset() 36 | 37 | def reset_all(self): 38 | for name in self.meter_names: 39 | self.reset_meter(name) 40 | 41 | def check_name(self, name): 42 | if name not in self.meter_names: 43 | if self.adaptive: 44 | self.register_meter(name) 45 | else: 46 | raise KeyError('No meter for key "{}".'.format(name)) 47 | 48 | def update(self, name, value): 49 | self.check_name(name) 50 | self.meter_dict[name].update(value) 51 | 52 | def update_from_result_dict(self, result_dict): 53 | if not isinstance(result_dict, dict): 54 | raise TypeError('`result_dict` must be a dict: {}.'.format(type(result_dict))) 55 | for key, value in result_dict.items(): 56 | if key not in self.meter_names and self.adaptive: 57 | self.register_meter(key) 58 | if key in self.meter_names: 59 | self.meter_dict[key].update(value) 60 | 61 | def sum(self, name): 62 | self.check_name(name) 63 | return self.meter_dict[name].sum() 64 | 65 | def mean(self, name): 66 | self.check_name(name) 67 | return self.meter_dict[name].mean() 68 | 69 | def std(self, name): 70 | self.check_name(name) 71 | return self.meter_dict[name].std() 72 | 73 | def median(self, name): 74 | self.check_name(name) 75 | return self.meter_dict[name].median() 76 | 77 | def tostring(self, names=None): 78 | if names is None: 79 | names = self.meter_names 80 | items = [] 81 | for name in names: 82 | value = self.meter_dict[name].mean() 83 | fmt = get_print_format(value) 84 | format_string = '{}: {:' + fmt + '}' 85 | items.append(format_string.format(name, value)) 86 | summary = ', '.join(items) 87 | return summary 88 | 89 | def summary(self, names=None): 90 | if names is None: 91 | names = self.meter_names 92 | summary_dict = {name: self.meter_dict[name].mean() for name in names} 93 | return summary_dict 94 | -------------------------------------------------------------------------------- /geotransformer/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer: 5 | def __init__(self): 6 | self.total_prepare_time = 0 7 | self.total_process_time = 0 8 | self.count_prepare_time = 0 9 | self.count_process_time = 0 10 | self.last_time = time.time() 11 | 12 | def reset(self): 13 | self.total_prepare_time = 0 14 | self.total_process_time = 0 15 | self.count_prepare_time = 0 16 | self.count_process_time = 0 17 | self.last_time = time.time() 18 | 19 | def record_time(self): 20 | self.last_time = time.time() 21 | 22 | def add_prepare_time(self): 23 | current_time = time.time() 24 | self.total_prepare_time += current_time - self.last_time 25 | self.count_prepare_time += 1 26 | self.last_time = current_time 27 | 28 | def add_process_time(self): 29 | current_time = time.time() 30 | self.total_process_time += current_time - self.last_time 31 | self.count_process_time += 1 32 | self.last_time = current_time 33 | 34 | def get_prepare_time(self): 35 | return self.total_prepare_time / (self.count_prepare_time + 1e-12) 36 | 37 | def get_process_time(self): 38 | return self.total_process_time / (self.count_process_time + 1e-12) 39 | 40 | def tostring(self): 41 | summary = 'time: ' 42 | if self.count_prepare_time > 0: 43 | summary += '{:.3f}s/'.format(self.get_prepare_time()) 44 | summary += '{:.3f}s'.format(self.get_process_time()) 45 | return summary 46 | 47 | 48 | class TimerDict: 49 | def __init__(self): 50 | self.total_time = {} 51 | self.count_time = {} 52 | self.last_time = {} 53 | self.timer_keys = [] 54 | 55 | def add_timer(self, key): 56 | self.total_time[key] = 0.0 57 | self.count_time[key] = 0 58 | self.last_time[key] = 0.0 59 | self.timer_keys.append(key) 60 | 61 | def tic(self, key): 62 | if key not in self.timer_keys: 63 | self.add_timer(key) 64 | self.last_time[key] = time.time() 65 | 66 | def toc(self, key): 67 | assert key in self.timer_keys 68 | duration = time.time() - self.last_time[key] 69 | self.total_time[key] += duration 70 | self.count_time[key] += 1 71 | 72 | def get_time(self, key): 73 | assert key in self.timer_keys 74 | return self.total_time[key] / (float(self.count_time[key]) + 1e-12) 75 | 76 | def summary(self, keys): 77 | summary = 'time: ' 78 | summary += '/'.join(['{:.3f}s'.format(self.get_time(key)) for key in keys]) 79 | return summary 80 | -------------------------------------------------------------------------------- /geotransformer/utils/torch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from typing import Callable 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torch.utils.data 10 | import torch.backends.cudnn as cudnn 11 | 12 | 13 | # Distributed Data Parallel Utilities 14 | 15 | 16 | def all_reduce_tensor(tensor, world_size=1): 17 | r"""Average reduce a tensor across all workers.""" 18 | reduced_tensor = tensor.clone() 19 | dist.all_reduce(reduced_tensor) 20 | reduced_tensor /= world_size 21 | return reduced_tensor 22 | 23 | 24 | def all_reduce_tensors(x, world_size=1): 25 | r"""Average reduce all tensors across all workers.""" 26 | if isinstance(x, list): 27 | x = [all_reduce_tensors(item, world_size=world_size) for item in x] 28 | elif isinstance(x, tuple): 29 | x = (all_reduce_tensors(item, world_size=world_size) for item in x) 30 | elif isinstance(x, dict): 31 | x = {key: all_reduce_tensors(value, world_size=world_size) for key, value in x.items()} 32 | elif isinstance(x, torch.Tensor): 33 | x = all_reduce_tensor(x, world_size=world_size) 34 | return x 35 | 36 | 37 | # Dataloader Utilities 38 | 39 | 40 | def reset_seed_worker_init_fn(worker_id): 41 | r"""Reset seed for data loader worker.""" 42 | seed = torch.initial_seed() % (2 ** 32) 43 | # print(worker_id, seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | 47 | 48 | def build_dataloader( 49 | dataset, 50 | batch_size=1, 51 | num_workers=1, 52 | shuffle=None, 53 | collate_fn=None, 54 | pin_memory=False, 55 | drop_last=False, 56 | distributed=False, 57 | ): 58 | if distributed: 59 | sampler = torch.utils.data.DistributedSampler(dataset) 60 | shuffle = False 61 | else: 62 | sampler = None 63 | shuffle = shuffle 64 | 65 | data_loader = torch.utils.data.DataLoader( 66 | dataset, 67 | batch_size=batch_size, 68 | num_workers=num_workers, 69 | shuffle=shuffle, 70 | sampler=sampler, 71 | collate_fn=collate_fn, 72 | worker_init_fn=reset_seed_worker_init_fn, 73 | pin_memory=pin_memory, 74 | drop_last=drop_last, 75 | ) 76 | 77 | return data_loader 78 | 79 | 80 | # Common Utilities 81 | 82 | 83 | def initialize(seed=None, cudnn_deterministic=True, autograd_anomaly_detection=False): 84 | if seed is not None: 85 | random.seed(seed) 86 | torch.manual_seed(seed) 87 | np.random.seed(seed) 88 | if cudnn_deterministic: 89 | cudnn.benchmark = False 90 | cudnn.deterministic = True 91 | else: 92 | cudnn.benchmark = True 93 | cudnn.deterministic = False 94 | torch.autograd.set_detect_anomaly(autograd_anomaly_detection) 95 | 96 | 97 | def release_cuda(x): 98 | r"""Release all tensors to item or numpy array.""" 99 | if isinstance(x, list): 100 | x = [release_cuda(item) for item in x] 101 | elif isinstance(x, tuple): 102 | x = (release_cuda(item) for item in x) 103 | elif isinstance(x, dict): 104 | x = {key: release_cuda(value) for key, value in x.items()} 105 | elif isinstance(x, torch.Tensor): 106 | if x.numel() == 1: 107 | x = x.item() 108 | else: 109 | x = x.detach().cpu().numpy() 110 | return x 111 | 112 | 113 | def to_cuda(x): 114 | r"""Move all tensors to cuda.""" 115 | if isinstance(x, list): 116 | x = [to_cuda(item) for item in x] 117 | elif isinstance(x, tuple): 118 | x = (to_cuda(item) for item in x) 119 | elif isinstance(x, dict): 120 | x = {key: to_cuda(value) for key, value in x.items()} 121 | elif isinstance(x, torch.Tensor): 122 | x = x.cuda() 123 | return x 124 | 125 | 126 | def load_weights(model, snapshot): 127 | r"""Load weights and check keys.""" 128 | state_dict = torch.load(snapshot) 129 | model_dict = state_dict['model'] 130 | model.load_state_dict(model_dict, strict=False) 131 | 132 | snapshot_keys = set(model_dict.keys()) 133 | model_keys = set(model.model_dict().keys()) 134 | missing_keys = model_keys - snapshot_keys 135 | unexpected_keys = snapshot_keys - model_keys 136 | 137 | return missing_keys, unexpected_keys 138 | 139 | 140 | # Learning Rate Scheduler 141 | 142 | 143 | class CosineAnnealingFunction(Callable): 144 | def __init__(self, max_epoch, eta_min=0.0): 145 | self.max_epoch = max_epoch 146 | self.eta_min = eta_min 147 | 148 | def __call__(self, last_epoch): 149 | next_epoch = last_epoch + 1 150 | return self.eta_min + 0.5 * (1.0 - self.eta_min) * (1.0 + math.cos(math.pi * next_epoch / self.max_epoch)) 151 | 152 | 153 | class WarmUpCosineAnnealingFunction(Callable): 154 | def __init__(self, total_steps, warmup_steps, eta_init=0.1, eta_min=0.1): 155 | self.total_steps = total_steps 156 | self.warmup_steps = warmup_steps 157 | self.normal_steps = total_steps - warmup_steps 158 | self.eta_init = eta_init 159 | self.eta_min = eta_min 160 | 161 | def __call__(self, last_step): 162 | # last_step starts from -1, which means last_steps=0 indicates the first call of lr annealing. 163 | next_step = last_step + 1 164 | if next_step < self.warmup_steps: 165 | return self.eta_init + (1.0 - self.eta_init) / self.warmup_steps * next_step 166 | else: 167 | if next_step > self.total_steps: 168 | return self.eta_min 169 | next_step -= self.warmup_steps 170 | return self.eta_min + 0.5 * (1.0 - self.eta_min) * (1 + np.cos(np.pi * next_step / self.normal_steps)) 171 | 172 | 173 | def build_warmup_cosine_lr_scheduler(optimizer, total_steps, warmup_steps, eta_init=0.1, eta_min=0.1, grad_acc_steps=1): 174 | total_steps //= grad_acc_steps 175 | warmup_steps //= grad_acc_steps 176 | cosine_func = WarmUpCosineAnnealingFunction(total_steps, warmup_steps, eta_init=eta_init, eta_min=eta_min) 177 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, cosine_func) 178 | return scheduler 179 | -------------------------------------------------------------------------------- /geotransformer/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import open3d as o3d 4 | from sklearn.manifold import TSNE 5 | from tqdm import tqdm 6 | 7 | from geotransformer.utils.open3d import ( 8 | make_open3d_point_cloud, 9 | make_open3d_axes, 10 | make_open3d_corr_lines, 11 | ) 12 | 13 | 14 | def draw_point_to_node(points, nodes, point_to_node, node_colors=None): 15 | if node_colors is None: 16 | node_colors = np.random.rand(*nodes.shape) 17 | # point_colors = node_colors[point_to_node] * make_scaling_along_axis(points, alpha=0.3).reshape(-1, 1) 18 | point_colors = node_colors[point_to_node] 19 | node_colors = np.ones_like(nodes) * np.array([[1, 0, 0]]) 20 | 21 | ncd = make_open3d_point_cloud(nodes, colors=node_colors) 22 | pcd = make_open3d_point_cloud(points, colors=point_colors) 23 | axes = make_open3d_axes() 24 | 25 | o3d.visualization.draw([pcd, ncd, axes]) 26 | 27 | 28 | def draw_node_correspondences( 29 | ref_points, 30 | ref_nodes, 31 | ref_point_to_node, 32 | src_points, 33 | src_nodes, 34 | src_point_to_node, 35 | node_correspondences, 36 | ref_node_colors=None, 37 | src_node_colors=None, 38 | offsets=(0, 2, 0), 39 | ): 40 | src_nodes = src_nodes + offsets 41 | src_points = src_points + offsets 42 | 43 | if ref_node_colors is None: 44 | ref_node_colors = np.random.rand(*ref_nodes.shape) 45 | # src_point_colors = src_node_colors[src_point_to_node] * make_scaling_along_axis(src_points).reshape(-1, 1) 46 | ref_point_colors = ref_node_colors[ref_point_to_node] 47 | ref_node_colors = np.ones_like(ref_nodes) * np.array([[1, 0, 0]]) 48 | 49 | if src_node_colors is None: 50 | src_node_colors = np.random.rand(*src_nodes.shape) 51 | # tgt_point_colors = tgt_node_colors[tgt_point_to_node] * make_scaling_along_axis(tgt_points).reshape(-1, 1) 52 | src_point_colors = src_node_colors[src_point_to_node] 53 | src_node_colors = np.ones_like(src_nodes) * np.array([[1, 0, 0]]) 54 | 55 | ref_ncd = make_open3d_point_cloud(ref_nodes, colors=ref_node_colors) 56 | ref_pcd = make_open3d_point_cloud(ref_points, colors=ref_point_colors) 57 | src_ncd = make_open3d_point_cloud(src_nodes, colors=src_node_colors) 58 | src_pcd = make_open3d_point_cloud(src_points, colors=src_point_colors) 59 | corr_lines = make_open3d_corr_lines(ref_nodes, src_nodes, node_correspondences) 60 | axes = make_open3d_axes(scale=0.1) 61 | 62 | o3d.visualization.draw([ref_pcd, ref_ncd, src_pcd, src_ncd, corr_lines, axes]) 63 | 64 | 65 | def get_colors_with_tsne(data): 66 | r""" 67 | Use t-SNE to project high-dimension feats to rgbd 68 | :param data: (N, C) 69 | :return colors: (N, 3) 70 | """ 71 | tsne = TSNE(n_components=1, perplexity=40, n_iter=300, random_state=0) 72 | tsne_results = tsne.fit_transform(data).reshape(-1) 73 | tsne_min = np.min(tsne_results) 74 | tsne_max = np.max(tsne_results) 75 | normalized_tsne_results = (tsne_results - tsne_min) / (tsne_max - tsne_min) 76 | colors = plt.cm.Spectral(normalized_tsne_results)[:, :3] 77 | return colors 78 | 79 | 80 | def write_points_to_obj(file_name, points, colors=None, radius=0.02, resolution=6): 81 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius, resolution=resolution) 82 | vertices = np.asarray(sphere.vertices) 83 | triangles = np.asarray(sphere.triangles) + 1 84 | 85 | v_lines = [] 86 | f_lines = [] 87 | 88 | num_point = points.shape[0] 89 | for i in tqdm(range(num_point)): 90 | n = i * vertices.shape[0] 91 | 92 | for j in range(vertices.shape[0]): 93 | new_vertex = points[i] + vertices[j] 94 | line = 'v {:.6f} {:.6f} {:.6f}'.format(new_vertex[0], new_vertex[1], new_vertex[2]) 95 | if colors is not None: 96 | line += ' {:.6f} {:.6f} {:.6f}'.format(colors[i, 0], colors[i, 1], colors[i, 2]) 97 | v_lines.append(line + '\n') 98 | 99 | for j in range(triangles.shape[0]): 100 | new_triangle = triangles[j] + n 101 | line = 'f {} {} {}\n'.format(new_triangle[0], new_triangle[1], new_triangle[2]) 102 | f_lines.append(line) 103 | 104 | with open(file_name, 'w') as f: 105 | f.writelines(v_lines) 106 | f.writelines(f_lines) 107 | 108 | 109 | def convert_points_to_mesh(points, colors=None, radius=0.02, resolution=6): 110 | sphere = o3d.geometry.TriangleMesh.create_sphere(radius=radius, resolution=resolution) 111 | vertices = np.asarray(sphere.vertices) 112 | triangles = np.asarray(sphere.triangles) 113 | 114 | new_vertices = points[:, None, :] + vertices[None, :, :] 115 | if colors is not None: 116 | new_vertex_colors = np.broadcast_to(colors[:, None, :], new_vertices.shape) 117 | new_vertices = new_vertices.reshape(-1, 3) 118 | new_vertex_colors = new_vertex_colors.reshape(-1, 3) 119 | bases = np.arange(points.shape[0]) * vertices.shape[0] 120 | new_triangles = bases[:, None, None] + triangles[None, :, :] 121 | new_triangles = new_triangles.reshape(-1, 3) 122 | 123 | mesh = o3d.geometry.TriangleMesh() 124 | mesh.vertices = o3d.utility.Vector3dVector(new_vertices) 125 | mesh.vertex_colors = o3d.utility.Vector3dVector(new_vertex_colors) 126 | mesh.triangles = o3d.utility.Vector3iVector(new_triangles) 127 | 128 | return mesh 129 | 130 | 131 | def write_points_to_ply(file_name, points, colors=None, radius=0.02, resolution=6): 132 | mesh = convert_points_to_mesh(points, colors=colors, radius=radius, resolution=resolution) 133 | o3d.io.write_triangle_mesh(file_name, mesh, write_vertex_normals=False) 134 | 135 | 136 | def write_correspondences_to_obj(file_name, src_corr_points, tgt_corr_points): 137 | v_lines = [] 138 | l_lines = [] 139 | 140 | num_corr = src_corr_points.shape[0] 141 | for i in tqdm(range(num_corr)): 142 | n = i * 2 143 | 144 | src_point = src_corr_points[i] 145 | tgt_point = tgt_corr_points[i] 146 | 147 | line = 'v {:.6f} {:.6f} {:.6f}\n'.format(src_point[0], src_point[1], src_point[2]) 148 | v_lines.append(line) 149 | 150 | line = 'v {:.6f} {:.6f} {:.6f}\n'.format(tgt_point[0], tgt_point[1], tgt_point[2]) 151 | v_lines.append(line) 152 | 153 | line = 'l {} {}\n'.format(n + 1, n + 2) 154 | l_lines.append(line) 155 | 156 | with open(file_name, 'w') as f: 157 | f.writelines(v_lines) 158 | f.writelines(l_lines) 159 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | scipy 4 | tqdm 5 | coloredlogs 6 | easydict 7 | nibabel 8 | open3d==0.11.2 9 | scikit-learn 10 | einops 11 | ipdb 12 | tensorboard 13 | tensorboardX -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='geotransformer', 7 | version='1.0.0', 8 | ext_modules=[ 9 | CUDAExtension( 10 | name='geotransformer.ext', 11 | sources=[ 12 | 'geotransformer/extensions/extra/cloud/cloud.cpp', 13 | 'geotransformer/extensions/cpu/grid_subsampling/grid_subsampling.cpp', 14 | 'geotransformer/extensions/cpu/grid_subsampling/grid_subsampling_cpu.cpp', 15 | 'geotransformer/extensions/cpu/radius_neighbors/radius_neighbors.cpp', 16 | 'geotransformer/extensions/cpu/radius_neighbors/radius_neighbors_cpu.cpp', 17 | 'geotransformer/extensions/pybind.cpp', 18 | ], 19 | ), 20 | ], 21 | cmdclass={'build_ext': BuildExtension}, 22 | ) 23 | --------------------------------------------------------------------------------