├── .gitmodules ├── 3dmatch.py ├── LICENSE ├── README.md ├── data.py ├── figures └── teaser.png ├── fps.py ├── gconstructor.py ├── gfilter.py ├── pipeline_SLAM ├── README.md ├── geotransformer-kitti.pth.tar ├── pipeline.py └── utils │ ├── ICP.py │ ├── PoseGraphManager.py │ ├── ScanContextManager.py │ ├── UtilsMisc.py │ ├── UtilsPointcloud.py │ ├── __pycache__ │ ├── ICP.cpython-310.pyc │ ├── PoseGraphManager.cpython-310.pyc │ ├── ScanContextManager.cpython-310.pyc │ ├── UtilsMisc.cpython-310.pyc │ ├── UtilsPointcloud.cpython-310.pyc │ ├── corr_downsample.cpython-310.pyc │ ├── extract_corr.cpython-310.pyc │ └── registration.cpython-310.pyc │ ├── corr_downsample.py │ ├── extract_corr.py │ ├── fastmac │ ├── __pycache__ │ │ ├── gconstructor.cpython-310.pyc │ │ └── gfilter.cpython-310.pyc │ ├── gconstructor.py │ └── gfilter.py │ └── registration.py ├── requirements.txt └── sota.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "pipeline_SLAM/thirdparty/GeoTransformer"] 2 | path = pipeline_SLAM/thirdparty/GeoTransformer 3 | url = https://github.com/qinzheng93/GeoTransformer.git 4 | -------------------------------------------------------------------------------- /3dmatch.py: -------------------------------------------------------------------------------- 1 | from data import ThreeDLomatch,ThreeDmatch 2 | from tqdm import tqdm 3 | import torch 4 | import torch.nn as nn 5 | from gconstructor import GraphConstructorFor3DMatch 6 | from gfilter import graphFilter,datasample 7 | import time 8 | import numpy as np 9 | torch.manual_seed(42) 10 | 11 | 12 | def normalize(x): 13 | # transform x to [0,1] 14 | x=x-x.min() 15 | x=x/x.max() 16 | return x 17 | 18 | 19 | def Config(): 20 | config={ 21 | "num_points":np.inf, 22 | "resolution":0.006, 23 | "data_dir":'/data/Processed_3dmatch_3dlomatch/', 24 | "name":"3dmatch", 25 | 'descriptor':'fpfh', 26 | 'batch_size':1, 27 | 'inlier_thresh':0.1, 28 | 'device':'cuda', 29 | 'mode':'graph', 30 | 'ratio':0.50, 31 | } 32 | return config 33 | 34 | def main(): 35 | config=Config() 36 | device=config["device"] 37 | mode=config["mode"] 38 | sample_ratio=config["ratio"] 39 | if config["name"]=="3dmatch": 40 | dataset=ThreeDmatch(num_points=config["num_points"],data_dir=config["data_dir"],descriptor=config["descriptor"]) 41 | elif config["name"]=="3dlomatch": 42 | dataset=ThreeDLomatch(num_points=config["num_points"],data_dir=config["data_dir"],descriptor=config["descriptor"]) 43 | trainloader = torch.utils.data.DataLoader( 44 | dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0 45 | ) 46 | if mode == "graph": 47 | gc=GraphConstructorFor3DMatch() 48 | print("Start") 49 | average_time=0 50 | for i, data_ in enumerate(tqdm(trainloader)): 51 | time_start=time.time() 52 | current_points,ground_truth,label,corr_path,gt_path,lb_path=data_ 53 | current_points=current_points.to(device) 54 | ground_truth=ground_truth.to(device) 55 | label=label.to(device) 56 | corr_graph=gc(current_points,config["resolution"],config["name"],config["descriptor"],config["inlier_thresh"]) 57 | degree_signal=torch.sum(corr_graph,dim=-1) 58 | 59 | corr_laplacian=(torch.diag_embed(degree_signal)-corr_graph).squeeze(0) 60 | corr_scores=graphFilter(degree_signal.transpose(0,1),corr_laplacian,is_sparse=False) 61 | 62 | corr_scores=normalize(corr_scores) 63 | total_scores=corr_scores 64 | 65 | k=int(current_points.shape[1]*sample_ratio) 66 | idxs=datasample(k,False,total_scores) 67 | 68 | time_end=time.time() 69 | average_time+=time_end-time_start 70 | 71 | samples=current_points.squeeze(0)[idxs,:] 72 | lb=label.squeeze(0)[idxs].long() 73 | samples=samples.cpu().numpy() 74 | 75 | 76 | out_corr_path=corr_path[0].split(".")[0]+"_"+config["mode"]+"_"+str(int(config["ratio"]*100))+".txt" 77 | our_gt_path=gt_path[0].split(".")[0]+"_"+config["mode"]+"_"+str(int(config["ratio"]*100))+".txt" 78 | out_lb_path=lb_path[0].split(".")[0]+"_"+config["mode"]+"_"+str(int(config["ratio"]*100))+".txt" 79 | np.savetxt(out_corr_path,samples) 80 | np.savetxt(our_gt_path,ground_truth.squeeze(0).cpu().numpy()) 81 | np.savetxt(out_lb_path,lb.cpu().numpy().astype(int),fmt="%d") 82 | 83 | print("Average time: ",average_time/len(trainloader)) 84 | else: 85 | raise NotImplementedError 86 | 87 | 88 | if __name__ == "__main__": 89 | main() 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ZHANG Yifei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastMAC: Stochastic Spectral Sampling of Correspondence Graph (CVPR 2024) 2 | Source code of [FastMAC: Stochastic Spectral Sampling of Correspondence Graph](https://arxiv.org/abs/2403.08770) 3 | 4 | ## Introduction 5 | 3D correspondence, i.e., a pair of 3D points, is a fundamental concept in computer vision. A set of 3D correspondences, when equipped with compatibility edges, forms a correspondence graph. This graph is a critical component in several state-of-the-art 3D point cloud registration approaches, e.g., the one based on maximal cliques (MAC). However, its properties have not been well understood. So we present the first study that introduces graph signal processing into the domain of correspondence graph. We exploit the generalized degree signal on correspondence graph and pursue sampling strategies that preserve high-frequency components of this signal. To address time-consuming singular value decomposition in deterministic sampling, we resort to a stochastic approximate sampling strategy. As such, the core of our method is the stochastic spectral sampling of correspondence graph. As an application, we build a complete 3D registration algorithm termed as FastMAC, that reaches real-time speed while leading to little to none performance drop. Through extensive experiments, we validate that FastMAC works for both indoor and outdoor benchmarks. For example, FastMAC can accelerate MAC by 80 times while maintaining high registration success rate on KITTI. ![](figures/teaser.png) 6 | 7 | ## News 8 | - [2024/2/27] Paper is accepted by CVPR 2024. 9 | - [2023/12/4] Code is released. 10 | 11 | ## Installation 12 | Please install [PyTorch](https://pytorch.org/) first, and then install other dependencies by the following command. Code has been tested with Python 3.8.10, PyTorch 1.12.0, CUDA 11.3 and cuDNN 8302 on Ubuntu 22.04. 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | Finally, install [MAC(3D Registration with Maximal Cliques)](https://github.com/zhangxy0517/3D-Registration-with-Maximal-Cliques/tree/main) as instructed. 17 | 18 | **NOTE: As our PCR method is based on MAC, please install and run MAC first. Or If you only want the output sampled correspondences, then it's OK to only install our code.** 19 | 20 | ## Datasets 21 | The test datasets include KITTI, 3DMatch, 3DLoMatch. Please download them from [MAC(3D Registration with Maximal Cliques)](https://github.com/zhangxy0517/3D-Registration-with-Maximal-Cliques/tree/main). 22 | 23 | ## Usage 24 | To demonstrate the reliability of our method's boosting performance for MAC, we use the original MAC as the registration module. Therefore, to run the complete pipeline, use the code we present here to downsample the input correspondences and then feed them into MAC, using the code in [MAC repository](https://github.com/zhangxy0517/3D-Registration-with-Maximal-Cliques/tree/main). In the future we would integrate the two parts into one codebase 25 | to form a complete pipeline for 26 | practical usage. 27 | 28 | ## KITTI 29 | To run FastMAC on KITTI, please use the following command: 30 | ``` 31 | python sota.py 32 | ``` 33 | In function Config(), modify "data_dir", "filename", "gtname", "labelname" and "outpath" as the actual path you set. "ratio" refers to the 34 | downsampling ratio from 0 to 1. 35 | 36 | **NOTE: set 'thresh' to 0.999 if using FCGF descriptor, 0.9 if using FPFH descriptor.** 37 | 38 | ## 3DMatch 39 | To run FastMAC on 3DMatch, please use the following command: 40 | ``` 41 | python 3dmatch.py 42 | ``` 43 | In function Config(), modify "data_dir", "descriptor" as the actual path you set. The output path will be the original dataset direction for convenience to apply MAC. "ratio" refers to the 44 | downsampling ratio from 0 to 1. Set "name" to "3dmatch" to run on 3DMatch. 45 | 46 | ## 3DLoMatch 47 | To run FastMAC on 3DMatch, please use the following command: 48 | ``` 49 | python 3dmatch.py 50 | ``` 51 | In function Config(), modify "data_dir", "descriptor" as the actual path you set. The output path will be the original dataset direction for convenience to apply MAC. "ratio" refers to the 52 | downsampling ratio from 0 to 1. Set "name" to "3dlomatch" to run on 3DLoMatch. 53 | 54 | ## Results 55 | 56 | ### KITTI 57 | 58 | | Descriptor | Ratio(%) | RR | RE(°) | TE(cm) | 59 | |:----------:|:--------:|:------:|:--------:|:-------:| 60 | | FPFH | 100 | 97.66% | 0.405772 | 8.61193 | 61 | | FPFH | 50 | 97.84% | 0.410393 | 8.61099 | 62 | | FPFH | 20 | 97.84% | 0.415011 | 8.64669 | 63 | | FPFH | 10 | 98.02% | 0.447299 | 9.06907 | 64 | | FPFH | 5 | 97.12% | 0.491153 | 9.64376 | 65 | | FPFH | 1 | 94.05% | 0.831317 | 13.5936 | 66 | 67 | | Descriptor | Ratio(%) | RR | RE(°) | TE(cm) | 68 | |:----------:|:--------:|:------:|:--------:|:-------:| 69 | | FCGF | 100 | 97.12% | 0.355121 | 7.99152 | 70 | | FCGF | 50 | 97.48% | 0.368148 | 8.0161 | 71 | | FCGF | 20 | 97.30% | 0.391029 | 8.45734 | 72 | | FCGF | 10 | 96.94% | 0.445949 | 9.20145 | 73 | | FCGF | 5 | 96.04% | 0.525363 | 10.0375 | 74 | | FCGF | 1 | 71.89% | 0.996978 | 14.8993 | 75 | 76 | ### 3DMatch 77 | 78 | | Descriptor | Ratio(%) | RR | RE(°) | TE(cm) | 79 | |:----------:|:--------:|:------:|:-------:|:-------:| 80 | | FPFH | 100 | 83.86% | 2.10952 | 6.79597 | 81 | | FPFH | 50 | 82.87% | 2.15102 | 6.73052 | 82 | | FPFH | 20 | 80.71% | 2.17369 | 6.80735 | 83 | | FPFH | 10 | 78.87% | 2.28292 | 7.05551 | 84 | | FPFH | 5 | 74.49% | 2.2949 | 6.97654 | 85 | | FPFH | 1 | 58.04% | 2.44924 | 7.28792 | 86 | 87 | 88 | | Descriptor | Ratio(%) | RR | RE(°) | TE(cm) | 89 | |:----------:|:--------:|:------:|:-------:|:-------:| 90 | | FCGF | 100 | 93.72% | 2.02746 | 6.53953 | 91 | | FCGF | 50 | 92.67% | 1.99611 | 6.46513 | 92 | | FCGF | 20 | 92.30% | 2.0205 | 6.51827 | 93 | | FCGF | 10 | 90.94% | 2.02694 | 6.52478 | 94 | | FCGF | 5 | 89.40% | 2.06517 | 6.75127 | 95 | | FCGF | 1 | 58.23% | 2.16245 | 7.10037 | 96 | 97 | ### 3DLoMatch 98 | 99 | | Descriptor | Ratio(%) | RR | RE(°) | TE(cm) | 100 | |:----------:|:--------:|:------:|:-------:|:-------:| 101 | | FPFH | 100 | 41.21% | 4.05137 | 10.6133 | 102 | | FPFH | 50 | 38.46% | 4.03769 | 10.4745 | 103 | | FPFH | 20 | 34.31% | 4.11826 | 10.8244 | 104 | | FPFH | 10 | 31.56% | 4.35467 | 11.3328 | 105 | | FPFH | 5 | 27.40% | 4.44883 | 11.3483 | 106 | | FPFH | 1 | 12.24% | 4.39649 | 12.5056 | 107 | 108 | | Descriptor | Ratio(%) | RR | RE(°) | TE(cm) | 109 | |:----------:|:--------:|:------:|:-------:|:-------:| 110 | | FCGF | 100 | 60.19% | 3.75996 | 10.6147 | 111 | | FCGF | 50 | 58.23% | 3.80416 | 10.8137 | 112 | | FCGF | 20 | 55.25% | 3.83575 | 10.7118 | 113 | | FCGF | 10 | 54.35% | 3.94558 | 10.9791 | 114 | | FCGF | 5 | 51.49% | 4.07549 | 11.0795 | 115 | | FCGF | 1 | 37.06% | 4.4706 | 12.1996 | 116 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import os 5 | import sys 6 | import open3d as o3d 7 | torch.manual_seed(42) 8 | sys.path.append(os.path.dirname(__file__)) 9 | 10 | class processedKITTI(data.Dataset): 11 | def __init__(self,num_points,data_dir,filename,gtname,labelname,num_samples=-1) -> None: 12 | super().__init__() 13 | self.data_dir=data_dir 14 | self.num_points=num_points 15 | self.filedirs=os.listdir(data_dir) 16 | self.filenames=[os.path.join(data_dir,filedir,filename) for filedir in self.filedirs] 17 | self.gtnames=[os.path.join(data_dir,filedir,gtname) for filedir in self.filedirs] 18 | self.labelnames=[os.path.join(data_dir,filedir,labelname) for filedir in self.filedirs] 19 | self.filenames.sort() 20 | self.gtnames.sort() 21 | self.labelnames.sort() 22 | if num_samples>0: 23 | self.filenames=self.filenames[:num_samples] 24 | self.gtnames=self.gtnames[:num_samples] 25 | self.labelnames=self.labelnames[:num_samples] 26 | 27 | 28 | def __getitem__(self, index): 29 | filename=self.filenames[index] 30 | gtname=self.gtnames[index] 31 | labelname=self.labelnames[index] 32 | data=np.loadtxt(filename,delimiter=' ') 33 | ground_truth=np.loadtxt(gtname,delimiter=' ') 34 | label=np.loadtxt(labelname,delimiter=' ') 35 | n_pts=data.shape[0] 36 | num_points=min(n_pts,self.num_points) 37 | pt_idxs=np.arange(0,n_pts) 38 | np.random.shuffle(pt_idxs) 39 | current_points=data[pt_idxs[:num_points],:].copy() 40 | current_points=torch.from_numpy(current_points).type(torch.FloatTensor) 41 | ground_truth=torch.from_numpy(ground_truth).type(torch.FloatTensor) 42 | label=torch.from_numpy(label).type(torch.FloatTensor) 43 | return current_points,ground_truth,label 44 | def __len__(self): 45 | return len(self.filenames) 46 | 47 | class ThreeDmatch(data.Dataset): 48 | def __init__(self,num_points,data_dir,descriptor,num_samples=-1) -> None: 49 | super().__init__() 50 | self.data_dir=data_dir 51 | self.num_points=num_points 52 | 53 | self.data_scenes=[ 54 | 55 | "7-scenes-redkitchen", 56 | "sun3d-home_at-home_at_scan1_2013_jan_1", 57 | "sun3d-home_md-home_md_scan9_2012_sep_30", 58 | "sun3d-hotel_uc-scan3", 59 | "sun3d-hotel_umd-maryland_hotel1", 60 | "sun3d-hotel_umd-maryland_hotel3", 61 | "sun3d-mit_76_studyroom-76-1studyroom2", 62 | "sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika", 63 | 64 | ] 65 | self.descriptor=descriptor 66 | self.filenames=[] 67 | self.gtnames=[] 68 | self.labelnames=[] 69 | self.srcply=[] 70 | self.tgtply=[] 71 | for data_scene in self.data_scenes: 72 | if (descriptor == "fpfh" or descriptor == "spinnet" or descriptor == "d3feat"): 73 | loadertxt=data_scene+"/dataload.txt" 74 | elif descriptor == "fcgf": 75 | loadertxt=data_scene+"/dataload_fcgf.txt" 76 | loadertxt=data_dir+'/'+loadertxt 77 | with open(loadertxt,'r') as f: 78 | lines=f.readlines() 79 | for line in lines: 80 | line=line.strip('\n') 81 | src_ply=data_dir+'/'+data_scene+line.split('+')[0]+".ply" 82 | tgt_ply=data_dir+'/'+data_scene+line.split('+')[1]+".ply" 83 | corr_path=data_dir+'/'+data_scene+'/'+line+("@corr_fcgf.txt" if descriptor == "fcgf" else "@corr.txt") 84 | gt_path=data_dir+'/'+data_scene+'/'+line+("@GTmat_fcgf.txt" if descriptor == "fcgf" else "@GTmat.txt") 85 | label_path=data_dir+'/'+data_scene+'/'+line+("@label_fcgf.txt" if descriptor == "fcgf" else "@label.txt") 86 | self.srcply.append(src_ply) 87 | self.tgtply.append(tgt_ply) 88 | self.filenames.append(corr_path) 89 | self.gtnames.append(gt_path) 90 | self.labelnames.append(label_path) 91 | 92 | if num_samples>0: 93 | self.srcply=self.srcply[:num_samples] 94 | self.tgtply=self.tgtply[:num_samples] 95 | self.filenames=self.filenames[:num_samples] 96 | self.gtnames=self.gtnames[:num_samples] 97 | self.labelnames=self.labelnames[:num_samples] 98 | 99 | 100 | def __getitem__(self, index): 101 | src_ply=self.srcply[index] 102 | tgt_ply=self.tgtply[index] 103 | filename=self.filenames[index] 104 | gtname=self.gtnames[index] 105 | labelname=self.labelnames[index] 106 | data=np.loadtxt(filename,delimiter=' ') 107 | ground_truth=np.loadtxt(gtname,delimiter=' ') 108 | label=np.loadtxt(labelname,delimiter=' ') 109 | n_pts=data.shape[0] 110 | num_points=min(n_pts,self.num_points) 111 | pt_idxs=np.arange(0,n_pts) 112 | np.random.shuffle(pt_idxs) 113 | current_points=data[pt_idxs[:num_points],:].copy() 114 | current_points=torch.from_numpy(current_points).type(torch.FloatTensor) 115 | ground_truth=torch.from_numpy(ground_truth).type(torch.FloatTensor) 116 | label=torch.from_numpy(label).type(torch.FloatTensor) 117 | return current_points,ground_truth,label,filename,gtname,labelname 118 | def __len__(self): 119 | return len(self.filenames) 120 | 121 | class ThreeDLomatch(data.Dataset): 122 | def __init__(self,num_points,data_dir,descriptor,num_samples=-1) -> None: 123 | super().__init__() 124 | self.data_dir=data_dir 125 | self.num_points=num_points 126 | 127 | self.data_scenes=[ 128 | "7-scenes-redkitchen_3dlomatch", 129 | "sun3d-home_at-home_at_scan1_2013_jan_1_3dlomatch", 130 | "sun3d-home_md-home_md_scan9_2012_sep_30_3dlomatch", 131 | "sun3d-hotel_uc-scan3_3dlomatch", 132 | "sun3d-hotel_umd-maryland_hotel1_3dlomatch", 133 | "sun3d-hotel_umd-maryland_hotel3_3dlomatch", 134 | "sun3d-mit_76_studyroom-76-1studyroom2_3dlomatch", 135 | "sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika_3dlomatch", 136 | ] 137 | self.descriptor=descriptor 138 | self.filenames=[] 139 | self.gtnames=[] 140 | self.labelnames=[] 141 | self.srcply=[] 142 | self.tgtply=[] 143 | for data_scene in self.data_scenes: 144 | if (descriptor == "fpfh" or descriptor == "spinnet" or descriptor == "d3feat"): 145 | loadertxt=data_scene+"/dataload.txt" 146 | elif descriptor == "fcgf": 147 | loadertxt=data_scene+"/dataload_fcgf.txt" 148 | loadertxt=data_dir+'/'+loadertxt 149 | with open(loadertxt,'r') as f: 150 | lines=f.readlines() 151 | for line in lines: 152 | line=line.strip('\n') 153 | src_ply=data_dir+'/'+data_scene+line.split('+')[0]+".ply" 154 | tgt_ply=data_dir+'/'+data_scene+line.split('+')[1]+".ply" 155 | corr_path=data_dir+'/'+data_scene+'/'+line+("@corr_fcgf.txt" if descriptor == "fcgf" else "@corr.txt") 156 | gt_path=data_dir+'/'+data_scene+'/'+line+("@GTmat_fcgf.txt" if descriptor == "fcgf" else "@GTmat.txt") 157 | label_path=data_dir+'/'+data_scene+'/'+line+("@label_fcgf.txt" if descriptor == "fcgf" else "@label.txt") 158 | self.srcply.append(src_ply) 159 | self.tgtply.append(tgt_ply) 160 | self.filenames.append(corr_path) 161 | self.gtnames.append(gt_path) 162 | self.labelnames.append(label_path) 163 | 164 | if num_samples>0: 165 | self.srcply=self.srcply[:num_samples] 166 | self.tgtply=self.tgtply[:num_samples] 167 | self.filenames=self.filenames[:num_samples] 168 | self.gtnames=self.gtnames[:num_samples] 169 | self.labelnames=self.labelnames[:num_samples] 170 | 171 | 172 | def __getitem__(self, index): 173 | src_ply=self.srcply[index] 174 | tgt_ply=self.tgtply[index] 175 | filename=self.filenames[index] 176 | gtname=self.gtnames[index] 177 | labelname=self.labelnames[index] 178 | data=np.loadtxt(filename,delimiter=' ') 179 | ground_truth=np.loadtxt(gtname,delimiter=' ') 180 | label=np.loadtxt(labelname,delimiter=' ') 181 | n_pts=data.shape[0] 182 | num_points=min(n_pts,self.num_points) 183 | pt_idxs=np.arange(0,n_pts) 184 | np.random.shuffle(pt_idxs) 185 | current_points=data[pt_idxs[:num_points],:].copy() 186 | current_points=torch.from_numpy(current_points).type(torch.FloatTensor) 187 | ground_truth=torch.from_numpy(ground_truth).type(torch.FloatTensor) 188 | label=torch.from_numpy(label).type(torch.FloatTensor) 189 | return current_points,ground_truth,label,filename,gtname,labelname 190 | def __len__(self): 191 | return len(self.filenames) 192 | 193 | 194 | 195 | if __name__ == "__main__": 196 | from tqdm import tqdm 197 | # num_points=5000 198 | # data_dir='/data/Processed_KITTI/correspondence_fpfh/' 199 | # filename='fpfh@corr.txt' 200 | # gtname='fpfh@gtmat.txt' 201 | # labelname='fpfh@gtlabel.txt' 202 | # dataset=processedKITTI(num_points,data_dir,filename,gtname,labelname) 203 | # trainloader = torch.utils.data.DataLoader( 204 | # dataset, batch_size=1, shuffle=False, num_workers=0 205 | # ) 206 | # for i, data_ in enumerate(tqdm(trainloader)): 207 | # pts,gt,label=data_ 208 | # print(label.shape) 209 | # break 210 | 211 | data_dir='/data/Processed_3dmatch_3dlomatch/' 212 | descriptor='fpfh' 213 | num_points=np.inf 214 | dataset=ThreeDmatch(num_points,data_dir,descriptor) 215 | trainloader = torch.utils.data.DataLoader( 216 | dataset, batch_size=1, shuffle=False, num_workers=0 217 | ) 218 | for i, data_ in enumerate(tqdm(trainloader)): 219 | pts,gt,label,corr_path,gt_path,lb_path=data_ 220 | 221 | print(label.shape) 222 | print(gt.shape) 223 | print(pts.shape) 224 | 225 | break -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/figures/teaser.png -------------------------------------------------------------------------------- /fps.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def farthest_point_sample(xyz, npoint): 4 | """ 5 | Input: 6 | xyz: pointcloud data, [B, N, 6] 7 | npoint: number of samples 8 | Return: 9 | centroids: sampled pointcloud index, [B, npoint] 10 | """ 11 | device = xyz.device 12 | B, N, C = xyz.shape 13 | # 初始化一个centroids矩阵,用于存储npoint个采样点的索引位置,大小为B×npoint 14 | # 其中B为BatchSize的个数 15 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 16 | # distance矩阵(B×N)记录某个batch中所有点到某一个点的距离,初始化的值很大,后面会迭代更新 17 | distance = torch.ones(B, N).to(device) * 1e10 18 | # farthest表示当前最远的点,也是随机初始化,范围为0~N,初始化B个;每个batch都随机有一个初始最远点 19 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 20 | # batch_indices初始化为0~(B-1)的数组 21 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 22 | # 直到采样点达到npoint,否则进行如下迭代: 23 | for i in range(npoint): 24 | # 设当前的采样点centroids为当前的最远点farthest 25 | centroids[:, i] = farthest 26 | # 取出该中心点centroid的坐标 27 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 6) 28 | # 求出所有点到该centroid点的欧式距离,存在dist矩阵中 29 | dist = torch.sum((xyz - centroid) ** 2, -1) 30 | # 建立一个mask,如果dist中的元素小于distance矩阵中保存的距离值,则更新distance中的对应值 31 | # 随着迭代的继续,distance矩阵中的值会慢慢变小, 32 | # 其相当于记录着某个Batch中每个点距离所有已出现的采样点的最小距离 33 | mask = dist < distance#确保拿到的是距离所有已选中心点最大的距离。比如已经是中心的点,其dist始终保持为 #0,二在它附近的点,也始终保持与这个中心点的距离 34 | distance[mask] = dist[mask] 35 | # 从distance矩阵取出最远的点为farthest,继续下一轮迭代 36 | farthest = torch.max(distance, -1)[1] 37 | return centroids 38 | 39 | if __name__ == "__main__": 40 | points=torch.rand(1,100,6) 41 | npoint=10 42 | centroids=farthest_point_sample(points,npoint) 43 | print(centroids) 44 | print(centroids.shape) -------------------------------------------------------------------------------- /gconstructor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def euclidean(a, b): 5 | return torch.norm(a - b, dim=-1, keepdim=True) 6 | 7 | def compatibility(a,b): 8 | assert(a.shape[-1]==6) 9 | assert(b.shape[-1]==6) 10 | n1=torch.norm(a[...,:3]-b[...,:3],dim=-1,keepdim=True) 11 | n2=torch.norm(a[...,3:]-b[...,3:],dim=-1,keepdim=True) 12 | return torch.abs(n1-n2) 13 | 14 | def Dmatrix(a,type): 15 | if type=="euclidean": 16 | return torch.cdist(a,a) 17 | 18 | elif type=="compatibility": 19 | a1=a[...,:3] 20 | a2=a[...,3:] 21 | return torch.abs(Dmatrix(a1,"euclidean")-Dmatrix(a2,"euclidean")) 22 | 23 | class GraphConstructor(nn.Module): 24 | def __init__(self,inlier_thresh,thresh,trainable,device="cuda",sigma=None,tau=None) -> None: 25 | ''' 26 | inlier thresh: KITTI 0.6, 3dmatch 0.1 27 | thresh: fpfh 0.9, fcgf 0.999 28 | ''' 29 | super().__init__() 30 | self.device=device 31 | self.inlier_thresh=nn.Parameter(torch.tensor(inlier_thresh,requires_grad=trainable,dtype=torch.float32)).to(device) 32 | self.thresh=nn.Parameter(torch.tensor(thresh,requires_grad=trainable,dtype=torch.float32)).to(device) 33 | if sigma is not None: 34 | self.sigma=nn.Parameter(torch.tensor(sigma,requires_grad=trainable,dtype=torch.float32)).to(device) 35 | else: 36 | self.sigma=self.inlier_thresh 37 | if tau is not None: 38 | self.tau=nn.Parameter(torch.tensor(tau,requires_grad=trainable,dtype=torch.float32)).to(device) 39 | else: 40 | self.tau=self.thresh 41 | def forward(self,points,mode,k1=2,k2=1): 42 | ''' 43 | points: B x M x 6 44 | output: B x M x M 45 | ''' 46 | if mode=="correspondence": 47 | points=points.to(self.device) 48 | dmatrix=Dmatrix(points,"compatibility") 49 | score=1-dmatrix**2/self.inlier_thresh**2 50 | # score=torch.exp(-dmatrix**2/self.inlier_thresh**2) 51 | score[scorebmk",score,score) 56 | elif mode=="pointcloud": 57 | ''' 58 | points: B x N x 3 59 | output: B x N x N 60 | ''' 61 | points=points.to(self.device) 62 | dmatrix=Dmatrix(points,"euclidean") 63 | 64 | # score=1-dmatrix**2/self.inlier_thresh**2 65 | score=torch.exp(-dmatrix**2/self.sigma**2) 66 | score[scorebmk",score,score) 71 | 72 | class GraphConstructorFor3DMatch(nn.Module): 73 | def __init__(self) -> None: 74 | super().__init__() 75 | pass 76 | def forward(self,correspondence, resolution, name, descriptor, inlier_thresh): 77 | self.device="cuda" 78 | correspondence=correspondence.to(self.device) 79 | dmatrix=Dmatrix(correspondence,"compatibility") 80 | 81 | if descriptor=="predator": 82 | score=1-dmatrix**2/inlier_thresh**2 83 | score[score<0.999]=0 84 | else: 85 | alpha_dis = 10 * resolution 86 | score = torch.exp(-dmatrix**2 / (2 * alpha_dis * alpha_dis)) 87 | if (name == "3dmatch" and descriptor == "fcgf"): 88 | score[score<0.999]=0 89 | elif (name == "3dmatch" and descriptor == "fpfh") : 90 | score[score<0.995]=0 91 | elif (descriptor == "spinnet" or descriptor == "d3feat") : 92 | score[score<0.85]=0 93 | #spinnet 5000 2500 1000 500 250 94 | # 0.99 0.99 0.95 0.9 0.85 95 | else: 96 | score[score<0.99]=0 #3dlomatch 0.99, 3dmatch fcgf 0.999 fpfh 0.995 97 | return score*torch.einsum("bmn,bnk->bmk",score,score) 98 | 99 | 100 | class Graph: 101 | def __init__(self): 102 | pass 103 | 104 | @staticmethod 105 | def construct_graph(pcloud, nb_neighbors): 106 | """ 107 | Construct a directed nearest neighbor graph on the input point cloud. 108 | 109 | Parameters 110 | ---------- 111 | pcloud : torch.Tensor 112 | Input point cloud. Size B x N x 3. 113 | nb_neighbors : int 114 | Number of nearest neighbors per point. 115 | 116 | Returns 117 | ------- 118 | graph : flot.models.graph.Graph 119 | Graph build on input point cloud containing the list of nearest 120 | neighbors (NN) for each point and all edge features (relative 121 | coordinates with NN). 122 | 123 | """ 124 | 125 | # Size 126 | nb_points = pcloud.shape[1] 127 | size_batch = pcloud.shape[0] 128 | 129 | # Distance between points 130 | distance_matrix = torch.sum(pcloud ** 2, -1, keepdim=True) 131 | distance_matrix = distance_matrix + distance_matrix.transpose(1, 2) 132 | distance_matrix = distance_matrix - 2 * torch.bmm( 133 | pcloud, pcloud.transpose(1, 2) 134 | ) 135 | # except self distance 136 | distance_matrix = distance_matrix + 1e6 * torch.eye(nb_points).unsqueeze(0).to(pcloud.device) 137 | 138 | # Find nearest neighbors 139 | neighbors = torch.argsort(distance_matrix, -1)[..., :nb_neighbors] 140 | 141 | # # direclty construct dense adjacency matrix 142 | # adj_dense=torch.zeros((nb_points,nb_points)).to(pcloud.device) 143 | # for i in range(nb_points): 144 | # for j in range(nb_neighbors): 145 | # adj_dense[i,neighbors[0,i,j]]=1 146 | 147 | 148 | # construct sparse adjacency matrix 149 | neighbors_flat = neighbors.reshape( -1) 150 | idx=torch.arange(nb_points).repeat(nb_neighbors,1).transpose(0,1).reshape(-1) 151 | idx=idx.to(pcloud.device) 152 | neighbors_flat=neighbors_flat.to(pcloud.device) 153 | i=torch.stack([idx,neighbors_flat],dim=0) 154 | v=torch.ones(i.shape[1]).to(pcloud.device) 155 | print(i) 156 | adj=torch.sparse_coo_tensor(i,v,(nb_points,nb_points)) 157 | 158 | # assert(torch.all(torch.eq(adj.to_dense(),adj_dense))) 159 | 160 | return adj 161 | 162 | 163 | 164 | 165 | if __name__ == "__main__": 166 | from plyfile import PlyData,PlyElement 167 | import numpy as np 168 | def write_ply(save_path,points,text=True): 169 | """ 170 | save_path : path to save: '/yy/XX.ply' 171 | pt: point_cloud: size (N,3) 172 | """ 173 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 174 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 175 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 176 | PlyData([el], text=text).write(save_path) 177 | def read_ply(filename): 178 | """ read XYZ point cloud from filename PLY file """ 179 | plydata = PlyData.read(filename) 180 | pc = plydata['vertex'].data 181 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 182 | return pc_array 183 | 184 | pc=read_ply('/data/plane2.ply') 185 | num_pts=pc.shape[0] 186 | sample_rate=0.01 187 | k=int(np.floor(sample_rate*num_pts)) 188 | pc_tensor=torch.from_numpy(pc).type(torch.FloatTensor).unsqueeze(0).cuda() 189 | g=GraphConstructor(0.6,0,False) 190 | adj=g(pc_tensor,"pointcloud") 191 | 192 | degree=torch.diag_embed(torch.sum(adj,dim=-1)) 193 | laplacian=(degree-adj).squeeze(0) 194 | low_shift=(torch.diag_embed(1/torch.sum(adj,dim=-1))*adj).squeeze(0) 195 | from gfilter import graphLowFilter,datasample 196 | scores=graphLowFilter(pc_tensor.squeeze(0),low_shift) 197 | idxs=datasample(k,False,scores) 198 | sampled_pc=pc_tensor.squeeze(0)[idxs,:] 199 | write_ply('/data/plane2_sampled_low.ply',sampled_pc.cpu().numpy()) 200 | -------------------------------------------------------------------------------- /gfilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.manual_seed(42) 4 | 5 | def graphFilter(points,adjacent_matrix,is_sparse): 6 | ''' 7 | points: n x 3 8 | adjacent_matrix: sparse matrix 9 | 10 | return: 11 | score: n x 1 12 | ''' 13 | if is_sparse: 14 | xyz=torch.sparse.mm(adjacent_matrix,points) 15 | else: 16 | xyz=torch.mm(adjacent_matrix,points) 17 | return torch.norm(xyz,dim=-1) 18 | 19 | def graphLowFilter(points,adjacent_matrix): 20 | ''' 21 | points: n x 3 22 | adjacent_matrix: sparse matrix 23 | 24 | return: 25 | score: n x 1 26 | ''' 27 | r=torch.matmul(torch.eye(points.shape[0]).to(adjacent_matrix.device)+adjacent_matrix, points) 28 | return torch.norm(r,p=2,dim=-1) 29 | 30 | def graphAllPassFilter(points): 31 | ''' 32 | points: n x 3 33 | adjacent_matrix: sparse matrix 34 | 35 | return: 36 | score: n x 1 37 | ''' 38 | return torch.norm(points,p=2,dim=-1) 39 | 40 | 41 | def datasample(k,replace,weights): 42 | ''' 43 | idxs: n 44 | k: int 45 | replace: bool 46 | weights: n 47 | ''' 48 | return torch.multinomial(weights,k,replacement=replace) 49 | 50 | 51 | if __name__ == "__main__": 52 | from plyfile import PlyData,PlyElement 53 | import numpy as np 54 | def write_ply(save_path,points,text=True): 55 | """ 56 | save_path : path to save: '/yy/XX.ply' 57 | pt: point_cloud: size (N,3) 58 | """ 59 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 60 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 61 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 62 | PlyData([el], text=text).write(save_path) 63 | def read_ply(filename): 64 | """ read XYZ point cloud from filename PLY file """ 65 | plydata = PlyData.read(filename) 66 | pc = plydata['vertex'].data 67 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 68 | return pc_array 69 | 70 | pc=read_ply('/data/cubic.ply') 71 | num_pts=pc.shape[0] 72 | sample_rate=0.25 73 | k=np.floor(sample_rate*num_pts) 74 | print(k) 75 | 76 | -------------------------------------------------------------------------------- /pipeline_SLAM/README.md: -------------------------------------------------------------------------------- 1 | ### Overview 2 | - The pipeline is modified from [PyICP-SLAM](https://github.com/gisbi-kim/PyICP-SLAM) and composed of three parts 3 | 1. Odometry: [GeoTransformer](https://github.com/qinzheng93/GeoTransformer/) and [FastMAC](https://github.com/Forrest-110/FastMAC) 4 | - In here, Point-to-point and frame-to-frame (i.e., no local mapping) 5 | 2. Loop detection: [Scan Context (IROS 18)](https://github.com/irapkaist/scancontext) 6 | - Reverse loop detection is supported. 7 | 3. Back-end (graph optimizer): [miniSAM](https://github.com/dongjing3309/minisam) 8 | - Python API 9 | 10 | - This is a simple python usage example without any parameter tuning or efficiency optimization. So the results may not be good enough. 11 | 12 | ### How to use 13 | Just run 14 | 15 | ```sh 16 | $ python3 pipeline.py 17 | ``` 18 | 19 | The details of parameters are eaily found in the argparser in that .py file. 20 | -------------------------------------------------------------------------------- /pipeline_SLAM/geotransformer-kitti.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/geotransformer-kitti.pth.tar -------------------------------------------------------------------------------- /pipeline_SLAM/pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import copy 5 | import time 6 | import random 7 | import argparse 8 | 9 | import numpy as np 10 | np.set_printoptions(precision=4) 11 | from matplotlib.animation import FFMpegWriter 12 | 13 | from tqdm import tqdm 14 | 15 | from utils.ScanContextManager import * 16 | from utils.PoseGraphManager import * 17 | from utils.UtilsMisc import * 18 | import utils.UtilsPointcloud as Ptutils 19 | import utils.ICP as ICP 20 | import open3d as o3d 21 | 22 | from utils.extract_corr import CorrExtractor 23 | from utils.corr_downsample import downsample 24 | from utils.registration import Registrator 25 | 26 | # params 27 | parser = argparse.ArgumentParser(description='SLAM arguments') 28 | 29 | parser.add_argument('--num_points', type=int, default=5000) # 5000 is enough for real time 30 | 31 | parser.add_argument('--num_rings', type=int, default=20) # same as the original paper 32 | parser.add_argument('--num_sectors', type=int, default=60) # same as the original paper 33 | parser.add_argument('--num_candidates', type=int, default=10) # must be int 34 | parser.add_argument('--try_gap_loop_detection', type=int, default=10) # same as the original paper 35 | 36 | parser.add_argument('--loop_threshold', type=float, default=0.11) # 0.11 is usually safe (for avoiding false loop closure) 37 | 38 | parser.add_argument('--data_base_dir', type=str, 39 | default='/Datasets/SLAM/data_odometry_velodyne/dataset/sequences/') 40 | parser.add_argument('--sequence_idx', type=str, default='00') 41 | 42 | parser.add_argument('--save_gap', type=int, default=300) 43 | 44 | 45 | args = parser.parse_args() 46 | 47 | 48 | # dataset 49 | sequence_dir = os.path.join(args.data_base_dir, args.sequence_idx, 'velodyne') 50 | sequence_manager = Ptutils.KittiScanDirManager(sequence_dir) 51 | scan_paths = sequence_manager.scan_fullpaths 52 | num_frames = len(scan_paths) 53 | print("Number of frames: ", num_frames) 54 | 55 | PGM = PoseGraphManager() 56 | PGM.addPriorFactor() 57 | SCM = ScanContextManager(shape=[args.num_rings, args.num_sectors], 58 | num_candidates=args.num_candidates, 59 | threshold=args.loop_threshold) 60 | 61 | CORR_EXTRACTOR = CorrExtractor() 62 | REGISTRATOR = Registrator() 63 | 64 | save_dir = "result/" + args.sequence_idx 65 | ResultSaver = PoseGraphResultSaver(init_pose=PGM.curr_se3, 66 | save_gap=args.save_gap, 67 | num_frames=num_frames, 68 | seq_idx=args.sequence_idx, 69 | save_dir=save_dir) 70 | 71 | fig_idx = 1 72 | fig = plt.figure(fig_idx) 73 | writer = FFMpegWriter(fps=15) 74 | video_name = args.sequence_idx + "_" + str(args.num_points) + ".mp4" 75 | num_frames_to_skip_to_show = 5 76 | num_frames_to_save = np.floor(num_frames/num_frames_to_skip_to_show) 77 | with writer.saving(fig, video_name, num_frames_to_save): 78 | for for_idx, scan_path in tqdm(enumerate(scan_paths), total=num_frames, mininterval=5.0): 79 | # get current information 80 | curr_scan_pts = Ptutils.readScan(scan_path) 81 | curr_scan_down_pts = Ptutils.random_sampling(curr_scan_pts, num_points=args.num_points) 82 | 83 | PGM.curr_node_idx = for_idx # make start with 0 84 | SCM.addNode(node_idx=PGM.curr_node_idx, ptcloud=curr_scan_down_pts) 85 | if(PGM.curr_node_idx == 0): 86 | PGM.prev_node_idx = PGM.curr_node_idx 87 | prev_scan_pts = copy.deepcopy(curr_scan_pts) 88 | icp_initial = np.eye(4) 89 | continue 90 | 91 | prev_scan_down_pts = Ptutils.random_sampling(prev_scan_pts, num_points=args.num_points) 92 | 93 | corr, score = CORR_EXTRACTOR.extract_corr(curr_scan_down_pts, prev_scan_down_pts) 94 | # print(odom_transform) 95 | down_corr = downsample(corr) 96 | odom_transform = REGISTRATOR.registration(down_corr[..., :3], down_corr[..., 3:]) 97 | # print(odom_transform) 98 | # exit() 99 | 100 | 101 | 102 | 103 | odom_transform = odom_transform.cpu().numpy() 104 | # update the current (moved) pose 105 | PGM.curr_se3 = np.matmul(PGM.curr_se3, odom_transform) 106 | icp_initial = odom_transform # assumption: constant velocity model (for better next ICP converges) 107 | 108 | # add the odometry factor to the graph 109 | PGM.addOdometryFactor(odom_transform) 110 | 111 | # renewal the prev information 112 | PGM.prev_node_idx = PGM.curr_node_idx 113 | prev_scan_pts = copy.deepcopy(curr_scan_pts) 114 | 115 | 116 | 117 | # loop detection and optimize the graph 118 | if(PGM.curr_node_idx > 1 and PGM.curr_node_idx % args.try_gap_loop_detection == 0): 119 | # 1/ loop detection 120 | loop_idx, loop_dist, yaw_diff_deg = SCM.detectLoop() 121 | if(loop_idx == None): # NOT FOUND 122 | pass 123 | else: 124 | print("Loop event detected: ", PGM.curr_node_idx, loop_idx, loop_dist) 125 | # 2-1/ add the loop factor 126 | loop_scan_down_pts = SCM.getPtcloud(loop_idx) 127 | loop_transform, _, _ = ICP.icp(curr_scan_down_pts, loop_scan_down_pts, init_pose=yawdeg2se3(yaw_diff_deg), max_iterations=20) 128 | PGM.addLoopFactor(loop_transform, loop_idx) 129 | 130 | # 2-2/ graph optimization 131 | PGM.optimizePoseGraph() 132 | 133 | # 2-2/ save optimized poses 134 | ResultSaver.saveOptimizedPoseGraphResult(PGM.curr_node_idx, PGM.graph_optimized) 135 | 136 | # save the ICP odometry pose result (no loop closure) 137 | ResultSaver.saveUnoptimizedPoseGraphResult(PGM.curr_se3, PGM.curr_node_idx) 138 | if(for_idx % num_frames_to_skip_to_show == 0): 139 | ResultSaver.vizCurrentTrajectory(fig_idx=fig_idx) 140 | writer.grab_frame() 141 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/ICP.py: -------------------------------------------------------------------------------- 1 | """ 2 | ref: https://github.com/ClayFlannigan/icp/blob/master/icp.py 3 | 4 | try this later: https://github.com/agnivsen/icp/blob/master/basicICP.py 5 | """ 6 | 7 | import numpy as np 8 | from sklearn.neighbors import NearestNeighbors 9 | 10 | def best_fit_transform(A, B): 11 | ''' 12 | Calculates the least-squares best-fit transform that maps corresponding points A to B in m spatial dimensions 13 | Input: 14 | A: Nxm numpy array of corresponding points 15 | B: Nxm numpy array of corresponding points 16 | Returns: 17 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 18 | R: mxm rotation matrix 19 | t: mx1 translation vector 20 | ''' 21 | 22 | assert A.shape == B.shape 23 | 24 | # get number of dimensions 25 | m = A.shape[1] 26 | 27 | # translate points to their centroids 28 | centroid_A = np.mean(A, axis=0) 29 | centroid_B = np.mean(B, axis=0) 30 | AA = A - centroid_A 31 | BB = B - centroid_B 32 | 33 | # rotation matrix 34 | H = np.dot(AA.T, BB) 35 | U, S, Vt = np.linalg.svd(H) 36 | R = np.dot(Vt.T, U.T) 37 | 38 | # special reflection case 39 | if np.linalg.det(R) < 0: 40 | Vt[m-1,:] *= -1 41 | R = np.dot(Vt.T, U.T) 42 | 43 | # translation 44 | t = centroid_B.T - np.dot(R,centroid_A.T) 45 | 46 | # homogeneous transformation 47 | T = np.identity(m+1) 48 | T[:m, :m] = R 49 | T[:m, m] = t 50 | 51 | return T, R, t 52 | 53 | 54 | def nearest_neighbor(src, dst): 55 | ''' 56 | Find the nearest (Euclidean) neighbor in dst for each point in src 57 | Input: 58 | src: Nxm array of points 59 | dst: Nxm array of points 60 | Output: 61 | distances: Euclidean distances of the nearest neighbor 62 | indices: dst indices of the nearest neighbor 63 | ''' 64 | 65 | assert src.shape == dst.shape 66 | 67 | neigh = NearestNeighbors(n_neighbors=1) 68 | neigh.fit(dst) 69 | distances, indices = neigh.kneighbors(src, return_distance=True) 70 | return distances.ravel(), indices.ravel() 71 | 72 | 73 | def icp(A, B, init_pose=None, max_iterations=20, tolerance=0.001): 74 | ''' 75 | The Iterative Closest Point method: finds best-fit transform that maps points A on to points B 76 | Input: 77 | A: Nxm numpy array of source mD points 78 | B: Nxm numpy array of destination mD point 79 | init_pose: (m+1)x(m+1) homogeneous transformation 80 | max_iterations: exit algorithm after max_iterations 81 | tolerance: convergence criteria 82 | Output: 83 | T: final homogeneous transformation that maps A on to B 84 | distances: Euclidean distances (errors) of the nearest neighbor 85 | i: number of iterations to converge 86 | ''' 87 | 88 | assert A.shape == B.shape 89 | 90 | # get number of dimensions 91 | m = A.shape[1] 92 | 93 | # make points homogeneous, copy them to maintain the originals 94 | src = np.ones((m+1,A.shape[0])) 95 | dst = np.ones((m+1,B.shape[0])) 96 | src[:m,:] = np.copy(A.T) 97 | dst[:m,:] = np.copy(B.T) 98 | 99 | # apply the initial pose estimation 100 | if init_pose is not None: 101 | src = np.dot(init_pose, src) 102 | 103 | prev_error = 0 104 | 105 | for i in range(max_iterations): 106 | # find the nearest neighbors between the current source and destination points 107 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T) 108 | 109 | # compute the transformation between the current source and nearest destination points 110 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T) 111 | 112 | # update the current source 113 | src = np.dot(T, src) 114 | 115 | # check error 116 | mean_error = np.mean(distances) 117 | if np.abs(prev_error - mean_error) < tolerance: 118 | break 119 | prev_error = mean_error 120 | 121 | # calculate final transformation 122 | T,_,_ = best_fit_transform(A, src[:m,:].T) 123 | 124 | return T, distances, i 125 | 126 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/PoseGraphManager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.set_printoptions(precision=4) 3 | 4 | # import minisam 5 | import gtsam 6 | from utils.UtilsMisc import * 7 | 8 | class PoseGraphManager: 9 | def __init__(self): 10 | 11 | self.prior_cov = gtsam.noiseModel.Diagonal.Sigmas(np.array([1e-6, 1e-6, 1e-6, 1e-4, 1e-4, 1e-4])) 12 | self.const_cov = np.array([0.5, 0.5, 0.5, 0.1, 0.1, 0.1]) 13 | self.odom_cov = gtsam.noiseModel.Diagonal.Sigmas(self.const_cov) 14 | self.loop_cov = gtsam.noiseModel.Diagonal.Sigmas(self.const_cov) 15 | 16 | self.graph_factors = gtsam.NonlinearFactorGraph() 17 | self.graph_initials = gtsam.Values() 18 | 19 | self.opt_param = gtsam.LevenbergMarquardtParams() 20 | self.opt = gtsam.LevenbergMarquardtOptimizer(self.graph_factors, self.graph_initials, self.opt_param) 21 | 22 | self.curr_se3 = None 23 | self.curr_node_idx = None 24 | self.prev_node_idx = None 25 | 26 | self.graph_optimized = None 27 | 28 | def addPriorFactor(self): 29 | self.curr_node_idx = 0 30 | self.prev_node_idx = 0 31 | 32 | self.curr_se3 = np.eye(4) 33 | 34 | self.graph_initials.insert(gtsam.symbol('x', self.curr_node_idx), gtsam.Pose3(self.curr_se3)) 35 | self.graph_factors.add(gtsam.PriorFactorPose3( 36 | gtsam.symbol('x', self.curr_node_idx), 37 | gtsam.Pose3(self.curr_se3), 38 | self.prior_cov)) 39 | 40 | def addOdometryFactor(self, odom_transform): 41 | 42 | self.graph_initials.insert(gtsam.symbol('x', self.curr_node_idx), gtsam.Pose3(self.curr_se3)) 43 | self.graph_factors.add(gtsam.BetweenFactorPose3( 44 | gtsam.symbol('x', self.prev_node_idx), 45 | gtsam.symbol('x', self.curr_node_idx), 46 | gtsam.Pose3(odom_transform), 47 | self.odom_cov)) 48 | 49 | def addLoopFactor(self, loop_transform, loop_idx): 50 | 51 | self.graph_factors.add(gtsam.BetweenFactorPose3( 52 | gtsam.symbol('x', loop_idx), 53 | gtsam.symbol('x', self.curr_node_idx), 54 | gtsam.Pose3(loop_transform), 55 | self.odom_cov)) 56 | 57 | def optimizePoseGraph(self): 58 | 59 | self.opt = gtsam.LevenbergMarquardtOptimizer(self.graph_factors, self.graph_initials, self.opt_param) 60 | self.graph_optimized = self.opt.optimize() 61 | 62 | # status = self.opt.optimize(self.graph_factors, self.graph_initials, self.graph_optimized) 63 | # if status != minisam.NonlinearOptimizationStatus.SUCCESS: 64 | # print("optimization error: ", status) 65 | 66 | # correct current pose 67 | pose_trans, pose_rot = getGraphNodePose(self.graph_optimized, self.curr_node_idx) 68 | self.curr_se3[:3, :3] = pose_rot 69 | self.curr_se3[:3, 3] = pose_trans 70 | 71 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/ScanContextManager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | np.set_printoptions(precision=4) 3 | 4 | import time 5 | from scipy import spatial 6 | 7 | def xy2theta(x, y): 8 | if (x >= 0 and y >= 0): 9 | theta = 180/np.pi * np.arctan(y/x); 10 | if (x < 0 and y >= 0): 11 | theta = 180 - ((180/np.pi) * np.arctan(y/(-x))); 12 | if (x < 0 and y < 0): 13 | theta = 180 + ((180/np.pi) * np.arctan(y/x)); 14 | if ( x >= 0 and y < 0): 15 | theta = 360 - ((180/np.pi) * np.arctan((-y)/x)); 16 | 17 | return theta 18 | 19 | 20 | def pt2rs(point, gap_ring, gap_sector, num_ring, num_sector): 21 | x = point[0] 22 | y = point[1] 23 | # z = point[2] 24 | 25 | if(x == 0.0): 26 | x = 0.001 27 | if(y == 0.0): 28 | y = 0.001 29 | 30 | theta = xy2theta(x, y) 31 | faraway = np.sqrt(x*x + y*y) 32 | 33 | idx_ring = np.divmod(faraway, gap_ring)[0] 34 | idx_sector = np.divmod(theta, gap_sector)[0] 35 | 36 | if(idx_ring >= num_ring): 37 | idx_ring = num_ring-1 # python starts with 0 and ends with N-1 38 | 39 | return int(idx_ring), int(idx_sector) 40 | 41 | 42 | def ptcloud2sc(ptcloud, sc_shape, max_length): 43 | num_ring = sc_shape[0] 44 | num_sector = sc_shape[1] 45 | 46 | gap_ring = max_length/num_ring 47 | gap_sector = 360/num_sector 48 | 49 | enough_large = 500 50 | sc_storage = np.zeros([enough_large, num_ring, num_sector]) 51 | sc_counter = np.zeros([num_ring, num_sector]) 52 | 53 | num_points = ptcloud.shape[0] 54 | for pt_idx in range(num_points): 55 | point = ptcloud[pt_idx, :] 56 | point_height = point[2] + 2.0 # for setting ground is roughly zero 57 | 58 | idx_ring, idx_sector = pt2rs(point, gap_ring, gap_sector, num_ring, num_sector) 59 | 60 | if sc_counter[idx_ring, idx_sector] >= enough_large: 61 | continue 62 | sc_storage[int(sc_counter[idx_ring, idx_sector]), idx_ring, idx_sector] = point_height 63 | sc_counter[idx_ring, idx_sector] = sc_counter[idx_ring, idx_sector] + 1 64 | 65 | sc = np.amax(sc_storage, axis=0) 66 | 67 | return sc 68 | 69 | 70 | def sc2rk(sc): 71 | return np.mean(sc, axis=1) 72 | 73 | def distance_sc(sc1, sc2): 74 | num_sectors = sc1.shape[1] 75 | 76 | # repeate to move 1 columns 77 | _one_step = 1 # const 78 | sim_for_each_cols = np.zeros(num_sectors) 79 | for i in range(num_sectors): 80 | # Shift 81 | sc1 = np.roll(sc1, _one_step, axis=1) # columne shift 82 | 83 | #compare 84 | sum_of_cossim = 0 85 | num_col_engaged = 0 86 | for j in range(num_sectors): 87 | col_j_1 = sc1[:, j] 88 | col_j_2 = sc2[:, j] 89 | if (~np.any(col_j_1) or ~np.any(col_j_2)): 90 | # to avoid being divided by zero when calculating cosine similarity 91 | # - but this part is quite slow in python, you can omit it. 92 | continue 93 | 94 | cossim = np.dot(col_j_1, col_j_2) / (np.linalg.norm(col_j_1) * np.linalg.norm(col_j_2)) 95 | sum_of_cossim = sum_of_cossim + cossim 96 | 97 | num_col_engaged = num_col_engaged + 1 98 | 99 | # save 100 | sim_for_each_cols[i] = sum_of_cossim / num_col_engaged 101 | 102 | yaw_diff = np.argmax(sim_for_each_cols) + 1 # because python starts with 0 103 | sim = np.max(sim_for_each_cols) 104 | dist = 1 - sim 105 | 106 | return dist, yaw_diff 107 | 108 | 109 | class ScanContextManager: 110 | def __init__(self, shape=[20,60], num_candidates=10, threshold=0.15): # defualt configs are same as the original paper 111 | self.shape = shape 112 | self.num_candidates = num_candidates 113 | self.threshold = threshold 114 | 115 | self.max_length = 80 # recommended but other (e.g., 100m) is also ok. 116 | 117 | self.ENOUGH_LARGE = 15000 # capable of up to ENOUGH_LARGE number of nodes 118 | self.ptclouds = [None] * self.ENOUGH_LARGE 119 | self.scancontexts = [None] * self.ENOUGH_LARGE 120 | self.ringkeys = [None] * self.ENOUGH_LARGE 121 | 122 | self.curr_node_idx = 0 123 | 124 | 125 | def addNode(self, node_idx, ptcloud): 126 | sc = ptcloud2sc(ptcloud, self.shape, self.max_length) 127 | rk = sc2rk(sc) 128 | 129 | self.curr_node_idx = node_idx 130 | self.ptclouds[node_idx] = ptcloud 131 | self.scancontexts[node_idx] = sc 132 | self.ringkeys[node_idx] = rk 133 | 134 | def getPtcloud(self, node_idx): 135 | return self.ptclouds[node_idx] 136 | 137 | 138 | def detectLoop(self): 139 | exclude_recent_nodes = 30 140 | valid_recent_node_idx = self.curr_node_idx - exclude_recent_nodes 141 | 142 | if(valid_recent_node_idx < 1): 143 | return None, None, None 144 | else: 145 | # step 1 146 | ringkey_history = np.array(self.ringkeys[:valid_recent_node_idx]) 147 | ringkey_tree = spatial.KDTree(ringkey_history) 148 | 149 | ringkey_query = self.ringkeys[self.curr_node_idx] 150 | _, nncandidates_idx = ringkey_tree.query(ringkey_query, k=self.num_candidates) 151 | 152 | # step 2 153 | query_sc = self.scancontexts[self.curr_node_idx] 154 | 155 | nn_dist = 1.0 # initialize with the largest value of distance 156 | nn_idx = None 157 | nn_yawdiff = None 158 | for ith in range(self.num_candidates): 159 | candidate_idx = nncandidates_idx[ith] 160 | candidate_sc = self.scancontexts[candidate_idx] 161 | dist, yaw_diff = distance_sc(candidate_sc, query_sc) 162 | if(dist < nn_dist): 163 | nn_dist = dist 164 | nn_yawdiff = yaw_diff 165 | nn_idx = candidate_idx 166 | 167 | if(nn_dist < self.threshold): 168 | nn_yawdiff_deg = nn_yawdiff * (360/self.shape[1]) 169 | return nn_idx, nn_dist, nn_yawdiff_deg # loop detected! 170 | else: 171 | return None, None, None 172 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/UtilsMisc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import copy 4 | import time 5 | import math 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | import gtsam 11 | 12 | def getConstDigitsNumber(val, num_digits): 13 | return "{:.{}f}".format(val, num_digits) 14 | 15 | def getUnixTime(): 16 | return int(time.time()) 17 | 18 | def eulerAnglesToRotationMatrix(theta) : 19 | 20 | R_x = np.array([[1, 0, 0 ], 21 | [0, math.cos(theta[0]), -math.sin(theta[0]) ], 22 | [0, math.sin(theta[0]), math.cos(theta[0]) ] 23 | ]) 24 | 25 | R_y = np.array([[math.cos(theta[1]), 0, math.sin(theta[1]) ], 26 | [0, 1, 0 ], 27 | [-math.sin(theta[1]), 0, math.cos(theta[1]) ] 28 | ]) 29 | 30 | R_z = np.array([[math.cos(theta[2]), -math.sin(theta[2]), 0], 31 | [math.sin(theta[2]), math.cos(theta[2]), 0], 32 | [0, 0, 1] 33 | ]) 34 | 35 | R = np.dot(R_z, np.dot( R_y, R_x )) 36 | 37 | return R 38 | 39 | def yawdeg2so3(yaw_deg): 40 | yaw_rad = np.deg2rad(yaw_deg) 41 | return eulerAnglesToRotationMatrix([0, 0, yaw_rad]) 42 | 43 | def yawdeg2se3(yaw_deg): 44 | se3 = np.eye(4) 45 | se3[:3, :3] = yawdeg2so3(yaw_deg) 46 | return se3 47 | 48 | 49 | def getGraphNodePose(graph, idx): 50 | 51 | pose = graph.atPose3(gtsam.symbol('x', idx)) 52 | pose_trans = np.array([pose.x(), pose.y(), pose.z()]) 53 | pose_rot = pose.rotation().matrix() 54 | 55 | return pose_trans, pose_rot 56 | 57 | def saveOptimizedGraphPose(curr_node_idx, graph_optimized, filename): 58 | 59 | for opt_idx in range(curr_node_idx): 60 | pose_trans, pose_rot = getGraphNodePose(graph_optimized, opt_idx) 61 | pose_trans = np.reshape(pose_trans, (-1, 3)).squeeze() 62 | pose_rot = np.reshape(pose_rot, (-1, 9)).squeeze() 63 | optimized_pose_ith = np.array([ pose_rot[0], pose_rot[1], pose_rot[2], pose_trans[0], 64 | pose_rot[3], pose_rot[4], pose_rot[5], pose_trans[1], 65 | pose_rot[6], pose_rot[7], pose_rot[8], pose_trans[2], 66 | 0.0, 0.0, 0.0, 0.1 ]) 67 | if(opt_idx == 0): 68 | optimized_pose_list = optimized_pose_ith 69 | else: 70 | optimized_pose_list = np.vstack((optimized_pose_list, optimized_pose_ith)) 71 | 72 | np.savetxt(filename, optimized_pose_list, delimiter=",") 73 | 74 | 75 | class PoseGraphResultSaver: 76 | def __init__(self, init_pose, save_gap, num_frames, seq_idx, save_dir): 77 | self.pose_list = np.reshape(init_pose, (-1, 16)) 78 | self.save_gap = save_gap 79 | self.num_frames = num_frames 80 | 81 | self.seq_idx = seq_idx 82 | self.save_dir = save_dir 83 | 84 | def saveUnoptimizedPoseGraphResult(self, cur_pose, cur_node_idx): 85 | # save 86 | self.pose_list = np.vstack((self.pose_list, np.reshape(cur_pose, (-1, 16)))) 87 | 88 | # write 89 | if(cur_node_idx % self.save_gap == 0 or cur_node_idx == self.num_frames): 90 | # save odometry-only poses 91 | filename = "pose" + self.seq_idx + "unoptimized_" + str(getUnixTime()) + ".csv" 92 | filename = os.path.join(self.save_dir, filename) 93 | np.savetxt(filename, self.pose_list, delimiter=",") 94 | 95 | def saveOptimizedPoseGraphResult(self, cur_node_idx, graph_optimized): 96 | filename = "pose" + self.seq_idx + "optimized_" + str(getUnixTime()) + ".csv" 97 | filename = os.path.join(self.save_dir, filename) 98 | saveOptimizedGraphPose(cur_node_idx, graph_optimized, filename) 99 | 100 | optimized_pose_list = np.loadtxt(open(filename, "rb"), delimiter=",", skiprows=1) 101 | self.pose_list = optimized_pose_list # update with optimized pose 102 | 103 | def vizCurrentTrajectory(self, fig_idx): 104 | x = self.pose_list[:,3] 105 | y = self.pose_list[:,7] 106 | z = self.pose_list[:,11] 107 | 108 | fig = plt.figure(fig_idx) 109 | plt.clf() 110 | plt.plot(-y, x, color='blue') # kitti camera coord for clarity 111 | plt.axis('equal') 112 | plt.xlabel('x', labelpad=10) 113 | plt.ylabel('y', labelpad=10) 114 | plt.draw() 115 | plt.pause(0.01) #is necessary for the plot to update for some reason 116 | 117 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/UtilsPointcloud.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | 5 | def random_sampling(orig_points, num_points): 6 | assert orig_points.shape[0] > num_points 7 | 8 | points_down_idx = random.sample(range(orig_points.shape[0]), num_points) 9 | down_points = orig_points[points_down_idx, :] 10 | 11 | return down_points 12 | 13 | def readScan(bin_path, dataset='KITTI'): 14 | if(dataset == 'KITTI'): 15 | return readKittiScan(bin_path) 16 | 17 | 18 | def readKittiScan(bin_path): 19 | scan = np.fromfile(bin_path, dtype=np.float32) 20 | scan = scan.reshape((-1, 4)) 21 | ptcloud_xyz = scan[:, :-1] 22 | return ptcloud_xyz 23 | 24 | 25 | class KittiScanDirManager: 26 | def __init__(self, scan_dir): 27 | self.scan_dir = scan_dir 28 | 29 | self.scan_names = os.listdir(scan_dir) 30 | self.scan_names.sort() 31 | 32 | self.scan_fullpaths = [os.path.join(self.scan_dir, name) for name in self.scan_names] 33 | 34 | self.num_scans = len(self.scan_names) 35 | 36 | def __repr__(self): 37 | return ' ' + str(self.num_scans) + ' scans in the sequence (' + self.scan_dir + '/)' 38 | 39 | def getScanNames(self): 40 | return self.scan_names 41 | def getScanFullPaths(self): 42 | return self.scan_fullpaths 43 | def printScanFullPaths(self): 44 | return print("\n".join(self.scan_fullpaths)) 45 | 46 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/ICP.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/ICP.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/PoseGraphManager.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/PoseGraphManager.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/ScanContextManager.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/ScanContextManager.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/UtilsMisc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/UtilsMisc.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/UtilsPointcloud.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/UtilsPointcloud.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/corr_downsample.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/corr_downsample.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/extract_corr.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/extract_corr.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/__pycache__/registration.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/__pycache__/registration.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/corr_downsample.py: -------------------------------------------------------------------------------- 1 | from .fastmac.gconstructor import GraphConstructor 2 | from .fastmac.gfilter import graphFilter,datasample 3 | import torch 4 | def Config(): 5 | config={ 6 | "num_points":5000, 7 | "data_dir":'/data/Processed_KITTI/correspondence_fcgf/', 8 | "filename":'fcgf@corr.txt', 9 | 'gtname':'fcgf@gtmat.txt', 10 | 'labelname':'fcgf@gtlabel.txt', 11 | 'batch_size':1, 12 | 'inlier_thresh':0.6, 13 | 'thresh':0.999, 14 | 'sigma':0.6, 15 | 'tau':0., 16 | 'device':'cuda', 17 | 'mode':'graph', 18 | 'ratio':0.01, 19 | 'outpath':'/home/Zero/mac/ablation/score_weight/111/1/sample', 20 | 'pc1_weight':1, 21 | 'pc2_weight':1, 22 | 'degree_weight':1, 23 | } 24 | return config 25 | 26 | config = Config() 27 | gc=GraphConstructor(config["inlier_thresh"],config["thresh"],trainable=False,sigma=config["sigma"],tau=config["tau"]) 28 | 29 | def normalize(x): 30 | # transform x to [0,1] 31 | x=x-x.min() 32 | x=x/x.max() 33 | return x 34 | 35 | def downsample(corr): 36 | if len(corr.shape)==2: 37 | corr=corr.unsqueeze(0) 38 | corr_graph=gc(corr,mode="correspondence") 39 | pc1_signal=corr[:,:,:3] 40 | pc2_signal=corr[:,:,3:] 41 | degree_signal=torch.sum(corr_graph,dim=-1) 42 | pc_graph1=gc(pc1_signal,mode="pointcloud") 43 | pc_graph2=gc(pc2_signal,mode="pointcloud") 44 | 45 | pc_laplacian1=(torch.diag_embed(torch.sum(pc_graph1,dim=-1))-pc_graph1).squeeze(0) 46 | pc_laplacian2=(torch.diag_embed(torch.sum(pc_graph2,dim=-1))-pc_graph2).squeeze(0) 47 | 48 | pc1_scores=graphFilter(pc1_signal.squeeze(0),pc_laplacian1,is_sparse=False) 49 | pc2_scores=graphFilter(pc2_signal.squeeze(0),pc_laplacian2,is_sparse=False) 50 | 51 | corr_laplacian=(torch.diag_embed(degree_signal)-corr_graph).squeeze(0) 52 | 53 | 54 | corr_scores=graphFilter(degree_signal.transpose(0,1),corr_laplacian,is_sparse=False) 55 | # corr_scores=graphFilter(torch.ones_like(degree_signal.transpose(0,1)).cuda(),torch.matmul(corr_laplacian,corr_laplacian),is_sparse=False) 56 | 57 | pc1_scores=normalize(pc1_scores) 58 | pc2_scores=normalize(pc2_scores) 59 | corr_scores=normalize(corr_scores) 60 | 61 | total_scores=config["pc1_weight"]*pc1_scores+config["pc2_weight"]*pc2_scores+config["degree_weight"]*corr_scores 62 | idxs=datasample(1000,False,total_scores) 63 | samples=corr.squeeze(0)[idxs,:] 64 | return samples -------------------------------------------------------------------------------- /pipeline_SLAM/utils/extract_corr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from easydict import EasyDict as edict 3 | from functools import partial 4 | import torch 5 | from geotransformer.utils.data import precompute_data_stack_mode 6 | from geotransformer.utils.torch import to_cuda, release_cuda 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from IPython import embed 11 | 12 | from geotransformer.modules.ops import point_to_node_partition, index_select 13 | from geotransformer.modules.registration import get_node_correspondences 14 | from geotransformer.modules.sinkhorn import LearnableLogOptimalTransport 15 | from geotransformer.modules.geotransformer import ( 16 | GeometricTransformer, 17 | SuperPointMatching, 18 | SuperPointTargetGenerator, 19 | LocalGlobalRegistration, 20 | ) 21 | from geotransformer.modules.kpconv import ConvBlock, ResidualBlock, UnaryBlock, LastUnaryBlock, nearest_upsample 22 | 23 | 24 | class KPConvFPN(nn.Module): 25 | def __init__(self, input_dim, output_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm): 26 | super(KPConvFPN, self).__init__() 27 | 28 | self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size, init_radius, init_sigma, group_norm) 29 | self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm) 30 | 31 | self.encoder2_1 = ResidualBlock( 32 | init_dim * 2, init_dim * 2, kernel_size, init_radius, init_sigma, group_norm, strided=True 33 | ) 34 | self.encoder2_2 = ResidualBlock( 35 | init_dim * 2, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 36 | ) 37 | self.encoder2_3 = ResidualBlock( 38 | init_dim * 4, init_dim * 4, kernel_size, init_radius * 2, init_sigma * 2, group_norm 39 | ) 40 | 41 | self.encoder3_1 = ResidualBlock( 42 | init_dim * 4, 43 | init_dim * 4, 44 | kernel_size, 45 | init_radius * 2, 46 | init_sigma * 2, 47 | group_norm, 48 | strided=True, 49 | ) 50 | self.encoder3_2 = ResidualBlock( 51 | init_dim * 4, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 52 | ) 53 | self.encoder3_3 = ResidualBlock( 54 | init_dim * 8, init_dim * 8, kernel_size, init_radius * 4, init_sigma * 4, group_norm 55 | ) 56 | 57 | self.encoder4_1 = ResidualBlock( 58 | init_dim * 8, 59 | init_dim * 8, 60 | kernel_size, 61 | init_radius * 4, 62 | init_sigma * 4, 63 | group_norm, 64 | strided=True, 65 | ) 66 | self.encoder4_2 = ResidualBlock( 67 | init_dim * 8, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 68 | ) 69 | self.encoder4_3 = ResidualBlock( 70 | init_dim * 16, init_dim * 16, kernel_size, init_radius * 8, init_sigma * 8, group_norm 71 | ) 72 | 73 | self.encoder5_1 = ResidualBlock( 74 | init_dim * 16, 75 | init_dim * 16, 76 | kernel_size, 77 | init_radius * 8, 78 | init_sigma * 8, 79 | group_norm, 80 | strided=True, 81 | ) 82 | self.encoder5_2 = ResidualBlock( 83 | init_dim * 16, init_dim * 32, kernel_size, init_radius * 16, init_sigma * 16, group_norm 84 | ) 85 | self.encoder5_3 = ResidualBlock( 86 | init_dim * 32, init_dim * 32, kernel_size, init_radius * 16, init_sigma * 16, group_norm 87 | ) 88 | 89 | self.decoder4 = UnaryBlock(init_dim * 48, init_dim * 16, group_norm) 90 | self.decoder3 = UnaryBlock(init_dim * 24, init_dim * 8, group_norm) 91 | self.decoder2 = LastUnaryBlock(init_dim * 12, output_dim) 92 | 93 | def forward(self, feats, data_dict): 94 | feats_list = [] 95 | 96 | points_list = data_dict['points'] 97 | neighbors_list = data_dict['neighbors'] 98 | subsampling_list = data_dict['subsampling'] 99 | upsampling_list = data_dict['upsampling'] 100 | 101 | feats_s1 = feats 102 | feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 103 | feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0], neighbors_list[0]) 104 | 105 | feats_s2 = self.encoder2_1(feats_s1, points_list[1], points_list[0], subsampling_list[0]) 106 | feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 107 | feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1], neighbors_list[1]) 108 | 109 | feats_s3 = self.encoder3_1(feats_s2, points_list[2], points_list[1], subsampling_list[1]) 110 | feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 111 | feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2], neighbors_list[2]) 112 | 113 | feats_s4 = self.encoder4_1(feats_s3, points_list[3], points_list[2], subsampling_list[2]) 114 | feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 115 | feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3], neighbors_list[3]) 116 | 117 | feats_s5 = self.encoder5_1(feats_s4, points_list[4], points_list[3], subsampling_list[3]) 118 | feats_s5 = self.encoder5_2(feats_s5, points_list[4], points_list[4], neighbors_list[4]) 119 | feats_s5 = self.encoder5_3(feats_s5, points_list[4], points_list[4], neighbors_list[4]) 120 | 121 | latent_s5 = feats_s5 122 | feats_list.append(feats_s5) 123 | 124 | latent_s4 = nearest_upsample(latent_s5, upsampling_list[3]) 125 | latent_s4 = torch.cat([latent_s4, feats_s4], dim=1) 126 | latent_s4 = self.decoder4(latent_s4) 127 | feats_list.append(latent_s4) 128 | 129 | latent_s3 = nearest_upsample(latent_s4, upsampling_list[2]) 130 | latent_s3 = torch.cat([latent_s3, feats_s3], dim=1) 131 | latent_s3 = self.decoder3(latent_s3) 132 | feats_list.append(latent_s3) 133 | 134 | latent_s2 = nearest_upsample(latent_s3, upsampling_list[1]) 135 | latent_s2 = torch.cat([latent_s2, feats_s2], dim=1) 136 | latent_s2 = self.decoder2(latent_s2) 137 | feats_list.append(latent_s2) 138 | 139 | feats_list.reverse() 140 | 141 | return feats_list 142 | 143 | 144 | class GeoTransformer(nn.Module): 145 | def __init__(self, cfg): 146 | super(GeoTransformer, self).__init__() 147 | self.num_points_in_patch = cfg.model.num_points_in_patch 148 | self.matching_radius = cfg.model.ground_truth_matching_radius 149 | 150 | self.backbone = KPConvFPN( 151 | cfg.backbone.input_dim, 152 | cfg.backbone.output_dim, 153 | cfg.backbone.init_dim, 154 | cfg.backbone.kernel_size, 155 | cfg.backbone.init_radius, 156 | cfg.backbone.init_sigma, 157 | cfg.backbone.group_norm, 158 | ) 159 | 160 | self.transformer = GeometricTransformer( 161 | cfg.geotransformer.input_dim, 162 | cfg.geotransformer.output_dim, 163 | cfg.geotransformer.hidden_dim, 164 | cfg.geotransformer.num_heads, 165 | cfg.geotransformer.blocks, 166 | cfg.geotransformer.sigma_d, 167 | cfg.geotransformer.sigma_a, 168 | cfg.geotransformer.angle_k, 169 | reduction_a=cfg.geotransformer.reduction_a, 170 | ) 171 | 172 | self.coarse_target = SuperPointTargetGenerator( 173 | cfg.coarse_matching.num_targets, cfg.coarse_matching.overlap_threshold 174 | ) 175 | 176 | self.coarse_matching = SuperPointMatching( 177 | cfg.coarse_matching.num_correspondences, cfg.coarse_matching.dual_normalization 178 | ) 179 | 180 | self.fine_matching = LocalGlobalRegistration( 181 | cfg.fine_matching.topk, 182 | cfg.fine_matching.acceptance_radius, 183 | mutual=cfg.fine_matching.mutual, 184 | confidence_threshold=cfg.fine_matching.confidence_threshold, 185 | use_dustbin=cfg.fine_matching.use_dustbin, 186 | use_global_score=cfg.fine_matching.use_global_score, 187 | correspondence_threshold=cfg.fine_matching.correspondence_threshold, 188 | correspondence_limit=cfg.fine_matching.correspondence_limit, 189 | num_refinement_steps=cfg.fine_matching.num_refinement_steps, 190 | ) 191 | 192 | self.optimal_transport = LearnableLogOptimalTransport(cfg.model.num_sinkhorn_iterations) 193 | 194 | def forward(self, data_dict): 195 | output_dict = {} 196 | 197 | # Downsample point clouds 198 | feats = data_dict['features'].detach() 199 | transform = data_dict['transform'].detach() 200 | 201 | ref_length_c = data_dict['lengths'][-1][0].item() 202 | ref_length_f = data_dict['lengths'][1][0].item() 203 | ref_length = data_dict['lengths'][0][0].item() 204 | points_c = data_dict['points'][-1].detach() 205 | points_f = data_dict['points'][1].detach() 206 | points = data_dict['points'][0].detach() 207 | 208 | 209 | ref_points_c = points_c[:ref_length_c] 210 | src_points_c = points_c[ref_length_c:] 211 | ref_points_f = points_f[:ref_length_f] 212 | src_points_f = points_f[ref_length_f:] 213 | ref_points = points[:ref_length] 214 | src_points = points[ref_length:] 215 | 216 | output_dict['ref_points_c'] = ref_points_c 217 | output_dict['src_points_c'] = src_points_c 218 | output_dict['ref_points_f'] = ref_points_f 219 | output_dict['src_points_f'] = src_points_f 220 | output_dict['ref_points'] = ref_points 221 | output_dict['src_points'] = src_points 222 | 223 | # 1. Generate ground truth node correspondences 224 | _, ref_node_masks, ref_node_knn_indices, ref_node_knn_masks = point_to_node_partition( 225 | ref_points_f, ref_points_c, self.num_points_in_patch 226 | ) 227 | _, src_node_masks, src_node_knn_indices, src_node_knn_masks = point_to_node_partition( 228 | src_points_f, src_points_c, self.num_points_in_patch 229 | ) 230 | 231 | ref_padded_points_f = torch.cat([ref_points_f, torch.zeros_like(ref_points_f[:1])], dim=0) 232 | src_padded_points_f = torch.cat([src_points_f, torch.zeros_like(src_points_f[:1])], dim=0) 233 | ref_node_knn_points = index_select(ref_padded_points_f, ref_node_knn_indices, dim=0) 234 | src_node_knn_points = index_select(src_padded_points_f, src_node_knn_indices, dim=0) 235 | 236 | gt_node_corr_indices, gt_node_corr_overlaps = get_node_correspondences( 237 | ref_points_c, 238 | src_points_c, 239 | ref_node_knn_points, 240 | src_node_knn_points, 241 | transform, 242 | self.matching_radius, 243 | ref_masks=ref_node_masks, 244 | src_masks=src_node_masks, 245 | ref_knn_masks=ref_node_knn_masks, 246 | src_knn_masks=src_node_knn_masks, 247 | ) 248 | 249 | output_dict['gt_node_corr_indices'] = gt_node_corr_indices 250 | output_dict['gt_node_corr_overlaps'] = gt_node_corr_overlaps 251 | 252 | # 2. KPFCNN Encoder 253 | feats_list = self.backbone(feats, data_dict) 254 | 255 | feats_c = feats_list[-1] 256 | feats_f = feats_list[0] 257 | 258 | # 3. Conditional Transformer 259 | ref_feats_c = feats_c[:ref_length_c] 260 | src_feats_c = feats_c[ref_length_c:] 261 | ref_feats_c, src_feats_c = self.transformer( 262 | ref_points_c.unsqueeze(0), 263 | src_points_c.unsqueeze(0), 264 | ref_feats_c.unsqueeze(0), 265 | src_feats_c.unsqueeze(0), 266 | ) 267 | ref_feats_c_norm = F.normalize(ref_feats_c.squeeze(0), p=2, dim=1) 268 | src_feats_c_norm = F.normalize(src_feats_c.squeeze(0), p=2, dim=1) 269 | 270 | output_dict['ref_feats_c'] = ref_feats_c_norm 271 | output_dict['src_feats_c'] = src_feats_c_norm 272 | 273 | # 5. Head for fine level matching 274 | ref_feats_f = feats_f[:ref_length_f] 275 | src_feats_f = feats_f[ref_length_f:] 276 | output_dict['ref_feats_f'] = ref_feats_f 277 | output_dict['src_feats_f'] = src_feats_f 278 | 279 | # 6. Select topk nearest node correspondences 280 | with torch.no_grad(): 281 | ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_matching( 282 | ref_feats_c_norm, src_feats_c_norm, ref_node_masks, src_node_masks 283 | ) 284 | 285 | output_dict['ref_node_corr_indices'] = ref_node_corr_indices 286 | output_dict['src_node_corr_indices'] = src_node_corr_indices 287 | 288 | # 7 Random select ground truth node correspondences during training 289 | if self.training: 290 | ref_node_corr_indices, src_node_corr_indices, node_corr_scores = self.coarse_target( 291 | gt_node_corr_indices, gt_node_corr_overlaps 292 | ) 293 | 294 | # 7.2 Generate batched node points & feats 295 | ref_node_corr_knn_indices = ref_node_knn_indices[ref_node_corr_indices] # (P, K) 296 | src_node_corr_knn_indices = src_node_knn_indices[src_node_corr_indices] # (P, K) 297 | ref_node_corr_knn_masks = ref_node_knn_masks[ref_node_corr_indices] # (P, K) 298 | src_node_corr_knn_masks = src_node_knn_masks[src_node_corr_indices] # (P, K) 299 | ref_node_corr_knn_points = ref_node_knn_points[ref_node_corr_indices] # (P, K, 3) 300 | src_node_corr_knn_points = src_node_knn_points[src_node_corr_indices] # (P, K, 3) 301 | 302 | ref_padded_feats_f = torch.cat([ref_feats_f, torch.zeros_like(ref_feats_f[:1])], dim=0) 303 | src_padded_feats_f = torch.cat([src_feats_f, torch.zeros_like(src_feats_f[:1])], dim=0) 304 | ref_node_corr_knn_feats = index_select(ref_padded_feats_f, ref_node_corr_knn_indices, dim=0) # (P, K, C) 305 | src_node_corr_knn_feats = index_select(src_padded_feats_f, src_node_corr_knn_indices, dim=0) # (P, K, C) 306 | 307 | output_dict['ref_node_corr_knn_points'] = ref_node_corr_knn_points 308 | output_dict['src_node_corr_knn_points'] = src_node_corr_knn_points 309 | output_dict['ref_node_corr_knn_masks'] = ref_node_corr_knn_masks 310 | output_dict['src_node_corr_knn_masks'] = src_node_corr_knn_masks 311 | 312 | # 8. Optimal transport 313 | matching_scores = torch.einsum('bnd,bmd->bnm', ref_node_corr_knn_feats, src_node_corr_knn_feats) # (P, K, K) 314 | matching_scores = matching_scores / feats_f.shape[1] ** 0.5 315 | matching_scores = self.optimal_transport(matching_scores, ref_node_corr_knn_masks, src_node_corr_knn_masks) 316 | 317 | output_dict['matching_scores'] = matching_scores 318 | 319 | # 9. Generate final correspondences during testing 320 | with torch.no_grad(): 321 | if not self.fine_matching.use_dustbin: 322 | matching_scores = matching_scores[:, :-1, :-1] 323 | 324 | ref_corr_points, src_corr_points, corr_scores, estimated_transform = self.fine_matching( 325 | ref_node_corr_knn_points, 326 | src_node_corr_knn_points, 327 | ref_node_corr_knn_masks, 328 | src_node_corr_knn_masks, 329 | matching_scores, 330 | node_corr_scores, 331 | ) 332 | 333 | output_dict['ref_corr_points'] = ref_corr_points 334 | output_dict['src_corr_points'] = src_corr_points 335 | output_dict['corr_scores'] = corr_scores 336 | output_dict['estimated_transform'] = estimated_transform 337 | 338 | return output_dict 339 | 340 | def load_data(src_points, ref_points): 341 | src_feats = np.ones_like(src_points[:, :1]) 342 | ref_feats = np.ones_like(ref_points[:, :1]) 343 | 344 | data_dict = { 345 | "ref_points": ref_points.astype(np.float32), 346 | "src_points": src_points.astype(np.float32), 347 | "ref_feats": ref_feats.astype(np.float32), 348 | "src_feats": src_feats.astype(np.float32), 349 | 'transform': np.eye(4).astype(np.float32) 350 | } 351 | return data_dict 352 | 353 | def make_cfg(): 354 | _C = edict() 355 | # common 356 | _C.seed = 7351 357 | 358 | # model - backbone 359 | _C.backbone = edict() 360 | _C.backbone.num_stages = 5 361 | _C.backbone.init_voxel_size = 0.3 362 | _C.backbone.kernel_size = 15 363 | _C.backbone.base_radius = 4.25 364 | _C.backbone.base_sigma = 2.0 365 | _C.backbone.init_radius = _C.backbone.base_radius * _C.backbone.init_voxel_size 366 | _C.backbone.init_sigma = _C.backbone.base_sigma * _C.backbone.init_voxel_size 367 | _C.backbone.group_norm = 32 368 | _C.backbone.input_dim = 1 369 | _C.backbone.init_dim = 64 370 | _C.backbone.output_dim = 256 371 | 372 | # model - Global 373 | _C.model = edict() 374 | _C.model.ground_truth_matching_radius = 0.6 375 | _C.model.num_points_in_patch = 128 376 | _C.model.num_sinkhorn_iterations = 100 377 | 378 | # model - Coarse Matching 379 | _C.coarse_matching = edict() 380 | _C.coarse_matching.num_targets = 128 381 | _C.coarse_matching.overlap_threshold = 0.1 382 | _C.coarse_matching.num_correspondences = 256 383 | _C.coarse_matching.dual_normalization = True 384 | 385 | # model - GeoTransformer 386 | _C.geotransformer = edict() 387 | _C.geotransformer.input_dim = 2048 388 | _C.geotransformer.hidden_dim = 128 389 | _C.geotransformer.output_dim = 256 390 | _C.geotransformer.num_heads = 4 391 | _C.geotransformer.blocks = ['self', 'cross', 'self', 'cross', 'self', 'cross'] 392 | _C.geotransformer.sigma_d = 4.8 393 | _C.geotransformer.sigma_a = 15 394 | _C.geotransformer.angle_k = 3 395 | _C.geotransformer.reduction_a = 'max' 396 | 397 | # model - Fine Matching 398 | _C.fine_matching = edict() 399 | _C.fine_matching.topk = 2 400 | _C.fine_matching.acceptance_radius = 0.6 401 | _C.fine_matching.mutual = True 402 | _C.fine_matching.confidence_threshold = 0.05 403 | _C.fine_matching.use_dustbin = False 404 | _C.fine_matching.use_global_score = False 405 | _C.fine_matching.correspondence_threshold = 3 406 | _C.fine_matching.correspondence_limit = None 407 | _C.fine_matching.num_refinement_steps = 5 408 | 409 | # loss - Coarse level 410 | _C.coarse_loss = edict() 411 | _C.coarse_loss.positive_margin = 0.1 412 | _C.coarse_loss.negative_margin = 1.4 413 | _C.coarse_loss.positive_optimal = 0.1 414 | _C.coarse_loss.negative_optimal = 1.4 415 | _C.coarse_loss.log_scale = 40 416 | _C.coarse_loss.positive_overlap = 0.1 417 | 418 | # loss - Fine level 419 | _C.fine_loss = edict() 420 | _C.fine_loss.positive_radius = 0.6 421 | 422 | # loss - Overall 423 | _C.loss = edict() 424 | _C.loss.weight_coarse_loss = 1.0 425 | _C.loss.weight_fine_loss = 1.0 426 | 427 | return _C 428 | 429 | 430 | 431 | def registration_collate_fn_stack_mode( 432 | data_dicts, num_stages, voxel_size, search_radius, neighbor_limits, precompute_data=True 433 | ): 434 | r"""Collate function for registration in stack mode. 435 | 436 | Points are organized in the following order: [ref_1, ..., ref_B, src_1, ..., src_B]. 437 | The correspondence indices are within each point cloud without accumulation. 438 | 439 | Args: 440 | data_dicts (List[Dict]) 441 | num_stages (int) 442 | voxel_size (float) 443 | search_radius (float) 444 | neighbor_limits (List[int]) 445 | precompute_data (bool) 446 | 447 | Returns: 448 | collated_dict (Dict) 449 | """ 450 | batch_size = len(data_dicts) 451 | # merge data with the same key from different samples into a list 452 | collated_dict = {} 453 | for data_dict in data_dicts: 454 | for key, value in data_dict.items(): 455 | if isinstance(value, np.ndarray): 456 | value = torch.from_numpy(value) 457 | if key not in collated_dict: 458 | collated_dict[key] = [] 459 | collated_dict[key].append(value) 460 | 461 | # handle special keys: [ref_feats, src_feats] -> feats, [ref_points, src_points] -> points, lengths 462 | feats = torch.cat(collated_dict.pop('ref_feats') + collated_dict.pop('src_feats'), dim=0) 463 | points_list = collated_dict.pop('ref_points') + collated_dict.pop('src_points') 464 | lengths = torch.LongTensor([points.shape[0] for points in points_list]) 465 | points = torch.cat(points_list, dim=0) 466 | 467 | if batch_size == 1: 468 | # remove wrapping brackets if batch_size is 1 469 | for key, value in collated_dict.items(): 470 | collated_dict[key] = value[0] 471 | 472 | collated_dict['features'] = feats 473 | if precompute_data: 474 | input_dict = precompute_data_stack_mode(points, lengths, num_stages, voxel_size, search_radius, neighbor_limits) 475 | collated_dict.update(input_dict) 476 | else: 477 | collated_dict['points'] = points 478 | collated_dict['lengths'] = lengths 479 | collated_dict['batch_size'] = batch_size 480 | 481 | return collated_dict 482 | 483 | class CorrExtractor: 484 | def __init__(self): 485 | cfg = make_cfg() 486 | self.neighbor_limits = [64 ,65 ,74, 80 ,79] 487 | model = GeoTransformer(cfg).cuda() 488 | state_dict = torch.load('geotransformer-kitti.pth.tar') 489 | model.load_state_dict(state_dict["model"]) 490 | self.model = model 491 | self.cfg = cfg 492 | 493 | def extract_corr(self, src_points, ref_points): 494 | data_dict = load_data(src_points, ref_points) 495 | data_dict = registration_collate_fn_stack_mode( 496 | [data_dict], self.cfg.backbone.num_stages, self.cfg.backbone.init_voxel_size, self.cfg.backbone.init_radius, self.neighbor_limits 497 | ) 498 | data_dict = to_cuda(data_dict) 499 | output_dict = self.model(data_dict) 500 | # data_dict = release_cuda(data_dict) 501 | # output_dict = release_cuda(output_dict) 502 | ref_corr_points = output_dict['ref_corr_points'] 503 | src_corr_points = output_dict['src_corr_points'] 504 | corr_scores = output_dict['corr_scores'] 505 | # estimated_transform = output_dict['estimated_transform'] 506 | 507 | return torch.cat([src_corr_points, ref_corr_points], dim=-1).detach(), corr_scores.detach() #, estimated_transform.detach() 508 | 509 | if __name__ == '__main__': 510 | src_points = np.load('example/scan_pts_0.npy') 511 | ref_points = np.load('example/scan_pts_1.npy') 512 | data_dict = load_data(src_points, ref_points) 513 | corr_extractor = CorrExtractor() 514 | corr_extractor.extract_corr(src_points, ref_points) -------------------------------------------------------------------------------- /pipeline_SLAM/utils/fastmac/__pycache__/gconstructor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/fastmac/__pycache__/gconstructor.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/fastmac/__pycache__/gfilter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forrest-110/FastMAC/b6aa26f07abf08a30f8be387657bdcebf63cb37d/pipeline_SLAM/utils/fastmac/__pycache__/gfilter.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline_SLAM/utils/fastmac/gconstructor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def euclidean(a, b): 5 | return torch.norm(a - b, dim=-1, keepdim=True) 6 | 7 | def compatibility(a,b): 8 | assert(a.shape[-1]==6) 9 | assert(b.shape[-1]==6) 10 | n1=torch.norm(a[...,:3]-b[...,:3],dim=-1,keepdim=True) 11 | n2=torch.norm(a[...,3:]-b[...,3:],dim=-1,keepdim=True) 12 | return torch.abs(n1-n2) 13 | 14 | def Dmatrix(a,type): 15 | if type=="euclidean": 16 | return torch.cdist(a,a) 17 | 18 | elif type=="compatibility": 19 | a1=a[...,:3] 20 | a2=a[...,3:] 21 | return torch.abs(Dmatrix(a1,"euclidean")-Dmatrix(a2,"euclidean")) 22 | 23 | class GraphConstructor(nn.Module): 24 | def __init__(self,inlier_thresh,thresh,trainable,device="cuda",sigma=None,tau=None) -> None: 25 | ''' 26 | inlier thresh: KITTI 0.6, 3dmatch 0.1 27 | thresh: fpfh 0.9, fcgf 0.999 28 | ''' 29 | super().__init__() 30 | self.device=device 31 | self.inlier_thresh=nn.Parameter(torch.tensor(inlier_thresh,requires_grad=trainable,dtype=torch.float32)).to(device) 32 | self.thresh=nn.Parameter(torch.tensor(thresh,requires_grad=trainable,dtype=torch.float32)).to(device) 33 | if sigma is not None: 34 | self.sigma=nn.Parameter(torch.tensor(sigma,requires_grad=trainable,dtype=torch.float32)).to(device) 35 | else: 36 | self.sigma=self.inlier_thresh 37 | if tau is not None: 38 | self.tau=nn.Parameter(torch.tensor(tau,requires_grad=trainable,dtype=torch.float32)).to(device) 39 | else: 40 | self.tau=self.thresh 41 | def forward(self,points,mode,k1=2,k2=1): 42 | ''' 43 | points: B x M x 6 44 | output: B x M x M 45 | ''' 46 | if mode=="correspondence": 47 | points=points.to(self.device) 48 | dmatrix=Dmatrix(points,"compatibility") 49 | score=1-dmatrix**2/self.inlier_thresh**2 50 | # score=torch.exp(-dmatrix**2/self.inlier_thresh**2) 51 | score[scorebmk",score,score) 56 | elif mode=="pointcloud": 57 | ''' 58 | points: B x N x 3 59 | output: B x N x N 60 | ''' 61 | points=points.to(self.device) 62 | dmatrix=Dmatrix(points,"euclidean") 63 | 64 | # score=1-dmatrix**2/self.inlier_thresh**2 65 | score=torch.exp(-dmatrix**2/self.sigma**2) 66 | score[scorebmk",score,score) 71 | 72 | class GraphConstructorFor3DMatch(nn.Module): 73 | def __init__(self) -> None: 74 | super().__init__() 75 | pass 76 | def forward(self,correspondence, resolution, name, descriptor, inlier_thresh): 77 | self.device="cuda" 78 | correspondence=correspondence.to(self.device) 79 | dmatrix=Dmatrix(correspondence,"compatibility") 80 | 81 | if descriptor=="predator": 82 | score=1-dmatrix**2/inlier_thresh**2 83 | score[score<0.999]=0 84 | else: 85 | alpha_dis = 10 * resolution 86 | score = torch.exp(-dmatrix**2 / (2 * alpha_dis * alpha_dis)) 87 | if (name == "3dmatch" and descriptor == "fcgf"): 88 | score[score<0.999]=0 89 | elif (name == "3dmatch" and descriptor == "fpfh") : 90 | score[score<0.995]=0 91 | elif (descriptor == "spinnet" or descriptor == "d3feat") : 92 | score[score<0.85]=0 93 | #spinnet 5000 2500 1000 500 250 94 | # 0.99 0.99 0.95 0.9 0.85 95 | else: 96 | score[score<0.99]=0 #3dlomatch 0.99, 3dmatch fcgf 0.999 fpfh 0.995 97 | return score*torch.einsum("bmn,bnk->bmk",score,score) 98 | 99 | 100 | class Graph: 101 | def __init__(self): 102 | pass 103 | 104 | @staticmethod 105 | def construct_graph(pcloud, nb_neighbors): 106 | """ 107 | Construct a directed nearest neighbor graph on the input point cloud. 108 | 109 | Parameters 110 | ---------- 111 | pcloud : torch.Tensor 112 | Input point cloud. Size B x N x 3. 113 | nb_neighbors : int 114 | Number of nearest neighbors per point. 115 | 116 | Returns 117 | ------- 118 | graph : flot.models.graph.Graph 119 | Graph build on input point cloud containing the list of nearest 120 | neighbors (NN) for each point and all edge features (relative 121 | coordinates with NN). 122 | 123 | """ 124 | 125 | # Size 126 | nb_points = pcloud.shape[1] 127 | size_batch = pcloud.shape[0] 128 | 129 | # Distance between points 130 | distance_matrix = torch.sum(pcloud ** 2, -1, keepdim=True) 131 | distance_matrix = distance_matrix + distance_matrix.transpose(1, 2) 132 | distance_matrix = distance_matrix - 2 * torch.bmm( 133 | pcloud, pcloud.transpose(1, 2) 134 | ) 135 | # except self distance 136 | distance_matrix = distance_matrix + 1e6 * torch.eye(nb_points).unsqueeze(0).to(pcloud.device) 137 | 138 | # Find nearest neighbors 139 | neighbors = torch.argsort(distance_matrix, -1)[..., :nb_neighbors] 140 | 141 | # # direclty construct dense adjacency matrix 142 | # adj_dense=torch.zeros((nb_points,nb_points)).to(pcloud.device) 143 | # for i in range(nb_points): 144 | # for j in range(nb_neighbors): 145 | # adj_dense[i,neighbors[0,i,j]]=1 146 | 147 | 148 | # construct sparse adjacency matrix 149 | neighbors_flat = neighbors.reshape( -1) 150 | idx=torch.arange(nb_points).repeat(nb_neighbors,1).transpose(0,1).reshape(-1) 151 | idx=idx.to(pcloud.device) 152 | neighbors_flat=neighbors_flat.to(pcloud.device) 153 | i=torch.stack([idx,neighbors_flat],dim=0) 154 | v=torch.ones(i.shape[1]).to(pcloud.device) 155 | print(i) 156 | adj=torch.sparse_coo_tensor(i,v,(nb_points,nb_points)) 157 | 158 | # assert(torch.all(torch.eq(adj.to_dense(),adj_dense))) 159 | 160 | return adj 161 | 162 | 163 | 164 | 165 | if __name__ == "__main__": 166 | from plyfile import PlyData,PlyElement 167 | import numpy as np 168 | def write_ply(save_path,points,text=True): 169 | """ 170 | save_path : path to save: '/yy/XX.ply' 171 | pt: point_cloud: size (N,3) 172 | """ 173 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 174 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 175 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 176 | PlyData([el], text=text).write(save_path) 177 | def read_ply(filename): 178 | """ read XYZ point cloud from filename PLY file """ 179 | plydata = PlyData.read(filename) 180 | pc = plydata['vertex'].data 181 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 182 | return pc_array 183 | 184 | pc=read_ply('/data/plane2.ply') 185 | num_pts=pc.shape[0] 186 | sample_rate=0.01 187 | k=int(np.floor(sample_rate*num_pts)) 188 | pc_tensor=torch.from_numpy(pc).type(torch.FloatTensor).unsqueeze(0).cuda() 189 | g=GraphConstructor(0.6,0,False) 190 | adj=g(pc_tensor,"pointcloud") 191 | 192 | degree=torch.diag_embed(torch.sum(adj,dim=-1)) 193 | laplacian=(degree-adj).squeeze(0) 194 | low_shift=(torch.diag_embed(1/torch.sum(adj,dim=-1))*adj).squeeze(0) 195 | from gfilter import graphLowFilter,datasample 196 | scores=graphLowFilter(pc_tensor.squeeze(0),low_shift) 197 | idxs=datasample(k,False,scores) 198 | sampled_pc=pc_tensor.squeeze(0)[idxs,:] 199 | write_ply('/data/plane2_sampled_low.ply',sampled_pc.cpu().numpy()) 200 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/fastmac/gfilter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | torch.manual_seed(42) 4 | 5 | def graphFilter(points,adjacent_matrix,is_sparse): 6 | ''' 7 | points: n x 3 8 | adjacent_matrix: sparse matrix 9 | 10 | return: 11 | score: n x 1 12 | ''' 13 | if is_sparse: 14 | xyz=torch.sparse.mm(adjacent_matrix,points) 15 | else: 16 | xyz=torch.mm(adjacent_matrix,points) 17 | return torch.norm(xyz,dim=-1) 18 | 19 | def graphLowFilter(points,adjacent_matrix): 20 | ''' 21 | points: n x 3 22 | adjacent_matrix: sparse matrix 23 | 24 | return: 25 | score: n x 1 26 | ''' 27 | r=torch.matmul(torch.eye(points.shape[0]).to(adjacent_matrix.device)+adjacent_matrix, points) 28 | return torch.norm(r,p=2,dim=-1) 29 | 30 | def graphAllPassFilter(points): 31 | ''' 32 | points: n x 3 33 | adjacent_matrix: sparse matrix 34 | 35 | return: 36 | score: n x 1 37 | ''' 38 | return torch.norm(points,p=2,dim=-1) 39 | 40 | 41 | def datasample(k,replace,weights): 42 | ''' 43 | idxs: n 44 | k: int 45 | replace: bool 46 | weights: n 47 | ''' 48 | return torch.multinomial(weights,k,replacement=replace) 49 | 50 | 51 | if __name__ == "__main__": 52 | from plyfile import PlyData,PlyElement 53 | import numpy as np 54 | def write_ply(save_path,points,text=True): 55 | """ 56 | save_path : path to save: '/yy/XX.ply' 57 | pt: point_cloud: size (N,3) 58 | """ 59 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 60 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 61 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 62 | PlyData([el], text=text).write(save_path) 63 | def read_ply(filename): 64 | """ read XYZ point cloud from filename PLY file """ 65 | plydata = PlyData.read(filename) 66 | pc = plydata['vertex'].data 67 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 68 | return pc_array 69 | 70 | pc=read_ply('/data/cubic.ply') 71 | num_pts=pc.shape[0] 72 | sample_rate=0.25 73 | k=np.floor(sample_rate*num_pts) 74 | print(k) 75 | 76 | -------------------------------------------------------------------------------- /pipeline_SLAM/utils/registration.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | import igraph 5 | import os 6 | def integrate_trans(R, t): 7 | """ 8 | Integrate SE3 transformations from R and t, support torch.Tensor and np.ndarry. 9 | Input 10 | - R: [3, 3] or [bs, 3, 3], rotation matrix 11 | - t: [3, 1] or [bs, 3, 1], translation matrix 12 | Output 13 | - trans: [4, 4] or [bs, 4, 4], SE3 transformation matrix 14 | """ 15 | if len(R.shape) == 3: 16 | if isinstance(R, torch.Tensor): 17 | trans = torch.eye(4)[None].repeat(R.shape[0], 1, 1).to(R.device) 18 | else: 19 | trans = np.eye(4)[None] 20 | trans[:, :3, :3] = R 21 | trans[:, :3, 3:4] = t.view([-1, 3, 1]) 22 | else: 23 | if isinstance(R, torch.Tensor): 24 | trans = torch.eye(4).to(R.device) 25 | else: 26 | trans = np.eye(4) 27 | trans[:3, :3] = R 28 | trans[:3, 3:4] = t 29 | return trans 30 | 31 | 32 | def transform(pts, trans): 33 | if len(pts.shape) == 3: 34 | trans_pts = torch.einsum('bnm,bmk->bnk', trans[:, :3, :3], 35 | pts.permute(0, 2, 1)) + trans[:, :3, 3:4] 36 | return trans_pts.permute(0, 2, 1) 37 | else: 38 | trans_pts = torch.einsum('nm,mk->nk', trans[:3, :3], 39 | pts.T) + trans[:3, 3:4] 40 | return trans_pts.T 41 | def rigid_transform_3d(A, B, weights=None, weight_threshold=0): 42 | """ 43 | Input: 44 | - A: [bs, num_corr, 3], source point cloud 45 | - B: [bs, num_corr, 3], target point cloud 46 | - weights: [bs, num_corr] weight for each correspondence 47 | - weight_threshold: float, clips points with weight below threshold 48 | Output: 49 | - R, t 50 | """ 51 | bs = A.shape[0] 52 | if weights is None: 53 | weights = torch.ones_like(A[:, :, 0]) 54 | weights[weights < weight_threshold] = 0 55 | # weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6) 56 | 57 | # find mean of point cloud 58 | centroid_A = torch.sum(A * weights[:, :, None], dim=1, keepdim=True) / ( 59 | torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6) 60 | centroid_B = torch.sum(B * weights[:, :, None], dim=1, keepdim=True) / ( 61 | torch.sum(weights, dim=1, keepdim=True)[:, :, None] + 1e-6) 62 | 63 | # subtract mean 64 | Am = A - centroid_A 65 | Bm = B - centroid_B 66 | 67 | # construct weight covariance matrix 68 | Weight = torch.diag_embed(weights) # 升维度,然后变为对角阵 69 | H = Am.permute(0, 2, 1) @ Weight @ Bm # permute : tensor中的每一块做转置 70 | 71 | # find rotation 72 | U, S, Vt = torch.svd(H.cpu()) 73 | U, S, Vt = U.to(weights.device), S.to(weights.device), Vt.to(weights.device) 74 | delta_UV = torch.det(Vt @ U.permute(0, 2, 1)) 75 | eye = torch.eye(3)[None, :, :].repeat(bs, 1, 1).to(A.device) 76 | eye[:, -1, -1] = delta_UV 77 | R = Vt @ eye @ U.permute(0, 2, 1) 78 | t = centroid_B.permute(0, 2, 1) - R @ centroid_A.permute(0, 2, 1) 79 | # warp_A = transform(A, integrate_trans(R,t)) 80 | # RMSE = torch.sum( (warp_A - B) ** 2, dim=-1).mean() 81 | return integrate_trans(R, t) 82 | def post_refinement(initial_trans, src_kpts, tgt_kpts, iters, weights=None): 83 | inlier_threshold = 0.1 84 | pre_inlier_count = 0 85 | for i in range(iters): 86 | pred_tgt = transform(src_kpts, initial_trans) 87 | L2_dis = torch.norm(pred_tgt - tgt_kpts, dim=-1) 88 | pred_inlier = (L2_dis < inlier_threshold)[0] 89 | inlier_count = torch.sum(pred_inlier) 90 | if inlier_count <= pre_inlier_count: 91 | break 92 | pre_inlier_count = inlier_count 93 | initial_trans = rigid_transform_3d( 94 | A=src_kpts[:, pred_inlier, :], 95 | B=tgt_kpts[:, pred_inlier, :], 96 | weights=1 / (1 + (L2_dis / inlier_threshold) ** 2)[:, pred_inlier] 97 | ) 98 | return initial_trans 99 | class Registrator: 100 | def __init__(self, device='cuda') -> None: 101 | self.device = device 102 | 103 | def graph_construction(self, src_pts, tgt_pts): 104 | src_dist = ((src_pts[:, None, :] - src_pts[None, :, :]) ** 2).sum(-1) ** 0.5 105 | tgt_dist = ((tgt_pts[:, None, :] - tgt_pts[None, :, :]) ** 2).sum(-1) ** 0.5 106 | cross_dis = torch.abs(src_dist - tgt_dist) 107 | FCG = torch.clamp(1 - cross_dis ** 2 / 0.6 ** 2, min=0) 108 | FCG = FCG - torch.diag_embed(torch.diag(FCG)) 109 | FCG[FCG < 0.999] = 0 110 | SCG = torch.matmul(FCG, FCG) * FCG 111 | return SCG 112 | 113 | def registration(self, src_pts, tgt_pts): 114 | ''' 115 | src_pts: N x 3, torch.Tensor 116 | tgt_pts: N x 3, torch.Tensor 117 | ''' 118 | num_pts = src_pts.shape[0] 119 | src_pts = src_pts.to(self.device) 120 | tgt_pts = tgt_pts.to(self.device) 121 | 122 | SCG = self.graph_construction(src_pts, tgt_pts).cpu().numpy() 123 | graph = igraph.Graph.Adjacency((SCG > 0).tolist()) 124 | graph.es['weight'] = SCG[SCG.nonzero()] 125 | graph.vs['label'] = range(0, num_pts) 126 | graph.to_undirected() 127 | macs = graph.maximal_cliques(min=3) 128 | 129 | clique_weight = np.zeros(len(macs), dtype=float) 130 | for ind in range(len(macs)): 131 | mac = list(macs[ind]) 132 | if len(mac) >= 3: 133 | for i in range(len(mac)): 134 | for j in range(i + 1, len(mac)): 135 | clique_weight[ind] = clique_weight[ind] + SCG[mac[i], mac[j]] 136 | 137 | 138 | clique_ind_of_node = np.ones(num_pts, dtype=int) * -1 139 | max_clique_weight = np.zeros(num_pts, dtype=float) 140 | max_size = 3 141 | for ind in range(len(macs)): 142 | mac = list(macs[ind]) 143 | weight = clique_weight[ind] 144 | if weight > 0: 145 | for i in range(len(mac)): 146 | if weight > max_clique_weight[mac[i]]: 147 | max_clique_weight[mac[i]] = weight 148 | clique_ind_of_node[mac[i]] = ind 149 | max_size = len(mac) > max_size and len(mac) or max_size 150 | 151 | filtered_clique_ind = list(set(clique_ind_of_node)) 152 | filtered_clique_ind.remove(-1) 153 | # print(f'After filtered: %d' % len(filtered_clique_ind)) 154 | 155 | 156 | group = [] 157 | for s in range(3, max_size + 1): 158 | group.append([]) 159 | for ind in filtered_clique_ind: 160 | mac = list(macs[ind]) 161 | group[len(mac) - 3].append(ind) 162 | 163 | tensor_list_A = [] 164 | tensor_list_B = [] 165 | for i in range(len(group)): 166 | if len(group[i]) == 0: 167 | continue 168 | batch_A = src_pts[list(macs[group[i][0]])][None] 169 | batch_B = tgt_pts[list(macs[group[i][0]])][None] 170 | if len(group) == 1: 171 | continue 172 | for j in range(1, len(group[i])): 173 | mac = list(macs[group[i][j]]) 174 | src_corr = src_pts[mac][None] 175 | tgt_corr = tgt_pts[mac][None] 176 | batch_A = torch.cat((batch_A, src_corr), 0) 177 | batch_B = torch.cat((batch_B, tgt_corr), 0) 178 | tensor_list_A.append(batch_A) 179 | tensor_list_B.append(batch_B) 180 | 181 | inlier_threshold = 0.6 182 | max_score = 0 183 | final_trans = torch.eye(4).to(self.device) 184 | for i in range(len(tensor_list_A)): 185 | trans = rigid_transform_3d(tensor_list_A[i], tensor_list_B[i], None, 0) 186 | pred_tgt = transform(src_pts[None], trans) # [bs, num_corr, 3] 187 | L2_dis = torch.norm(pred_tgt - tgt_pts[None], dim=-1) # [bs, num_corr] 188 | MAE_score = torch.div(torch.sub(inlier_threshold, L2_dis), inlier_threshold) 189 | MAE_score = torch.sum(MAE_score * (L2_dis < inlier_threshold), dim=-1) 190 | max_batch_score_ind = MAE_score.argmax(dim=-1) 191 | max_batch_score = MAE_score[max_batch_score_ind] 192 | if max_batch_score > max_score: 193 | max_score = max_batch_score 194 | final_trans = trans[max_batch_score_ind] 195 | 196 | final_trans = post_refinement(initial_trans=final_trans[None], src_kpts=src_pts[None], tgt_kpts=tgt_pts[None], iters=20) 197 | return final_trans[0] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | tqdm 4 | open3d==0.10.0 -------------------------------------------------------------------------------- /sota.py: -------------------------------------------------------------------------------- 1 | from data import processedKITTI 2 | from tqdm import tqdm 3 | import torch 4 | from gconstructor import GraphConstructor 5 | from gfilter import graphFilter,datasample 6 | import time 7 | import numpy as np 8 | import os 9 | torch.manual_seed(42) 10 | 11 | 12 | def normalize(x): 13 | # transform x to [0,1] 14 | x=x-x.min() 15 | x=x/x.max() 16 | return x 17 | 18 | def Config(): 19 | config={ 20 | "num_points":5000, 21 | "data_dir":'/data/Processed_KITTI/correspondence_fcgf/', 22 | "filename":'fcgf@corr.txt', 23 | 'gtname':'fcgf@gtmat.txt', 24 | 'labelname':'fcgf@gtlabel.txt', 25 | 'batch_size':1, 26 | 'inlier_thresh':0.6, 27 | 'thresh':0.999, 28 | 'sigma':0.6, 29 | 'tau':0., 30 | 'device':'cuda', 31 | 'mode':'graph', 32 | 'ratio':0.01, 33 | 'outpath':'', 34 | } 35 | return config 36 | 37 | def main(): 38 | config=Config() 39 | device=config["device"] 40 | mode=config["mode"] 41 | sample_ratio=config["ratio"] 42 | dataset=processedKITTI(config["num_points"],config["data_dir"],config["filename"],config["gtname"],config["labelname"]) 43 | trainloader = torch.utils.data.DataLoader( 44 | dataset, batch_size=config["batch_size"], shuffle=False, num_workers=0 45 | ) 46 | if mode == "graph": 47 | gc=GraphConstructor(config["inlier_thresh"],config["thresh"],trainable=False,sigma=config["sigma"],tau=config["tau"]) 48 | print("Start") 49 | average_time=0 50 | for i, data_ in enumerate(tqdm(trainloader)): 51 | time_start=time.time() 52 | pts,gt,lb=data_ 53 | pts=pts.to(device) 54 | gt=gt.to(device) 55 | lb=lb.to(device) 56 | corr_graph=gc(pts,mode="correspondence") 57 | degree_signal=torch.sum(corr_graph,dim=-1) 58 | 59 | 60 | 61 | corr_laplacian=(torch.diag_embed(degree_signal)-corr_graph).squeeze(0) 62 | corr_scores=graphFilter(degree_signal.transpose(0,1),corr_laplacian,is_sparse=False) 63 | 64 | 65 | total_scores=corr_scores 66 | k=int(config["num_points"]*sample_ratio) 67 | idxs=datasample(k,False,total_scores) 68 | 69 | time_end=time.time() 70 | 71 | average_time+=time_end-time_start 72 | 73 | samples=pts.squeeze(0)[idxs,:] 74 | lb=lb.squeeze(0)[idxs].long() 75 | samples=samples.cpu().numpy() 76 | 77 | outdir=os.path.join(config["outpath"],str(i)) 78 | if not os.path.exists((outdir)): 79 | os.makedirs((outdir)) 80 | 81 | np.savetxt(outdir+'/'+config["filename"],samples) 82 | np.savetxt(outdir+'/'+config["gtname"],gt.squeeze(0).cpu().numpy()) 83 | np.savetxt(outdir+'/'+config["labelname"],lb.cpu().numpy().astype(int),fmt="%d") 84 | 85 | print("Average time: ",average_time/len(trainloader)) 86 | else: 87 | raise NotImplementedError 88 | 89 | 90 | if __name__ == "__main__": 91 | main() 92 | --------------------------------------------------------------------------------