├── .gitignore ├── KITTI ├── Test │ ├── test_kitti.py │ └── test_kitti.txt └── Train │ ├── dataloader.py │ ├── dataset.py │ ├── train.py │ └── trainer.py ├── LICENSE ├── README.md ├── ThreeDMatch ├── Test │ ├── 3dmatch │ │ ├── evaluate.m │ │ └── external │ │ │ ├── ElasticReconstruction │ │ │ ├── mrDrawTrajectory.m │ │ │ ├── mrEvaluateRegistration.m │ │ │ ├── mrEvaluateTrajectory.m │ │ │ ├── mrLoadInfo.m │ │ │ ├── mrLoadLog.m │ │ │ ├── mrMatchDepthColor.m │ │ │ ├── mrWriteInfo.m │ │ │ └── mrWriteLog.m │ │ │ └── npy-matlab │ │ │ ├── constructNPYheader.m │ │ │ ├── datToNPY.m │ │ │ ├── readNPY.m │ │ │ ├── readNPYheader.m │ │ │ └── writeNPY.m │ ├── evaluate.py │ ├── gt_result │ │ ├── 7-scenes-redkitchen-evaluation │ │ │ └── gt.info │ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1-evaluation │ │ │ └── gt.info │ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30-evaluation │ │ │ └── gt.info │ │ ├── sun3d-hotel_uc-scan3-evaluation │ │ │ └── gt.info │ │ ├── sun3d-hotel_umd-maryland_hotel1-evaluation │ │ │ └── gt.info │ │ ├── sun3d-hotel_umd-maryland_hotel3-evaluation │ │ │ └── gt.info │ │ ├── sun3d-mit_76_studyroom-76-1studyroom2-evaluation │ │ │ └── gt.info │ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika-evaluation │ │ │ └── gt.info │ ├── preparation.py │ └── tools.py └── Train │ ├── dataloader.py │ ├── dataset.py │ ├── train.py │ └── trainer.py ├── figs ├── Fig1.png ├── Fig2.png ├── Fig3.png ├── Fig4.png ├── Fig5.png ├── Table1.png ├── Table2.png ├── Table3.png ├── Table4.png ├── Table5.png ├── Table6.png └── Table7.png ├── generalization ├── KITTI-to-ThreeDMatch │ ├── evaluate.py │ └── preparation.py ├── ThreeDMatch-to-ETH │ ├── evaluate.py │ └── preparation.py └── ThreeDMatch-to-KITTI │ ├── test.py │ └── test_kitti.txt ├── loss └── desc_loss.py ├── network ├── SpinNet.py └── ThreeDCCN.py ├── pre-trained_models ├── 3DMatch_best.pkl └── KITTI_best.pkl └── script ├── cal_overlap.py ├── common.py ├── download.sh ├── fuse_fragments_3DMatch.py └── io.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /KITTI/Test/test_kitti.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 4 | import logging 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | import torch.nn as nn 9 | import glob 10 | import time 11 | import gc 12 | import shutil 13 | import pointnet2_ops.pointnet2_utils as pnt2 14 | import copy 15 | import importlib 16 | import sys 17 | 18 | sys.path.append('../../') 19 | import script.common as cm 20 | 21 | kitti_icp_cache = {} 22 | kitti_cache = {} 23 | 24 | 25 | class Timer(object): 26 | """A simple timer.""" 27 | 28 | def __init__(self, binary_fn=None, init_val=0): 29 | self.total_time = 0. 30 | self.calls = 0 31 | self.start_time = 0. 32 | self.diff = 0. 33 | self.binary_fn = binary_fn 34 | self.tmp = init_val 35 | 36 | def reset(self): 37 | self.total_time = 0 38 | self.calls = 0 39 | self.start_time = 0 40 | self.diff = 0 41 | 42 | @property 43 | def avg(self): 44 | return self.total_time / self.calls 45 | 46 | def tic(self): 47 | # using time.time instead of time.clock because time time.clock 48 | # does not normalize for multithreading 49 | self.start_time = time.time() 50 | 51 | def toc(self, average=True): 52 | self.diff = time.time() - self.start_time 53 | self.total_time += self.diff 54 | self.calls += 1 55 | if self.binary_fn: 56 | self.tmp = self.binary_fn(self.tmp, self.diff) 57 | if average: 58 | return self.avg 59 | else: 60 | return self.diff 61 | 62 | 63 | class AverageMeter(object): 64 | """Computes and stores the average and current value""" 65 | 66 | def __init__(self): 67 | self.reset() 68 | 69 | def reset(self): 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0.0 73 | self.sq_sum = 0.0 74 | self.count = 0 75 | 76 | def update(self, val, n=1): 77 | self.val = val 78 | self.sum += val * n 79 | self.count += n 80 | self.avg = self.sum / self.count 81 | self.sq_sum += val ** 2 * n 82 | self.var = self.sq_sum / self.count - self.avg ** 2 83 | 84 | 85 | def get_desc(descpath, filename): 86 | desc = np.load(os.path.join(descpath, filename + '.npy')) 87 | return desc 88 | 89 | 90 | def get_keypts(keypts_path, filename): 91 | keypts = np.load(os.path.join(keypts_path, filename + '.npy')) 92 | return keypts 93 | 94 | 95 | def make_open3d_feature(data, dim, npts): 96 | feature = o3d.pipelines.registration.Feature() 97 | feature.resize(dim, npts) 98 | feature.data = data.astype('d').transpose() 99 | return feature 100 | 101 | 102 | def make_open3d_point_cloud(xyz, color=None): 103 | pcd = o3d.geometry.PointCloud() 104 | pcd.points = o3d.utility.Vector3dVector(xyz) 105 | if color is not None: 106 | pcd.paint_uniform_color(color) 107 | return pcd 108 | 109 | 110 | def get_matching_indices(source, target, trans, search_voxel_size, K=None): 111 | source_copy = copy.deepcopy(source) 112 | target_copy = copy.deepcopy(target) 113 | source_copy.transform(trans) 114 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 115 | 116 | match_inds = [] 117 | for i, point in enumerate(source_copy.points): 118 | [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) 119 | if K is not None: 120 | idx = idx[:K] 121 | for j in idx: 122 | match_inds.append((i, j)) 123 | return match_inds 124 | 125 | 126 | class KITTI(object): 127 | DATA_FILES = { 128 | 'train': 'train_kitti.txt', 129 | 'val': 'val_kitti.txt', 130 | 'test': 'test_kitti.txt' 131 | } 132 | """ 133 | Given point cloud fragments and corresponding pose in '{root}'. 134 | 1. Save the aligned point cloud pts in '{savepath}/3DMatch_{downsample}_points.pkl' 135 | 2. Calculate the overlap ratio and save in '{savepath}/3DMatch_{downsample}_overlap.pkl' 136 | 3. Save the ids of anchor keypoints and positive keypoints in '{savepath}/3DMatch_{downsample}_keypts.pkl' 137 | """ 138 | 139 | def __init__(self, root, descpath, icp_path, split, model, num_points_per_patch, use_random_points): 140 | self.root = root 141 | self.descpath = descpath 142 | self.split = split 143 | self.num_points_per_patch = num_points_per_patch 144 | self.icp_path = icp_path 145 | self.use_random_points = use_random_points 146 | self.model = model 147 | if not os.path.exists(self.icp_path): 148 | os.makedirs(self.icp_path) 149 | 150 | # list: anc & pos 151 | self.patches = [] 152 | self.pose = [] 153 | # Initiate containers 154 | self.files = {'train': [], 'val': [], 'test': []} 155 | 156 | self.prepare_kitti_ply(split=self.split) 157 | 158 | def prepare_kitti_ply(self, split='train'): 159 | subset_names = open(self.DATA_FILES[split]).read().split() 160 | for dirname in subset_names: 161 | drive_id = int(dirname) 162 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id) 163 | assert len(fnames) > 0, f"Make sure that the path {self.root} has data {dirname}" 164 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 165 | 166 | all_odo = self.get_video_odometry(drive_id, return_all=True) 167 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 168 | Ts = all_pos[:, :3, 3] 169 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3)) ** 2 170 | pdist = np.sqrt(pdist.sum(-1)) 171 | more_than_10 = pdist > 10 172 | curr_time = inames[0] 173 | while curr_time in inames: 174 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0] 175 | if len(next_time) == 0: 176 | curr_time += 1 177 | else: 178 | next_time = next_time[0] + curr_time - 1 179 | 180 | if next_time in inames: 181 | self.files[split].append((drive_id, curr_time, next_time)) 182 | curr_time = next_time + 1 183 | # Remove problematic sequence 184 | for item in [ 185 | (8, 15, 58), 186 | ]: 187 | if item in self.files[split]: 188 | self.files[split].pop(self.files[split].index(item)) 189 | 190 | if split == 'train': 191 | self.num_train = len(self.files[split]) 192 | print("Num_train", self.num_train) 193 | elif split == 'val': 194 | self.num_val = len(self.files[split]) 195 | print("Num_val", self.num_val) 196 | elif split == 'test': 197 | self.num_test = len(self.files[split]) 198 | print("Num_test", self.num_test) 199 | 200 | for idx in range(len(self.files[split])): 201 | drive = self.files[split][idx][0] 202 | t0, t1 = self.files[split][idx][1], self.files[split][idx][2] 203 | all_odometry = self.get_video_odometry(drive, [t0, t1]) 204 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry] 205 | fname0 = self._get_velodyne_fn(drive, t0) 206 | fname1 = self._get_velodyne_fn(drive, t1) 207 | 208 | # XYZ and reflectance 209 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4) 210 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4) 211 | 212 | xyz0 = xyzr0[:, :3] 213 | xyz1 = xyzr1[:, :3] 214 | 215 | key = '%d_%d_%d' % (drive, t0, t1) 216 | filename = self.icp_path + '/' + key + '.npy' 217 | if key not in kitti_icp_cache: 218 | if not os.path.exists(filename): 219 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T) 220 | @ np.linalg.inv(self.velo2cam)).T 221 | xyz0_t = self.apply_transform(xyz0, M) 222 | pcd0 = make_open3d_point_cloud(xyz0_t, [0.5, 0.5, 0.5]) 223 | pcd1 = make_open3d_point_cloud(xyz1, [0, 1, 0]) 224 | reg = o3d.pipelines.registration.registration_icp(pcd0, pcd1, 0.10, np.eye(4), 225 | o3d.pipelines.registration.TransformationEstimationPointToPoint(), 226 | o3d.pipelines.registration.ICPConvergenceCriteria( 227 | max_iteration=400)) 228 | pcd0.transform(reg.transformation) 229 | M2 = M @ reg.transformation 230 | # write to a file 231 | np.save(filename, M2) 232 | else: 233 | M2 = np.load(filename) 234 | kitti_icp_cache[key] = M2 235 | else: 236 | M2 = kitti_icp_cache[key] 237 | trans = M2 238 | # extract patches for anc&pos 239 | np.random.shuffle(xyz0) 240 | np.random.shuffle(xyz1) 241 | 242 | if is_rotate_dataset: 243 | # Add arbitrary rotation 244 | # rotate terminal frament with an arbitrary angle around the z-axis 245 | angles_3d = np.random.rand(3) * np.pi * 2 246 | R = cm.angles2rotation_matrix(angles_3d) 247 | T = np.identity(4) 248 | T[:3, :3] = R 249 | pcd1 = make_open3d_point_cloud(xyz1) 250 | pcd1.transform(T) 251 | xyz1 = np.array(pcd1.points) 252 | all_trans_matrix[key] = T 253 | 254 | if not os.path.exists(self.descpath + str(drive)): 255 | os.makedirs(self.descpath + str(drive)) 256 | if self.use_random_points: 257 | num_keypts = 5000 258 | step_size = 50 259 | desc_len = 32 260 | model = self.model.cuda() 261 | # calc t0 descriptors 262 | desc_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".desc.bin.npy") 263 | keypts_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".keypts.npy") 264 | if not os.path.exists(desc_t0_path): 265 | keypoints_id = np.random.choice(xyz0.shape[0], num_keypts) 266 | keypts = xyz0[keypoints_id] 267 | np.save(keypts_t0_path, keypts.astype(np.float32)) 268 | local_patches = self.select_patches(xyz0, keypts, vicinity=vicinity, 269 | num_points_per_patch=self.num_points_per_patch) 270 | B = local_patches.shape[0] 271 | # cuda out of memry 272 | desc_list = [] 273 | start_time = time.time() 274 | iter_num = np.int(np.ceil(B / step_size)) 275 | for k in range(iter_num): 276 | if k == iter_num - 1: 277 | desc = model(local_patches[k * step_size:, :, :]) 278 | else: 279 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :]) 280 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 281 | del desc 282 | step_time = time.time() - start_time 283 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 284 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 285 | np.save(desc_t0_path, desc.astype(np.float32)) 286 | else: 287 | print(f"{desc_t0_path} already exists.") 288 | 289 | # calc t1 descriptors 290 | desc_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".desc.bin.npy") 291 | keypts_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".keypts.npy") 292 | if not os.path.exists(desc_t1_path): 293 | keypoints_id = np.random.choice(xyz1.shape[0], num_keypts) 294 | keypts = xyz1[keypoints_id] 295 | np.save(keypts_t1_path, keypts.astype(np.float32)) 296 | local_patches = self.select_patches(xyz1, keypts, vicinity=vicinity, 297 | num_points_per_patch=self.num_points_per_patch) 298 | B = local_patches.shape[0] 299 | # calculate descriptors 300 | desc_list = [] 301 | start_time = time.time() 302 | iter_num = np.int(np.ceil(B / step_size)) 303 | for k in range(iter_num): 304 | if k == iter_num - 1: 305 | desc = model(local_patches[k * step_size:, :, :]) 306 | else: 307 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :]) 308 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 309 | del desc 310 | step_time = time.time() - start_time 311 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 312 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 313 | np.save(desc_t1_path, desc.astype(np.float32)) 314 | else: 315 | print(f"{desc_t1_path} already exists.") 316 | else: 317 | num_keypts = 512 318 | 319 | def select_patches(self, pts, refer_pts, vicinity, num_points_per_patch=1024): 320 | gc.collect() 321 | pts = torch.FloatTensor(pts).cuda().unsqueeze(0) 322 | refer_pts = torch.FloatTensor(refer_pts).cuda().unsqueeze(0) 323 | group_idx = pnt2.ball_query(vicinity, num_points_per_patch, pts, refer_pts) 324 | pts_trans = pts.transpose(1, 2).contiguous() 325 | new_points = pnt2.grouping_operation( 326 | pts_trans, group_idx 327 | ) 328 | new_points = new_points.permute([0, 2, 3, 1]) 329 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, num_points_per_patch) 330 | mask = (group_idx == mask).float() 331 | mask[:, :, 0] = 0 332 | mask[:, :, num_points_per_patch - 1] = 1 333 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3]) 334 | new_pts = refer_pts.unsqueeze(2).repeat([1, 1, num_points_per_patch, 1]) 335 | local_patches = new_points * (1 - mask).float() + new_pts * mask.float() 336 | local_patches = local_patches.squeeze(0) 337 | del mask 338 | del new_points 339 | del group_idx 340 | del new_pts 341 | del pts 342 | del pts_trans 343 | 344 | return local_patches 345 | 346 | def apply_transform(self, pts, trans): 347 | R = trans[:3, :3] 348 | T = trans[:3, 3] 349 | pts = pts @ R.T + T 350 | return pts 351 | 352 | @property 353 | def velo2cam(self): 354 | try: 355 | velo2cam = self._velo2cam 356 | except AttributeError: 357 | R = np.array([ 358 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 359 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 360 | ]).reshape(3, 3) 361 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 362 | velo2cam = np.hstack([R, T]) 363 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 364 | return self._velo2cam 365 | 366 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False): 367 | data_path = self.root + '/poses/%02d.txt' % drive 368 | if data_path not in kitti_cache: 369 | kitti_cache[data_path] = np.genfromtxt(data_path) 370 | if return_all: 371 | return kitti_cache[data_path] 372 | else: 373 | return kitti_cache[data_path][indices] 374 | 375 | def odometry_to_positions(self, odometry): 376 | T_w_cam0 = odometry.reshape(3, 4) 377 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 378 | return T_w_cam0 379 | 380 | def _get_velodyne_fn(self, drive, t): 381 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t) 382 | return fname 383 | 384 | 385 | if __name__ == '__main__': 386 | is_rotate_dataset = False 387 | all_trans_matrix = {} 388 | experiment_id = time.strftime('%m%d%H%M') # '11210201'# 389 | model_str = experiment_id 390 | reg_timer = Timer() 391 | success_meter, rte_meter, rre_meter = AverageMeter(), AverageMeter(), AverageMeter() 392 | ch = logging.StreamHandler(sys.stdout) 393 | logging.getLogger().setLevel(logging.INFO) 394 | logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 395 | 396 | # dynamically load the model from snapshot 397 | module_file_path = '../model.py' 398 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), module_file_path) 399 | module_name = '' 400 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) 401 | module = importlib.util.module_from_spec(module_spec) 402 | module_spec.loader.exec_module(module) 403 | 404 | vicinity = 2.0 405 | model = module.Descriptor_Net(vicinity, 9, 60, 30, 0.3, 30, 'KITTI') 406 | model = nn.DataParallel(model, device_ids=[0]) 407 | model.load_state_dict(torch.load('../../pre-trained_models/KITTI_best.pkl')) 408 | 409 | test_data = KITTI(root='../../data/KITTI/dataset', 410 | descpath=f'SpinNet_desc_{model_str}/', 411 | icp_path='../../data/KITTI/icp', 412 | split='test', 413 | model=model, 414 | num_points_per_patch=2048, 415 | use_random_points=True 416 | ) 417 | 418 | files = test_data.files[test_data.split] 419 | for idx in range(len(files)): 420 | drive = files[idx][0] 421 | t0, t1 = files[idx][1], files[idx][2] 422 | key = '%d_%d_%d' % (drive, t0, t1) 423 | filename = test_data.icp_path + '/' + key + '.npy' 424 | T_gth = kitti_icp_cache[key] 425 | if is_rotate_dataset: 426 | T_gth = np.matmul(all_trans_matrix[key], T_gth) 427 | 428 | descpath = os.path.join(test_data.descpath, str(drive)) 429 | fname0 = test_data._get_velodyne_fn(drive, t0) 430 | fname1 = test_data._get_velodyne_fn(drive, t1) 431 | # XYZ and reflectance 432 | xyz0 = get_keypts(descpath, f"cloud_bin_" + str(t0) + f".keypts") 433 | xyz1 = get_keypts(descpath, f"cloud_bin_" + str(t1) + f".keypts") 434 | pcd0 = make_open3d_point_cloud(xyz0) 435 | pcd1 = make_open3d_point_cloud(xyz1) 436 | 437 | source_desc = get_desc(descpath, f"cloud_bin_" + str(t0) + f".desc.bin") 438 | target_desc = get_desc(descpath, f"cloud_bin_" + str(t1) + f".desc.bin") 439 | feat0 = make_open3d_feature(source_desc, 32, source_desc.shape[0]) 440 | feat1 = make_open3d_feature(target_desc, 32, target_desc.shape[0]) 441 | 442 | reg_timer.tic() 443 | distance_threshold = 0.3 444 | ransac_result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( 445 | pcd0, pcd1, feat0, feat1, distance_threshold, 446 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 4, [ 447 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 448 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold) 449 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) 450 | T_ransac = torch.from_numpy(ransac_result.transformation.astype(np.float32)) 451 | reg_timer.toc() 452 | 453 | # Translation error 454 | rte = np.linalg.norm(T_ransac[:3, 3] - T_gth[:3, 3]) 455 | rre = np.arccos((np.trace(T_ransac[:3, :3].t() @ T_gth[:3, :3]) - 1) / 2) 456 | 457 | if rte < 2: 458 | rte_meter.update(rte) 459 | 460 | if not np.isnan(rre) and rre < np.pi / 180 * 5: 461 | rre_meter.update(rre * 180 / np.pi) 462 | 463 | if rte < 2 and not np.isnan(rre) and rre < np.pi / 180 * 5: 464 | success_meter.update(1) 465 | else: 466 | success_meter.update(0) 467 | logging.info(f"Failed with RTE: {rte}, RRE: {rre}") 468 | 469 | if (idx + 1) % 10 == 0: 470 | logging.info( 471 | f" RRE: {rre_meter.avg}, Success: {success_meter.sum} / {success_meter.count}" + 472 | f" ({success_meter.avg * 100} %)" 473 | ) 474 | reg_timer.reset() 475 | 476 | logging.info( 477 | f"RTE: {rte_meter.avg}, var: {rte_meter.var}," + 478 | f" RRE: {rre_meter.avg}, var: {rre_meter.var}, Success: {success_meter.sum} " + 479 | f"/ {success_meter.count} ({success_meter.avg * 100} %)" 480 | ) 481 | -------------------------------------------------------------------------------- /KITTI/Test/test_kitti.txt: -------------------------------------------------------------------------------- 1 | 8 2 | 9 3 | 10 4 | -------------------------------------------------------------------------------- /KITTI/Train/dataloader.py: -------------------------------------------------------------------------------- 1 | import time 2 | from KITTI.Train.dataset import KITTIDataset 3 | import torch 4 | 5 | 6 | def get_dataloader(root, split, batch_size=1, num_workers=0, shuffle=True, drop_last=True): 7 | dataset = KITTIDataset( 8 | root=root, 9 | split=split, 10 | batch_size=batch_size, 11 | shuffle=shuffle, 12 | drop_last=drop_last 13 | ) 14 | dataset.initial() 15 | dataloader = torch.utils.data.DataLoader( 16 | dataset=dataset, 17 | batch_size=batch_size, 18 | num_workers=0, 19 | drop_last=drop_last 20 | ) 21 | 22 | return dataloader 23 | -------------------------------------------------------------------------------- /KITTI/Train/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as Data 2 | import os 3 | import random 4 | import glob 5 | import pickle 6 | import open3d as o3d 7 | import numpy as np 8 | 9 | 10 | class KITTIDataset(Data.Dataset): 11 | def __init__(self, root, split, batch_size, shuffle, drop_last): 12 | """ 13 | Create ThreeDMatchDataset to read multiple training files 14 | Args: 15 | root: the path to the dataset file 16 | shuffle: whether the data need to shuffle 17 | """ 18 | self.patches_path = os.path.join(root, split) 19 | self.split = split 20 | # Get name of all training pkl files 21 | training_data_files = glob.glob(self.patches_path + '/*.pkl') 22 | ids = [file.split("/")[-1] for file in training_data_files] 23 | ids = sorted(ids, key=lambda x: int(x.split("_")[-1].split(".")[0])) 24 | ids = [file for file in ids if file.split("_")[1] == 'anc&pos'] 25 | self.training_data_files = ids 26 | # Get info of training files 27 | self.per_num_patch = int(training_data_files[0].split("/")[-1].split("_")[2]) 28 | self.dataset_len = int(ids[-1].split("_")[-1].split(".")[0]) * self.per_num_patch 29 | self.batch_size = batch_size 30 | self.shuffle = shuffle 31 | self.drop_last = drop_last 32 | # Record the loaded i-th training file 33 | self.num_file = 0 34 | # load poses for each type of patches 35 | self.per_patch_points = int(self.training_data_files[-1].split("_")[3]) 36 | self.num_framents = int(self.training_data_files[-1].split("_")[4].split(".")[0]) 37 | with open(os.path.join(root, 38 | f'{self.split}/{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl'), 39 | 'rb') as file: 40 | self.poses = pickle.load(file) 41 | print( 42 | f"load training poses {os.path.join(root, f'{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl')}") 43 | self.cur_pose_ind = 0 44 | 45 | def initial(self): 46 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file: 47 | self.patches = pickle.load(file) 48 | print(f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}") 49 | 50 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1]) 51 | poses = self.poses[self.cur_pose_ind:next_pose_ind] 52 | for i in range(len(self.patches)): 53 | ind = int(np.floor(i / self.per_num_patch)) 54 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6).astype(np.float32) 55 | self.patches[i] = np.concatenate([pose, self.patches[i]]) 56 | self.cur_pose_ind = next_pose_ind 57 | 58 | self.current_patches_num = len(self.patches) 59 | self.index = list(range(self.current_patches_num)) 60 | if self.shuffle: 61 | random.shuffle(self.patches) 62 | 63 | def __len__(self): 64 | return self.dataset_len 65 | 66 | def __getitem__(self, item): 67 | idx = self.index[0] 68 | patches = self.patches[idx] 69 | self.index = self.index[1:] 70 | self.current_patches_num -= 1 71 | 72 | if self.drop_last: 73 | if self.current_patches_num <= (len(self.patches) % self.batch_size): # reach the end of training file 74 | self.num_file = self.num_file + 1 75 | if self.num_file < len(self.training_data_files): 76 | remain_patches = [self.patches[i] for i in self.index] # the remained training patches 77 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file: 78 | self.patches = pickle.load(file) 79 | print( 80 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}") 81 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1]) 82 | poses = self.poses[self.cur_pose_ind:next_pose_ind] 83 | for i in range(len(self.patches)): 84 | ind = int(np.floor(i / self.per_num_patch)) 85 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6).astype( 86 | np.float32) 87 | self.patches[i] = np.concatenate([pose, self.patches[i]]) 88 | self.cur_pose_ind = next_pose_ind 89 | self.patches += remain_patches # add the remained patches to compose a set of new patches 90 | self.current_patches_num = len(self.patches) 91 | self.index = list(range(self.current_patches_num)) 92 | if self.shuffle: 93 | random.shuffle(self.patches) 94 | else: 95 | self.num_file = 0 96 | self.cur_pose_ind = 0 97 | self.initial() 98 | else: 99 | if self.current_patches_num <= 0: 100 | self.num_file = self.num_file + 1 101 | if self.num_file < len(self.training_data_files): 102 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file: 103 | self.patches = pickle.load(file) 104 | print( 105 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}") 106 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1]) 107 | poses = self.poses[self.cur_pose_ind:next_pose_ind] 108 | for i in range(len(self.patches)): 109 | ind = int(np.floor(i / self.per_num_patch)) 110 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6).astype( 111 | np.float32) 112 | self.patches[i] = np.concatenate([pose, self.patches[i]]) 113 | self.cur_pose_ind = next_pose_ind 114 | self.current_patches_num = len(self.patches) 115 | self.index = list(range(self.current_patches_num)) 116 | if self.shuffle: 117 | random.shuffle(self.patches) 118 | else: 119 | self.num_file = 0 120 | self.cur_pose_ind = 0 121 | self.initial() 122 | 123 | anc_local_patch = patches[2:, :3] 124 | pos_local_patch = patches[2:, 3:] 125 | rotate = patches[:2, :].reshape(12)[:9].reshape(3, 3) 126 | shift = patches[:2, :].reshape(12)[9:] 127 | 128 | # np.random.shuffle(anc_local_patch) 129 | # np.random.shuffle(pos_local_patch) 130 | 131 | return anc_local_patch, pos_local_patch, rotate, shift 132 | 133 | 134 | if __name__ == "__main__": 135 | data_root = "../../data/KITTI_patches/" 136 | batch_size = 48 137 | epoch = 1 138 | train_dataset = KITTIDataset(root=data_root, split='train', batch_size=batch_size, shuffle=True, drop_last=True) 139 | train_dataset.initial() 140 | for _ in range(epoch): 141 | train_iter = Data.DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True) 142 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(train_iter): 143 | B = anc_local_patch.shape[0] 144 | -------------------------------------------------------------------------------- /KITTI/Train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import time 5 | import shutil 6 | import sys 7 | 8 | sys.path.append('../../') 9 | from KITTI.Train.dataloader import get_dataloader 10 | from KITTI.Train.trainer import Trainer 11 | from network.SpinNet import Descriptor_Net 12 | from torch import optim 13 | 14 | 15 | class Args(object): 16 | def __init__(self): 17 | self.experiment_id = "Proposal" + time.strftime('%m%d%H%M') 18 | snapshot_root = 'snapshot/%s' % self.experiment_id 19 | tensorboard_root = 'tensorboard/%s' % self.experiment_id 20 | os.makedirs(snapshot_root, exist_ok=True) 21 | os.makedirs(tensorboard_root, exist_ok=True) 22 | shutil.copy2(os.path.join('', 'train.py'), os.path.join(snapshot_root, 'train.py')) 23 | shutil.copy2(os.path.join('', 'trainer.py'), os.path.join(snapshot_root, 'trainer.py')) 24 | shutil.copy2(os.path.join('', '../../network/SpinNet.py'), os.path.join(snapshot_root, 'SpinNet.py')) 25 | shutil.copy2(os.path.join('', '../../network/ThreeDCCN.py'), os.path.join(snapshot_root, 'ThreeDCCN.py')) 26 | shutil.copy2(os.path.join('', '../../loss/desc_loss.py'), os.path.join(snapshot_root, 'loss.py')) 27 | self.epoch = 20 28 | self.num_patches = 10 29 | self.num_points_per_patch = 2048 # num of points per patches 30 | self.batch_size = 60 31 | self.rad_n = 9 32 | self.azi_n = 60 33 | self.ele_n = 30 34 | self.des_r = 2.0 35 | self.voxel_r = 0.3 36 | self.voxel_sample = 30 37 | 38 | self.dataset = 'KITTI' 39 | self.data_train_dir = '../../data/KITTI/patches' 40 | self.data_val_dir = '../../data/KITTI/patches' 41 | 42 | self.gpu_mode = True 43 | self.verbose = True 44 | self.freeze_epoch = 5 45 | 46 | # model & optimizer 47 | self.model = Descriptor_Net(self.des_r, self.rad_n, self.azi_n, self.ele_n, 48 | self.voxel_r, self.voxel_sample, self.dataset) 49 | self.pretrain = '' 50 | self.parameter = self.model.get_parameter() 51 | self.optimizer = optim.Adam(self.parameter, lr=0.001, betas=(0.9, 0.999), weight_decay=1e-6) 52 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.5) 53 | self.scheduler_interval = 7 54 | 55 | # dataloader 56 | self.train_loader = get_dataloader(root=self.data_train_dir, 57 | batch_size=self.batch_size, 58 | split='train', 59 | shuffle=True, 60 | num_workers=0, 61 | ) 62 | self.val_loader = get_dataloader(root=self.data_val_dir, 63 | batch_size=self.batch_size, 64 | split='val', 65 | shuffle=False, 66 | num_workers=0, 67 | ) 68 | 69 | print("Training set size:", self.train_loader.dataset.__len__()) 70 | print("Validate set size:", self.val_loader.dataset.__len__()) 71 | 72 | # snapshot 73 | self.snapshot_interval = int(self.train_loader.dataset.__len__() / self.batch_size / 2) 74 | self.save_dir = os.path.join(snapshot_root, 'models/') 75 | self.result_dir = os.path.join(snapshot_root, 'results/') 76 | self.tboard_dir = tensorboard_root 77 | 78 | # evaluate 79 | self.evaluate_interval = 1 80 | 81 | self.check_args() 82 | 83 | def check_args(self): 84 | """checking arguments""" 85 | if not os.path.exists(self.save_dir): 86 | os.makedirs(self.save_dir) 87 | if not os.path.exists(self.result_dir): 88 | os.makedirs(self.result_dir) 89 | if not os.path.exists(self.tboard_dir): 90 | os.makedirs(self.tboard_dir) 91 | return self 92 | 93 | 94 | if __name__ == '__main__': 95 | args = Args() 96 | trainer = Trainer(args) 97 | trainer.train() 98 | -------------------------------------------------------------------------------- /KITTI/Train/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time, os 4 | import numpy as np 5 | from loss.desc_loss import ContrastiveLoss 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, args): 11 | # parameters 12 | self.epoch = args.epoch 13 | self.num_points_per_patch = args.num_points_per_patch 14 | self.batch_size = args.batch_size 15 | self.dataset = args.dataset 16 | self.save_dir = args.save_dir 17 | self.result_dir = args.result_dir 18 | self.gpu_mode = args.gpu_mode 19 | self.verbose = args.verbose 20 | self.freeze_epoch = args.freeze_epoch 21 | 22 | self.rad_n = args.rad_n 23 | self.azi_n = args.azi_n 24 | self.ele_n = args.ele_n 25 | self.des_r = args.des_r 26 | self.voxel_r = args.voxel_r 27 | self.voxel_sample = args.voxel_sample 28 | 29 | self.model = args.model 30 | self.optimizer = args.optimizer 31 | self.scheduler = args.scheduler 32 | self.scheduler_interval = args.scheduler_interval 33 | self.snapshot_interval = args.snapshot_interval 34 | self.evaluate_interval = args.evaluate_interval 35 | self.writer = SummaryWriter(log_dir=args.tboard_dir) 36 | 37 | self.train_loader = args.train_loader 38 | self.val_loader = args.val_loader 39 | 40 | self.desc_loss = ContrastiveLoss() 41 | 42 | if self.gpu_mode: 43 | self.model = self.model.cuda() 44 | self.model = torch.nn.DataParallel(self.model, device_ids=[0]) 45 | 46 | if args.pretrain != '': 47 | self._load_pretrain(args.pretrain) 48 | 49 | def train(self): 50 | self.train_hist = { 51 | 'loss': [], 52 | 'per_epoch_time': [], 53 | 'total_time': [] 54 | } 55 | best_loss = 1000000000 56 | print('training start!!') 57 | start_time = time.time() 58 | 59 | self.model.train() 60 | freeze_sign = 1 61 | for epoch in range(self.epoch): 62 | 63 | self.train_epoch(epoch) 64 | 65 | if epoch % self.evaluate_interval == 0 or epoch == 0: 66 | res = self.evaluate(epoch + 1) 67 | print(f'Evaluation: Epoch {epoch}: Loss {res["loss"]}') 68 | 69 | if res['loss'] < best_loss: 70 | best_loss = res['loss'] 71 | self._snapshot('best') 72 | if self.writer: 73 | self.writer.add_scalar('Loss', res['loss'], epoch) 74 | 75 | if epoch % self.scheduler_interval == 0: 76 | old_lr = self.optimizer.param_groups[0]['lr'] 77 | self.scheduler.step() 78 | new_lr = self.optimizer.param_groups[0]['lr'] 79 | print('update detector learning rate: %f -> %f' % (old_lr, new_lr)) 80 | 81 | if self.writer: 82 | self.writer.add_scalar('Learning Rate', self._get_lr(), epoch) 83 | self.writer.add_scalar('Train Loss', self.train_hist['loss'][-1], epoch) 84 | 85 | # finish all epoch 86 | self.train_hist['total_time'].append(time.time() - start_time) 87 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 88 | self.epoch, self.train_hist['total_time'][0])) 89 | print("Training finish!... save training results") 90 | 91 | def train_epoch(self, epoch): 92 | epoch_start_time = time.time() 93 | loss_buf = [] 94 | num_batch = int(len(self.train_loader.dataset) / self.batch_size) 95 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.train_loader): 96 | 97 | B = anc_local_patch.shape[0] 98 | anc_local_patch = anc_local_patch.float() 99 | pos_local_patch = pos_local_patch.float() 100 | rotate = rotate.float() 101 | shift = shift.float() 102 | 103 | if self.gpu_mode: 104 | anc_local_patch = anc_local_patch.cuda() 105 | pos_local_patch = pos_local_patch.cuda() 106 | 107 | # forward 108 | self.optimizer.zero_grad() 109 | a_desc = self.model(anc_local_patch) 110 | p_desc = self.model(pos_local_patch) 111 | anc_desc = F.normalize(a_desc.view(B, -1), p=2, dim=1) 112 | pos_desc = F.normalize(p_desc.view(B, -1), p=2, dim=1) 113 | 114 | # calculate the contrastive loss 115 | des_loss, accuracy = self.desc_loss(anc_desc, pos_desc) 116 | loss = des_loss 117 | 118 | # backward 119 | loss.backward() 120 | self.optimizer.step() 121 | loss_buf.append(float(loss)) 122 | 123 | if iter % self.snapshot_interval == 0: 124 | self._snapshot(f'{epoch}_{iter + 1}') 125 | 126 | if iter % 200 == 0 and self.verbose: 127 | iter_time = time.time() - epoch_start_time 128 | print(f"Epoch: {epoch} [{iter:4d}/{num_batch}] loss: {loss:.2f} time: {iter_time:.2f}s") 129 | print(f"Epoch: {epoch} [{iter:4d}/{num_batch}] des loss: {des_loss:.2f} time: {iter_time:.2f}s") 130 | print(f"Accuracy: {accuracy.item():.4f}\n") 131 | del loss 132 | del anc_local_patch 133 | del pos_local_patch 134 | # finish one epoch 135 | epoch_time = time.time() - epoch_start_time 136 | self.train_hist['per_epoch_time'].append(epoch_time) 137 | self.train_hist['loss'].append(np.mean(loss_buf)) 138 | print(f'Epoch {epoch}: Loss {np.mean(loss_buf)}, time {epoch_time:.4f}s') 139 | 140 | del loss_buf 141 | 142 | def evaluate(self): 143 | self.model.eval() 144 | loss_buf = [] 145 | with torch.no_grad(): 146 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.val_loader): 147 | 148 | B = anc_local_patch.shape[0] 149 | anc_local_patch = anc_local_patch.float() 150 | pos_local_patch = pos_local_patch.float() 151 | rotate = rotate.float() 152 | shift = shift.float() 153 | 154 | if self.gpu_mode: 155 | anc_local_patch = anc_local_patch.cuda() 156 | pos_local_patch = pos_local_patch.cuda() 157 | 158 | # forward 159 | a_des = self.model(anc_local_patch) 160 | p_des = self.model(pos_local_patch) 161 | anc_des = F.normalize(a_des.view(B, -1), p=2, dim=1) 162 | pos_des = F.normalize(p_des.view(B, -1), p=2, dim=1) 163 | 164 | # calculate the contrastive loss 165 | des_loss, accuracy = self.desc_loss(anc_des, pos_des) 166 | loss = des_loss 167 | loss_buf.append(float(loss)) 168 | 169 | del loss 170 | del anc_local_patch 171 | del pos_local_patch 172 | 173 | self.model.train() 174 | 175 | res = { 176 | 'loss': np.mean(loss_buf) 177 | } 178 | del loss_buf 179 | return res 180 | 181 | def _snapshot(self, epoch): 182 | save_dir = os.path.join(self.save_dir, self.dataset) 183 | torch.save(self.model.state_dict(), save_dir + "_" + str(epoch) + '.pkl') 184 | print(f"Save model to {save_dir}_{str(epoch)}.pkl") 185 | 186 | def _load_pretrain(self, pretrain): 187 | state_dict = torch.load(pretrain, map_location='cpu') 188 | self.model.load_state_dict(state_dict) 189 | print(f"Load model from {pretrain}.pkl") 190 | 191 | def _get_lr(self, group=0): 192 | return self.optimizer.param_groups[group]['lr'] 193 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Qingyong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/spinnet-learning-a-general-surface-descriptor/point-cloud-registration-on-3dmatch-benchmark)](https://paperswithcode.com/sota/point-cloud-registration-on-3dmatch-benchmark?p=spinnet-learning-a-general-surface-descriptor) 2 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) 3 | [![arXiv](https://img.shields.io/badge/arXiv-2011.12149-b31b1b.svg)](https://arxiv.org/abs/2011.12149) 4 | # SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration (CVPR 2021) 5 | 6 | This is the official repository of **SpinNet**, a conceptually simple neural architecture to extract local 7 | features which are rotationally invariant whilst sufficiently informative to enable accurate registration. For technical details, please refer to: 8 | 9 | **[SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration](https://arxiv.org/abs/2011.12149)**
10 | [Sheng Ao*](http://scholar.google.com/citations?user=cvS1yuMAAAAJ&hl=zh-CN), [Qingyong Hu*](https://www.cs.ox.ac.uk/people/qingyong.hu/), [Bo Yang](https://yang7879.github.io/), [Andrew Markham](https://www.cs.ox.ac.uk/people/andrew.markham/), [Yulan Guo](http://yulanguo.me/).
11 | (* *indicates equal contribution*) 12 | 13 | **[[Paper](https://arxiv.org/abs/2011.12149)] [Video] [Project page]**
14 | 15 | 16 | ### (1) Overview 17 | 18 |

19 | 20 |

21 | 22 | 23 | 24 | ### (2) Setup 25 | This code has been tested with Python 3.6, Pytorch 1.6.0, CUDA 10.2 on Ubuntu 18.04. 26 | 27 | - Clone the repository 28 | ``` 29 | git clone https://github.com/QingyongHu/SpinNet && cd SpinNet 30 | ``` 31 | - Setup conda virtual environment 32 | ``` 33 | conda create -n spinnet python=3.6 34 | source activate spinnet 35 | conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.2 -c pytorch 36 | conda install -c open3d-admin open3d==0.11.1 37 | pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib" 38 | ``` 39 | 40 | ### (3) 3DMatch 41 | Download the processed dataset from [Google Drive](https://drive.google.com/file/d/1PrkSE0nY79gOF_VJcKv2VpxQ8s7DOITg/view?usp=sharing), [Baidu Yun](https://pan.baidu.com/s/1FB7IUbKAAlk7RVnB_AgwcQ) (Verification code:d1vn) and put the folder into `data`. 42 | Then the structure should be as follows: 43 | ``` 44 | --data--3DMatch--fragments 45 | |--intermediate-files-real 46 | |--patches 47 | 48 | ``` 49 | 50 | **Training** 51 | 52 | Training SpinNet on the 3DMatch dataset: 53 | ``` 54 | cd ./ThreeDMatch/Train 55 | python train.py 56 | ``` 57 | **Testing** 58 | 59 | Evaluate the performance of the trained models on the 3DMatch dataset: 60 | 61 | ``` 62 | cd ./ThreeDMatch/Test 63 | python preparation.py 64 | ``` 65 | The learned descriptors for each point will be saved in `ThreeDMatch/Test/SpinNet_{timestr}/` folder. 66 | Then the `Feature Matching Recall(FMR)` and `Inlier Ratio(IR)` can be calculated by running: 67 | ``` 68 | python evaluate.py [timestr] 69 | ``` 70 | The ground truth poses have been put in the `ThreeDMatch/Test/gt_result` folder. 71 | The `Registration Recall` can be calculated by running the `evaluate.m` in `ThreeDMatch/Test/3dmatch` which are provided by [3DMatch](https://github.com/andyzeng/3dmatch-toolbox/tree/master/evaluation/geometric-registration). 72 | Note that, you need to modify the `descriptorName` to `SpinNet_{timestr}` in the `ThreeDMatch/Test/3dmatch/evaluate.m` file. 73 | 74 | 75 | ### (4) KITTI 76 | Download the processed dataset from [Google Drive](https://drive.google.com/file/d/1fuJiQwAay23BUKtxBG3__MwStyMuvrMQ/view?usp=sharing), [Baidu Yun](https://pan.baidu.com/s/1FB7IUbKAAlk7RVnB_AgwcQ) (Verification code:d1vn), and put the folder into `data`. 77 | Then the structure is as follows: 78 | ``` 79 | --data--KITTI--dataset 80 | |--icp 81 | |--patches 82 | 83 | ``` 84 | 85 | **Training** 86 | 87 | Training SpinNet on the KITTI dataset: 88 | 89 | ``` 90 | cd ./KITTI/Train/ 91 | python train.py 92 | ``` 93 | 94 | **Testing** 95 | 96 | Evaluate the performance of the trained models on the KITTI dataset: 97 | 98 | ``` 99 | cd ./KITTI/Test/ 100 | python test_kitti.py 101 | ``` 102 | 103 | 104 | ### (5) ETH 105 | 106 | The test set can be downloaded from [here](https://share.phys.ethz.ch/~gsg/3DSmoothNet/data/ETH.rar), and put the folder into `data`, then the structure is as follows: 107 | ``` 108 | --data--ETH--gazebo_summer 109 | |--gazebo_winter 110 | |--wood_autmn 111 | |--wood_summer 112 | ``` 113 | 114 | ### (6) Generalization across Unseen Datasets 115 | 116 | **3DMatch to ETH** 117 | 118 | Generalization from 3DMatch dataset to ETH dataset: 119 | ``` 120 | cd ./generalization/ThreeDMatch-to-ETH 121 | python preparation.py 122 | ``` 123 | The descriptors for each point will be generated and saved in the `generalization/ThreeDMatch-to-ETH/SpinNet_{timestr}/` folder. 124 | Then the `Feature Matching Recall` and `inlier ratio` can be caluclated by running 125 | ``` 126 | python evaluate.py [timestr] 127 | ``` 128 | 129 | **3DMatch to KITTI** 130 | 131 | Generalization from 3DMatch dataset to KITTI dataset: 132 | 133 | ``` 134 | cd ./generalization/ThreeDMatch-to-KITTI 135 | python test.py 136 | ``` 137 | 138 | **KITTI to 3DMatch** 139 | 140 | Generalization from KITTI dataset to 3DMatch dataset: 141 | ``` 142 | cd ./generalization/KITTI-to-ThreeDMatch 143 | python preparation.py 144 | ``` 145 | The descriptors for each point will be generated and saved in `generalization/KITTI-to-3DMatch/SpinNet_{timestr}/` folder. 146 | Then the `Feature Matching Recall` and `inlier ratio` can be caluclated by running 147 | ``` 148 | python evaluate.py [timestr] 149 | ``` 150 | 151 | ## Acknowledgement 152 | 153 | In this project, we use (parts of) the implementations of the following works: 154 | 155 | * [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) 156 | * [PPF-FoldNet](https://github.com/XuyangBai/PPF-FoldNet) 157 | * [Spherical CNNs](https://github.com/jonas-koehler/s2cnn) 158 | * [FCGF](https://github.com/chrischoy/FCGF) 159 | * [r2d2](https://github.com/naver/r2d2) 160 | * [D3Feat](https://github.com/XuyangBai/D3Feat) 161 | * [D3Feat.pytorch](https://github.com/XuyangBai/D3Feat.pytorch) 162 | 163 | 164 | ### Citation 165 | If you find our work useful in your research, please consider citing: 166 | 167 | @inproceedings{ao2020SpinNet, 168 | title={SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration}, 169 | author={Ao, Sheng and Hu, Qingyong and Yang, Bo and Markham, Andrew and Guo, Yulan}, 170 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 171 | year={2021} 172 | } 173 | 174 | ### References 175 | 176 | 177 | [1] 3DMatch: Learning Local Geometric Descriptors from RGB-D Reconstructions, Andy Zeng, Shuran Song, Matthias Nießner, Matthew Fisher, Jianxiong Xiao, and Thomas Funkhouser, CVPR 2017. 178 | 179 | 180 | 181 | ### Updates 182 | * 03/04/2021: The code is released! 183 | * 01/03/2021: This paper has been accepted by CVPR 2021! 184 | * 25/11/2020: Initial release! 185 | 186 | ## Related Repos 187 | 1. [RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds](https://github.com/QingyongHu/RandLA-Net) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/RandLA-Net.svg?style=flat&label=Star) 188 | 2. [SoTA-Point-Cloud: Deep Learning for 3D Point Clouds: A Survey](https://github.com/QingyongHu/SoTA-Point-Cloud) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SoTA-Point-Cloud.svg?style=flat&label=Star) 189 | 3. [3D-BoNet: Learning Object Bounding Boxes for 3D Instance Segmentation on Point Clouds](https://github.com/Yang7879/3D-BoNet) ![GitHub stars](https://img.shields.io/github/stars/Yang7879/3D-BoNet.svg?style=flat&label=Star) 190 | 4. [SensatUrban: Learning Semantics from Urban-Scale Photogrammetric Point Clouds](https://github.com/QingyongHu/SensatUrban) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SensatUrban.svg?style=flat&label=Star) 191 | 5. [SQN: Weakly-Supervised Semantic Segmentation of Large-Scale 3D Point Clouds with 1000x Fewer Labels](https://github.com/QingyongHu/SQN) ![GitHub stars](https://img.shields.io/github/stars/QingyongHu/SQN.svg?style=flat&label=Star) 192 | 193 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/evaluate.m: -------------------------------------------------------------------------------- 1 | % Script to evaluate .log files for the geometric registration benchmarks, 2 | % in the same spirit as Choi et al 2015. Please see: 3 | % 4 | % http://redwood-data.org/indoor/regbasic.html 5 | % https://github.com/qianyizh/ElasticReconstruction/tree/master/Matlab_Toolbox 6 | 7 | 8 | descriptorName = 'SpinNet_10051828'; % 9 | 10 | % Locations of evaluation files 11 | dataPath = '../log_result'; 12 | 13 | % Real data benchmark 14 | sceneList = { 15 | '7-scenes-redkitchen-evaluation', ... 16 | 'sun3d-home_at-home_at_scan1_2013_jan_1-evaluation', ... 17 | 'sun3d-home_md-home_md_scan9_2012_sep_30-evaluation', ... 18 | 'sun3d-hotel_uc-scan3-evaluation', ... 19 | 'sun3d-hotel_umd-maryland_hotel1-evaluation', ... 20 | 'sun3d-hotel_umd-maryland_hotel3-evaluation', ... 21 | 'sun3d-mit_76_studyroom-76-1studyroom2-evaluation', ... 22 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika-evaluation' 23 | }; 24 | 25 | % Load Elastic Reconstruction toolbox 26 | addpath(genpath('external')); 27 | 28 | % Compute precision and recall 29 | totalRecall = []; totalPrecision = []; 30 | totalGt = 0; 31 | totalTP = 0; 32 | for sceneIdx = 1:length(sceneList) 33 | scenePath = fullfile(dataPath,sceneList{sceneIdx}); 34 | gtPath = fullfile('../gt_result',sceneList{sceneIdx}); 35 | 36 | % Compute registration error 37 | gt = mrLoadLog(fullfile(gtPath,'gt.log')); 38 | gt_info = mrLoadInfo(fullfile(gtPath,'gt.info')); 39 | result = mrLoadLog(fullfile(scenePath,sprintf('%s.log',descriptorName))); 40 | [recall,precision,gt_num] = mrEvaluateRegistration(result,gt,gt_info); 41 | totalRecall = [totalRecall;recall]; 42 | totalPrecision = [totalPrecision;precision]; 43 | totalGt = totalGt + gt_num; 44 | totalTP = totalTP + round(gt_num * recall); 45 | end 46 | totalRecall 47 | fprintf('Mean registration recall: %f precision: %f\n',mean(totalRecall),mean(totalPrecision)); 48 | fprintf('True average recall: %f (%d/%d)\n',totalTP/totalGt,totalTP, totalGt); 49 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrDrawTrajectory.m: -------------------------------------------------------------------------------- 1 | function mrDrawTraj( traj, c, init_trans ) 2 | if ~exist( 'c', 'var' ) 3 | c = 'b-'; 4 | end 5 | 6 | if ~exist( 'init_trans', 'var' ) 7 | init_trans = traj( 1 ).trans; 8 | end 9 | 10 | n = size( traj, 2 ); 11 | x = zeros( 2, n ); 12 | init_inverse = init_trans ^ -1; 13 | 14 | for k = 1 : n 15 | m = init_inverse * traj( k ).trans; 16 | x( :, k ) = [ m( 1, 4 ); m( 3, 4 ) ]; 17 | end 18 | 19 | plot( -x( 1, : ), -x( 2, : ), c, 'LineWidth',2 ); 20 | axis equal; 21 | end -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrEvaluateRegistration.m: -------------------------------------------------------------------------------- 1 | function [ recall, precision, gt_num ] = mrEvaluateRegistration( result, gt, gt_info, err2 ) 2 | if ~exist( 'err2', 'var' ) 3 | err2 = 0.04; 4 | end 5 | num = gt( 1 ).info( 3 ); 6 | 7 | mask = zeros( 1, num * num ); 8 | gt_num = 0; 9 | for i = 1 : size( gt, 2 ) 10 | if ( gt( i ).info( 2 ) - gt( i ).info( 1 ) > 1 ) 11 | mask( gt( i ).info( 1 ) + gt( i ).info( 2 ) * num + 1 ) = i; 12 | gt_num = gt_num + 1; 13 | end 14 | end 15 | 16 | rs_num = 0; 17 | good = 0; 18 | bad = 0; 19 | false_pos = 0; 20 | error_dis = []; 21 | for i = 1 : size( result, 2 ) 22 | if ( result( i ).info( 2 ) - result( i ).info( 1 ) > 1 ) 23 | rs_num = rs_num + 1; 24 | idx = mask( result( i ).info( 1 ) + result( i ).info( 2 ) * num + 1 ); 25 | if idx == 0 26 | false_pos = false_pos + 1; 27 | else 28 | p = mrComputeTransformationError( gt( idx ).trans ^ -1 * result( i ).trans, gt_info( idx ).mat ); 29 | error_dis = [ error_dis, p ]; 30 | if ( p <= err2 ) 31 | good = good + 1; 32 | else 33 | bad = bad + 1; 34 | end 35 | end 36 | end 37 | end 38 | recall = good / gt_num; 39 | precision = good / rs_num; 40 | disp( [ 'recall : ' num2str( recall ) ' ( ' num2str( good ) ' / ' num2str( gt_num ) ' )' ] ); 41 | %disp( [ 'precision : ' num2str( precision ) ' ( ' num2str( good ) ' / ' num2str( rs_num ) ' )' ] ); 42 | end 43 | 44 | function [ p ] = mrComputeTransformationError( trans, info ) 45 | te = trans( 1 : 3, 4 ); 46 | qt = dcm2quat( trans( 1 : 3, 1 : 3 ) ); 47 | er = [ te; - qt( 2 : 4 )' ]; 48 | p = er' * info * er / info( 1, 1 ); 49 | end 50 | 51 | function [qout] = dcm2quat(DCM) 52 | % this is consistent with the matlab function in 53 | % the Aerospace Toolbox 54 | qout = zeros(1,4); 55 | qout(1) = 0.5 * sqrt(1 + DCM(1,1) + DCM(2,2) + DCM(3,3)); 56 | qout(2) = - (DCM(3,2) - DCM(2,3)) / ( 4 * qout(1) ); 57 | qout(3) = - (DCM(1,3) - DCM(3,1)) / ( 4 * qout(1) ); 58 | qout(4) = - (DCM(2,1) - DCM(1,2)) / ( 4 * qout(1) ); 59 | end -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrEvaluateTrajectory.m: -------------------------------------------------------------------------------- 1 | function [ rmse, trans ] = mrEvaluateTraj( traj_et, traj_gt ) 2 | gt_n = size( traj_gt, 2 ); 3 | et_n = size( traj_et, 2 ); 4 | if (gt_n ~= et_n) 5 | fprintf('WARNING: There are Lost Frames!\n'); 6 | fprintf('ground truth traj : %d frames\n', gt_n); 7 | fprintf('estimated traj : %d frames\n', et_n); 8 | gt_n = min( [ gt_n, et_n ] ); 9 | et_n = gt_n; 10 | end 11 | n = et_n; 12 | 13 | trans = mrAlignTraj( traj_et, traj_gt ); 14 | err = zeros( 1, n ); 15 | 16 | for i = 1 : n 17 | assert( traj_et( i ).info( 3 ) == traj_gt( i ).info( 3 ),... 18 | 'bad trajectory file format or asynchronized frame.' ); 19 | trans_et = trans * traj_et( i ).trans; 20 | trans_gt = traj_gt( i ).trans; 21 | err( i ) = norm( trans_gt( 1 : 3, 4 ) - trans_et( 1 : 3, 4 ) ); 22 | end 23 | 24 | rmse = sqrt( err * err' / size( err, 2 ) ); 25 | fprintf( 'median absolute translational error %f m\n', median( err ) ); 26 | fprintf( 'rmse %f m\n', rmse ); 27 | end 28 | 29 | function [ trans ] = mrAlignTraj( traj_et, traj_gt ) 30 | n = size( traj_et, 2 ); 31 | gt_trans = zeros( 3, n ); 32 | et_trans = zeros( 3, n ); 33 | 34 | for i = 1 : n 35 | gt_trans( :, i ) = traj_gt( i ).trans( 1 : 3, 4 ); 36 | et_trans( :, i ) = traj_et( i ).trans( 1 : 3, 4 ); 37 | end 38 | 39 | gt_mean = mean( gt_trans, 2 ); 40 | et_mean = mean( et_trans, 2 ); 41 | gt_centered = gt_trans - repmat( gt_mean, 1, n ); 42 | et_centered = et_trans - repmat( et_mean, 1, n ); 43 | 44 | W = zeros( 3, 3 ); 45 | for i = 1 : n 46 | W = W + et_centered( :, i ) * gt_centered( :, i )'; 47 | end 48 | 49 | [ U, ~, V ] = svd( W' ); 50 | Vh = V'; 51 | S = eye( 3 ); 52 | if ( det( U ) * det( Vh ) < 0 ) 53 | S( 3, 3 ) = -1; 54 | end 55 | 56 | r = U * S * Vh; 57 | t = gt_mean - r * et_mean; 58 | 59 | trans = [ r, t; 0, 0, 0, 1 ]; 60 | end -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrLoadInfo.m: -------------------------------------------------------------------------------- 1 | function [ info ] = mrLoadInfo( filename ) 2 | fid = fopen( filename ); 3 | k = 1; 4 | x = fscanf( fid, '%d', [ 1, 3 ] ); 5 | while ( size( x, 2 ) == 3 ) 6 | m = fscanf( fid, '%f', [ 6, 6 ] ); 7 | info( k ) = struct( 'info', x, 'mat', m' ); 8 | k = k + 1; 9 | x = fscanf( fid, '%d', [ 1, 3 ] ); 10 | end 11 | fclose( fid ); 12 | %disp( [ num2str( size( info, 2 ) ), ' matrices have been read.' ] ); 13 | end 14 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrLoadLog.m: -------------------------------------------------------------------------------- 1 | function [ traj ] = mrLoadLog( filename ) 2 | fid = fopen( filename ); 3 | k = 1; 4 | x = fscanf( fid, '%d', [1 3] ); 5 | while ( size( x, 2 ) == 3 ) 6 | m = fscanf( fid, '%f', [4 4] ); 7 | traj( k ) = struct( 'info', x, 'trans', m' ); 8 | k = k + 1; 9 | x = fscanf( fid, '%d', [1 3] ); 10 | end 11 | fclose( fid ); 12 | %disp( [ num2str( size( traj, 2 ) ), ' frames have been read.' ] ); 13 | end 14 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrMatchDepthColor.m: -------------------------------------------------------------------------------- 1 | function mrMatchDepthColor( basepath, unique, depthdir, imagedir, matchfile ) 2 | if ~exist( 'matchfile', 'var' ) 3 | matchfile = 'match'; 4 | end 5 | if ~exist( 'imagedir', 'var' ) 6 | imagedir = 'rgb'; 7 | end 8 | if ~exist( 'depthdir', 'var' ) 9 | depthdir = 'depth'; 10 | end 11 | if ~exist( 'unique', 'var' ) 12 | unique = 1; 13 | end 14 | 15 | depth_file_list = dir( [ basepath, depthdir, '/*.png' ] ); 16 | if ( size( depth_file_list, 1 ) <= 1 ) 17 | disp( 'Error: path not found' ); 18 | return; 19 | end 20 | disp( [ num2str( size( depth_file_list, 1 ) ) ' depth images detected.' ] ); 21 | depth_timestamp = parseTimestamp( depth_file_list ); 22 | 23 | color_file_list = dir( [ basepath, imagedir, '/*.jpg' ] ); 24 | if ( size( color_file_list, 1 ) <= 1 ) 25 | disp( 'Error: path not found' ); 26 | return; 27 | end 28 | disp( [ num2str( size( color_file_list, 1 ) ) ' color images detected.' ] ); 29 | color_timestamp = parseTimestamp( color_file_list ); 30 | color_timestamp_mat = cell2mat( color_timestamp( :, 1 ) ); 31 | 32 | fid = fopen( [ basepath, matchfile ], 'w' ); 33 | used_color = zeros( size( color_timestamp, 1 ), 1 ); 34 | k = 0; 35 | for i = 1 : size( depth_timestamp, 1 ) 36 | idx = findClosestColor( depth_timestamp{ i, 1 }, color_timestamp_mat ); 37 | if ( unique == 0 || used_color( idx ) == 0 ) 38 | used_color( idx ) = 1; 39 | fprintf( fid, '%s/%s %s/%s\n', depthdir, depth_timestamp{ i, 2 }, imagedir, color_timestamp{ idx, 2 } ); 40 | k = k + 1; 41 | end 42 | end 43 | fclose( fid ); 44 | disp( [ num2str( k ) ' pairs have been written.' ] ); 45 | end 46 | 47 | function [ i ] = findClosestColor( depth_ts, color_ts_mat ) 48 | [ ~, i ] = min( abs( color_ts_mat - depth_ts ) ); 49 | end 50 | 51 | function [ timestamp ] = parseTimestamp( filelist ) 52 | num = size( filelist, 1 ); 53 | timestamp = cell( num, 2 ); 54 | for i = 1 : num 55 | x = sscanf( filelist( i ).name, '%f-%f.' )'; 56 | timestamp{ i, 1 } = x( 2 ); 57 | timestamp{ i, 2 } = filelist( i ).name; 58 | end 59 | sortrows( timestamp ); 60 | end -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrWriteInfo.m: -------------------------------------------------------------------------------- 1 | function mrWriteInfo( info, filename ) 2 | fid = fopen( filename, 'w' ); 3 | for i = 1 : size( info, 2 ) 4 | mrWriteInfoStruct( fid, info( i ).info, info( i ).mat ); 5 | end 6 | fclose( fid ); 7 | %disp( [ num2str( size( info, 2 ) ), ' matrices have been written.' ] ); 8 | end 9 | 10 | function mrWriteInfoStruct( fid, x, m ) 11 | fprintf( fid, '%d\t%d\t%d\n', x(1), x(2), x(3) ); 12 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ... 13 | m(1,1), m(1,2), m(1,3), m(1,4), m(1,5), m(1,6) ); 14 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ... 15 | m(2,1), m(2,2), m(2,3), m(2,4), m(2,5), m(2,6) ); 16 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ... 17 | m(3,1), m(3,2), m(3,3), m(3,4), m(3,5), m(3,6) ); 18 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ... 19 | m(4,1), m(4,2), m(4,3), m(4,4), m(4,5), m(4,6) ); 20 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ... 21 | m(5,1), m(5,2), m(5,3), m(5,4), m(5,5), m(5,6) ); 22 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ... 23 | m(6,1), m(6,2), m(6,3), m(6,4), m(6,5), m(6,6) ); 24 | end 25 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrWriteLog.m: -------------------------------------------------------------------------------- 1 | function mrWriteLog( traj, filename ) 2 | fid = fopen( filename, 'w' ); 3 | for i = 1 : size( traj, 2 ) 4 | mrWriteLogStruct( fid, traj( i ).info, traj( i ).trans ); 5 | end 6 | fclose( fid ); 7 | %disp( [ num2str( size( traj, 2 ) ), ' frames have been written.' ] ); 8 | end 9 | 10 | function mrWriteLogStruct( fid, x, m ) 11 | fprintf( fid, '%d\t%d\t%d\n', x(1), x(2), x(3) ); 12 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(1,1), m(1,2), m(1,3), m(1,4) ); 13 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(2,1), m(2,2), m(2,3), m(2,4) ); 14 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(3,1), m(3,2), m(3,3), m(3,4) ); 15 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(4,1), m(4,2), m(4,3), m(4,4) ); 16 | end 17 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/npy-matlab/constructNPYheader.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | function header = constructNPYheader(dataType, shape, varargin) 5 | 6 | if ~isempty(varargin) 7 | fortranOrder = varargin{1}; % must be true/false 8 | littleEndian = varargin{2}; % must be true/false 9 | else 10 | fortranOrder = true; 11 | littleEndian = true; 12 | end 13 | 14 | dtypesMatlab = {'uint8','uint16','uint32','uint64','int8','int16','int32','int64','single','double', 'logical'}; 15 | dtypesNPY = {'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'f4', 'f8', 'b1'}; 16 | 17 | magicString = uint8([147 78 85 77 80 89]); %x93NUMPY 18 | 19 | majorVersion = uint8(1); 20 | minorVersion = uint8(0); 21 | 22 | % build the dict specifying data type, array order, endianness, and 23 | % shape 24 | dictString = '{''descr'': '''; 25 | 26 | if littleEndian 27 | dictString = [dictString '<']; 28 | else 29 | dictString = [dictString '>']; 30 | end 31 | 32 | dictString = [dictString dtypesNPY{strcmp(dtypesMatlab,dataType)} ''', ']; 33 | 34 | dictString = [dictString '''fortran_order'': ']; 35 | 36 | if fortranOrder 37 | dictString = [dictString 'True, ']; 38 | else 39 | dictString = [dictString 'False, ']; 40 | end 41 | 42 | dictString = [dictString '''shape'': (']; 43 | 44 | % if length(shape)==1 && shape==1 45 | % 46 | % else 47 | % for s = 1:length(shape) 48 | % if s==length(shape) && shape(s)==1 49 | % 50 | % else 51 | % dictString = [dictString num2str(shape(s))]; 52 | % if length(shape)>1 && s+1==length(shape) && shape(s+1)==1 53 | % dictString = [dictString ',']; 54 | % elseif length(shape)>1 && s %s', tempFilename, inFilename, outFilename)); 38 | 39 | otherwise 40 | fprintf(1, 'I don''t know how to concatenate files for your OS, but you can finish making the NPY youself by concatenating %s with %s.\n', tempFilename, inFilename); 41 | end 42 | 43 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/npy-matlab/readNPY.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | function data = readNPY(filename) 4 | % Function to read NPY files into matlab. 5 | % *** Only reads a subset of all possible NPY files, specifically N-D arrays of certain data types. 6 | % See https://github.com/kwikteam/npy-matlab/blob/master/tests/npy.ipynb for 7 | % more. 8 | % 9 | 10 | [shape, dataType, fortranOrder, littleEndian, totalHeaderLength, ~] = readNPYheader(filename); 11 | 12 | if littleEndian 13 | fid = fopen(filename, 'r', 'l'); 14 | else 15 | fid = fopen(filename, 'r', 'b'); 16 | end 17 | 18 | try 19 | 20 | [~] = fread(fid, totalHeaderLength, 'uint8'); 21 | 22 | % read the data 23 | data = fread(fid, prod(shape), [dataType '=>' dataType]); 24 | 25 | if length(shape)>1 && ~fortranOrder 26 | data = reshape(data, shape(end:-1:1)); 27 | data = permute(data, [length(shape):-1:1]); 28 | elseif length(shape)>1 29 | data = reshape(data, shape); 30 | end 31 | 32 | fclose(fid); 33 | 34 | catch me 35 | fclose(fid); 36 | rethrow(me); 37 | end 38 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/npy-matlab/readNPYheader.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | function [arrayShape, dataType, fortranOrder, littleEndian, totalHeaderLength, npyVersion] = readNPYheader(filename) 4 | % function [arrayShape, dataType, fortranOrder, littleEndian, ... 5 | % totalHeaderLength, npyVersion] = readNPYheader(filename) 6 | % 7 | % parse the header of a .npy file and return all the info contained 8 | % therein. 9 | % 10 | % Based on spec at http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html 11 | 12 | fid = fopen(filename); 13 | 14 | % verify that the file exists 15 | if (fid == -1) 16 | if ~isempty(dir(filename)) 17 | error('Permission denied: %s', filename); 18 | else 19 | error('File not found: %s', filename); 20 | end 21 | end 22 | 23 | try 24 | 25 | dtypesMatlab = {'uint8','uint16','uint32','uint64','int8','int16','int32','int64','single','double', 'logical'}; 26 | dtypesNPY = {'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'f4', 'f8', 'b1'}; 27 | 28 | 29 | magicString = fread(fid, [1 6], 'uint8=>uint8'); 30 | 31 | if ~all(magicString == [147,78,85,77,80,89]) 32 | error('readNPY:NotNUMPYFile', 'Error: This file does not appear to be NUMPY format based on the header.'); 33 | end 34 | 35 | majorVersion = fread(fid, [1 1], 'uint8=>uint8'); 36 | minorVersion = fread(fid, [1 1], 'uint8=>uint8'); 37 | 38 | npyVersion = [majorVersion minorVersion]; 39 | 40 | headerLength = fread(fid, [1 1], 'uint16=>uint16'); 41 | 42 | totalHeaderLength = 10+headerLength; 43 | 44 | arrayFormat = fread(fid, [1 headerLength], 'char=>char'); 45 | 46 | % to interpret the array format info, we make some fairly strict 47 | % assumptions about its format... 48 | 49 | r = regexp(arrayFormat, '''descr''\s*:\s*''(.*?)''', 'tokens'); 50 | dtNPY = r{1}{1}; 51 | 52 | littleEndian = ~strcmp(dtNPY(1), '>'); 53 | 54 | dataType = dtypesMatlab{strcmp(dtNPY(2:3), dtypesNPY)}; 55 | 56 | r = regexp(arrayFormat, '''fortran_order''\s*:\s*(\w+)', 'tokens'); 57 | fortranOrder = strcmp(r{1}{1}, 'True'); 58 | 59 | r = regexp(arrayFormat, '''shape''\s*:\s*\((.*?)\)', 'tokens'); 60 | shapeStr = r{1}{1}; 61 | arrayShape = str2num(shapeStr(shapeStr~='L')); 62 | 63 | 64 | fclose(fid); 65 | 66 | catch me 67 | fclose(fid); 68 | rethrow(me); 69 | end 70 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/3dmatch/external/npy-matlab/writeNPY.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | function writeNPY(var, filename) 4 | % function writeNPY(var, filename) 5 | % 6 | % Only writes little endian, fortran (column-major) ordering; only writes 7 | % with NPY version number 1.0. 8 | % 9 | % Always outputs a shape according to matlab's convention, e.g. (10, 1) 10 | % rather than (10,). 11 | 12 | 13 | shape = size(var); 14 | dataType = class(var); 15 | 16 | header = constructNPYheader(dataType, shape); 17 | 18 | fid = fopen(filename, 'w'); 19 | fwrite(fid, header, 'uint8'); 20 | fwrite(fid, var, dataType); 21 | fclose(fid); 22 | 23 | 24 | end 25 | 26 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../') 4 | import open3d 5 | import numpy as np 6 | import time 7 | import os 8 | from ThreeDMatch.Test.tools import get_pcd, get_keypts, get_desc, loadlog 9 | from sklearn.neighbors import KDTree 10 | 11 | 12 | def calculate_M(source_desc, target_desc): 13 | """ 14 | Find the mutually closest point pairs in feature space. 15 | source and target are descriptor for 2 point cloud key points. [5000, 512] 16 | """ 17 | 18 | kdtree_s = KDTree(target_desc) 19 | sourceNNdis, sourceNNidx = kdtree_s.query(source_desc, 1) 20 | kdtree_t = KDTree(source_desc) 21 | targetNNdis, targetNNidx = kdtree_t.query(target_desc, 1) 22 | result = [] 23 | for i in range(len(sourceNNidx)): 24 | if targetNNidx[sourceNNidx[i]] == i: 25 | result.append([i, sourceNNidx[i][0]]) 26 | return np.array(result) 27 | 28 | 29 | def register2Fragments(id1, id2, keyptspath, descpath, resultpath, desc_name='SpinNet'): 30 | cloud_bin_s = f'cloud_bin_{id1}' 31 | cloud_bin_t = f'cloud_bin_{id2}' 32 | write_file = f'{cloud_bin_s}_{cloud_bin_t}.rt.txt' 33 | if os.path.exists(os.path.join(resultpath, write_file)): 34 | print(f"{write_file} already exists.") 35 | return 0, 0, 0 36 | 37 | if is_D3Feat_keypts: 38 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_s + '.npy' 39 | source_keypts = np.load(keypts_path) 40 | source_keypts = source_keypts[-num_keypoints:, :] 41 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_t + '.npy' 42 | target_keypts = np.load(keypts_path) 43 | target_keypts = target_keypts[-num_keypoints:, :] 44 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name) 45 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name) 46 | source_desc = np.nan_to_num(source_desc) 47 | target_desc = np.nan_to_num(target_desc) 48 | source_desc = source_desc[-num_keypoints:, :] 49 | target_desc = target_desc[-num_keypoints:, :] 50 | else: 51 | source_keypts = get_keypts(keyptspath, cloud_bin_s) 52 | target_keypts = get_keypts(keyptspath, cloud_bin_t) 53 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name) 54 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name) 55 | source_desc = np.nan_to_num(source_desc) 56 | target_desc = np.nan_to_num(target_desc) 57 | if source_desc.shape[0] > num_keypoints: 58 | rand_ind = np.random.choice(source_desc.shape[0], num_keypoints, replace=False) 59 | source_keypts = source_keypts[rand_ind] 60 | target_keypts = target_keypts[rand_ind] 61 | source_desc = source_desc[rand_ind] 62 | target_desc = target_desc[rand_ind] 63 | 64 | key = f'{cloud_bin_s.split("_")[-1]}_{cloud_bin_t.split("_")[-1]}' 65 | if key not in gtLog.keys(): 66 | num_inliers = 0 67 | inlier_ratio = 0 68 | gt_flag = 0 69 | else: 70 | # find mutually cloest point. 71 | corr = calculate_M(source_desc, target_desc) 72 | 73 | gtTrans = gtLog[key] 74 | frag1 = source_keypts[corr[:, 0]] 75 | frag2_pc = open3d.geometry.PointCloud() 76 | frag2_pc.points = open3d.utility.Vector3dVector(target_keypts[corr[:, 1]]) 77 | frag2_pc.transform(gtTrans) 78 | frag2 = np.asarray(frag2_pc.points) 79 | distance = np.sqrt(np.sum(np.power(frag1 - frag2, 2), axis=1)) 80 | num_inliers = np.sum(distance < 0.10) 81 | inlier_ratio = num_inliers / len(distance) 82 | gt_flag = 1 83 | 84 | # calculate the transformation matrix using RANSAC, this is for Registration Recall. 85 | source_pcd = open3d.geometry.PointCloud() 86 | source_pcd.points = open3d.utility.Vector3dVector(source_keypts) 87 | target_pcd = open3d.geometry.PointCloud() 88 | target_pcd.points = open3d.utility.Vector3dVector(target_keypts) 89 | s_desc = open3d.pipelines.registration.Feature() 90 | s_desc.data = source_desc.T 91 | t_desc = open3d.pipelines.registration.Feature() 92 | t_desc.data = target_desc.T 93 | 94 | # Another registration method 95 | corr_v = open3d.utility.Vector2iVector(corr) 96 | result = open3d.pipelines.registration.registration_ransac_based_on_correspondence( 97 | source_pcd, target_pcd, corr_v, 98 | 0.05, 99 | open3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3, 100 | open3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) 101 | 102 | # write the transformation matrix into .log file for evaluation. 103 | with open(os.path.join(logpath, f'{desc_name}_{timestr}.log'), 'a+') as f: 104 | trans = result.transformation 105 | trans = np.linalg.inv(trans) 106 | s1 = f'{id1}\t {id2}\t 37\n' 107 | f.write(s1) 108 | f.write(f"{trans[0, 0]}\t {trans[0, 1]}\t {trans[0, 2]}\t {trans[0, 3]}\t \n") 109 | f.write(f"{trans[1, 0]}\t {trans[1, 1]}\t {trans[1, 2]}\t {trans[1, 3]}\t \n") 110 | f.write(f"{trans[2, 0]}\t {trans[2, 1]}\t {trans[2, 2]}\t {trans[2, 3]}\t \n") 111 | f.write(f"{trans[3, 0]}\t {trans[3, 1]}\t {trans[3, 2]}\t {trans[3, 3]}\t \n") 112 | 113 | s = f"{cloud_bin_s}\t{cloud_bin_t}\t{num_inliers}\t{inlier_ratio:.8f}\t{gt_flag}" 114 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'w+') as f: 115 | f.write(s) 116 | return num_inliers, inlier_ratio, gt_flag 117 | 118 | 119 | def read_register_result(id1, id2): 120 | cloud_bin_s = f'cloud_bin_{id1}' 121 | cloud_bin_t = f'cloud_bin_{id2}' 122 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'r') as f: 123 | content = f.readlines() 124 | nums = content[0].replace("\n", "").split("\t")[2:5] 125 | return nums 126 | 127 | 128 | if __name__ == '__main__': 129 | scene_list = [ 130 | '7-scenes-redkitchen', 131 | 'sun3d-home_at-home_at_scan1_2013_jan_1', 132 | 'sun3d-home_md-home_md_scan9_2012_sep_30', 133 | 'sun3d-hotel_uc-scan3', 134 | 'sun3d-hotel_umd-maryland_hotel1', 135 | 'sun3d-hotel_umd-maryland_hotel3', 136 | 'sun3d-mit_76_studyroom-76-1studyroom2', 137 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika' 138 | ] 139 | desc_name = 'SpinNet' 140 | timestr = sys.argv[1] 141 | inliers_list = [] 142 | recall_list = [] 143 | inliers_ratio_list = [] 144 | num_keypoints = 5000 145 | is_D3Feat_keypts = False 146 | for scene in scene_list: 147 | pcdpath = f"../../data/3DMatch/fragments/{scene}/" 148 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/" 149 | gtpath = f'../../data/3DMatch/fragments/{scene}-evaluation/' 150 | keyptspath = interpath # os.path.join(interpath, "keypoints/") 151 | descpath = os.path.join(".", f"{desc_name}_desc_{timestr}/{scene}") 152 | gtLog = loadlog(gtpath) 153 | logpath = f"log_result/{scene}-evaluation" 154 | resultpath = os.path.join(".", f"pred_result/{scene}/{desc_name}_result_{timestr}") 155 | if not os.path.exists(resultpath): 156 | os.makedirs(resultpath) 157 | if not os.path.exists(logpath): 158 | os.makedirs(logpath) 159 | 160 | # register each pair 161 | num_frag = len(os.listdir(pcdpath)) 162 | print(f"Start Evaluate Descriptor {desc_name} for {scene}") 163 | start_time = time.time() 164 | for id1 in range(num_frag): 165 | for id2 in range(id1 + 1, num_frag): 166 | num_inliers, inlier_ratio, gt_flag = register2Fragments(id1, id2, keyptspath, descpath, resultpath, 167 | desc_name) 168 | print(f"Finish Evaluation, time: {time.time() - start_time:.2f}s") 169 | 170 | # evaluate 171 | result = [] 172 | for id1 in range(num_frag): 173 | for id2 in range(id1 + 1, num_frag): 174 | line = read_register_result(id1, id2) 175 | result.append([int(line[0]), float(line[1]), int(line[2])]) 176 | result = np.array(result) 177 | indices_results = np.sum(result[:, 2] == 1) 178 | correct_match = np.sum(result[:, 1] > 0.05) 179 | recall = float(correct_match / indices_results) * 100 180 | print(f"Correct Match {correct_match}, ground truth Match {indices_results}") 181 | print(f"Recall {recall}%") 182 | ave_num_inliers = np.sum(np.where(result[:, 1] > 0.05, result[:, 0], np.zeros(result.shape[0]))) / correct_match 183 | print(f"Average Num Inliners: {ave_num_inliers}") 184 | ave_inlier_ratio = np.sum( 185 | np.where(result[:, 1] > 0.05, result[:, 1], np.zeros(result.shape[0]))) / correct_match 186 | print(f"Average Num Inliner Ratio: {ave_inlier_ratio}") 187 | recall_list.append(recall) 188 | inliers_list.append(ave_num_inliers) 189 | inliers_ratio_list.append(ave_inlier_ratio) 190 | print(recall_list) 191 | average_recall = sum(recall_list) / len(recall_list) 192 | print(f"All 8 scene, average recall: {average_recall}%") 193 | average_inliers = sum(inliers_list) / len(inliers_list) 194 | print(f"All 8 scene, average num inliers: {average_inliers}") 195 | average_inliers_ratio = sum(inliers_ratio_list) / len(inliers_list) 196 | print(f"All 8 scene, average num inliers ratio: {average_inliers_ratio}") 197 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import time 5 | import numpy as np 6 | import torch 7 | import shutil 8 | import torch.nn as nn 9 | import sys 10 | 11 | sys.path.append('../../') 12 | import script.common as cm 13 | from ThreeDMatch.Test.tools import get_pcd, get_keypts 14 | from sklearn.neighbors import KDTree 15 | import importlib 16 | import open3d 17 | 18 | 19 | def make_open3d_point_cloud(xyz, color=None): 20 | pcd = open3d.geometry.PointCloud() 21 | pcd.points = open3d.utility.Vector3dVector(xyz) 22 | if color is not None: 23 | pcd.paint_uniform_color(color) 24 | return pcd 25 | 26 | 27 | def build_patch_input(pcd, keypts, vicinity=0.3, num_points_per_patch=2048): 28 | refer_pts = keypts.astype(np.float32) 29 | pts = np.array(pcd.points).astype(np.float32) 30 | num_patches = refer_pts.shape[0] 31 | tree = KDTree(pts[:, 0:3]) 32 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity) 33 | local_patches = np.zeros([num_patches, num_points_per_patch, 3], dtype=float) 34 | for i in range(num_patches): 35 | local_neighbors = pts[ind_local[i], :] 36 | if local_neighbors.shape[0] >= num_points_per_patch: 37 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False) 38 | local_neighbors = local_neighbors[temp] 39 | local_neighbors[-1, :] = refer_pts[i, :] 40 | else: 41 | fix_idx = np.asarray(range(local_neighbors.shape[0])) 42 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch: 43 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0) 44 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0], 45 | replace=False) 46 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0) 47 | local_neighbors = local_neighbors[choice_idx] 48 | local_neighbors[-1, :] = refer_pts[i, :] 49 | local_patches[i] = local_neighbors 50 | 51 | return local_patches 52 | 53 | 54 | def prepare_patch(pcdpath, filename, keyptspath, trans_matrix): 55 | pcd = get_pcd(pcdpath, filename) 56 | keypts = get_keypts(keyptspath, filename) 57 | # load D3Feat keypts 58 | if is_D3Feat_keypts: 59 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + pcdpath.split('/')[-2] + '/' + filename + '.npy' 60 | keypts = np.load(keypts_path) 61 | keypts = keypts[-5000:, :] 62 | if is_rotate_dataset: 63 | # Add arbitrary rotation 64 | # rotate terminal frament with an arbitrary angle around the z-axis 65 | angles_3d = np.random.rand(3) * np.pi * 2 66 | R = cm.angles2rotation_matrix(angles_3d) 67 | T = np.identity(4) 68 | T[:3, :3] = R 69 | pcd.transform(T) 70 | keypts_pcd = make_open3d_point_cloud(keypts) 71 | keypts_pcd.transform(T) 72 | keypts = np.array(keypts_pcd.points) 73 | trans_matrix.append(T) 74 | 75 | local_patches = build_patch_input(pcd, keypts) # [num_keypts, 1024, 4] 76 | return local_patches 77 | 78 | 79 | def generate_descriptor(model, desc_name, pcdpath, keyptspath, descpath): 80 | model.eval() 81 | num_frag = len(os.listdir(pcdpath)) 82 | num_desc = len(os.listdir(descpath)) 83 | trans_matrix = [] 84 | if num_frag == num_desc: 85 | print("Descriptor already prepared.") 86 | return 87 | for j in range(num_frag): 88 | local_patches = prepare_patch(pcdpath, 'cloud_bin_' + str(j), keyptspath, trans_matrix) 89 | input_ = torch.tensor(local_patches.astype(np.float32)) 90 | B = input_.shape[0] 91 | input_ = input_.cuda() 92 | model = model.cuda() 93 | # calculate descriptors 94 | desc_list = [] 95 | start_time = time.time() 96 | desc_len = 32 97 | step_size = 100 98 | iter_num = np.int(np.ceil(B / step_size)) 99 | for k in range(iter_num): 100 | if k == iter_num - 1: 101 | desc = model(input_[k * step_size:, :, :]) 102 | else: 103 | desc = model(input_[k * step_size: (k + 1) * step_size, :, :]) 104 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 105 | del desc 106 | step_time = time.time() - start_time 107 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 108 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 109 | np.save(descpath + 'cloud_bin_' + str(j) + f".desc.{desc_name}.bin", desc.astype(np.float32)) 110 | if is_rotate_dataset: 111 | scene_name = pcdpath.split('/')[-2] 112 | all_trans_matrix[scene_name] = trans_matrix 113 | 114 | 115 | if __name__ == '__main__': 116 | scene_list = [ 117 | '7-scenes-redkitchen', 118 | 'sun3d-home_at-home_at_scan1_2013_jan_1', 119 | 'sun3d-home_md-home_md_scan9_2012_sep_30', 120 | 'sun3d-hotel_uc-scan3', 121 | 'sun3d-hotel_umd-maryland_hotel1', 122 | 'sun3d-hotel_umd-maryland_hotel3', 123 | 'sun3d-mit_76_studyroom-76-1studyroom2', 124 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika' 125 | ] 126 | 127 | experiment_id = time.strftime('%m%d%H%M') 128 | model_str = experiment_id # sys.argv[1] 129 | if not os.path.exists(f"SpinNet_desc_{model_str}/"): 130 | os.mkdir(f"SpinNet_desc_{model_str}") 131 | 132 | # dynamically load the model 133 | module_file_path = '../model.py' 134 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), module_file_path) 135 | module_name = '' 136 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) 137 | module = importlib.util.module_from_spec(module_spec) 138 | module_spec.loader.exec_module(module) 139 | model = module.Descriptor_Net(0.30, 9, 80, 40, 0.04, 30, '3DMatch') 140 | model = nn.DataParallel(model, device_ids=[0]) 141 | model.load_state_dict(torch.load('../../pre-trained_models/3DMatch_best.pkl')) 142 | 143 | all_trans_matrix = {} 144 | is_rotate_dataset = False 145 | is_D3Feat_keypts = False 146 | for scene in scene_list: 147 | pcdpath = f"../../data/3DMatch/fragments/{scene}/" 148 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/" 149 | keyptspath = interpath 150 | descpath = os.path.join('.', f"SpinNet_desc_{model_str}/{scene}/") 151 | if not os.path.exists(descpath): 152 | os.makedirs(descpath) 153 | start_time = time.time() 154 | print(f"Begin Processing {scene}") 155 | generate_descriptor(model, desc_name='SpinNet', pcdpath=pcdpath, keyptspath=keyptspath, descpath=descpath) 156 | print(f"Finish in {time.time() - start_time}s") 157 | if is_rotate_dataset: 158 | np.save(f"trans_matrix", all_trans_matrix) 159 | -------------------------------------------------------------------------------- /ThreeDMatch/Test/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import open3d 3 | import numpy as np 4 | 5 | 6 | def get_pcd(pcdpath, filename): 7 | return open3d.io.read_point_cloud(os.path.join(pcdpath, filename + '.ply')) 8 | 9 | 10 | def get_keypts(keyptspath, filename): 11 | keypts = np.fromfile(os.path.join(keyptspath, filename + '.keypts.bin'), dtype=np.float32) 12 | num_keypts = int(keypts[0]) 13 | keypts = keypts[1:].reshape([num_keypts, 3]) 14 | return keypts 15 | 16 | 17 | def get_ETH_keypts(pcd, keyptspath, filename): 18 | pts = np.array(pcd.points) 19 | key_ind = np.loadtxt(os.path.join(keyptspath, filename + '_Keypoints.txt'), dtype=np.int) 20 | keypts = pts[key_ind] 21 | return keypts 22 | 23 | 24 | def get_keypts_(keyptspath, filename): 25 | keypts = np.load(os.path.join(keyptspath, filename + f'.keypts.bin.npy')) 26 | return keypts 27 | 28 | 29 | def get_desc(descpath, filename, desc_name): 30 | if desc_name == '3dmatch': 31 | desc = np.fromfile(os.path.join(descpath, filename + '.desc.3dmatch.bin'), dtype=np.float32) 32 | num_desc = int(desc[0]) 33 | desc_size = int(desc[1]) 34 | desc = desc[2:].reshape([num_desc, desc_size]) 35 | elif desc_name == 'SpinNet': 36 | desc = np.load(os.path.join(descpath, filename + '.desc.SpinNet.bin.npy')) 37 | else: 38 | print("No such descriptor") 39 | exit(-1) 40 | return desc 41 | 42 | 43 | def loadlog(gtpath): 44 | with open(os.path.join(gtpath, 'gt.log')) as f: 45 | content = f.readlines() 46 | result = {} 47 | i = 0 48 | while i < len(content): 49 | line = content[i].replace("\n", "").split("\t")[0:3] 50 | trans = np.zeros([4, 4]) 51 | trans[0] = [float(x) for x in content[i + 1].replace("\n", "").split("\t")[0:4]] 52 | trans[1] = [float(x) for x in content[i + 2].replace("\n", "").split("\t")[0:4]] 53 | trans[2] = [float(x) for x in content[i + 3].replace("\n", "").split("\t")[0:4]] 54 | trans[3] = [float(x) for x in content[i + 4].replace("\n", "").split("\t")[0:4]] 55 | i = i + 5 56 | result[f'{int(line[0])}_{int(line[1])}'] = trans 57 | 58 | return result 59 | -------------------------------------------------------------------------------- /ThreeDMatch/Train/dataloader.py: -------------------------------------------------------------------------------- 1 | import time 2 | from ThreeDMatch.Train.dataset import ThreeDMatchDataset 3 | import torch 4 | 5 | 6 | def get_dataloader(root, split, batch_size=1, num_workers=4, shuffle=True, drop_last=True): 7 | dataset = ThreeDMatchDataset( 8 | root=root, 9 | split=split, 10 | batch_size=batch_size, 11 | shuffle=shuffle, 12 | drop_last=drop_last 13 | ) 14 | dataset.initial() 15 | dataloader = torch.utils.data.DataLoader( 16 | dataset=dataset, 17 | batch_size=batch_size, 18 | num_workers=num_workers, 19 | drop_last=drop_last 20 | ) 21 | 22 | return dataloader 23 | 24 | 25 | if __name__ == '__main__': 26 | dataset = 'sun3d' 27 | dataroot = "/data/3DMatch/whole" 28 | trainloader = get_dataloader(dataroot, split='test', batch_size=32) 29 | start_time = time.time() 30 | print(f"Totally {len(trainloader)} iter.") 31 | for iter, (patches, ids) in enumerate(trainloader): 32 | if iter % 100 == 0: 33 | print(f"Iter {iter}: {time.time() - start_time} s") 34 | print(f"On the fly: {time.time() - start_time}") 35 | -------------------------------------------------------------------------------- /ThreeDMatch/Train/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as Data 2 | import os 3 | import random 4 | import glob 5 | import pickle 6 | import open3d as o3d 7 | import numpy as np 8 | 9 | 10 | class ThreeDMatchDataset(Data.Dataset): 11 | def __init__(self, root, split, batch_size, shuffle, drop_last): 12 | """ 13 | Create ThreeDMatchDataset to read multiple training files 14 | Args: 15 | root: the path to the dataset file 16 | shuffle: whether the data need to shuffle 17 | """ 18 | self.patches_path = os.path.join(root, split) 19 | self.split = split 20 | # Get name of all training pkl files 21 | training_data_files = glob.glob(self.patches_path + '/*.pkl') 22 | ids = [file.split("/")[-1] for file in training_data_files] 23 | ids = sorted(ids, key=lambda x: int(x.split("_")[-1].split(".")[0])) 24 | ids = [file for file in ids if file.split("_")[1] == 'anc&pos'] 25 | self.training_data_files = ids 26 | # Get info of training files 27 | self.per_num_patch = int(training_data_files[0].split("/")[-1].split("_")[2]) 28 | self.dataset_len = int(ids[-1].split("_")[-1].split(".")[0]) * self.per_num_patch 29 | self.batch_size = batch_size 30 | self.shuffle = shuffle 31 | self.drop_last = drop_last 32 | # Record the loaded i-th training file 33 | self.num_file = 0 34 | # load poses for each type of patches 35 | self.per_patch_points = int(self.training_data_files[-1].split("_")[3]) 36 | self.num_framents = int(self.training_data_files[-1].split("_")[4].split(".")[0]) 37 | with open(os.path.join(root, 38 | f'{self.split}/{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl'), 39 | 'rb') as file: 40 | self.poses = pickle.load(file) 41 | print( 42 | f"load training poses {os.path.join(root, f'{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl')}") 43 | self.cur_pose_ind = 0 44 | 45 | def initial(self): 46 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file: 47 | self.patches = pickle.load(file) 48 | print(f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}") 49 | 50 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1]) 51 | poses = self.poses[self.cur_pose_ind:next_pose_ind] 52 | for i in range(len(self.patches)): 53 | ind = int(np.floor(i / self.per_num_patch)) 54 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6) 55 | self.patches[i] = np.concatenate([pose, self.patches[i]]) 56 | self.cur_pose_ind = next_pose_ind 57 | 58 | self.current_patches_num = len(self.patches) 59 | self.index = list(range(self.current_patches_num)) 60 | if self.shuffle: 61 | random.shuffle(self.patches) 62 | 63 | def __len__(self): 64 | return self.dataset_len 65 | 66 | def __getitem__(self, item): 67 | idx = self.index[0] 68 | patches = self.patches[idx] 69 | self.index = self.index[1:] 70 | self.current_patches_num -= 1 71 | 72 | if self.drop_last: 73 | if self.current_patches_num <= (len(self.patches) % self.batch_size): # reach the end of training file 74 | self.num_file = self.num_file + 1 75 | if self.num_file < len(self.training_data_files): 76 | remain_patches = [self.patches[i] for i in self.index] # the remained training patches 77 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file: 78 | self.patches = pickle.load(file) 79 | print( 80 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}") 81 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1]) 82 | poses = self.poses[self.cur_pose_ind:next_pose_ind] 83 | for i in range(len(self.patches)): 84 | ind = int(np.floor(i / self.per_num_patch)) 85 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6) 86 | self.patches[i] = np.concatenate([pose, self.patches[i]]) 87 | self.cur_pose_ind = next_pose_ind 88 | self.patches += remain_patches # add the remained patches to compose a set of new patches 89 | self.current_patches_num = len(self.patches) 90 | self.index = list(range(self.current_patches_num)) 91 | if self.shuffle: 92 | random.shuffle(self.patches) 93 | else: 94 | self.num_file = 0 95 | self.cur_pose_ind = 0 96 | self.initial() 97 | else: 98 | if self.current_patches_num <= 0: 99 | self.num_file = self.num_file + 1 100 | if self.num_file < len(self.training_data_files): 101 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file: 102 | self.patches = pickle.load(file) 103 | print( 104 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}") 105 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1]) 106 | poses = self.poses[self.cur_pose_ind:next_pose_ind] 107 | for i in range(len(self.patches)): 108 | ind = int(np.floor(i / self.per_num_patch)) 109 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6) 110 | self.patches[i] = np.concatenate([pose, self.patches[i]]) 111 | self.cur_pose_ind = next_pose_ind 112 | self.current_patches_num = len(self.patches) 113 | self.index = list(range(self.current_patches_num)) 114 | if self.shuffle: 115 | random.shuffle(self.patches) 116 | else: 117 | self.num_file = 0 118 | self.cur_pose_ind = 0 119 | self.initial() 120 | 121 | anc_local_patch = patches[2:, :3] 122 | pos_local_patch = patches[2:, 3:] 123 | rotate = patches[:2, :].reshape(12)[:9].reshape(3, 3) 124 | shift = patches[:2, :].reshape(12)[9:] 125 | 126 | # np.random.shuffle(anc_local_patch) 127 | # np.random.shuffle(pos_local_patch) 128 | 129 | return anc_local_patch, pos_local_patch, rotate, shift 130 | 131 | 132 | if __name__ == "__main__": 133 | data_root = "../data/3DMatch_patches/" 134 | batch_size = 48 135 | epoch = 1 136 | train_dataset = ThreeDMatchDataset(root=data_root, split='train', batch_size=batch_size, shuffle=True, 137 | drop_last=True) 138 | train_dataset.initial() 139 | for _ in range(epoch): 140 | train_iter = Data.DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True) 141 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(train_iter): 142 | B = anc_local_patch.shape[0] 143 | -------------------------------------------------------------------------------- /ThreeDMatch/Train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import time 5 | import shutil 6 | import sys 7 | 8 | sys.path.append('../../') 9 | from ThreeDMatch.Train.dataloader import get_dataloader 10 | from ThreeDMatch.Train.trainer import Trainer 11 | from network.SpinNet import Descriptor_Net 12 | from torch import optim 13 | 14 | 15 | class Args(object): 16 | def __init__(self): 17 | self.experiment_id = "Proposal" + time.strftime('%m%d%H%M') 18 | snapshot_root = 'snapshot/%s' % self.experiment_id 19 | tensorboard_root = 'tensorboard/%s' % self.experiment_id 20 | os.makedirs(snapshot_root, exist_ok=True) 21 | os.makedirs(tensorboard_root, exist_ok=True) 22 | shutil.copy2(os.path.join('', 'train.py'), os.path.join(snapshot_root, 'train.py')) 23 | shutil.copy2(os.path.join('', 'trainer.py'), os.path.join(snapshot_root, 'trainer.py')) 24 | shutil.copy2(os.path.join('', '../../network/SpinNet.py'), os.path.join(snapshot_root, 'SpinNet.py')) 25 | shutil.copy2(os.path.join('', '../../network/ThreeDCCN.py'), os.path.join(snapshot_root, 'ThreeDCCN.py')) 26 | shutil.copy2(os.path.join('', '../../loss/desc_loss.py'), os.path.join(snapshot_root, 'loss.py')) 27 | self.epoch = 20 28 | self.batch_size = 76 29 | self.rad_n = 9 30 | self.azi_n = 80 31 | self.ele_n = 40 32 | self.des_r = 0.30 33 | self.voxel_r = 0.04 34 | self.voxel_sample = 30 35 | 36 | self.dataset = '3DMatch' 37 | self.data_train_dir = '../../data/3DMatch/patches' 38 | self.data_val_dir = '../../data/3DMatch/patches' 39 | 40 | self.gpu_mode = True 41 | self.verbose = True 42 | self.freeze_epoch = 5 43 | 44 | # model & optimizer 45 | self.model = Descriptor_Net(self.des_r, self.rad_n, self.azi_n, self.ele_n, 46 | self.voxel_r, self.voxel_sample, self.dataset) 47 | self.pretrain = '' 48 | self.parameter = self.model.get_parameter() 49 | self.optimizer = optim.Adam(self.parameter, lr=0.001, betas=(0.9, 0.999), weight_decay=1e-6) 50 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.5) 51 | self.scheduler_interval = 5 52 | 53 | # dataloader 54 | self.train_loader = get_dataloader(root=self.data_train_dir, 55 | batch_size=self.batch_size, 56 | split='train', 57 | shuffle=True, 58 | num_workers=0, # if the dataset is offline generated, must 0 59 | ) 60 | self.val_loader = get_dataloader(root=self.data_val_dir, 61 | batch_size=self.batch_size, 62 | split='val', 63 | shuffle=False, 64 | num_workers=0, # if the dataset is offline generated, must 0 65 | ) 66 | 67 | print("Training set size:", self.train_loader.dataset.__len__()) 68 | print("Validate set size:", self.val_loader.dataset.__len__()) 69 | 70 | # snapshot 71 | self.snapshot_interval = int(self.train_loader.dataset.__len__() / self.batch_size / 2) 72 | self.save_dir = os.path.join(snapshot_root, 'models/') 73 | self.result_dir = os.path.join(snapshot_root, 'results/') 74 | self.tboard_dir = tensorboard_root 75 | 76 | # evaluate 77 | self.evaluate_interval = 1 78 | 79 | self.check_args() 80 | 81 | def check_args(self): 82 | """checking arguments""" 83 | if not os.path.exists(self.save_dir): 84 | os.makedirs(self.save_dir) 85 | if not os.path.exists(self.result_dir): 86 | os.makedirs(self.result_dir) 87 | if not os.path.exists(self.tboard_dir): 88 | os.makedirs(self.tboard_dir) 89 | return self 90 | 91 | 92 | if __name__ == '__main__': 93 | args = Args() 94 | trainer = Trainer(args) 95 | trainer.train() 96 | -------------------------------------------------------------------------------- /ThreeDMatch/Train/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import time, os 4 | import numpy as np 5 | from loss.desc_loss import ContrastiveLoss 6 | from tensorboardX import SummaryWriter 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, args): 11 | # parameters 12 | self.epoch = args.epoch 13 | self.batch_size = args.batch_size 14 | self.dataset = args.dataset 15 | self.save_dir = args.save_dir 16 | self.result_dir = args.result_dir 17 | self.gpu_mode = args.gpu_mode 18 | self.verbose = args.verbose 19 | self.model = args.model 20 | self.optimizer = args.optimizer 21 | self.scheduler = args.scheduler 22 | self.scheduler_interval = args.scheduler_interval 23 | self.snapshot_interval = args.snapshot_interval 24 | self.evaluate_interval = args.evaluate_interval 25 | self.writer = SummaryWriter(log_dir=args.tboard_dir) 26 | 27 | self.train_loader = args.train_loader 28 | self.val_loader = args.val_loader 29 | 30 | self.desc_loss = ContrastiveLoss() 31 | 32 | if self.gpu_mode: 33 | self.model = self.model.cuda() 34 | self.model = torch.nn.DataParallel(self.model, device_ids=[0]) 35 | 36 | if args.pretrain != '': 37 | self._load_pretrain(args.pretrain) 38 | 39 | def train(self): 40 | self.train_hist = { 41 | 'loss': [], 42 | 'per_epoch_time': [], 43 | 'total_time': [] 44 | } 45 | best_loss = 1000000000 46 | print('training start!!') 47 | start_time = time.time() 48 | 49 | self.model.train() 50 | for epoch in range(self.epoch): 51 | 52 | self.train_epoch(epoch) 53 | 54 | if epoch % self.evaluate_interval == 0 or epoch == 0: 55 | res = self.evaluate() 56 | print(f'Evaluation: Epoch {epoch}: Loss {res["loss"]}') 57 | 58 | if res['loss'] < best_loss: 59 | best_loss = res['loss'] 60 | self._snapshot('best') 61 | if self.writer: 62 | self.writer.add_scalar('Loss', res['loss'], epoch) 63 | 64 | if epoch % self.scheduler_interval == 0: 65 | old_lr = self.optimizer.param_groups[0]['lr'] 66 | self.scheduler.step() 67 | new_lr = self.optimizer.param_groups[0]['lr'] 68 | print('update detector learning rate: %f -> %f' % (old_lr, new_lr)) 69 | 70 | if self.writer: 71 | self.writer.add_scalar('Learning Rate', self._get_lr(), epoch) 72 | self.writer.add_scalar('Train Loss', self.train_hist['loss'][-1], epoch) 73 | 74 | # finish all epoch 75 | self.train_hist['total_time'].append(time.time() - start_time) 76 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), 77 | self.epoch, self.train_hist['total_time'][0])) 78 | print("Training finish!... save training results") 79 | 80 | def train_epoch(self, epoch): 81 | epoch_start_time = time.time() 82 | loss_buf = [] 83 | num_batch = int(len(self.train_loader.dataset) / self.batch_size) 84 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.train_loader): 85 | 86 | B = anc_local_patch.shape[0] 87 | anc_local_patch = anc_local_patch.float() 88 | pos_local_patch = pos_local_patch.float() 89 | rotate = rotate.float() 90 | shift = shift.float() 91 | 92 | if self.gpu_mode: 93 | anc_local_patch = anc_local_patch.cuda() 94 | pos_local_patch = pos_local_patch.cuda() 95 | 96 | # forward 97 | self.optimizer.zero_grad() 98 | a_des = self.model(anc_local_patch) 99 | p_des = self.model(pos_local_patch) 100 | anc_des = F.normalize(a_des.view(B, -1), p=2, dim=1) 101 | pos_des = F.normalize(p_des.view(B, -1), p=2, dim=1) 102 | 103 | # calculate the contrastive loss 104 | des_loss, accuracy = self.desc_loss(anc_des, pos_des) 105 | loss = des_loss 106 | 107 | # backward 108 | loss.backward() 109 | self.optimizer.step() 110 | loss_buf.append(float(loss)) 111 | 112 | if iter % self.snapshot_interval == 0: 113 | self._snapshot(f'{epoch}_{iter + 1}') 114 | 115 | if iter % 200 == 0 and self.verbose: 116 | iter_time = time.time() - epoch_start_time 117 | print(f"Epoch: {epoch} [{iter:4d}/{num_batch}] loss: {loss:.2f} time: {iter_time:.2f}s") 118 | print(f"Accuracy: {accuracy.item():.4f}\n") 119 | del loss 120 | del anc_local_patch 121 | del pos_local_patch 122 | # finish one epoch 123 | epoch_time = time.time() - epoch_start_time 124 | self.train_hist['per_epoch_time'].append(epoch_time) 125 | self.train_hist['loss'].append(np.mean(loss_buf)) 126 | print(f'Epoch {epoch}: Loss {np.mean(loss_buf)}, time {epoch_time:.4f}s') 127 | 128 | del loss_buf 129 | 130 | def evaluate(self): 131 | self.model.eval() 132 | loss_buf = [] 133 | with torch.no_grad(): 134 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.val_loader): 135 | 136 | B = anc_local_patch.shape[0] 137 | anc_local_patch = anc_local_patch.float() 138 | pos_local_patch = pos_local_patch.float() 139 | rotate = rotate.float() 140 | shift = shift.float() 141 | 142 | if self.gpu_mode: 143 | anc_local_patch = anc_local_patch.cuda() 144 | pos_local_patch = pos_local_patch.cuda() 145 | 146 | # forward 147 | a_des = self.model(anc_local_patch) 148 | p_des = self.model(pos_local_patch) 149 | 150 | # descriptor loss 151 | anc_des = F.normalize(a_des.view(B, -1), p=2, dim=1) 152 | pos_des = F.normalize(p_des.view(B, -1), p=2, dim=1) 153 | 154 | # calculate the contrastive loss 155 | des_loss, accuracy = self.desc_loss(anc_des, pos_des) 156 | loss = des_loss 157 | loss_buf.append(float(loss)) 158 | 159 | del loss 160 | del anc_local_patch 161 | del pos_local_patch 162 | 163 | self.model.train() 164 | 165 | res = { 166 | 'loss': np.mean(loss_buf), 167 | } 168 | del loss_buf 169 | return res 170 | 171 | def _snapshot(self, epoch): 172 | save_dir = os.path.join(self.save_dir, self.dataset) 173 | torch.save(self.model.state_dict(), save_dir + "_" + str(epoch) + '.pkl') 174 | print(f"Save model to {save_dir}_{str(epoch)}.pkl") 175 | 176 | def _load_pretrain(self, pretrain): 177 | state_dict = torch.load(pretrain) 178 | self.model.load_state_dict(state_dict) 179 | print(f"Load model from {pretrain}.pkl") 180 | 181 | def _get_lr(self, group=0): 182 | return self.optimizer.param_groups[group]['lr'] 183 | -------------------------------------------------------------------------------- /figs/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig1.png -------------------------------------------------------------------------------- /figs/Fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig2.png -------------------------------------------------------------------------------- /figs/Fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig3.png -------------------------------------------------------------------------------- /figs/Fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig4.png -------------------------------------------------------------------------------- /figs/Fig5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig5.png -------------------------------------------------------------------------------- /figs/Table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table1.png -------------------------------------------------------------------------------- /figs/Table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table2.png -------------------------------------------------------------------------------- /figs/Table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table3.png -------------------------------------------------------------------------------- /figs/Table4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table4.png -------------------------------------------------------------------------------- /figs/Table5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table5.png -------------------------------------------------------------------------------- /figs/Table6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table6.png -------------------------------------------------------------------------------- /figs/Table7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table7.png -------------------------------------------------------------------------------- /generalization/KITTI-to-ThreeDMatch/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../') 4 | import open3d 5 | import numpy as np 6 | import time 7 | import os 8 | from ThreeDMatch.Test.tools import get_pcd, get_keypts, get_desc, loadlog 9 | from sklearn.neighbors import KDTree 10 | 11 | 12 | def calculate_M(source_desc, target_desc): 13 | """ 14 | Find the mutually closest point pairs in feature space. 15 | source and target are descriptor for 2 point cloud key points. [5000, 512] 16 | """ 17 | 18 | kdtree_s = KDTree(target_desc) 19 | sourceNNdis, sourceNNidx = kdtree_s.query(source_desc, 1) 20 | kdtree_t = KDTree(source_desc) 21 | targetNNdis, targetNNidx = kdtree_t.query(target_desc, 1) 22 | result = [] 23 | for i in range(len(sourceNNidx)): 24 | if targetNNidx[sourceNNidx[i]] == i: 25 | result.append([i, sourceNNidx[i][0]]) 26 | return np.array(result) 27 | 28 | 29 | def register2Fragments(id1, id2, keyptspath, descpath, resultpath, desc_name='SpinNet'): 30 | cloud_bin_s = f'cloud_bin_{id1}' 31 | cloud_bin_t = f'cloud_bin_{id2}' 32 | write_file = f'{cloud_bin_s}_{cloud_bin_t}.rt.txt' 33 | if os.path.exists(os.path.join(resultpath, write_file)): 34 | print(f"{write_file} already exists.") 35 | return 0, 0, 0 36 | 37 | if is_D3Feat_keypts: 38 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_s + '.npy' 39 | source_keypts = np.load(keypts_path) 40 | source_keypts = source_keypts[-num_keypoints:, :] 41 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_t + '.npy' 42 | target_keypts = np.load(keypts_path) 43 | target_keypts = target_keypts[-num_keypoints:, :] 44 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name) 45 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name) 46 | source_desc = np.nan_to_num(source_desc) 47 | target_desc = np.nan_to_num(target_desc) 48 | source_desc = source_desc[-num_keypoints:, :] 49 | target_desc = target_desc[-num_keypoints:, :] 50 | else: 51 | source_keypts = get_keypts(keyptspath, cloud_bin_s) 52 | target_keypts = get_keypts(keyptspath, cloud_bin_t) 53 | # print(source_keypts.shape) 54 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name) 55 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name) 56 | source_desc = np.nan_to_num(source_desc) 57 | target_desc = np.nan_to_num(target_desc) 58 | if source_desc.shape[0] > num_keypoints: 59 | rand_ind = np.random.choice(source_desc.shape[0], num_keypoints, replace=False) 60 | source_keypts = source_keypts[rand_ind] 61 | target_keypts = target_keypts[rand_ind] 62 | source_desc = source_desc[rand_ind] 63 | target_desc = target_desc[rand_ind] 64 | 65 | key = f'{cloud_bin_s.split("_")[-1]}_{cloud_bin_t.split("_")[-1]}' 66 | if key not in gtLog.keys(): 67 | num_inliers = 0 68 | inlier_ratio = 0 69 | gt_flag = 0 70 | else: 71 | # find mutually cloest point. 72 | corr = calculate_M(source_desc, target_desc) 73 | 74 | gtTrans = gtLog[key] 75 | frag1 = source_keypts[corr[:, 0]] 76 | frag2_pc = open3d.geometry.PointCloud() 77 | frag2_pc.points = open3d.utility.Vector3dVector(target_keypts[corr[:, 1]]) 78 | frag2_pc.transform(gtTrans) 79 | frag2 = np.asarray(frag2_pc.points) 80 | distance = np.sqrt(np.sum(np.power(frag1 - frag2, 2), axis=1)) 81 | num_inliers = np.sum(distance < 0.10) 82 | inlier_ratio = num_inliers / len(distance) 83 | gt_flag = 1 84 | 85 | # calculate the transformation matrix using RANSAC, this is for Registration Recall. 86 | source_pcd = open3d.geometry.PointCloud() 87 | source_pcd.points = open3d.utility.Vector3dVector(source_keypts) 88 | target_pcd = open3d.geometry.PointCloud() 89 | target_pcd.points = open3d.utility.Vector3dVector(target_keypts) 90 | s_desc = open3d.pipelines.registration.Feature() 91 | s_desc.data = source_desc.T 92 | t_desc = open3d.pipelines.registration.Feature() 93 | t_desc.data = target_desc.T 94 | 95 | # Another registration method 96 | corr_v = open3d.utility.Vector2iVector(corr) 97 | result = open3d.pipelines.registration.registration_ransac_based_on_correspondence( 98 | source_pcd, target_pcd, corr_v, 99 | 0.05, 100 | open3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3, 101 | open3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) 102 | 103 | # write the transformation matrix into .log file for evaluation. 104 | with open(os.path.join(logpath, f'{desc_name}_{timestr}.log'), 'a+') as f: 105 | trans = result.transformation 106 | trans = np.linalg.inv(trans) 107 | s1 = f'{id1}\t {id2}\t 37\n' 108 | f.write(s1) 109 | f.write(f"{trans[0, 0]}\t {trans[0, 1]}\t {trans[0, 2]}\t {trans[0, 3]}\t \n") 110 | f.write(f"{trans[1, 0]}\t {trans[1, 1]}\t {trans[1, 2]}\t {trans[1, 3]}\t \n") 111 | f.write(f"{trans[2, 0]}\t {trans[2, 1]}\t {trans[2, 2]}\t {trans[2, 3]}\t \n") 112 | f.write(f"{trans[3, 0]}\t {trans[3, 1]}\t {trans[3, 2]}\t {trans[3, 3]}\t \n") 113 | 114 | s = f"{cloud_bin_s}\t{cloud_bin_t}\t{num_inliers}\t{inlier_ratio:.8f}\t{gt_flag}" 115 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'w+') as f: 116 | f.write(s) 117 | return num_inliers, inlier_ratio, gt_flag 118 | 119 | 120 | def read_register_result(id1, id2): 121 | cloud_bin_s = f'cloud_bin_{id1}' 122 | cloud_bin_t = f'cloud_bin_{id2}' 123 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'r') as f: 124 | content = f.readlines() 125 | nums = content[0].replace("\n", "").split("\t")[2:5] 126 | return nums 127 | 128 | 129 | if __name__ == '__main__': 130 | scene_list = [ 131 | '7-scenes-redkitchen', 132 | 'sun3d-home_at-home_at_scan1_2013_jan_1', 133 | 'sun3d-home_md-home_md_scan9_2012_sep_30', 134 | 'sun3d-hotel_uc-scan3', 135 | 'sun3d-hotel_umd-maryland_hotel1', 136 | 'sun3d-hotel_umd-maryland_hotel3', 137 | 'sun3d-mit_76_studyroom-76-1studyroom2', 138 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika' 139 | ] 140 | desc_name = 'SpinNet' 141 | timestr = sys.argv[1] 142 | inliers_list = [] 143 | recall_list = [] 144 | inliers_ratio_list = [] 145 | num_keypoints = 5000 146 | is_D3Feat_keypts = False 147 | for scene in scene_list: 148 | pcdpath = f"../../data/3DMatch/fragments/{scene}/" 149 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/" 150 | gtpath = f'../../data/3DMatch/fragments/{scene}-evaluation/' 151 | keyptspath = interpath # os.path.join(interpath, "keypoints/") 152 | descpath = os.path.join(".", f"{desc_name}_desc_{timestr}/{scene}") 153 | gtLog = loadlog(gtpath) 154 | logpath = f"log_result/{scene}-evaluation" 155 | resultpath = os.path.join(".", f"pred_result/{scene}/{desc_name}_result_{timestr}") 156 | if not os.path.exists(resultpath): 157 | os.makedirs(resultpath) 158 | if not os.path.exists(logpath): 159 | os.makedirs(logpath) 160 | 161 | # register each pair 162 | num_frag = len(os.listdir(pcdpath)) 163 | print(f"Start Evaluate Descriptor {desc_name} for {scene}") 164 | start_time = time.time() 165 | for id1 in range(num_frag): 166 | for id2 in range(id1 + 1, num_frag): 167 | num_inliers, inlier_ratio, gt_flag = register2Fragments(id1, id2, keyptspath, descpath, resultpath, 168 | desc_name) 169 | print(f"Finish Evaluation, time: {time.time() - start_time:.2f}s") 170 | 171 | # evaluate 172 | result = [] 173 | for id1 in range(num_frag): 174 | for id2 in range(id1 + 1, num_frag): 175 | line = read_register_result(id1, id2) 176 | result.append([int(line[0]), float(line[1]), int(line[2])]) 177 | result = np.array(result) 178 | indices_results = np.sum(result[:, 2] == 1) 179 | correct_match = np.sum(result[:, 1] > 0.05) 180 | recall = float(correct_match / indices_results) * 100 181 | print(f"Correct Match {correct_match}, ground truth Match {indices_results}") 182 | print(f"Recall {recall}%") 183 | ave_num_inliers = np.sum(np.where(result[:, 1] > 0.05, result[:, 0], np.zeros(result.shape[0]))) / correct_match 184 | print(f"Average Num Inliners: {ave_num_inliers}") 185 | ave_inlier_ratio = np.sum( 186 | np.where(result[:, 1] > 0.05, result[:, 1], np.zeros(result.shape[0]))) / correct_match 187 | print(f"Average Num Inliner Ratio: {ave_inlier_ratio}") 188 | recall_list.append(recall) 189 | inliers_list.append(ave_num_inliers) 190 | inliers_ratio_list.append(ave_inlier_ratio) 191 | print(recall_list) 192 | average_recall = sum(recall_list) / len(recall_list) 193 | print(f"All 8 scene, average recall: {average_recall}%") 194 | average_inliers = sum(inliers_list) / len(inliers_list) 195 | print(f"All 8 scene, average num inliers: {average_inliers}") 196 | average_inliers_ratio = sum(inliers_ratio_list) / len(inliers_list) 197 | print(f"All 8 scene, average num inliers ratio: {average_inliers_ratio}") 198 | -------------------------------------------------------------------------------- /generalization/KITTI-to-ThreeDMatch/preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import time 5 | import numpy as np 6 | import torch 7 | import shutil 8 | import torch.nn as nn 9 | import sys 10 | 11 | sys.path.append('../../') 12 | from ThreeDMatch.Test.tools import get_pcd, get_keypts 13 | from sklearn.neighbors import KDTree 14 | import importlib 15 | import script.common as cm 16 | import open3d 17 | 18 | 19 | def make_open3d_point_cloud(xyz, color=None): 20 | pcd = open3d.geometry.PointCloud() 21 | pcd.points = open3d.utility.Vector3dVector(xyz) 22 | if color is not None: 23 | pcd.paint_uniform_color(color) 24 | return pcd 25 | 26 | 27 | def build_patch_input(pcd, keypts, vicinity=0.3, num_points_per_patch=2048): 28 | refer_pts = keypts.astype(np.float32) 29 | pts = np.array(pcd.points).astype(np.float32) 30 | num_patches = refer_pts.shape[0] 31 | tree = KDTree(pts[:, 0:3]) 32 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity) 33 | local_patches = np.zeros([num_patches, num_points_per_patch, 3], dtype=float) 34 | for i in range(num_patches): 35 | local_neighbors = pts[ind_local[i], :] 36 | if local_neighbors.shape[0] >= num_points_per_patch: 37 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False) 38 | local_neighbors = local_neighbors[temp] 39 | local_neighbors[-1, :] = refer_pts[i, :] 40 | else: 41 | fix_idx = np.asarray(range(local_neighbors.shape[0])) 42 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch: 43 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0) 44 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0], 45 | replace=False) 46 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0) 47 | local_neighbors = local_neighbors[choice_idx] 48 | local_neighbors[-1, :] = refer_pts[i, :] 49 | local_patches[i] = local_neighbors 50 | 51 | return local_patches 52 | 53 | 54 | def prepare_patch(pcdpath, filename, keyptspath, trans_matrix): 55 | pcd = get_pcd(pcdpath, filename) 56 | keypts = get_keypts(keyptspath, filename) 57 | # load D3Feat keypts 58 | if is_D3Feat_keypts: 59 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + pcdpath.split('/')[-2] + '/' + filename + '.npy' 60 | keypts = np.load(keypts_path) 61 | keypts = keypts[-5000:, :] 62 | if is_rotate_dataset: 63 | # Add arbitrary rotation 64 | # rotate terminal frament with an arbitrary angle 65 | angles_3d = np.random.rand(3) * np.pi * 2 66 | R = cm.angles2rotation_matrix(angles_3d) 67 | T = np.identity(4) 68 | T[:3, :3] = R 69 | pcd.transform(T) 70 | keypts_pcd = make_open3d_point_cloud(keypts) 71 | keypts_pcd.transform(T) 72 | keypts = np.array(keypts_pcd.points) 73 | trans_matrix.append(T) 74 | 75 | local_patches = build_patch_input(pcd, keypts, des_r) 76 | return local_patches 77 | 78 | 79 | def generate_descriptor(model, desc_name, pcdpath, keyptspath, descpath): 80 | model.eval() 81 | num_frag = len(os.listdir(pcdpath)) 82 | num_desc = len(os.listdir(descpath)) 83 | trans_matrix = [] 84 | if num_frag == num_desc: 85 | print("Descriptor already prepared.") 86 | return 87 | for j in range(num_frag): 88 | local_patches = prepare_patch(pcdpath, 'cloud_bin_' + str(j), keyptspath, trans_matrix) 89 | input_ = torch.tensor(local_patches.astype(np.float32)) 90 | B = input_.shape[0] 91 | input_ = input_.cuda() 92 | model = model.cuda() 93 | # calculate descriptors 94 | desc_list = [] 95 | start_time = time.time() 96 | desc_len = 32 97 | step_size = 100 98 | iter_num = np.int(np.ceil(B / step_size)) 99 | for k in range(iter_num): 100 | if k == iter_num - 1: 101 | desc = model(input_[k * step_size:, :, :]) 102 | else: 103 | desc = model(input_[k * step_size: (k + 1) * step_size, :, :]) 104 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 105 | del desc 106 | step_time = time.time() - start_time 107 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 108 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 109 | np.save(descpath + 'cloud_bin_' + str(j) + f".desc.{desc_name}.bin", desc.astype(np.float32)) 110 | if is_rotate_dataset: 111 | scene_name = pcdpath.split('/')[-2] 112 | all_trans_matrix[scene_name] = trans_matrix 113 | 114 | 115 | if __name__ == '__main__': 116 | scene_list = [ 117 | '7-scenes-redkitchen', 118 | 'sun3d-home_at-home_at_scan1_2013_jan_1', 119 | 'sun3d-home_md-home_md_scan9_2012_sep_30', 120 | 'sun3d-hotel_uc-scan3', 121 | 'sun3d-hotel_umd-maryland_hotel1', 122 | 'sun3d-hotel_umd-maryland_hotel3', 123 | 'sun3d-mit_76_studyroom-76-1studyroom2', 124 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika' 125 | ] 126 | 127 | experiment_id = time.strftime('%m%d%H%M') 128 | model_str = experiment_id # sys.argv[1] 129 | if not os.path.exists(f"SpinNet_desc_{model_str}/"): 130 | os.mkdir(f"SpinNet_desc_{model_str}") 131 | 132 | # dynamically load the model 133 | module_file_path = '../model.py' 134 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), '../model.py') 135 | module_name = '' 136 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) 137 | module = importlib.util.module_from_spec(module_spec) 138 | module_spec.loader.exec_module(module) 139 | 140 | des_r = 0.45 141 | model = module.Descriptor_Net(des_r, 9, 60, 30, 0.15, 30, '3DMatch') 142 | model = nn.DataParallel(model, device_ids=[0]) 143 | model.load_state_dict(torch.load('../../pre-trained_models/KITTI_best.pkl')) 144 | all_trans_matrix = {} 145 | is_rotate_dataset = False 146 | is_D3Feat_keypts = False 147 | for scene in scene_list: 148 | pcdpath = f"../../data/3DMatch/fragments/{scene}/" 149 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/" 150 | keyptspath = interpath 151 | descpath = os.path.join('.', f"SpinNet_desc_{model_str}/{scene}/") 152 | if not os.path.exists(descpath): 153 | os.makedirs(descpath) 154 | start_time = time.time() 155 | print(f"Begin Processing {scene}") 156 | generate_descriptor(model, desc_name='SpinNet', pcdpath=pcdpath, keyptspath=keyptspath, descpath=descpath) 157 | print(f"Finish in {time.time() - start_time}s") 158 | if is_rotate_dataset: 159 | np.save(f"trans_matrix", all_trans_matrix) 160 | -------------------------------------------------------------------------------- /generalization/ThreeDMatch-to-ETH/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../../') 4 | import open3d 5 | import numpy as np 6 | import time 7 | import os 8 | from ThreeDMatch.Test.tools import get_pcd, get_ETH_keypts, get_desc, loadlog 9 | from sklearn.neighbors import KDTree 10 | import glob 11 | 12 | 13 | def calculate_M(source_desc, target_desc): 14 | """ 15 | Find the mutually closest point pairs in feature space. 16 | source and target are descriptor for 2 point cloud key points. [5000, 512] 17 | """ 18 | 19 | kdtree_s = KDTree(target_desc) 20 | sourceNNdis, sourceNNidx = kdtree_s.query(source_desc, 1) 21 | kdtree_t = KDTree(source_desc) 22 | targetNNdis, targetNNidx = kdtree_t.query(target_desc, 1) 23 | result = [] 24 | for i in range(len(sourceNNidx)): 25 | if targetNNidx[sourceNNidx[i]] == i: 26 | result.append([i, sourceNNidx[i][0]]) 27 | return np.array(result) 28 | 29 | 30 | def register2Fragments(id1, id2, keyptspath, descpath, resultpath, desc_name='ppf'): 31 | cloud_bin_s = f'Hokuyo_{id1}' 32 | cloud_bin_t = f'Hokuyo_{id2}' 33 | write_file = f'{cloud_bin_s}_{cloud_bin_t}.rt.txt' 34 | if os.path.exists(os.path.join(resultpath, write_file)): 35 | # print(f"{write_file} already exists.") 36 | return 0, 0, 0 37 | pcd_s = get_pcd(pcdpath, cloud_bin_s) 38 | source_keypts = get_ETH_keypts(pcd_s, keyptspath, cloud_bin_s) 39 | pcd_t = get_pcd(pcdpath, cloud_bin_t) 40 | target_keypts = get_ETH_keypts(pcd_t, keyptspath, cloud_bin_t) 41 | # print(source_keypts.shape) 42 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name) 43 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name) 44 | source_desc = np.nan_to_num(source_desc) 45 | target_desc = np.nan_to_num(target_desc) 46 | 47 | key = f'{cloud_bin_s.split("_")[-1]}_{cloud_bin_t.split("_")[-1]}' 48 | if key not in gtLog.keys(): 49 | num_inliers = 0 50 | inlier_ratio = 0 51 | gt_flag = 0 52 | else: 53 | # find mutually cloest point. 54 | corr = calculate_M(source_desc, target_desc) 55 | 56 | gtTrans = gtLog[key] 57 | frag1 = source_keypts[corr[:, 0]] 58 | frag2_pc = open3d.geometry.PointCloud() 59 | frag2_pc.points = open3d.utility.Vector3dVector(target_keypts[corr[:, 1]]) 60 | frag2_pc.transform(gtTrans) 61 | frag2 = np.asarray(frag2_pc.points) 62 | distance = np.sqrt(np.sum(np.power(frag1 - frag2, 2), axis=1)) 63 | num_inliers = np.sum(distance < 0.1) 64 | inlier_ratio = num_inliers / len(distance) 65 | gt_flag = 1 66 | 67 | # calculate the transformation matrix using RANSAC, this is for Registration Recall. 68 | source_pcd = open3d.geometry.PointCloud() 69 | source_pcd.points = open3d.utility.Vector3dVector(source_keypts) 70 | target_pcd = open3d.geometry.PointCloud() 71 | target_pcd.points = open3d.utility.Vector3dVector(target_keypts) 72 | s_desc = open3d.pipelines.registration.Feature() 73 | s_desc.data = source_desc.T 74 | t_desc = open3d.pipelines.registration.Feature() 75 | t_desc.data = target_desc.T 76 | result = open3d.pipelines.registration.registration_ransac_based_on_feature_matching( 77 | source_pcd, target_pcd, s_desc, t_desc, 78 | 0.05, 79 | open3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3, 80 | [open3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 81 | open3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(0.05)], 82 | open3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) 83 | # write the transformation matrix into .log file for evaluation. 84 | with open(os.path.join(logpath, f'{desc_name}_{timestr}.log'), 'a+') as f: 85 | trans = result.transformation 86 | trans = np.linalg.inv(trans) 87 | s1 = f'{id1}\t {id2}\t 37\n' 88 | f.write(s1) 89 | f.write(f"{trans[0, 0]}\t {trans[0, 1]}\t {trans[0, 2]}\t {trans[0, 3]}\t \n") 90 | f.write(f"{trans[1, 0]}\t {trans[1, 1]}\t {trans[1, 2]}\t {trans[1, 3]}\t \n") 91 | f.write(f"{trans[2, 0]}\t {trans[2, 1]}\t {trans[2, 2]}\t {trans[2, 3]}\t \n") 92 | f.write(f"{trans[3, 0]}\t {trans[3, 1]}\t {trans[3, 2]}\t {trans[3, 3]}\t \n") 93 | 94 | s = f"{cloud_bin_s}\t{cloud_bin_t}\t{num_inliers}\t{inlier_ratio:.8f}\t{gt_flag}" 95 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'w+') as f: 96 | f.write(s) 97 | return num_inliers, inlier_ratio, gt_flag 98 | 99 | 100 | def read_register_result(id1, id2): 101 | cloud_bin_s = f'Hokuyo_{id1}' 102 | cloud_bin_t = f'Hokuyo_{id2}' 103 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'r') as f: 104 | content = f.readlines() 105 | nums = content[0].replace("\n", "").split("\t")[2:5] 106 | return nums 107 | 108 | 109 | if __name__ == '__main__': 110 | scene_list = [ 111 | 'gazebo_summer', 112 | 'gazebo_winter', 113 | 'wood_autmn', 114 | 'wood_summer', 115 | ] 116 | desc_name = 'SpinNet' 117 | timestr = sys.argv[1] 118 | inliers_list = [] 119 | recall_list = [] 120 | for scene in scene_list: 121 | pcdpath = f"../../data/ETH/{scene}/" 122 | interpath = f"../../data/ETH/{scene}/01_Keypoints/" 123 | gtpath = f'../../data/ETH/{scene}/' 124 | keyptspath = interpath # os.path.join(interpath, "keypoints/") 125 | descpath = os.path.join(".", f"{desc_name}_desc_{timestr}/{scene}") 126 | logpath = f"log_result/{scene}-evaluation" 127 | gtLog = loadlog(gtpath) 128 | resultpath = os.path.join(".", f"pred_result/{scene}/{desc_name}_result_{timestr}") 129 | if not os.path.exists(resultpath): 130 | os.makedirs(resultpath) 131 | if not os.path.exists(logpath): 132 | os.makedirs(logpath) 133 | 134 | # register each pair 135 | fragments = glob.glob(pcdpath + '*.ply') 136 | num_frag = len(fragments) 137 | print(f"Start Evaluate Descriptor {desc_name} for {scene}") 138 | start_time = time.time() 139 | for id1 in range(num_frag): 140 | for id2 in range(id1 + 1, num_frag): 141 | num_inliers, inlier_ratio, gt_flag = register2Fragments(id1, id2, keyptspath, descpath, resultpath, 142 | desc_name) 143 | print(f"Finish Evaluation, time: {time.time() - start_time:.2f}s") 144 | 145 | # evaluate 146 | result = [] 147 | for id1 in range(num_frag): 148 | for id2 in range(id1 + 1, num_frag): 149 | line = read_register_result(id1, id2) 150 | result.append([int(line[0]), float(line[1]), int(line[2])]) 151 | result = np.array(result) 152 | indices_results = np.sum(result[:, 2] == 1) 153 | correct_match = np.sum(result[:, 1] > 0.05) 154 | recall = float(correct_match / indices_results) * 100 155 | print(f"Correct Match {correct_match}, ground truth Match {indices_results}") 156 | print(f"Recall {recall}%") 157 | ave_num_inliers = np.sum(np.where(result[:, 1] > 0.05, result[:, 0], np.zeros(result.shape[0]))) / correct_match 158 | print(f"Average Num Inliners: {ave_num_inliers}") 159 | recall_list.append(recall) 160 | inliers_list.append(ave_num_inliers) 161 | print(recall_list) 162 | average_recall = sum(recall_list) / len(recall_list) 163 | print(f"All 8 scene, average recall: {average_recall}%") 164 | average_inliers = sum(inliers_list) / len(inliers_list) 165 | print(f"All 8 scene, average num inliers: {average_inliers}") 166 | -------------------------------------------------------------------------------- /generalization/ThreeDMatch-to-ETH/preparation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import time 5 | import numpy as np 6 | import torch 7 | import shutil 8 | import torch.nn as nn 9 | import glob 10 | import sys 11 | 12 | sys.path.append('../../') 13 | import script.common as cm 14 | import open3d 15 | from ThreeDMatch.Test.tools import get_pcd, get_ETH_keypts, get_desc, loadlog 16 | from sklearn.neighbors import KDTree 17 | import importlib 18 | 19 | 20 | def make_open3d_point_cloud(xyz, color=None): 21 | pcd = open3d.geometry.PointCloud() 22 | pcd.points = open3d.utility.Vector3dVector(xyz) 23 | if color is not None: 24 | pcd.paint_uniform_color(color) 25 | return pcd 26 | 27 | 28 | def build_patch_input(pcd, keypts, vicinity=0.3, num_points_per_patch=2048): 29 | refer_pts = keypts.astype(np.float32) 30 | pts = np.array(pcd.points).astype(np.float32) 31 | num_patches = refer_pts.shape[0] 32 | tree = KDTree(pts[:, 0:3]) 33 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity) 34 | local_patches = np.zeros([num_patches, num_points_per_patch, 3], dtype=float) 35 | for i in range(num_patches): 36 | local_neighbors = pts[ind_local[i], :] 37 | if local_neighbors.shape[0] >= num_points_per_patch: 38 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False) 39 | local_neighbors = local_neighbors[temp] 40 | local_neighbors[-1, :] = refer_pts[i, :] 41 | else: 42 | fix_idx = np.asarray(range(local_neighbors.shape[0])) 43 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch: 44 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0) 45 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0], 46 | replace=False) 47 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0) 48 | local_neighbors = local_neighbors[choice_idx] 49 | local_neighbors[-1, :] = refer_pts[i, :] 50 | local_patches[i] = local_neighbors 51 | 52 | return local_patches 53 | 54 | 55 | def prepare_patch(pcdpath, filename, keyptspath, trans_matrix): 56 | pcd = get_pcd(pcdpath, filename) 57 | keypts = get_ETH_keypts(pcd, keyptspath, filename) 58 | if is_rotate_dataset: 59 | # Add arbitrary rotation 60 | # rotate terminal frament with an arbitrary angle around the z-axis 61 | angles_3d = np.random.rand(3) * np.pi * 2 62 | R = cm.angles2rotation_matrix(angles_3d) 63 | T = np.identity(4) 64 | T[:3, :3] = R 65 | pcd.transform(T) 66 | keypts_pcd = make_open3d_point_cloud(keypts) 67 | keypts_pcd.transform(T) 68 | keypts = np.array(keypts_pcd.points) 69 | trans_matrix.append(T) 70 | local_patches = build_patch_input(pcd, keypts, des_r) # [num_keypts, 1024, 4] 71 | return local_patches 72 | 73 | 74 | def generate_descriptor(model, desc_name, pcdpath, keyptspath, descpath): 75 | model.eval() 76 | fragments = glob.glob(pcdpath + '*.ply') 77 | num_frag = len(fragments) 78 | num_desc = len(os.listdir(descpath)) 79 | trans_matrix = [] 80 | if num_frag == num_desc: 81 | print("Descriptor already prepared.") 82 | return 83 | for j in range(num_frag): 84 | local_patches = prepare_patch(pcdpath, 'Hokuyo_' + str(j), keyptspath, trans_matrix) 85 | input_ = torch.tensor(local_patches.astype(np.float32)) 86 | B = input_.shape[0] 87 | input_ = input_.cuda() 88 | model = model.cuda() 89 | # calculate descriptors 90 | desc_list = [] 91 | start_time = time.time() 92 | desc_len = 32 93 | step_size = 100 94 | iter_num = np.int(np.ceil(B / step_size)) 95 | for k in range(iter_num): 96 | if k == iter_num - 1: 97 | desc = model(input_[k * step_size:, :, :]) 98 | else: 99 | desc = model(input_[k * step_size: (k + 1) * step_size, :, :]) 100 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 101 | del desc 102 | step_time = time.time() - start_time 103 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 104 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 105 | np.save(descpath + 'Hokuyo_' + str(j) + f".desc.{desc_name}.bin", desc.astype(np.float32)) 106 | if is_rotate_dataset: 107 | scene_name = pcdpath.split('/')[-2] 108 | all_trans_matrix[scene_name] = trans_matrix 109 | 110 | 111 | if __name__ == '__main__': 112 | scene_list = [ 113 | 'gazebo_summer', 114 | 'gazebo_winter', 115 | 'wood_autmn', 116 | 'wood_summer', 117 | ] 118 | experiment_id = time.strftime('%m%d%H%M') 119 | model_str = experiment_id # sys.argv[1] 120 | if not os.path.exists(f"SpinNet_desc_{model_str}/"): 121 | os.mkdir(f"SpinNet_desc_{model_str}") 122 | 123 | # dynamically load the model 124 | module_file_path = '../model.py' 125 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), '../model.py') 126 | module_name = '' 127 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) 128 | module = importlib.util.module_from_spec(module_spec) 129 | module_spec.loader.exec_module(module) 130 | 131 | des_r = 0.8 132 | model = module.Descriptor_Net(des_r, 9, 80, 40, 0.10, 30, '3DMatch') 133 | model = nn.DataParallel(model, device_ids=[0]) 134 | model.load_state_dict(torch.load('../../pre-trained_models/3DMatch_best.pkl')) 135 | all_trans_matrix = {} 136 | is_rotate_dataset = False 137 | 138 | for scene in scene_list: 139 | pcdpath = f"../../data/ETH/{scene}/" 140 | interpath = f"../../data/ETH/{scene}/01_Keypoints/" 141 | keyptspath = interpath 142 | descpath = os.path.join('.', f"SpinNet_desc_{model_str}/{scene}/") 143 | if not os.path.exists(descpath): 144 | os.makedirs(descpath) 145 | start_time = time.time() 146 | print(f"Begin Processing {scene}") 147 | generate_descriptor(model, desc_name='SpinNet', pcdpath=pcdpath, keyptspath=keyptspath, descpath=descpath) 148 | print(f"Finish in {time.time() - start_time}s") 149 | if is_rotate_dataset: 150 | np.save(f"trans_matrix", all_trans_matrix) 151 | -------------------------------------------------------------------------------- /generalization/ThreeDMatch-to-KITTI/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | import logging 5 | import numpy as np 6 | import open3d as o3d 7 | import torch 8 | import torch.nn as nn 9 | import glob 10 | import time 11 | import gc 12 | import shutil 13 | import pointnet2_ops.pointnet2_utils as pnt2 14 | import copy 15 | import importlib 16 | import sys 17 | 18 | sys.path.append('../../') 19 | import script.common as cm 20 | 21 | kitti_icp_cache = {} 22 | kitti_cache = {} 23 | 24 | 25 | class Timer(object): 26 | """A simple timer.""" 27 | 28 | def __init__(self, binary_fn=None, init_val=0): 29 | self.total_time = 0. 30 | self.calls = 0 31 | self.start_time = 0. 32 | self.diff = 0. 33 | self.binary_fn = binary_fn 34 | self.tmp = init_val 35 | 36 | def reset(self): 37 | self.total_time = 0 38 | self.calls = 0 39 | self.start_time = 0 40 | self.diff = 0 41 | 42 | @property 43 | def avg(self): 44 | return self.total_time / self.calls 45 | 46 | def tic(self): 47 | # using time.time instead of time.clock because time time.clock 48 | # does not normalize for multithreading 49 | self.start_time = time.time() 50 | 51 | def toc(self, average=True): 52 | self.diff = time.time() - self.start_time 53 | self.total_time += self.diff 54 | self.calls += 1 55 | if self.binary_fn: 56 | self.tmp = self.binary_fn(self.tmp, self.diff) 57 | if average: 58 | return self.avg 59 | else: 60 | return self.diff 61 | 62 | 63 | class AverageMeter(object): 64 | """Computes and stores the average and current value""" 65 | 66 | def __init__(self): 67 | self.reset() 68 | 69 | def reset(self): 70 | self.val = 0 71 | self.avg = 0 72 | self.sum = 0.0 73 | self.sq_sum = 0.0 74 | self.count = 0 75 | 76 | def update(self, val, n=1): 77 | self.val = val 78 | self.sum += val * n 79 | self.count += n 80 | self.avg = self.sum / self.count 81 | self.sq_sum += val ** 2 * n 82 | self.var = self.sq_sum / self.count - self.avg ** 2 83 | 84 | 85 | def get_desc(descpath, filename): 86 | desc = np.load(os.path.join(descpath, filename + '.npy')) 87 | return desc 88 | 89 | 90 | def get_keypts(keypts_path, filename): 91 | keypts = np.load(os.path.join(keypts_path, filename + '.npy')) 92 | return keypts 93 | 94 | 95 | def make_open3d_feature(data, dim, npts): 96 | feature = o3d.pipelines.registration.Feature() 97 | feature.resize(dim, npts) 98 | feature.data = data.astype('d').transpose() 99 | return feature 100 | 101 | 102 | def make_open3d_point_cloud(xyz, color=None): 103 | pcd = o3d.geometry.PointCloud() 104 | pcd.points = o3d.utility.Vector3dVector(xyz) 105 | if color is not None: 106 | pcd.paint_uniform_color(color) 107 | return pcd 108 | 109 | 110 | def get_matching_indices(source, target, trans, search_voxel_size, K=None): 111 | source_copy = copy.deepcopy(source) 112 | target_copy = copy.deepcopy(target) 113 | source_copy.transform(trans) 114 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 115 | 116 | match_inds = [] 117 | for i, point in enumerate(source_copy.points): 118 | [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) 119 | if K is not None: 120 | idx = idx[:K] 121 | for j in idx: 122 | match_inds.append((i, j)) 123 | return match_inds 124 | 125 | 126 | class KITTI(object): 127 | DATA_FILES = { 128 | 'train': 'train_kitti.txt', 129 | 'val': 'val_kitti.txt', 130 | 'test': 'test_kitti.txt' 131 | } 132 | """ 133 | Given point cloud fragments and corresponding pose in '{root}'. 134 | 1. Save the aligned point cloud pts in '{savepath}/3DMatch_{downsample}_points.pkl' 135 | 2. Calculate the overlap ratio and save in '{savepath}/3DMatch_{downsample}_overlap.pkl' 136 | 3. Save the ids of anchor keypoints and positive keypoints in '{savepath}/3DMatch_{downsample}_keypts.pkl' 137 | """ 138 | 139 | def __init__(self, root, descpath, icp_path, split, model, num_points_per_patch, use_random_points): 140 | self.root = root 141 | self.descpath = descpath 142 | self.split = split 143 | self.num_points_per_patch = num_points_per_patch 144 | self.icp_path = icp_path 145 | self.use_random_points = use_random_points 146 | self.model = model 147 | if not os.path.exists(self.icp_path): 148 | os.makedirs(self.icp_path) 149 | 150 | # list: anc & pos 151 | self.patches = [] 152 | self.pose = [] 153 | # Initiate containers 154 | self.files = {'train': [], 'val': [], 'test': []} 155 | 156 | self.prepare_kitti_ply(split=self.split) 157 | 158 | def prepare_kitti_ply(self, split='train'): 159 | subset_names = open(self.DATA_FILES[split]).read().split() 160 | for dirname in subset_names: 161 | drive_id = int(dirname) 162 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id) 163 | assert len(fnames) > 0, f"Make sure that the path {self.root} has data {dirname}" 164 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 165 | 166 | all_odo = self.get_video_odometry(drive_id, return_all=True) 167 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 168 | Ts = all_pos[:, :3, 3] 169 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3)) ** 2 170 | pdist = np.sqrt(pdist.sum(-1)) 171 | more_than_10 = pdist > 10 172 | curr_time = inames[0] 173 | while curr_time in inames: 174 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0] 175 | if len(next_time) == 0: 176 | curr_time += 1 177 | else: 178 | next_time = next_time[0] + curr_time - 1 179 | 180 | if next_time in inames: 181 | self.files[split].append((drive_id, curr_time, next_time)) 182 | curr_time = next_time + 1 183 | # Remove problematic sequence 184 | for item in [ 185 | (8, 15, 58), 186 | ]: 187 | if item in self.files[split]: 188 | self.files[split].pop(self.files[split].index(item)) 189 | 190 | if split == 'train': 191 | self.num_train = len(self.files[split]) 192 | print("Num_train", self.num_train) 193 | elif split == 'val': 194 | self.num_val = len(self.files[split]) 195 | print("Num_val", self.num_val) 196 | elif split == 'test': 197 | self.num_test = len(self.files[split]) 198 | print("Num_test", self.num_test) 199 | 200 | for idx in range(len(self.files[split])): 201 | drive = self.files[split][idx][0] 202 | t0, t1 = self.files[split][idx][1], self.files[split][idx][2] 203 | all_odometry = self.get_video_odometry(drive, [t0, t1]) 204 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry] 205 | fname0 = self._get_velodyne_fn(drive, t0) 206 | fname1 = self._get_velodyne_fn(drive, t1) 207 | 208 | # XYZ and reflectance 209 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4) 210 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4) 211 | 212 | xyz0 = xyzr0[:, :3] 213 | xyz1 = xyzr1[:, :3] 214 | 215 | key = '%d_%d_%d' % (drive, t0, t1) 216 | filename = self.icp_path + '/' + key + '.npy' 217 | if key not in kitti_icp_cache: 218 | if not os.path.exists(filename): 219 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T) 220 | @ np.linalg.inv(self.velo2cam)).T 221 | xyz0_t = self.apply_transform(xyz0, M) 222 | pcd0 = make_open3d_point_cloud(xyz0_t, [0.5, 0.5, 0.5]) 223 | pcd1 = make_open3d_point_cloud(xyz1, [0, 1, 0]) 224 | reg = o3d.pipelines.registration.registration_icp(pcd0, pcd1, 0.10, np.eye(4), 225 | o3d.pipelines.registration.TransformationEstimationPointToPoint(), 226 | o3d.pipelines.registration.ICPConvergenceCriteria( 227 | max_iteration=400)) 228 | pcd0.transform(reg.transformation) 229 | M2 = M @ reg.transformation 230 | # write to a file 231 | np.save(filename, M2) 232 | else: 233 | M2 = np.load(filename) 234 | kitti_icp_cache[key] = M2 235 | else: 236 | M2 = kitti_icp_cache[key] 237 | trans = M2 238 | # extract patches for anc&pos 239 | np.random.shuffle(xyz0) 240 | np.random.shuffle(xyz1) 241 | 242 | if is_rotate_dataset: 243 | # Add arbitrary rotation 244 | # rotate terminal frament with an arbitrary angle 245 | angles_3d = np.random.rand(3) * np.pi * 2 246 | R = cm.angles2rotation_matrix(angles_3d) 247 | T = np.identity(4) 248 | T[:3, :3] = R 249 | pcd1 = make_open3d_point_cloud(xyz1) 250 | pcd1.transform(T) 251 | xyz1 = np.array(pcd1.points) 252 | all_trans_matrix[key] = T 253 | 254 | if not os.path.exists(self.descpath + str(drive)): 255 | os.makedirs(self.descpath + str(drive)) 256 | if self.use_random_points: 257 | num_keypts = 5000 258 | step_size = 100 259 | desc_len = 32 260 | model = self.model.cuda() 261 | # calc t0 descriptors 262 | desc_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".desc.bin.npy") 263 | keypts_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".keypts.npy") 264 | if not os.path.exists(desc_t0_path): 265 | keypoints_id = np.random.choice(xyz0.shape[0], num_keypts) 266 | keypts = xyz0[keypoints_id] 267 | np.save(keypts_t0_path, keypts.astype(np.float32)) 268 | local_patches = self.select_patches(xyz0, keypts, vicinity=vicinity, 269 | num_points_per_patch=self.num_points_per_patch) 270 | B = local_patches.shape[0] 271 | # cuda out of memry 272 | desc_list = [] 273 | start_time = time.time() 274 | iter_num = np.int(np.ceil(B / step_size)) 275 | for k in range(iter_num): 276 | if k == iter_num - 1: 277 | desc = model(local_patches[k * step_size:, :, :]) 278 | else: 279 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :]) 280 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 281 | del desc 282 | step_time = time.time() - start_time 283 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 284 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 285 | np.save(desc_t0_path, desc.astype(np.float32)) 286 | else: 287 | print(f"{desc_t0_path} already exists.") 288 | 289 | # calc t1 descriptors 290 | desc_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".desc.bin.npy") 291 | keypts_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".keypts.npy") 292 | if not os.path.exists(desc_t1_path): 293 | keypoints_id = np.random.choice(xyz1.shape[0], num_keypts) 294 | keypts = xyz1[keypoints_id] 295 | np.save(keypts_t1_path, keypts.astype(np.float32)) 296 | local_patches = self.select_patches(xyz1, keypts, vicinity=vicinity, 297 | num_points_per_patch=self.num_points_per_patch) 298 | B = local_patches.shape[0] 299 | # cuda out of memry 300 | desc_list = [] 301 | start_time = time.time() 302 | iter_num = np.int(np.ceil(B / step_size)) 303 | for k in range(iter_num): 304 | if k == iter_num - 1: 305 | desc = model(local_patches[k * step_size:, :, :]) 306 | else: 307 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :]) 308 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy()) 309 | del desc 310 | step_time = time.time() - start_time 311 | print(f'Finish {B} descriptors spend {step_time:.4f}s') 312 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len]) 313 | np.save(desc_t1_path, desc.astype(np.float32)) 314 | else: 315 | print(f"{desc_t1_path} already exists.") 316 | else: 317 | num_keypts = 512 318 | 319 | def select_patches(self, pts, refer_pts, vicinity, num_points_per_patch=1024): 320 | gc.collect() 321 | pts = torch.FloatTensor(pts).cuda().unsqueeze(0) 322 | refer_pts = torch.FloatTensor(refer_pts).cuda().unsqueeze(0) 323 | group_idx = pnt2.ball_query(vicinity, num_points_per_patch, pts, refer_pts) 324 | pts_trans = pts.transpose(1, 2).contiguous() 325 | new_points = pnt2.grouping_operation( 326 | pts_trans, group_idx 327 | ) 328 | new_points = new_points.permute([0, 2, 3, 1]) 329 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, num_points_per_patch) 330 | mask = (group_idx == mask).float() 331 | mask[:, :, 0] = 0 332 | mask[:, :, num_points_per_patch - 1] = 1 333 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3]) 334 | new_pts = refer_pts.unsqueeze(2).repeat([1, 1, num_points_per_patch, 1]) 335 | local_patches = new_points * (1 - mask).float() + new_pts * mask.float() 336 | # local_patches = list(local_patches.squeeze(0).detach().cpu().numpy()) 337 | local_patches = local_patches.squeeze(0) 338 | del mask 339 | del new_points 340 | del group_idx 341 | del new_pts 342 | del pts 343 | del pts_trans 344 | 345 | return local_patches 346 | 347 | def apply_transform(self, pts, trans): 348 | R = trans[:3, :3] 349 | T = trans[:3, 3] 350 | pts = pts @ R.T + T 351 | return pts 352 | 353 | @property 354 | def velo2cam(self): 355 | try: 356 | velo2cam = self._velo2cam 357 | except AttributeError: 358 | R = np.array([ 359 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 360 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 361 | ]).reshape(3, 3) 362 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 363 | velo2cam = np.hstack([R, T]) 364 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 365 | return self._velo2cam 366 | 367 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False): 368 | data_path = self.root + '/poses/%02d.txt' % drive 369 | if data_path not in kitti_cache: 370 | kitti_cache[data_path] = np.genfromtxt(data_path) 371 | if return_all: 372 | return kitti_cache[data_path] 373 | else: 374 | return kitti_cache[data_path][indices] 375 | 376 | def odometry_to_positions(self, odometry): 377 | T_w_cam0 = odometry.reshape(3, 4) 378 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 379 | return T_w_cam0 380 | 381 | def _get_velodyne_fn(self, drive, t): 382 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t) 383 | return fname 384 | 385 | 386 | if __name__ == '__main__': 387 | is_rotate_dataset = False 388 | all_trans_matrix = {} 389 | experiment_id = time.strftime('%m%d%H%M') # '11210201'# 390 | model_str = experiment_id 391 | reg_timer = Timer() 392 | success_meter, rte_meter, rre_meter = AverageMeter(), AverageMeter(), AverageMeter() 393 | ch = logging.StreamHandler(sys.stdout) 394 | logging.getLogger().setLevel(logging.INFO) 395 | logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 396 | 397 | # dynamically load the model from snapshot 398 | module_file_path = '../model.py' 399 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), module_file_path) 400 | module_name = '' 401 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path) 402 | module = importlib.util.module_from_spec(module_spec) 403 | module_spec.loader.exec_module(module) 404 | 405 | vicinity = 3.0 406 | model = module.Descriptor_Net(vicinity, 9, 80, 40, 0.5, 30, 'KITTI') 407 | model = nn.DataParallel(model, device_ids=[0]) 408 | model.load_state_dict(torch.load('../../pre-trained_models/3DMatch_best.pkl')) 409 | 410 | test_data = KITTI(root='../../data/KITTI/dataset', 411 | descpath=f'SpinNet_desc_{model_str}/', 412 | icp_path='../../data/KITTI/icp', 413 | split='test', 414 | model=model, 415 | num_points_per_patch=2048, 416 | use_random_points=True 417 | ) 418 | 419 | files = test_data.files[test_data.split] 420 | for idx in range(len(files)): 421 | drive = files[idx][0] 422 | t0, t1 = files[idx][1], files[idx][2] 423 | key = '%d_%d_%d' % (drive, t0, t1) 424 | filename = test_data.icp_path + '/' + key + '.npy' 425 | T_gth = kitti_icp_cache[key] 426 | if is_rotate_dataset: 427 | T_gth = np.matmul(all_trans_matrix[key], T_gth) 428 | 429 | descpath = os.path.join(test_data.descpath, str(drive)) 430 | fname0 = test_data._get_velodyne_fn(drive, t0) 431 | fname1 = test_data._get_velodyne_fn(drive, t1) 432 | # XYZ and reflectance 433 | xyz0 = get_keypts(descpath, f"cloud_bin_" + str(t0) + f".keypts") 434 | xyz1 = get_keypts(descpath, f"cloud_bin_" + str(t1) + f".keypts") 435 | pcd0 = make_open3d_point_cloud(xyz0) 436 | pcd1 = make_open3d_point_cloud(xyz1) 437 | 438 | source_desc = get_desc(descpath, f"cloud_bin_" + str(t0) + f".desc.bin") 439 | target_desc = get_desc(descpath, f"cloud_bin_" + str(t1) + f".desc.bin") 440 | feat0 = make_open3d_feature(source_desc, 32, source_desc.shape[0]) 441 | feat1 = make_open3d_feature(target_desc, 32, target_desc.shape[0]) 442 | 443 | reg_timer.tic() 444 | distance_threshold = 0.3 445 | ransac_result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( 446 | pcd0, pcd1, feat0, feat1, distance_threshold, 447 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 4, [ 448 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 449 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold) 450 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000)) 451 | T_ransac = torch.from_numpy(ransac_result.transformation.astype(np.float32)) 452 | reg_timer.toc() 453 | 454 | # Translation error 455 | rte = np.linalg.norm(T_ransac[:3, 3] - T_gth[:3, 3]) 456 | rre = np.arccos((np.trace(T_ransac[:3, :3].t() @ T_gth[:3, :3]) - 1) / 2) 457 | 458 | if rte < 2: 459 | rte_meter.update(rte) 460 | 461 | if not np.isnan(rre) and rre < np.pi / 180 * 5: 462 | rre_meter.update(rre * 180 / np.pi) 463 | 464 | if rte < 2 and not np.isnan(rre) and rre < np.pi / 180 * 5: 465 | success_meter.update(1) 466 | else: 467 | success_meter.update(0) 468 | logging.info(f"Failed with RTE: {rte}, RRE: {rre}") 469 | 470 | if (idx + 1) % 10 == 0: 471 | logging.info( 472 | f" RRE: {rre_meter.avg}, Success: {success_meter.sum} / {success_meter.count}" + 473 | f" ({success_meter.avg * 100} %)" 474 | ) 475 | reg_timer.reset() 476 | 477 | logging.info( 478 | f"RTE: {rte_meter.avg}, var: {rte_meter.var}," + 479 | f" RRE: {rre_meter.avg}, var: {rre_meter.var}, Success: {success_meter.sum} " + 480 | f"/ {success_meter.count} ({success_meter.avg * 100} %)" 481 | ) 482 | -------------------------------------------------------------------------------- /generalization/ThreeDMatch-to-KITTI/test_kitti.txt: -------------------------------------------------------------------------------- 1 | 8 2 | 9 3 | 10 4 | -------------------------------------------------------------------------------- /loss/desc_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def all_diffs(a, b): 7 | """ Returns a tensor of all combinations of a - b. 8 | 9 | Args: 10 | a (2D tensor): A batch of vectors shaped (B1, F). 11 | b (2D tensor): A batch of vectors shaped (B2, F). 12 | 13 | Returns: 14 | The matrix of all pairwise differences between all vectors in `a` and in 15 | `b`, will be of shape (B1, B2). 16 | 17 | """ 18 | return torch.unsqueeze(a, dim=1) - torch.unsqueeze(b, dim=0) 19 | 20 | 21 | def cdist(a, b, metric='euclidean'): 22 | """Similar to scipy.spatial's cdist, but symbolic. 23 | 24 | The currently supported metrics can be listed as `cdist.supported_metrics` and are: 25 | - 'euclidean', although with a fudge-factor epsilon. 26 | - 'sqeuclidean', the squared euclidean. 27 | - 'cityblock', the manhattan or L1 distance. 28 | 29 | Args: 30 | a (2D tensor): The left-hand side, shaped (B1, F). 31 | b (2D tensor): The right-hand side, shaped (B2, F). 32 | metric (string): Which distance metric to use, see notes. 33 | 34 | Returns: 35 | The matrix of all pairwise distances between all vectors in `a` and in 36 | `b`, will be of shape (B1, B2). 37 | 38 | Note: 39 | When a square root is taken (such as in the Euclidean case), a small 40 | epsilon is added because the gradient of the square-root at zero is 41 | undefined. Thus, it will never return exact zero in these cases. 42 | """ 43 | 44 | diffs = all_diffs(a, b) 45 | if metric == 'sqeuclidean': 46 | return torch.sum(diffs ** 2, dim=-1) 47 | elif metric == 'euclidean': 48 | return torch.sqrt(torch.sum(diffs ** 2, dim=-1) + 1e-12) 49 | elif metric == 'cityblock': 50 | return torch.sum(torch.abs(diffs), dim=-1) 51 | else: 52 | raise NotImplementedError( 53 | 'The following metric is not implemented by `cdist` yet: {}'.format(metric)) 54 | 55 | 56 | class ContrastiveLoss(nn.Module): 57 | def __init__(self, pos_margin=0.1, neg_margin=1.4, metric='euclidean', safe_radius=0.25): 58 | super(ContrastiveLoss, self).__init__() 59 | self.pos_margin = pos_margin 60 | self.neg_margin = neg_margin 61 | self.metric = metric 62 | self.safe_radius = safe_radius 63 | 64 | def forward(self, anchor, positive): 65 | pids = torch.FloatTensor(np.arange(len(anchor))).to(anchor.device) 66 | dist = cdist(anchor, positive, metric=self.metric) 67 | return self.calculate_loss(dist, pids) 68 | 69 | def calculate_loss(self, dists, pids): 70 | """Computes the batch-hard loss from arxiv.org/abs/1703.07737. 71 | 72 | Args: 73 | dists (2D tensor): A square all-to-all distance matrix as given by cdist. 74 | pids (1D tensor): The identities of the entries in `batch`, shape (B,). 75 | This can be of any type that can be compared, thus also a string. 76 | margin: The value of the margin if a number, alternatively the string 77 | 'soft' for using the soft-margin formulation, or `None` for not 78 | using a margin at all. 79 | 80 | Returns: 81 | A 1D tensor of shape (B,) containing the loss value for each sample. 82 | """ 83 | # generate the mask that mask[i, j] reprensent whether i th and j th are from the same identity. 84 | # torch.equal is to check whether two tensors have the same size and elements 85 | # torch.eq is to computes element-wise equality 86 | same_identity_mask = torch.eq(torch.unsqueeze(pids, dim=1), torch.unsqueeze(pids, dim=0)) 87 | 88 | # dists * same_identity_mask get the distance of each valid anchor-positive pair. 89 | furthest_positive, _ = torch.max(dists * same_identity_mask.float(), dim=1) 90 | closest_negative, _ = torch.min(dists + 1e5 * same_identity_mask.float(), dim=1) 91 | diff = furthest_positive - closest_negative 92 | accuracy = (diff < 0).sum() * 100.0 / diff.shape[0] 93 | loss = torch.max(furthest_positive - self.pos_margin, torch.zeros_like(diff)) + torch.max( 94 | self.neg_margin - closest_negative, torch.zeros_like(diff)) 95 | 96 | return torch.mean(loss), accuracy 97 | -------------------------------------------------------------------------------- /network/SpinNet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import network.ThreeDCCN as pn 8 | import script.common as cm 9 | from script.common import switch 10 | 11 | 12 | class Descriptor_Net(nn.Module): 13 | def __init__(self, des_r, rad_n, azi_n, ele_n, voxel_r, voxel_sample, dataset): 14 | super(Descriptor_Net, self).__init__() 15 | self.des_r = des_r 16 | self.rad_n = rad_n 17 | self.azi_n = azi_n 18 | self.ele_n = ele_n 19 | self.voxel_r = voxel_r 20 | self.voxel_sample = voxel_sample 21 | self.dataset = dataset 22 | 23 | self.bn_xyz_raising = nn.BatchNorm2d(16) 24 | self.bn_mapping = nn.BatchNorm2d(16) 25 | self.activation = nn.ReLU() 26 | self.xyz_raising = nn.Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1)) 27 | self.conv_net = pn.Cylindrical_Net(inchan=16, dim=32) 28 | 29 | def forward(self, input): 30 | center = input[:, -1, :].unsqueeze(1) 31 | delta_x = input[:, :, 0:3] - center[:, :, 0:3] # (B, npoint, 3), normalized coordinates 32 | for case in switch(self.dataset): 33 | if case('3DMatch'): 34 | z_axis = cm.cal_Z_axis(delta_x, ref_point=input[:, -1, :3]) 35 | z_axis = cm.l2_norm(z_axis, axis=1) 36 | R = cm.RodsRotatFormula(z_axis, torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(z_axis.shape[0], 1)) 37 | delta_x = torch.matmul(delta_x, R) 38 | break 39 | if case('KITTI'): 40 | break 41 | 42 | # partition the local surface along elevator, azimuth, radial dimensions 43 | S2_xyz = torch.FloatTensor(cm.get_voxel_coordinate(radius=self.des_r, 44 | rad_n=self.rad_n, 45 | azi_n=self.azi_n, 46 | ele_n=self.ele_n)) 47 | 48 | pts_xyz = S2_xyz.view(1, -1, 3).repeat([delta_x.shape[0], 1, 1]).cuda() 49 | # query points in sphere 50 | new_points = cm.sphere_query(delta_x, pts_xyz, radius=self.voxel_r, 51 | nsample=self.voxel_sample) 52 | # transform rotation-variant coords into rotation-invariant coords 53 | new_points = new_points - pts_xyz.unsqueeze(2).repeat([1, 1, self.voxel_sample, 1]) 54 | new_points = cm.var_to_invar(new_points, self.rad_n, self.azi_n, self.ele_n) 55 | 56 | new_points = new_points.permute(0, 3, 1, 2) # (B, C_in, npoint, nsample), input features 57 | C_in = new_points.size()[1] 58 | nsample = new_points.size()[3] 59 | x = self.activation(self.bn_xyz_raising(self.xyz_raising(new_points))) 60 | x = F.max_pool2d(x, kernel_size=(1, nsample)).squeeze(3) # (B, C_in, npoint) 61 | del new_points 62 | del pts_xyz 63 | x = x.view(x.shape[0], x.shape[1], self.rad_n, self.ele_n, self.azi_n) 64 | 65 | x = self.conv_net(x) 66 | x = F.max_pool2d(x, kernel_size=(x.shape[2], x.shape[3])) 67 | 68 | return x 69 | 70 | def get_parameter(self): 71 | return list(self.parameters()) 72 | -------------------------------------------------------------------------------- /network/ThreeDCCN.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import script.common as cm 6 | 7 | 8 | class BaseNet(nn.Module): 9 | """ Takes a list of images as input, and returns for each image: 10 | - a pixelwise descriptor 11 | - a pixelwise confidence 12 | """ 13 | 14 | def forward_one(self, x): 15 | raise NotImplementedError() 16 | 17 | def forward(self, imgs): 18 | res = self.forward_one(imgs) 19 | return res 20 | 21 | 22 | class Cyclindrical_ConvNet(BaseNet): 23 | def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False): 24 | BaseNet.__init__(self) 25 | self.inchan = inchan 26 | self.curchan = inchan 27 | self.dilated = dilated 28 | self.dilation = dilation 29 | self.bn = bn 30 | self.bn_affine = bn_affine 31 | self.ops = nn.ModuleList([]) 32 | 33 | def _make_bn_2d(self, outd): 34 | return nn.BatchNorm2d(outd, affine=self.bn_affine) 35 | 36 | def _make_bn_3d(self, outd): 37 | return nn.BatchNorm3d(outd, affine=self.bn_affine) 38 | 39 | def _add_conv_2d(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True): 40 | d = self.dilation * dilation 41 | self.dilation *= stride 42 | self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=(k, k), dilation=d)) 43 | if bn and self.bn: self.ops.append(self._make_bn_2d(outd)) 44 | if relu: self.ops.append(nn.ReLU(inplace=True)) 45 | self.curchan = outd 46 | 47 | def _add_conv_3d(self, outd, k, stride=1, dilation=1, bn=True, relu=True): 48 | d = self.dilation * dilation 49 | self.dilation *= stride 50 | self.ops.append(nn.Conv3d(self.curchan, outd, kernel_size=(k[0], k[1], k[2]), dilation=d)) 51 | if bn and self.bn: self.ops.append(self._make_bn_3d(outd)) 52 | if relu: self.ops.append(nn.ReLU(inplace=True)) 53 | self.curchan = outd 54 | 55 | def forward_one(self, x): 56 | assert self.ops, "You need to add convolutions first" 57 | for n, op in enumerate(self.ops): 58 | k_exist = hasattr(op, 'kernel_size') 59 | if k_exist: 60 | if len(op.kernel_size) == 3: 61 | x = cm.pad_image_3d(x, op.kernel_size[1] + (op.kernel_size[1] - 1) * (op.dilation[0] - 1)) 62 | else: 63 | if len(x.shape) == 5: 64 | x = x.squeeze(2) 65 | x = cm.pad_image(x, op.kernel_size[0] + (op.kernel_size[0] - 1) * (op.dilation[0] - 1)) 66 | x = op(x) 67 | return x 68 | 69 | 70 | class Cylindrical_Net(Cyclindrical_ConvNet): 71 | """ Compute a descriptor for all overlapping patches. 72 | From the L2Net paper (CVPR'17). 73 | """ 74 | 75 | def __init__(self, inchan=16, dim=32, **kw): 76 | Cyclindrical_ConvNet.__init__(self, inchan=inchan, **kw) 77 | add_conv_2d = lambda n, **kw: self._add_conv_2d(n, **kw) 78 | add_conv_3d = lambda n, **kw: self._add_conv_3d(n, **kw) 79 | add_conv_3d(32, k=[3, 3, 3]) 80 | add_conv_3d(32, k=[3, 3, 3]) 81 | add_conv_3d(64, k=[3, 3, 3]) 82 | add_conv_3d(64, k=[3, 3, 3]) 83 | add_conv_2d(128, stride=2) 84 | add_conv_2d(128) 85 | add_conv_2d(64, stride=2) 86 | add_conv_2d(64) 87 | add_conv_2d(32, k=2, stride=2, relu=False) 88 | add_conv_2d(32, k=2, stride=2, relu=False) 89 | add_conv_2d(dim, k=2, stride=2, bn=False, relu=False) 90 | self.out_dim = dim 91 | -------------------------------------------------------------------------------- /pre-trained_models/3DMatch_best.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/pre-trained_models/3DMatch_best.pkl -------------------------------------------------------------------------------- /pre-trained_models/KITTI_best.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/pre-trained_models/KITTI_best.pkl -------------------------------------------------------------------------------- /script/cal_overlap.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import exists, join 3 | import pickle 4 | import numpy as np 5 | import open3d 6 | import cv2 7 | import time 8 | 9 | 10 | class ThreeDMatch(object): 11 | """ 12 | Given point cloud fragments and corresponding pose in '{root}'. 13 | 1. Save the aligned point cloud pts in '{savepath}/3DMatch_{downsample}_points.pkl' 14 | 2. Calculate the overlap ratio and save in '{savepath}/3DMatch_{downsample}_overlap.pkl' 15 | 3. Save the ids of anchor keypoints and positive keypoints in '{savepath}/3DMatch_{downsample}_keypts.pkl' 16 | """ 17 | 18 | def __init__(self, root, savepath, split, downsample): 19 | self.root = root 20 | self.savepath = savepath 21 | self.split = split 22 | self.downsample = downsample 23 | 24 | # dict: from id to pts. 25 | self.pts = {} 26 | 27 | # dict: from id_id to overlap_ratio 28 | self.overlap_ratio = {} 29 | # dict: from id_id to anc_keypts id & pos_keypts id 30 | self.keypts_pairs = {} 31 | 32 | with open(os.path.join(root, f'scene_list_{split}.txt')) as f: 33 | scene_list = f.readlines() 34 | self.ids_list = [] 35 | self.scene_to_ids = {} 36 | for scene in scene_list: 37 | scene = scene.replace("\n", "") 38 | self.scene_to_ids[scene] = [] 39 | for seq in sorted(os.listdir(os.path.join(self.root, scene))): 40 | if not seq.startswith('seq'): 41 | continue 42 | scene_path = os.path.join(self.root, scene + f'/{seq}') 43 | ids = [scene + f"/{seq}/" + str(filename.split(".")[0]) for filename in os.listdir(scene_path) if 44 | filename.endswith('ply')] 45 | ids = sorted(ids, key=lambda x: int(x.split("_")[-1])) 46 | self.ids_list += ids 47 | self.scene_to_ids[scene] += ids 48 | print(f"Scene {scene}, seq {seq}: num ply: {len(ids)}") 49 | print(f"Total {len(scene_list)} scenes, {len(self.ids_list)} point cloud fragments.") 50 | self.idpair_list = [] 51 | self.load_all_ply(downsample) 52 | self.cal_overlap(downsample) 53 | 54 | def load_ply(self, data_dir, ind, downsample, aligned=True): 55 | pcd = open3d.io.read_point_cloud(join(data_dir, f'{ind}.ply')) 56 | pcd = open3d.geometry.PointCloud.voxel_down_sample(pcd, voxel_size=downsample) 57 | if aligned is True: 58 | matrix = np.load(join(data_dir, f'{ind}.pose.npy')) 59 | pcd.transform(matrix) 60 | return pcd 61 | 62 | def load_all_ply(self, downsample): 63 | pts_filename = join(self.savepath, f'3DMatch_{self.split}_{downsample:.3f}_points.pkl') 64 | if exists(pts_filename): 65 | with open(pts_filename, 'rb') as file: 66 | self.pts = pickle.load(file) 67 | print(f"Load pts file from {self.savepath}") 68 | return 69 | self.pts = {} 70 | for i, anc_id in enumerate(self.ids_list): 71 | anc_pcd = self.load_ply(self.root, anc_id, downsample=downsample, aligned=True) 72 | points = np.array(anc_pcd.points) 73 | print(len(points)) 74 | self.pts[anc_id] = points 75 | print('processing ply: {:.1f}%'.format(100 * i / len(self.ids_list))) 76 | with open(pts_filename, 'wb') as file: 77 | pickle.dump(self.pts, file) 78 | 79 | def get_matching_indices(self, anc_pts, pos_pts, search_voxel_size, K=None): 80 | match_inds = [] 81 | bf_matcher = cv2.BFMatcher(cv2.NORM_L2) 82 | match = bf_matcher.match(anc_pts, pos_pts) 83 | for match_val in match: 84 | if match_val.distance < search_voxel_size: 85 | match_inds.append([match_val.queryIdx, match_val.trainIdx]) 86 | return np.array(match_inds) 87 | 88 | def cal_overlap(self, downsample): 89 | overlap_filename = join(self.savepath, f'3DMatch_{self.split}_{downsample:.3f}_overlap.pkl') 90 | keypts_filename = join(self.savepath, f'3DMatch_{self.split}_{downsample:.3f}_keypts.pkl') 91 | if exists(overlap_filename) and exists(keypts_filename): 92 | with open(overlap_filename, 'rb') as file: 93 | self.overlap_ratio = pickle.load(file) 94 | print(f"Reload overlap info from {overlap_filename}") 95 | with open(keypts_filename, 'rb') as file: 96 | self.keypts_pairs = pickle.load(file) 97 | print(f"Reload keypts info from {keypts_filename}") 98 | import pdb 99 | pdb.set_trace() 100 | return 101 | t0 = time.time() 102 | for scene, scene_ids in self.scene_to_ids.items(): 103 | scene_overlap = {} 104 | print(f"Begin processing scene {scene}") 105 | for i in range(0, len(scene_ids)): 106 | anc_id = scene_ids[i] 107 | for j in range(i + 1, len(scene_ids)): 108 | pos_id = scene_ids[j] 109 | anc_pts = self.pts[anc_id].astype(np.float32) 110 | pos_pts = self.pts[pos_id].astype(np.float32) 111 | 112 | try: 113 | matching_01 = self.get_matching_indices(anc_pts, pos_pts, self.downsample) 114 | except BaseException as e: 115 | print(f"Something wrong with get_matching_indices {e} for {anc_id}, {pos_id}") 116 | matching_01 = np.array([]) 117 | overlap_ratio = len(matching_01) / len(anc_pts) 118 | scene_overlap[f'{anc_id}@{pos_id}'] = overlap_ratio 119 | if overlap_ratio > 0.30: 120 | self.keypts_pairs[f'{anc_id}@{pos_id}'] = matching_01.astype(np.int32) 121 | self.overlap_ratio[f'{anc_id}@{pos_id}'] = overlap_ratio 122 | print(f'\t {anc_id}, {pos_id} overlap ratio: {overlap_ratio}') 123 | print('processing {:s} ply: {:.1f}%'.format(scene, 100 * i / len(scene_ids))) 124 | print('Finish {:s}, Done in {:.1f}s'.format(scene, time.time() - t0)) 125 | 126 | with open(overlap_filename, 'wb') as file: 127 | pickle.dump(self.overlap_ratio, file) 128 | with open(keypts_filename, 'wb') as file: 129 | pickle.dump(self.keypts_pairs, file) 130 | 131 | 132 | if __name__ == '__main__': 133 | ThreeDMatch(root='path to your ply file.', 134 | savepath='data/3DMatch', 135 | split='train', 136 | downsample=0.030 137 | ) 138 | -------------------------------------------------------------------------------- /script/common.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import numpy as np 3 | import os 4 | import time 5 | import torch 6 | from sklearn.neighbors import KDTree 7 | import pointnet2_ops.pointnet2_utils as pnt2 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | 12 | class switch(object): 13 | def __init__(self, value): 14 | self.value = value 15 | self.fall = False 16 | 17 | def __iter__(self): 18 | """Return the match method once, then stop""" 19 | yield self.match 20 | raise StopIteration 21 | 22 | def match(self, *args): 23 | """Indicate whether or not to enter a case suite""" 24 | if self.fall or not args: 25 | return True 26 | elif self.value in args: # changed for v1.5, see below 27 | self.fall = True 28 | return True 29 | else: 30 | return False 31 | 32 | 33 | def select_patches(pts, ind, num_patches=1024, vicinity=0.15, num_points_per_patch=1024, is_rand=True): 34 | # A point sampling algorithm for 3d matching of irregular geometries. 35 | tree = KDTree(pts[:, 0:3]) 36 | num_points = pts.shape[0] 37 | if is_rand: 38 | out_inds = np.random.choice(range(ind.shape[0]), num_patches, replace=False) 39 | inds = ind[out_inds] 40 | else: 41 | inds = ind 42 | refer_pts = pts[inds] 43 | 44 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity) 45 | local_patches = [] 46 | for i in range(np.size(ind_local)): 47 | local_neighbors = pts[ind_local[i], :] 48 | if local_neighbors.shape[0] >= num_points_per_patch: 49 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False) 50 | local_neighbors = local_neighbors[temp] 51 | local_neighbors[-1, :] = refer_pts[i, :] 52 | else: 53 | fix_idx = np.asarray(range(local_neighbors.shape[0])) 54 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch: 55 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0) 56 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0], 57 | replace=False) 58 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0) 59 | local_neighbors = local_neighbors[choice_idx] 60 | local_neighbors[-1, :] = refer_pts[i, :] 61 | 62 | # fill_num = num_points_per_patch-local_neighbors.shape[0] 63 | # local_neighbors = np.concatenate((local_neighbors, np.tile(refer_pts[i,:],(fill_num,1))), axis=0) 64 | local_patches.append(local_neighbors) 65 | if is_rand: 66 | return local_patches, out_inds 67 | else: 68 | return local_patches 69 | 70 | 71 | def transform_pc_pytorch(pc, sn): 72 | ''' 73 | 74 | :param pc: 3xN tensor 75 | :param sn: 5xN tensor / 4xN tensor 76 | :param node: 3xM tensor 77 | :return: pc, sn, node of the same shape, detach 78 | ''' 79 | angles_3d = np.random.rand(3) * np.pi * 2 80 | shift = np.random.uniform(-1, 1, (1, 3)) 81 | 82 | sigma, clip = 0.010, 0.02 83 | N, C = pc.shape 84 | jitter_pc = np.clip(sigma * np.random.randn(N, 3), -1 * clip, clip) 85 | sigma, clip = 0.010, 0.02 86 | jitter_sn = np.clip(sigma * np.random.randn(N, 4), -1 * clip, clip) 87 | pc += jitter_pc 88 | sn += jitter_sn 89 | 90 | pc = pc_rotate_translate(pc, angles_3d, shift) 91 | sn[:, 0:3] = vec_rotate(sn[:, 0:3], angles_3d) # 3x3 * 3xN -> 3xN 92 | 93 | return pc, sn, \ 94 | angles_3d, shift 95 | 96 | 97 | def l2_norm(input, axis=1): 98 | norm = torch.norm(input, p=2, dim=axis, keepdim=True) 99 | output = torch.div(input, norm) 100 | return output 101 | 102 | 103 | def angles2rotation_matrix(angles): 104 | Rx = np.array([[1, 0, 0], 105 | [0, np.cos(angles[0]), -np.sin(angles[0])], 106 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 107 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 108 | [0, 1, 0], 109 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 110 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 111 | [np.sin(angles[2]), np.cos(angles[2]), 0], 112 | [0, 0, 1]]) 113 | R = np.dot(Rz, np.dot(Ry, Rx)) 114 | return R 115 | 116 | 117 | def pc_rotate_translate(data, angles, translates): 118 | ''' 119 | :param data: numpy array of Nx3 array 120 | :param angles: numpy array / list of 3 121 | :param translates: numpy array / list of 3 122 | :return: rotated_data: numpy array of Nx3 123 | ''' 124 | R = angles2rotation_matrix(angles) 125 | rotated_data = np.dot(data, np.transpose(R)) + translates 126 | 127 | return rotated_data 128 | 129 | 130 | def pc_rotate_translate_torch(data, angles, translates): 131 | ''' 132 | :param data: Tensor of BxNx3 array 133 | :param angles: Tensor of Bx3 134 | :param translates: Tensor of Bx3 135 | :return: rotated_data: Tensor of Nx3 136 | ''' 137 | device = data.device 138 | B, N, _ = data.shape 139 | 140 | R = np.zeros([B, 3, 3]) 141 | for i in range(B): 142 | R[i] = angles2rotation_matrix(angles[i]) # 3x3 143 | R = torch.FloatTensor(R).to(device) 144 | 145 | rotated_data = torch.matmul(data, R.transpose(-1, -2)) + torch.FloatTensor(translates).unsqueeze(1).to(device) 146 | 147 | return rotated_data 148 | 149 | 150 | def _pc_rotate_translate_torch(data, R, translates): 151 | ''' 152 | :param data: Tensor of BxNx3 array 153 | :param angles: Tensor of Bx3 154 | :param translates: Tensor of Bx3 155 | :return: rotated_data: Tensor of Nx3 156 | ''' 157 | device = data.device 158 | B, N, _ = data.shape 159 | 160 | rotated_data = torch.matmul(data, R.to(device).transpose(-1, -2)) + torch.FloatTensor(translates).unsqueeze(1).to( 161 | device) 162 | 163 | return rotated_data 164 | 165 | 166 | def max_ind(data): 167 | B, C, row, col = data.shape 168 | inds = np.zeros([B, 2]) 169 | for i in range(B): 170 | ind = torch.argmax(data[i]) 171 | r = int(ind // col) 172 | c = ind % col 173 | inds[i, 0] = r 174 | inds[i, 1] = c 175 | return inds 176 | 177 | 178 | def vec_rotate(data, angles): 179 | ''' 180 | :param data: numpy array of Nx3 array 181 | :param angles: numpy array / list of 3 182 | :return: rotated_data: numpy array of Nx3 183 | ''' 184 | R = angles2rotation_matrix(angles) 185 | rotated_data = np.dot(data, R) 186 | 187 | return rotated_data 188 | 189 | 190 | def vec_rotate_torch(data, angles): 191 | ''' 192 | :param data: BxNx3 tensor 193 | :param angles: Bx3 numpy array 194 | :return: 195 | ''' 196 | device = data.device 197 | B, N, _ = data.shape 198 | 199 | R = np.zeros([B, 3, 3]) 200 | for i in range(B): 201 | R[i] = angles2rotation_matrix(angles[i]) # 3x3 202 | R = torch.FloatTensor(R).to(device) 203 | 204 | rotated_data = torch.matmul(data, R.transpose(-1, -2)) # BxNx3 * Bx3x3 -> BxNx3 205 | return rotated_data 206 | 207 | 208 | def rotate_perturbation_point_cloud(data, angle_sigma=0.01, angle_clip=0.05): 209 | """ Randomly perturb the point clouds by small rotations 210 | Input: 211 | Nx3 array, original point clouds 212 | Return: 213 | Nx3 array, rotated point clouds 214 | """ 215 | # truncated Gaussian sampling 216 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip) 217 | rotated_data = vec_rotate(data, angles) 218 | 219 | return rotated_data 220 | 221 | 222 | def jitter_point_cloud(data, sigma=0.01, clip=0.05): 223 | """ Randomly jitter points. jittering is per point. 224 | Input: 225 | BxNx3 array, original point clouds 226 | Return: 227 | BxNx3 array, jittered point clouds 228 | """ 229 | B, N, C = data.shape 230 | assert (clip > 0) 231 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip) 232 | jittered_data += data 233 | return jittered_data 234 | 235 | 236 | def square_distance(src, dst): 237 | """ 238 | Calculate Euclid distance between each two points. 239 | src^T * dst = xn * xm + yn * ym + zn * zm; 240 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 241 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 242 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 243 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 244 | Input: 245 | src: source points, [B, N, C] 246 | dst: target points, [B, M, C] 247 | Output: 248 | dist: per-point square distance, [B, N, M] 249 | """ 250 | B, N, _ = src.shape 251 | _, M, _ = dst.shape 252 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 253 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 254 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 255 | return dist 256 | 257 | 258 | def cdist(a, b): 259 | ''' 260 | :param a: 261 | :param b: 262 | :return: 263 | ''' 264 | diff = a.unsqueeze(0) - b.unsqueeze(1) 265 | dis_matrix = torch.sqrt(torch.sum(diff * diff, dim=-1) + 1e-12) 266 | return dis_matrix 267 | 268 | 269 | def s2_grid(n_alpha, n_beta): 270 | ''' 271 | :return: rings around the equator 272 | size of the kernel = n_alpha * n_beta 273 | ''' 274 | beta = np.linspace(start=0, stop=np.pi, num=n_beta, endpoint=False) + np.pi / n_beta / 2 275 | # ele = np.arcsin(np.linspace(start=0, stop=1, num=n_beta / 2, endpoint=False) + 1 / n_beta / 4) 276 | # beta = np.concatenate([np.sort(-ele), ele]) 277 | alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) + np.pi / n_alpha 278 | B, A = np.meshgrid(beta, alpha, indexing='ij') 279 | B = B.flatten() 280 | A = A.flatten() 281 | grid = np.stack((B, A), axis=1) 282 | return grid 283 | 284 | 285 | def pad_image(input, kernel_size): 286 | """ 287 | Circularly padding image for convolution 288 | :param input: [B, C, H, W] 289 | :param kernel_size: 290 | :return: 291 | """ 292 | device = input.device 293 | if kernel_size % 2 == 0: 294 | pad_size = kernel_size // 2 295 | output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3) 296 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], pad_size, output.shape[3]]).to(device) 297 | output = torch.cat([output, zeros_pad], dim=2) 298 | else: 299 | pad_size = (kernel_size - 1) // 2 300 | output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3) 301 | output = torch.cat([input[:, :, :, -pad_size:], output], dim=3) 302 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], pad_size, output.shape[3]]).to(device) 303 | output = torch.cat([output, zeros_pad], dim=2) 304 | output = torch.cat([zeros_pad, output], dim=2) 305 | return output 306 | 307 | 308 | def pad_image_3d(input, kernel_size): 309 | """ 310 | Circularly padding image for convolution 311 | :param input: [B, C, D, H, W] 312 | :param kernel_size: 313 | :return: 314 | """ 315 | device = input.device 316 | if kernel_size % 2 == 0: 317 | pad_size = kernel_size // 2 318 | output = torch.cat([input, input[:, :, :, :, 0:pad_size]], dim=4) 319 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], output.shape[2], pad_size, output.shape[4]]).to( 320 | device) 321 | output = torch.cat([output, zeros_pad], dim=3) 322 | else: 323 | pad_size = (kernel_size - 1) // 2 324 | output = torch.cat([input, input[:, :, :, :, 0:pad_size]], dim=4) 325 | output = torch.cat([input[:, :, :, :, -pad_size:], output], dim=4) 326 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], output.shape[2], pad_size, output.shape[4]]).to( 327 | device) 328 | output = torch.cat([output, zeros_pad], dim=3) 329 | output = torch.cat([zeros_pad, output], dim=3) 330 | return output 331 | 332 | 333 | def pad_image_on_azi(input, kernel_size): 334 | """ 335 | Circularly padding image for convolution 336 | :param input: [B, C, H, W] 337 | :param kernel_size: 338 | :return: 339 | """ 340 | device = input.device 341 | pad_size = (kernel_size - 1) // 2 342 | output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3) 343 | output = torch.cat([input[:, :, :, -pad_size:], output], dim=3) 344 | return output 345 | 346 | 347 | def kmax_pooling(x, dim, k): 348 | kmax = x.topk(k, dim=dim)[0] 349 | return kmax 350 | 351 | 352 | def change_coordinates(coords, radius, p_from='C', p_to='S'): 353 | """ 354 | Change Spherical to Cartesian coordinates and vice versa, for points x in S^2. 355 | 356 | In the spherical system, we have coordinates beta and alpha, 357 | where beta in [0, pi] and alpha in [0, 2pi] 358 | 359 | We use the names beta and alpha for compatibility with the SO(3) code (S^2 being a quotient SO(3)/SO(2)). 360 | Many sources, like wikipedia use theta=beta and phi=alpha. 361 | 362 | :param coords: coordinate array 363 | :param p_from: 'C' for Cartesian or 'S' for spherical coordinates 364 | :param p_to: 'C' for Cartesian or 'S' for spherical coordinates 365 | :return: new coordinates 366 | """ 367 | if p_from == p_to: 368 | return coords 369 | elif p_from == 'S' and p_to == 'C': 370 | 371 | beta = coords[..., 0] 372 | alpha = coords[..., 1] 373 | r = radius 374 | 375 | out = np.empty(beta.shape + (3,)) 376 | 377 | ct = np.cos(beta) 378 | cp = np.cos(alpha) 379 | st = np.sin(beta) 380 | sp = np.sin(alpha) 381 | out[..., 0] = r * st * cp # x 382 | out[..., 1] = r * st * sp # y 383 | out[..., 2] = r * ct # z 384 | return out 385 | 386 | elif p_from == 'C' and p_to == 'S': 387 | 388 | x = coords[..., 0] 389 | y = coords[..., 1] 390 | z = coords[..., 2] 391 | 392 | out = np.empty(x.shape + (2,)) 393 | out[..., 0] = np.arccos(z) # beta 394 | out[..., 1] = np.arctan2(y, x) # alpha 395 | return out 396 | 397 | else: 398 | raise ValueError('Unknown conversion:' + str(p_from) + ' to ' + str(p_to)) 399 | 400 | 401 | def get_voxel_coordinate(radius, rad_n, azi_n, ele_n): 402 | grid = s2_grid(n_alpha=azi_n, n_beta=ele_n) 403 | pts_xyz_on_S2 = change_coordinates(grid, radius, 'S', 'C') 404 | pts_xyz_on_S2 = np.expand_dims(pts_xyz_on_S2, axis=0).repeat(rad_n, axis=0) 405 | scale = np.reshape(np.arange(rad_n) / rad_n + 1 / (2 * rad_n), [rad_n, 1, 1]) 406 | pts_xyz = scale * pts_xyz_on_S2 407 | return pts_xyz 408 | 409 | 410 | def knn_query(pts, new_pts, knn): 411 | """ 412 | :param pts: all points, [B. N. 3] 413 | :param new_pts: query points, [B, S. 3] 414 | :param knn: the number of queried points 415 | :return: 416 | """ 417 | device = pts.device 418 | B, N, C = pts.shape 419 | _, S, _ = new_pts.shape 420 | group_idx = torch.arange(N).to(device).view(1, 1, N).repeat([B, S, 1]) 421 | sqrdists = square_distance(new_pts, pts) 422 | 423 | 424 | def sphere_query(pts, new_pts, radius, nsample): 425 | """ 426 | :param pts: all points, [B. N. 3] 427 | :param new_pts: query points, [B, S. 3] 428 | :param radius: local sperical radius 429 | :param nsample: max sample number in local sphere 430 | :return: 431 | """ 432 | 433 | device = pts.device 434 | B, N, C = pts.shape 435 | _, S, _ = new_pts.shape 436 | 437 | pts = pts.contiguous() 438 | new_pts = new_pts.contiguous() 439 | group_idx = pnt2.ball_query(radius, nsample, pts, new_pts) 440 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, nsample) 441 | mask = (group_idx == mask).float() 442 | mask[:, :, 0] = 0 443 | 444 | # C implementation 445 | pts_trans = pts.transpose(1, 2).contiguous() 446 | new_points = pnt2.grouping_operation( 447 | pts_trans, group_idx 448 | ) # (B, 3, npoint, nsample) 449 | new_points = new_points.permute([0, 2, 3, 1]) 450 | 451 | # replace the wrong points using new_pts 452 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3]) 453 | # new_pts = new_pts.unsqueeze(2).repeat([1, 1, nsample + 1, 1]) 454 | new_pts = new_pts.unsqueeze(2).repeat([1, 1, nsample, 1]) 455 | n_points = new_points * (1 - mask).float() + new_pts * mask.float() 456 | 457 | del mask 458 | del new_points 459 | del group_idx 460 | del new_pts 461 | del pts 462 | del pts_trans 463 | 464 | return n_points 465 | 466 | 467 | def sphere_query_new(pts, new_pts, radius, nsample): 468 | """ 469 | :param pts: all points, [B. N. 3] 470 | :param new_pts: query points, [B, S. 3] 471 | :param radius: local sperical radius 472 | :param nsample: max sample number in local sphere 473 | :return: 474 | """ 475 | 476 | device = pts.device 477 | B, N, C = pts.shape 478 | _, S, _ = new_pts.shape 479 | 480 | pts = pts.contiguous() 481 | new_pts = new_pts.contiguous() 482 | group_idx = pnt2.ball_query(radius, nsample, pts, new_pts) 483 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, nsample) 484 | mask = (group_idx == mask).float() 485 | mask[:, :, 0] = 0 486 | 487 | mask1 = (group_idx[:, :, 0] == 0).unsqueeze(2).float() 488 | mask1 = torch.cat([mask1, torch.zeros_like(mask)[:, :, :-1]], dim=2) 489 | mask = mask + mask1 490 | 491 | # C implementation 492 | pts_trans = pts.transpose(1, 2).contiguous() 493 | new_points = pnt2.grouping_operation( 494 | pts_trans, group_idx 495 | ) # (B, 3, npoint, nsample) 496 | new_points = new_points.permute([0, 2, 3, 1]) 497 | 498 | # replace the wrong points using new_pts 499 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3]) 500 | n_points = new_points * (1 - mask).float() 501 | 502 | del mask 503 | del new_points 504 | del group_idx 505 | del new_pts 506 | del pts 507 | del pts_trans 508 | 509 | return n_points 510 | 511 | 512 | def var_to_invar(pts, rad_n, azi_n, ele_n): 513 | """ 514 | :param pts: input points data, [B, N, nsample, 3] 515 | :param rad_n: radial number 516 | :param azi_n: azimuth number 517 | :param ele_n: elevator number 518 | :return: 519 | """ 520 | device = pts.device 521 | B, N, nsample, C = pts.shape 522 | assert N == rad_n * azi_n * ele_n 523 | angle_step = np.array([0, 0, 2 * np.pi / azi_n]) 524 | pts = pts.view(B, rad_n, ele_n, azi_n, nsample, C) 525 | 526 | R = np.zeros([azi_n, 3, 3]) 527 | for i in range(azi_n): 528 | angle = -1 * i * angle_step 529 | r = angles2rotation_matrix(angle) 530 | R[i] = r 531 | R = torch.FloatTensor(R).to(device) 532 | R = R.view(1, 1, 1, azi_n, 3, 3).repeat(B, rad_n, ele_n, 1, 1, 1) 533 | new_pts = torch.matmul(pts, R.transpose(-1, -2)) 534 | 535 | del R 536 | del pts 537 | 538 | return new_pts.view(B, -1, nsample, C) 539 | 540 | 541 | def cal_Z_axis(local_cor, local_weight=None, ref_point=None): 542 | device = local_cor.device 543 | B, N, _ = local_cor.shape 544 | cov_matrix = torch.matmul(local_cor.transpose(-1, -2), local_cor) if local_weight is None \ 545 | else Variable(torch.matmul(local_cor.transpose(-1, -2), local_cor * local_weight), requires_grad=True) 546 | Z_axis = torch.symeig(cov_matrix, eigenvectors=True)[1][:, :, 0] 547 | mask = (torch.sum(-Z_axis * ref_point, dim=1) < 0).float().unsqueeze(1) 548 | Z_axis = Z_axis * (1 - mask) - Z_axis * mask 549 | 550 | return Z_axis 551 | 552 | 553 | def RodsRotatFormula(a, b): 554 | B, _ = a.shape 555 | device = a.device 556 | b = b.to(device) 557 | c = torch.cross(a, b) 558 | theta = torch.acos(F.cosine_similarity(a, b)).unsqueeze(1).unsqueeze(2) 559 | 560 | c = F.normalize(c, p=2, dim=1) 561 | one = torch.ones(B, 1, 1).to(device) 562 | zero = torch.zeros(B, 1, 1).to(device) 563 | a11 = zero 564 | a12 = -c[:, 2].unsqueeze(1).unsqueeze(2) 565 | a13 = c[:, 1].unsqueeze(1).unsqueeze(2) 566 | a21 = c[:, 2].unsqueeze(1).unsqueeze(2) 567 | a22 = zero 568 | a23 = -c[:, 0].unsqueeze(1).unsqueeze(2) 569 | a31 = -c[:, 1].unsqueeze(1).unsqueeze(2) 570 | a32 = c[:, 0].unsqueeze(1).unsqueeze(2) 571 | a33 = zero 572 | Rx = torch.cat( 573 | (torch.cat((a11, a12, a13), dim=2), torch.cat((a21, a22, a23), dim=2), torch.cat((a31, a32, a33), dim=2)), 574 | dim=1) 575 | I = torch.eye(3).to(device) 576 | R = I.unsqueeze(0).repeat(B, 1, 1) + torch.sin(theta) * Rx + (1 - torch.cos(theta)) * torch.matmul(Rx, Rx) 577 | return R.transpose(-1, -2) 578 | 579 | 580 | def rgbd_to_point_cloud(data_dir, ind, downsample=0.03, aligned=True): 581 | pcd = open3d.read_point_cloud(os.path.join(data_dir, f'{ind}.ply')) 582 | # downsample the point cloud 583 | if downsample != 0: 584 | pcd = open3d.voxel_down_sample(pcd, voxel_size=downsample) 585 | # align the point cloud 586 | if aligned is True: 587 | matrix = np.load(os.path.join(data_dir, f'{ind}.pose.npy')) 588 | pcd.transform(matrix) 589 | 590 | return pcd 591 | 592 | 593 | def cal_local_normal(pcd): 594 | if open3d.geometry.estimate_normals(pcd, open3d.KDTreeSearchParamKNN(knn=17)): 595 | return True 596 | else: 597 | print("Calculate Normal Error") 598 | return False 599 | 600 | 601 | def select_referenced_point(pcd, num_patches=2048): 602 | # A point sampling algorithm for 3d matching of irregular geometries. 603 | pts = np.asarray(pcd.points) 604 | num_points = pts.shape[0] 605 | inds = np.random.choice(range(num_points), num_patches, replace=False) 606 | return open3d.geometry.select_down_sample(pcd, inds) 607 | 608 | 609 | def collect_local_neighbor(ref_pcd, pcd, vicinity=0.3, num_points_per_patch=1024, random_state=None): 610 | # collect local neighbor within vicinity for each interest point. 611 | # each local patch is downsampled to 1024 (setting of PPFNet p5.) 612 | kdtree = open3d.geometry.KDTreeFlann(pcd) 613 | dict = [] 614 | for point in ref_pcd.points: 615 | # Bug fix: here the first returned result will be itself. So the calculated ppf will be nan. 616 | [k, idx, variant] = kdtree.search_radius_vector_3d(point, vicinity) 617 | # random select fix number [num_points] of points to form the local patch. 618 | if random_state is not None: 619 | if k > num_points_per_patch: 620 | idx = random_state.choice(idx[1:], num_points_per_patch, replace=False) 621 | else: 622 | idx = random_state.choice(idx[1:], num_points_per_patch) 623 | else: 624 | if k > num_points_per_patch: 625 | idx = np.random.choice(idx[1:], num_points_per_patch, replace=False) 626 | else: 627 | idx = np.random.choice(idx[1:], num_points_per_patch) 628 | dict.append(idx) 629 | return dict 630 | -------------------------------------------------------------------------------- /script/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd ../data 3 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_bm_1-brown_bm_1.zip 4 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_bm_4-brown_bm_4.zip 5 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_cogsci_1-brown_cogsci_1.zip 6 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_cs_2-brown_cs2.zip 7 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_cs_3-brown_cs3.zip 8 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c3-hv_c3_1.zip 9 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c5-hv_c5_1.zip 10 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c6-hv_c6_1.zip 11 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c8-hv_c8_3.zip 12 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c11-hv_c11_2.zip 13 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-home_at-home_at_scan1_2013_jan_1.zip 14 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika.zip 15 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-home_md-home_md_scan9_2012_sep_30.zip 16 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_nips2012-nips_4.zip 17 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_sf-scan1.zip 18 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_uc-scan3.zip 19 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_umd-maryland_hotel1.zip 20 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_umd-maryland_hotel3.zip 21 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_32_d507-d507_2.zip 22 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_46_ted_lab1-ted_lab_2.zip 23 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_76_417-76-417b.zip 24 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_76_studyroom-76-1studyroom2.zip 25 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika.zip 26 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika.zip 27 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika.zip 28 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-chess.zip 29 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-fire.zip 30 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-heads.zip 31 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-office.zip 32 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-pumpkin.zip 33 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-redkitchen.zip 34 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-stairs.zip 35 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_01.zip 36 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_02.zip 37 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_03.zip 38 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_04.zip 39 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_05.zip 40 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_06.zip 41 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_07.zip 42 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_08.zip 43 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_09.zip 44 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_10.zip 45 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_11.zip 46 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_12.zip 47 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_13.zip 48 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_14.zip 49 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-apt0.zip 50 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-apt1.zip 51 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-apt2.zip 52 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-copyroom.zip 53 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office0.zip 54 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office1.zip 55 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office2.zip 56 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office3.zip 57 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt1-kitchen.zip 58 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt1-living.zip 59 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-bed.zip 60 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-kitchen.zip 61 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-living.zip 62 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-luke.zip 63 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-office2-5a.zip 64 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-office2-5b.zip 65 | -------------------------------------------------------------------------------- /script/fuse_fragments_3DMatch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | from pathlib import Path 5 | import argparse 6 | import math 7 | import numpy as np 8 | import os.path as osp 9 | import os 10 | import sys 11 | 12 | ROOT_DIR = osp.abspath('../') 13 | if ROOT_DIR not in sys.path: 14 | sys.path.append(ROOT_DIR) 15 | 16 | from script import io as uio 17 | 18 | 19 | # ---------------------------------------------------------------------------- # 20 | # Fuse rgbd frames into fragments in 3DMatch 21 | # - Use existing camera poses 22 | # - Save colors & normals 23 | # ---------------------------------------------------------------------------- # 24 | def read_intrinsic(filepath, width, height): 25 | import open3d as o3d 26 | 27 | m = np.loadtxt(filepath, dtype=np.float32) 28 | intrinsic = o3d.camera.PinholeCameraIntrinsic(width, height, m[0, 0], m[1, 1], m[0, 2], m[1, 2]) 29 | return intrinsic 30 | 31 | 32 | def read_extrinsic(filepath): 33 | m = np.loadtxt(filepath, dtype=np.float32) 34 | if np.isnan(m).any(): 35 | return None 36 | return m # (4, 4) 37 | 38 | 39 | def read_rgbd_image(cfg, color_file, depth_file, convert_rgb_to_intensity): 40 | import open3d as o3d 41 | if color_file is None: 42 | color_file = depth_file # to avoid "Unsupported image format." 43 | # rgbd_image = o3d.RGBDImage() 44 | # rgbd_image.depth = o3d.io.read_image(depth_file) 45 | # return rgbd_image 46 | color = o3d.io.read_image(color_file) 47 | depth = o3d.io.read_image(depth_file) 48 | rgbd_image = o3d.geometry.create_rgbd_image_from_color_and_depth(color, depth, cfg.depth_scale, cfg.depth_trunc, 49 | convert_rgb_to_intensity) 50 | return rgbd_image 51 | 52 | 53 | def process_single_fragment(cfg, color_files, depth_files, frag_id, n_frags, intrinsic_path, out_folder): 54 | import open3d as o3d 55 | 56 | depth_only_flag = (len(color_files) == 0) 57 | n_frames = len(depth_files) 58 | intrinsic = read_intrinsic(intrinsic_path, cfg.width, cfg.height) 59 | if depth_only_flag: 60 | color_type = o3d.integration.TSDFVolumeColorType.__dict__['None'] 61 | else: 62 | color_type = o3d.integration.TSDFVolumeColorType.__dict__['RGB8'] 63 | 64 | volume = o3d.integration.ScalableTSDFVolume(voxel_length=cfg.tsdf_cubic_size / 512.0, 65 | sdf_trunc=0.04, 66 | color_type=color_type) 67 | 68 | sid = frag_id * cfg.frames_per_frag 69 | eid = min(sid + cfg.frames_per_frag, n_frames) 70 | pose_base2world = None 71 | pose_base2world_inv = None 72 | for fid in range(sid, eid): 73 | if not depth_only_flag: 74 | color_path = color_files[fid] 75 | else: 76 | color_path = None 77 | depth_path = depth_files[fid] 78 | pose_path = depth_path[:-10] + '.pose.txt' 79 | 80 | pose_cam2world = read_extrinsic(pose_path) 81 | if pose_cam2world is None: 82 | continue 83 | if fid == sid: # Use as base frame 84 | pose_base2world = pose_cam2world 85 | pose_base2world_inv = np.linalg.inv(pose_base2world) 86 | if pose_base2world_inv is None: 87 | break 88 | # Relative camera pose 89 | pose_cam2world = np.matmul(pose_base2world_inv, pose_cam2world) 90 | 91 | rgbd = read_rgbd_image(cfg, color_path, depth_path, False) 92 | volume.integrate(rgbd, intrinsic, np.linalg.inv(pose_cam2world)) 93 | if pose_base2world_inv is None: 94 | return 95 | 96 | pcloud = volume.extract_point_cloud() 97 | o3d.geometry.estimate_normals(pcloud) 98 | o3d.write_point_cloud(osp.join(out_folder, 'cloud_bin_{}.ply'.format(frag_id)), pcloud) 99 | 100 | np.save(osp.join(out_folder, 'cloud_bin_{}.pose.npy'.format(frag_id)), pose_base2world) 101 | 102 | 103 | # ---------------------------------------------------------------------------- # 104 | # Iterate Folders 105 | # ---------------------------------------------------------------------------- # 106 | def run_seq(cfg, scene, seq): 107 | print(" Start {}".format(seq)) 108 | 109 | seq_folder = osp.join(cfg.dataset_root, scene, seq) 110 | color_names = uio.list_files(seq_folder, '*.color.png') 111 | color_paths = [osp.join(seq_folder, cf) for cf in color_names] 112 | depth_names = uio.list_files(seq_folder, '*.depth.png') 113 | depth_paths = [osp.join(seq_folder, df) for df in depth_names] 114 | # depth_paths = [osp.join(seq_folder, cf[:-10] + '.depth.png') for cf in depth_names] 115 | 116 | # n_frames = len(color_paths) 117 | n_frames = len(depth_paths) 118 | n_frags = int(math.ceil(float(n_frames) / cfg.frames_per_frag)) 119 | 120 | out_folder = osp.join(cfg.out_root, scene, seq) 121 | uio.may_create_folder(out_folder) 122 | 123 | intrinsic_path = osp.join(cfg.dataset_root, scene, 'camera-intrinsics.txt') 124 | 125 | if cfg.threads > 1: 126 | from joblib import Parallel, delayed 127 | import multiprocessing 128 | 129 | Parallel(n_jobs=cfg.threads)( 130 | delayed(process_single_fragment)(cfg, color_paths, depth_paths, frag_id, n_frags, intrinsic_path, 131 | out_folder) 132 | for frag_id in range(n_frags)) 133 | 134 | else: 135 | for frag_id in range(n_frags): 136 | process_single_fragment(cfg, color_paths, depth_paths, frag_id, n_frags, intrinsic_path, out_folder) 137 | 138 | print(" Finished {}".format(seq)) 139 | 140 | 141 | def run_scene(cfg, scene): 142 | print(" Start scene {} ".format(scene)) 143 | 144 | scene_folder = osp.join(cfg.dataset_root, scene) 145 | seqs = uio.list_folders(scene_folder) 146 | print(" {} sequences".format(len(seqs))) 147 | for seq in seqs: 148 | run_seq(cfg, scene, seq) 149 | 150 | print(" Finished scene {} ".format(scene)) 151 | 152 | 153 | def run(cfg): 154 | print("Start making fragments") 155 | 156 | uio.may_create_folder(cfg.out_root) 157 | 158 | scenes = uio.list_folders(cfg.dataset_root, sort=False) 159 | print("{} scenes".format(len(scenes))) 160 | for scene in scenes: 161 | # if not scene.startswith('analysis'): 162 | # continue 163 | run_scene(cfg, scene) 164 | 165 | print("Finished making fragments") 166 | 167 | 168 | # ---------------------------------------------------------------------------- # 169 | # Arguments 170 | # ---------------------------------------------------------------------------- # 171 | def parse_args(): 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--dataset_root', default='../../data/3DMatch_raw/') 174 | parser.add_argument('--out_root', default='../../data/3DMatch_fragments/') 175 | parser.add_argument('--depth_scale', type=float, default=1000.0) 176 | parser.add_argument('--depth_trunc', type=float, default=6.0) 177 | parser.add_argument('--frames_per_frag', type=int, default=50) 178 | parser.add_argument('--height', type=int, default=480) 179 | parser.add_argument('--threads', type=int, default=1) 180 | parser.add_argument('--tsdf_cubic_size', type=float, default=3.0) 181 | parser.add_argument('--width', type=int, default=640) 182 | 183 | return parser.parse_args() 184 | 185 | 186 | if __name__ == '__main__': 187 | cfg = parse_args() 188 | run(cfg) 189 | -------------------------------------------------------------------------------- /script/io.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | from collections import defaultdict 5 | from pathlib import Path 6 | import cv2 7 | import json 8 | import numpy as np 9 | import os 10 | import os.path as osp 11 | import re 12 | import shutil 13 | 14 | 15 | def is_number(s): 16 | try: 17 | float(s) 18 | return True 19 | except ValueError: 20 | return False 21 | 22 | 23 | # ---------------------------------------------------------------------------- # 24 | # Common IO 25 | # ---------------------------------------------------------------------------- # 26 | def may_create_folder(folder_path): 27 | if not osp.exists(folder_path): 28 | oldmask = os.umask(000) 29 | os.makedirs(folder_path, mode=0o777) 30 | os.umask(oldmask) 31 | return True 32 | return False 33 | 34 | 35 | def make_clean_folder(folder_path): 36 | success = may_create_folder(folder_path) 37 | if not success: 38 | shutil.rmtree(folder_path) 39 | may_create_folder(folder_path) 40 | 41 | 42 | def sorted_alphanum(file_list_ordered): 43 | convert = lambda text: int(text) if text.isdigit() else text 44 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key) if len(c) > 0] 45 | return sorted(file_list_ordered, key=alphanum_key) 46 | 47 | 48 | def list_files(folder_path, name_filter, sort=True): 49 | file_list = [p.name for p in list(Path(folder_path).glob(name_filter))] 50 | if sort: 51 | return sorted_alphanum(file_list) 52 | else: 53 | return file_list 54 | 55 | 56 | def list_folders(folder_path, name_filter=None, sort=True): 57 | folders = list() 58 | for subfolder in Path(folder_path).iterdir(): 59 | if subfolder.is_dir() and not subfolder.name.startswith('.'): 60 | folder_name = subfolder.name 61 | if name_filter is not None: 62 | if name_filter in folder_name: 63 | folders.append(folder_name) 64 | else: 65 | folders.append(folder_name) 66 | if sort: 67 | return sorted_alphanum(folders) 68 | else: 69 | return folders 70 | 71 | 72 | def read_lines(file_path): 73 | """ 74 | :param file_path: 75 | :return: 76 | """ 77 | with open(file_path, 'r') as fin: 78 | lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0] 79 | return lines 80 | 81 | 82 | def read_json(filepath): 83 | with open(filepath, 'r') as fh: 84 | ret = json.load(fh) 85 | return ret 86 | 87 | 88 | # ---------------------------------------------------------------------------- # 89 | # Image IO 90 | # ---------------------------------------------------------------------------- # 91 | def read_color_image(file_path): 92 | """ 93 | Args: 94 | file_path (str): 95 | 96 | Returns: 97 | np.array: RGB. 98 | """ 99 | img = cv2.imread(file_path) 100 | return img[..., ::-1] 101 | 102 | 103 | def read_gray_image(file_path): 104 | """Load a gray image 105 | 106 | Args: 107 | file_path (str): 108 | 109 | Returns: 110 | np.array: np.uint8, max 255. 111 | """ 112 | img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE) 113 | return img 114 | 115 | 116 | def read_16bit_image(file_path): 117 | """Load a 16bit image 118 | 119 | Args: 120 | file_path (str): 121 | 122 | Returns: 123 | np.array: np.uint16, max 65535. 124 | """ 125 | img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED) 126 | return img 127 | 128 | 129 | def write_color_image(file_path, image): 130 | """ 131 | Args: 132 | file_path (str): 133 | image (np.array): in RGB. 134 | 135 | Returns: 136 | str: 137 | """ 138 | cv2.imwrite(file_path, image[..., ::-1]) 139 | return file_path 140 | 141 | 142 | def write_gray_image(file_path, image): 143 | """ 144 | Args: 145 | file_path (str): 146 | image (np.array): 147 | 148 | Returns: 149 | str: 150 | """ 151 | cv2.imwrite(file_path, image) 152 | return file_path 153 | 154 | 155 | def write_image(file_path, image): 156 | """ 157 | Args: 158 | file_path (str): 159 | image (np.array): 160 | 161 | Returns: 162 | str: 163 | """ 164 | if image.ndim == 2: 165 | return write_gray_image(file_path, image) 166 | elif image.ndim == 3: 167 | return write_color_image(file_path, image) 168 | else: 169 | raise RuntimeError('Image dimensions are not correct!') 170 | --------------------------------------------------------------------------------