├── .gitignore
├── KITTI
├── Test
│ ├── test_kitti.py
│ └── test_kitti.txt
└── Train
│ ├── dataloader.py
│ ├── dataset.py
│ ├── train.py
│ └── trainer.py
├── LICENSE
├── README.md
├── ThreeDMatch
├── Test
│ ├── 3dmatch
│ │ ├── evaluate.m
│ │ └── external
│ │ │ ├── ElasticReconstruction
│ │ │ ├── mrDrawTrajectory.m
│ │ │ ├── mrEvaluateRegistration.m
│ │ │ ├── mrEvaluateTrajectory.m
│ │ │ ├── mrLoadInfo.m
│ │ │ ├── mrLoadLog.m
│ │ │ ├── mrMatchDepthColor.m
│ │ │ ├── mrWriteInfo.m
│ │ │ └── mrWriteLog.m
│ │ │ └── npy-matlab
│ │ │ ├── constructNPYheader.m
│ │ │ ├── datToNPY.m
│ │ │ ├── readNPY.m
│ │ │ ├── readNPYheader.m
│ │ │ └── writeNPY.m
│ ├── evaluate.py
│ ├── gt_result
│ │ ├── 7-scenes-redkitchen-evaluation
│ │ │ └── gt.info
│ │ ├── sun3d-home_at-home_at_scan1_2013_jan_1-evaluation
│ │ │ └── gt.info
│ │ ├── sun3d-home_md-home_md_scan9_2012_sep_30-evaluation
│ │ │ └── gt.info
│ │ ├── sun3d-hotel_uc-scan3-evaluation
│ │ │ └── gt.info
│ │ ├── sun3d-hotel_umd-maryland_hotel1-evaluation
│ │ │ └── gt.info
│ │ ├── sun3d-hotel_umd-maryland_hotel3-evaluation
│ │ │ └── gt.info
│ │ ├── sun3d-mit_76_studyroom-76-1studyroom2-evaluation
│ │ │ └── gt.info
│ │ └── sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika-evaluation
│ │ │ └── gt.info
│ ├── preparation.py
│ └── tools.py
└── Train
│ ├── dataloader.py
│ ├── dataset.py
│ ├── train.py
│ └── trainer.py
├── figs
├── Fig1.png
├── Fig2.png
├── Fig3.png
├── Fig4.png
├── Fig5.png
├── Table1.png
├── Table2.png
├── Table3.png
├── Table4.png
├── Table5.png
├── Table6.png
└── Table7.png
├── generalization
├── KITTI-to-ThreeDMatch
│ ├── evaluate.py
│ └── preparation.py
├── ThreeDMatch-to-ETH
│ ├── evaluate.py
│ └── preparation.py
└── ThreeDMatch-to-KITTI
│ ├── test.py
│ └── test_kitti.txt
├── loss
└── desc_loss.py
├── network
├── SpinNet.py
└── ThreeDCCN.py
├── pre-trained_models
├── 3DMatch_best.pkl
└── KITTI_best.pkl
└── script
├── cal_overlap.py
├── common.py
├── download.sh
├── fuse_fragments_3DMatch.py
└── io.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/KITTI/Test/test_kitti.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "1"
4 | import logging
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | import torch.nn as nn
9 | import glob
10 | import time
11 | import gc
12 | import shutil
13 | import pointnet2_ops.pointnet2_utils as pnt2
14 | import copy
15 | import importlib
16 | import sys
17 |
18 | sys.path.append('../../')
19 | import script.common as cm
20 |
21 | kitti_icp_cache = {}
22 | kitti_cache = {}
23 |
24 |
25 | class Timer(object):
26 | """A simple timer."""
27 |
28 | def __init__(self, binary_fn=None, init_val=0):
29 | self.total_time = 0.
30 | self.calls = 0
31 | self.start_time = 0.
32 | self.diff = 0.
33 | self.binary_fn = binary_fn
34 | self.tmp = init_val
35 |
36 | def reset(self):
37 | self.total_time = 0
38 | self.calls = 0
39 | self.start_time = 0
40 | self.diff = 0
41 |
42 | @property
43 | def avg(self):
44 | return self.total_time / self.calls
45 |
46 | def tic(self):
47 | # using time.time instead of time.clock because time time.clock
48 | # does not normalize for multithreading
49 | self.start_time = time.time()
50 |
51 | def toc(self, average=True):
52 | self.diff = time.time() - self.start_time
53 | self.total_time += self.diff
54 | self.calls += 1
55 | if self.binary_fn:
56 | self.tmp = self.binary_fn(self.tmp, self.diff)
57 | if average:
58 | return self.avg
59 | else:
60 | return self.diff
61 |
62 |
63 | class AverageMeter(object):
64 | """Computes and stores the average and current value"""
65 |
66 | def __init__(self):
67 | self.reset()
68 |
69 | def reset(self):
70 | self.val = 0
71 | self.avg = 0
72 | self.sum = 0.0
73 | self.sq_sum = 0.0
74 | self.count = 0
75 |
76 | def update(self, val, n=1):
77 | self.val = val
78 | self.sum += val * n
79 | self.count += n
80 | self.avg = self.sum / self.count
81 | self.sq_sum += val ** 2 * n
82 | self.var = self.sq_sum / self.count - self.avg ** 2
83 |
84 |
85 | def get_desc(descpath, filename):
86 | desc = np.load(os.path.join(descpath, filename + '.npy'))
87 | return desc
88 |
89 |
90 | def get_keypts(keypts_path, filename):
91 | keypts = np.load(os.path.join(keypts_path, filename + '.npy'))
92 | return keypts
93 |
94 |
95 | def make_open3d_feature(data, dim, npts):
96 | feature = o3d.pipelines.registration.Feature()
97 | feature.resize(dim, npts)
98 | feature.data = data.astype('d').transpose()
99 | return feature
100 |
101 |
102 | def make_open3d_point_cloud(xyz, color=None):
103 | pcd = o3d.geometry.PointCloud()
104 | pcd.points = o3d.utility.Vector3dVector(xyz)
105 | if color is not None:
106 | pcd.paint_uniform_color(color)
107 | return pcd
108 |
109 |
110 | def get_matching_indices(source, target, trans, search_voxel_size, K=None):
111 | source_copy = copy.deepcopy(source)
112 | target_copy = copy.deepcopy(target)
113 | source_copy.transform(trans)
114 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy)
115 |
116 | match_inds = []
117 | for i, point in enumerate(source_copy.points):
118 | [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size)
119 | if K is not None:
120 | idx = idx[:K]
121 | for j in idx:
122 | match_inds.append((i, j))
123 | return match_inds
124 |
125 |
126 | class KITTI(object):
127 | DATA_FILES = {
128 | 'train': 'train_kitti.txt',
129 | 'val': 'val_kitti.txt',
130 | 'test': 'test_kitti.txt'
131 | }
132 | """
133 | Given point cloud fragments and corresponding pose in '{root}'.
134 | 1. Save the aligned point cloud pts in '{savepath}/3DMatch_{downsample}_points.pkl'
135 | 2. Calculate the overlap ratio and save in '{savepath}/3DMatch_{downsample}_overlap.pkl'
136 | 3. Save the ids of anchor keypoints and positive keypoints in '{savepath}/3DMatch_{downsample}_keypts.pkl'
137 | """
138 |
139 | def __init__(self, root, descpath, icp_path, split, model, num_points_per_patch, use_random_points):
140 | self.root = root
141 | self.descpath = descpath
142 | self.split = split
143 | self.num_points_per_patch = num_points_per_patch
144 | self.icp_path = icp_path
145 | self.use_random_points = use_random_points
146 | self.model = model
147 | if not os.path.exists(self.icp_path):
148 | os.makedirs(self.icp_path)
149 |
150 | # list: anc & pos
151 | self.patches = []
152 | self.pose = []
153 | # Initiate containers
154 | self.files = {'train': [], 'val': [], 'test': []}
155 |
156 | self.prepare_kitti_ply(split=self.split)
157 |
158 | def prepare_kitti_ply(self, split='train'):
159 | subset_names = open(self.DATA_FILES[split]).read().split()
160 | for dirname in subset_names:
161 | drive_id = int(dirname)
162 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id)
163 | assert len(fnames) > 0, f"Make sure that the path {self.root} has data {dirname}"
164 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames])
165 |
166 | all_odo = self.get_video_odometry(drive_id, return_all=True)
167 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo])
168 | Ts = all_pos[:, :3, 3]
169 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3)) ** 2
170 | pdist = np.sqrt(pdist.sum(-1))
171 | more_than_10 = pdist > 10
172 | curr_time = inames[0]
173 | while curr_time in inames:
174 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0]
175 | if len(next_time) == 0:
176 | curr_time += 1
177 | else:
178 | next_time = next_time[0] + curr_time - 1
179 |
180 | if next_time in inames:
181 | self.files[split].append((drive_id, curr_time, next_time))
182 | curr_time = next_time + 1
183 | # Remove problematic sequence
184 | for item in [
185 | (8, 15, 58),
186 | ]:
187 | if item in self.files[split]:
188 | self.files[split].pop(self.files[split].index(item))
189 |
190 | if split == 'train':
191 | self.num_train = len(self.files[split])
192 | print("Num_train", self.num_train)
193 | elif split == 'val':
194 | self.num_val = len(self.files[split])
195 | print("Num_val", self.num_val)
196 | elif split == 'test':
197 | self.num_test = len(self.files[split])
198 | print("Num_test", self.num_test)
199 |
200 | for idx in range(len(self.files[split])):
201 | drive = self.files[split][idx][0]
202 | t0, t1 = self.files[split][idx][1], self.files[split][idx][2]
203 | all_odometry = self.get_video_odometry(drive, [t0, t1])
204 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry]
205 | fname0 = self._get_velodyne_fn(drive, t0)
206 | fname1 = self._get_velodyne_fn(drive, t1)
207 |
208 | # XYZ and reflectance
209 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4)
210 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4)
211 |
212 | xyz0 = xyzr0[:, :3]
213 | xyz1 = xyzr1[:, :3]
214 |
215 | key = '%d_%d_%d' % (drive, t0, t1)
216 | filename = self.icp_path + '/' + key + '.npy'
217 | if key not in kitti_icp_cache:
218 | if not os.path.exists(filename):
219 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T)
220 | @ np.linalg.inv(self.velo2cam)).T
221 | xyz0_t = self.apply_transform(xyz0, M)
222 | pcd0 = make_open3d_point_cloud(xyz0_t, [0.5, 0.5, 0.5])
223 | pcd1 = make_open3d_point_cloud(xyz1, [0, 1, 0])
224 | reg = o3d.pipelines.registration.registration_icp(pcd0, pcd1, 0.10, np.eye(4),
225 | o3d.pipelines.registration.TransformationEstimationPointToPoint(),
226 | o3d.pipelines.registration.ICPConvergenceCriteria(
227 | max_iteration=400))
228 | pcd0.transform(reg.transformation)
229 | M2 = M @ reg.transformation
230 | # write to a file
231 | np.save(filename, M2)
232 | else:
233 | M2 = np.load(filename)
234 | kitti_icp_cache[key] = M2
235 | else:
236 | M2 = kitti_icp_cache[key]
237 | trans = M2
238 | # extract patches for anc&pos
239 | np.random.shuffle(xyz0)
240 | np.random.shuffle(xyz1)
241 |
242 | if is_rotate_dataset:
243 | # Add arbitrary rotation
244 | # rotate terminal frament with an arbitrary angle around the z-axis
245 | angles_3d = np.random.rand(3) * np.pi * 2
246 | R = cm.angles2rotation_matrix(angles_3d)
247 | T = np.identity(4)
248 | T[:3, :3] = R
249 | pcd1 = make_open3d_point_cloud(xyz1)
250 | pcd1.transform(T)
251 | xyz1 = np.array(pcd1.points)
252 | all_trans_matrix[key] = T
253 |
254 | if not os.path.exists(self.descpath + str(drive)):
255 | os.makedirs(self.descpath + str(drive))
256 | if self.use_random_points:
257 | num_keypts = 5000
258 | step_size = 50
259 | desc_len = 32
260 | model = self.model.cuda()
261 | # calc t0 descriptors
262 | desc_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".desc.bin.npy")
263 | keypts_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".keypts.npy")
264 | if not os.path.exists(desc_t0_path):
265 | keypoints_id = np.random.choice(xyz0.shape[0], num_keypts)
266 | keypts = xyz0[keypoints_id]
267 | np.save(keypts_t0_path, keypts.astype(np.float32))
268 | local_patches = self.select_patches(xyz0, keypts, vicinity=vicinity,
269 | num_points_per_patch=self.num_points_per_patch)
270 | B = local_patches.shape[0]
271 | # cuda out of memry
272 | desc_list = []
273 | start_time = time.time()
274 | iter_num = np.int(np.ceil(B / step_size))
275 | for k in range(iter_num):
276 | if k == iter_num - 1:
277 | desc = model(local_patches[k * step_size:, :, :])
278 | else:
279 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :])
280 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
281 | del desc
282 | step_time = time.time() - start_time
283 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
284 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
285 | np.save(desc_t0_path, desc.astype(np.float32))
286 | else:
287 | print(f"{desc_t0_path} already exists.")
288 |
289 | # calc t1 descriptors
290 | desc_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".desc.bin.npy")
291 | keypts_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".keypts.npy")
292 | if not os.path.exists(desc_t1_path):
293 | keypoints_id = np.random.choice(xyz1.shape[0], num_keypts)
294 | keypts = xyz1[keypoints_id]
295 | np.save(keypts_t1_path, keypts.astype(np.float32))
296 | local_patches = self.select_patches(xyz1, keypts, vicinity=vicinity,
297 | num_points_per_patch=self.num_points_per_patch)
298 | B = local_patches.shape[0]
299 | # calculate descriptors
300 | desc_list = []
301 | start_time = time.time()
302 | iter_num = np.int(np.ceil(B / step_size))
303 | for k in range(iter_num):
304 | if k == iter_num - 1:
305 | desc = model(local_patches[k * step_size:, :, :])
306 | else:
307 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :])
308 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
309 | del desc
310 | step_time = time.time() - start_time
311 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
312 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
313 | np.save(desc_t1_path, desc.astype(np.float32))
314 | else:
315 | print(f"{desc_t1_path} already exists.")
316 | else:
317 | num_keypts = 512
318 |
319 | def select_patches(self, pts, refer_pts, vicinity, num_points_per_patch=1024):
320 | gc.collect()
321 | pts = torch.FloatTensor(pts).cuda().unsqueeze(0)
322 | refer_pts = torch.FloatTensor(refer_pts).cuda().unsqueeze(0)
323 | group_idx = pnt2.ball_query(vicinity, num_points_per_patch, pts, refer_pts)
324 | pts_trans = pts.transpose(1, 2).contiguous()
325 | new_points = pnt2.grouping_operation(
326 | pts_trans, group_idx
327 | )
328 | new_points = new_points.permute([0, 2, 3, 1])
329 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, num_points_per_patch)
330 | mask = (group_idx == mask).float()
331 | mask[:, :, 0] = 0
332 | mask[:, :, num_points_per_patch - 1] = 1
333 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3])
334 | new_pts = refer_pts.unsqueeze(2).repeat([1, 1, num_points_per_patch, 1])
335 | local_patches = new_points * (1 - mask).float() + new_pts * mask.float()
336 | local_patches = local_patches.squeeze(0)
337 | del mask
338 | del new_points
339 | del group_idx
340 | del new_pts
341 | del pts
342 | del pts_trans
343 |
344 | return local_patches
345 |
346 | def apply_transform(self, pts, trans):
347 | R = trans[:3, :3]
348 | T = trans[:3, 3]
349 | pts = pts @ R.T + T
350 | return pts
351 |
352 | @property
353 | def velo2cam(self):
354 | try:
355 | velo2cam = self._velo2cam
356 | except AttributeError:
357 | R = np.array([
358 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04,
359 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02
360 | ]).reshape(3, 3)
361 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1)
362 | velo2cam = np.hstack([R, T])
363 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T
364 | return self._velo2cam
365 |
366 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False):
367 | data_path = self.root + '/poses/%02d.txt' % drive
368 | if data_path not in kitti_cache:
369 | kitti_cache[data_path] = np.genfromtxt(data_path)
370 | if return_all:
371 | return kitti_cache[data_path]
372 | else:
373 | return kitti_cache[data_path][indices]
374 |
375 | def odometry_to_positions(self, odometry):
376 | T_w_cam0 = odometry.reshape(3, 4)
377 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1]))
378 | return T_w_cam0
379 |
380 | def _get_velodyne_fn(self, drive, t):
381 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t)
382 | return fname
383 |
384 |
385 | if __name__ == '__main__':
386 | is_rotate_dataset = False
387 | all_trans_matrix = {}
388 | experiment_id = time.strftime('%m%d%H%M') # '11210201'#
389 | model_str = experiment_id
390 | reg_timer = Timer()
391 | success_meter, rte_meter, rre_meter = AverageMeter(), AverageMeter(), AverageMeter()
392 | ch = logging.StreamHandler(sys.stdout)
393 | logging.getLogger().setLevel(logging.INFO)
394 | logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch])
395 |
396 | # dynamically load the model from snapshot
397 | module_file_path = '../model.py'
398 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), module_file_path)
399 | module_name = ''
400 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
401 | module = importlib.util.module_from_spec(module_spec)
402 | module_spec.loader.exec_module(module)
403 |
404 | vicinity = 2.0
405 | model = module.Descriptor_Net(vicinity, 9, 60, 30, 0.3, 30, 'KITTI')
406 | model = nn.DataParallel(model, device_ids=[0])
407 | model.load_state_dict(torch.load('../../pre-trained_models/KITTI_best.pkl'))
408 |
409 | test_data = KITTI(root='../../data/KITTI/dataset',
410 | descpath=f'SpinNet_desc_{model_str}/',
411 | icp_path='../../data/KITTI/icp',
412 | split='test',
413 | model=model,
414 | num_points_per_patch=2048,
415 | use_random_points=True
416 | )
417 |
418 | files = test_data.files[test_data.split]
419 | for idx in range(len(files)):
420 | drive = files[idx][0]
421 | t0, t1 = files[idx][1], files[idx][2]
422 | key = '%d_%d_%d' % (drive, t0, t1)
423 | filename = test_data.icp_path + '/' + key + '.npy'
424 | T_gth = kitti_icp_cache[key]
425 | if is_rotate_dataset:
426 | T_gth = np.matmul(all_trans_matrix[key], T_gth)
427 |
428 | descpath = os.path.join(test_data.descpath, str(drive))
429 | fname0 = test_data._get_velodyne_fn(drive, t0)
430 | fname1 = test_data._get_velodyne_fn(drive, t1)
431 | # XYZ and reflectance
432 | xyz0 = get_keypts(descpath, f"cloud_bin_" + str(t0) + f".keypts")
433 | xyz1 = get_keypts(descpath, f"cloud_bin_" + str(t1) + f".keypts")
434 | pcd0 = make_open3d_point_cloud(xyz0)
435 | pcd1 = make_open3d_point_cloud(xyz1)
436 |
437 | source_desc = get_desc(descpath, f"cloud_bin_" + str(t0) + f".desc.bin")
438 | target_desc = get_desc(descpath, f"cloud_bin_" + str(t1) + f".desc.bin")
439 | feat0 = make_open3d_feature(source_desc, 32, source_desc.shape[0])
440 | feat1 = make_open3d_feature(target_desc, 32, target_desc.shape[0])
441 |
442 | reg_timer.tic()
443 | distance_threshold = 0.3
444 | ransac_result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
445 | pcd0, pcd1, feat0, feat1, distance_threshold,
446 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 4, [
447 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
448 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)
449 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000))
450 | T_ransac = torch.from_numpy(ransac_result.transformation.astype(np.float32))
451 | reg_timer.toc()
452 |
453 | # Translation error
454 | rte = np.linalg.norm(T_ransac[:3, 3] - T_gth[:3, 3])
455 | rre = np.arccos((np.trace(T_ransac[:3, :3].t() @ T_gth[:3, :3]) - 1) / 2)
456 |
457 | if rte < 2:
458 | rte_meter.update(rte)
459 |
460 | if not np.isnan(rre) and rre < np.pi / 180 * 5:
461 | rre_meter.update(rre * 180 / np.pi)
462 |
463 | if rte < 2 and not np.isnan(rre) and rre < np.pi / 180 * 5:
464 | success_meter.update(1)
465 | else:
466 | success_meter.update(0)
467 | logging.info(f"Failed with RTE: {rte}, RRE: {rre}")
468 |
469 | if (idx + 1) % 10 == 0:
470 | logging.info(
471 | f" RRE: {rre_meter.avg}, Success: {success_meter.sum} / {success_meter.count}" +
472 | f" ({success_meter.avg * 100} %)"
473 | )
474 | reg_timer.reset()
475 |
476 | logging.info(
477 | f"RTE: {rte_meter.avg}, var: {rte_meter.var}," +
478 | f" RRE: {rre_meter.avg}, var: {rre_meter.var}, Success: {success_meter.sum} " +
479 | f"/ {success_meter.count} ({success_meter.avg * 100} %)"
480 | )
481 |
--------------------------------------------------------------------------------
/KITTI/Test/test_kitti.txt:
--------------------------------------------------------------------------------
1 | 8
2 | 9
3 | 10
4 |
--------------------------------------------------------------------------------
/KITTI/Train/dataloader.py:
--------------------------------------------------------------------------------
1 | import time
2 | from KITTI.Train.dataset import KITTIDataset
3 | import torch
4 |
5 |
6 | def get_dataloader(root, split, batch_size=1, num_workers=0, shuffle=True, drop_last=True):
7 | dataset = KITTIDataset(
8 | root=root,
9 | split=split,
10 | batch_size=batch_size,
11 | shuffle=shuffle,
12 | drop_last=drop_last
13 | )
14 | dataset.initial()
15 | dataloader = torch.utils.data.DataLoader(
16 | dataset=dataset,
17 | batch_size=batch_size,
18 | num_workers=0,
19 | drop_last=drop_last
20 | )
21 |
22 | return dataloader
23 |
--------------------------------------------------------------------------------
/KITTI/Train/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as Data
2 | import os
3 | import random
4 | import glob
5 | import pickle
6 | import open3d as o3d
7 | import numpy as np
8 |
9 |
10 | class KITTIDataset(Data.Dataset):
11 | def __init__(self, root, split, batch_size, shuffle, drop_last):
12 | """
13 | Create ThreeDMatchDataset to read multiple training files
14 | Args:
15 | root: the path to the dataset file
16 | shuffle: whether the data need to shuffle
17 | """
18 | self.patches_path = os.path.join(root, split)
19 | self.split = split
20 | # Get name of all training pkl files
21 | training_data_files = glob.glob(self.patches_path + '/*.pkl')
22 | ids = [file.split("/")[-1] for file in training_data_files]
23 | ids = sorted(ids, key=lambda x: int(x.split("_")[-1].split(".")[0]))
24 | ids = [file for file in ids if file.split("_")[1] == 'anc&pos']
25 | self.training_data_files = ids
26 | # Get info of training files
27 | self.per_num_patch = int(training_data_files[0].split("/")[-1].split("_")[2])
28 | self.dataset_len = int(ids[-1].split("_")[-1].split(".")[0]) * self.per_num_patch
29 | self.batch_size = batch_size
30 | self.shuffle = shuffle
31 | self.drop_last = drop_last
32 | # Record the loaded i-th training file
33 | self.num_file = 0
34 | # load poses for each type of patches
35 | self.per_patch_points = int(self.training_data_files[-1].split("_")[3])
36 | self.num_framents = int(self.training_data_files[-1].split("_")[4].split(".")[0])
37 | with open(os.path.join(root,
38 | f'{self.split}/{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl'),
39 | 'rb') as file:
40 | self.poses = pickle.load(file)
41 | print(
42 | f"load training poses {os.path.join(root, f'{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl')}")
43 | self.cur_pose_ind = 0
44 |
45 | def initial(self):
46 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file:
47 | self.patches = pickle.load(file)
48 | print(f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}")
49 |
50 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1])
51 | poses = self.poses[self.cur_pose_ind:next_pose_ind]
52 | for i in range(len(self.patches)):
53 | ind = int(np.floor(i / self.per_num_patch))
54 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6).astype(np.float32)
55 | self.patches[i] = np.concatenate([pose, self.patches[i]])
56 | self.cur_pose_ind = next_pose_ind
57 |
58 | self.current_patches_num = len(self.patches)
59 | self.index = list(range(self.current_patches_num))
60 | if self.shuffle:
61 | random.shuffle(self.patches)
62 |
63 | def __len__(self):
64 | return self.dataset_len
65 |
66 | def __getitem__(self, item):
67 | idx = self.index[0]
68 | patches = self.patches[idx]
69 | self.index = self.index[1:]
70 | self.current_patches_num -= 1
71 |
72 | if self.drop_last:
73 | if self.current_patches_num <= (len(self.patches) % self.batch_size): # reach the end of training file
74 | self.num_file = self.num_file + 1
75 | if self.num_file < len(self.training_data_files):
76 | remain_patches = [self.patches[i] for i in self.index] # the remained training patches
77 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file:
78 | self.patches = pickle.load(file)
79 | print(
80 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}")
81 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1])
82 | poses = self.poses[self.cur_pose_ind:next_pose_ind]
83 | for i in range(len(self.patches)):
84 | ind = int(np.floor(i / self.per_num_patch))
85 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6).astype(
86 | np.float32)
87 | self.patches[i] = np.concatenate([pose, self.patches[i]])
88 | self.cur_pose_ind = next_pose_ind
89 | self.patches += remain_patches # add the remained patches to compose a set of new patches
90 | self.current_patches_num = len(self.patches)
91 | self.index = list(range(self.current_patches_num))
92 | if self.shuffle:
93 | random.shuffle(self.patches)
94 | else:
95 | self.num_file = 0
96 | self.cur_pose_ind = 0
97 | self.initial()
98 | else:
99 | if self.current_patches_num <= 0:
100 | self.num_file = self.num_file + 1
101 | if self.num_file < len(self.training_data_files):
102 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file:
103 | self.patches = pickle.load(file)
104 | print(
105 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}")
106 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1])
107 | poses = self.poses[self.cur_pose_ind:next_pose_ind]
108 | for i in range(len(self.patches)):
109 | ind = int(np.floor(i / self.per_num_patch))
110 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6).astype(
111 | np.float32)
112 | self.patches[i] = np.concatenate([pose, self.patches[i]])
113 | self.cur_pose_ind = next_pose_ind
114 | self.current_patches_num = len(self.patches)
115 | self.index = list(range(self.current_patches_num))
116 | if self.shuffle:
117 | random.shuffle(self.patches)
118 | else:
119 | self.num_file = 0
120 | self.cur_pose_ind = 0
121 | self.initial()
122 |
123 | anc_local_patch = patches[2:, :3]
124 | pos_local_patch = patches[2:, 3:]
125 | rotate = patches[:2, :].reshape(12)[:9].reshape(3, 3)
126 | shift = patches[:2, :].reshape(12)[9:]
127 |
128 | # np.random.shuffle(anc_local_patch)
129 | # np.random.shuffle(pos_local_patch)
130 |
131 | return anc_local_patch, pos_local_patch, rotate, shift
132 |
133 |
134 | if __name__ == "__main__":
135 | data_root = "../../data/KITTI_patches/"
136 | batch_size = 48
137 | epoch = 1
138 | train_dataset = KITTIDataset(root=data_root, split='train', batch_size=batch_size, shuffle=True, drop_last=True)
139 | train_dataset.initial()
140 | for _ in range(epoch):
141 | train_iter = Data.DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True)
142 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(train_iter):
143 | B = anc_local_patch.shape[0]
144 |
--------------------------------------------------------------------------------
/KITTI/Train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4 | import time
5 | import shutil
6 | import sys
7 |
8 | sys.path.append('../../')
9 | from KITTI.Train.dataloader import get_dataloader
10 | from KITTI.Train.trainer import Trainer
11 | from network.SpinNet import Descriptor_Net
12 | from torch import optim
13 |
14 |
15 | class Args(object):
16 | def __init__(self):
17 | self.experiment_id = "Proposal" + time.strftime('%m%d%H%M')
18 | snapshot_root = 'snapshot/%s' % self.experiment_id
19 | tensorboard_root = 'tensorboard/%s' % self.experiment_id
20 | os.makedirs(snapshot_root, exist_ok=True)
21 | os.makedirs(tensorboard_root, exist_ok=True)
22 | shutil.copy2(os.path.join('', 'train.py'), os.path.join(snapshot_root, 'train.py'))
23 | shutil.copy2(os.path.join('', 'trainer.py'), os.path.join(snapshot_root, 'trainer.py'))
24 | shutil.copy2(os.path.join('', '../../network/SpinNet.py'), os.path.join(snapshot_root, 'SpinNet.py'))
25 | shutil.copy2(os.path.join('', '../../network/ThreeDCCN.py'), os.path.join(snapshot_root, 'ThreeDCCN.py'))
26 | shutil.copy2(os.path.join('', '../../loss/desc_loss.py'), os.path.join(snapshot_root, 'loss.py'))
27 | self.epoch = 20
28 | self.num_patches = 10
29 | self.num_points_per_patch = 2048 # num of points per patches
30 | self.batch_size = 60
31 | self.rad_n = 9
32 | self.azi_n = 60
33 | self.ele_n = 30
34 | self.des_r = 2.0
35 | self.voxel_r = 0.3
36 | self.voxel_sample = 30
37 |
38 | self.dataset = 'KITTI'
39 | self.data_train_dir = '../../data/KITTI/patches'
40 | self.data_val_dir = '../../data/KITTI/patches'
41 |
42 | self.gpu_mode = True
43 | self.verbose = True
44 | self.freeze_epoch = 5
45 |
46 | # model & optimizer
47 | self.model = Descriptor_Net(self.des_r, self.rad_n, self.azi_n, self.ele_n,
48 | self.voxel_r, self.voxel_sample, self.dataset)
49 | self.pretrain = ''
50 | self.parameter = self.model.get_parameter()
51 | self.optimizer = optim.Adam(self.parameter, lr=0.001, betas=(0.9, 0.999), weight_decay=1e-6)
52 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.5)
53 | self.scheduler_interval = 7
54 |
55 | # dataloader
56 | self.train_loader = get_dataloader(root=self.data_train_dir,
57 | batch_size=self.batch_size,
58 | split='train',
59 | shuffle=True,
60 | num_workers=0,
61 | )
62 | self.val_loader = get_dataloader(root=self.data_val_dir,
63 | batch_size=self.batch_size,
64 | split='val',
65 | shuffle=False,
66 | num_workers=0,
67 | )
68 |
69 | print("Training set size:", self.train_loader.dataset.__len__())
70 | print("Validate set size:", self.val_loader.dataset.__len__())
71 |
72 | # snapshot
73 | self.snapshot_interval = int(self.train_loader.dataset.__len__() / self.batch_size / 2)
74 | self.save_dir = os.path.join(snapshot_root, 'models/')
75 | self.result_dir = os.path.join(snapshot_root, 'results/')
76 | self.tboard_dir = tensorboard_root
77 |
78 | # evaluate
79 | self.evaluate_interval = 1
80 |
81 | self.check_args()
82 |
83 | def check_args(self):
84 | """checking arguments"""
85 | if not os.path.exists(self.save_dir):
86 | os.makedirs(self.save_dir)
87 | if not os.path.exists(self.result_dir):
88 | os.makedirs(self.result_dir)
89 | if not os.path.exists(self.tboard_dir):
90 | os.makedirs(self.tboard_dir)
91 | return self
92 |
93 |
94 | if __name__ == '__main__':
95 | args = Args()
96 | trainer = Trainer(args)
97 | trainer.train()
98 |
--------------------------------------------------------------------------------
/KITTI/Train/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import time, os
4 | import numpy as np
5 | from loss.desc_loss import ContrastiveLoss
6 | from tensorboardX import SummaryWriter
7 |
8 |
9 | class Trainer(object):
10 | def __init__(self, args):
11 | # parameters
12 | self.epoch = args.epoch
13 | self.num_points_per_patch = args.num_points_per_patch
14 | self.batch_size = args.batch_size
15 | self.dataset = args.dataset
16 | self.save_dir = args.save_dir
17 | self.result_dir = args.result_dir
18 | self.gpu_mode = args.gpu_mode
19 | self.verbose = args.verbose
20 | self.freeze_epoch = args.freeze_epoch
21 |
22 | self.rad_n = args.rad_n
23 | self.azi_n = args.azi_n
24 | self.ele_n = args.ele_n
25 | self.des_r = args.des_r
26 | self.voxel_r = args.voxel_r
27 | self.voxel_sample = args.voxel_sample
28 |
29 | self.model = args.model
30 | self.optimizer = args.optimizer
31 | self.scheduler = args.scheduler
32 | self.scheduler_interval = args.scheduler_interval
33 | self.snapshot_interval = args.snapshot_interval
34 | self.evaluate_interval = args.evaluate_interval
35 | self.writer = SummaryWriter(log_dir=args.tboard_dir)
36 |
37 | self.train_loader = args.train_loader
38 | self.val_loader = args.val_loader
39 |
40 | self.desc_loss = ContrastiveLoss()
41 |
42 | if self.gpu_mode:
43 | self.model = self.model.cuda()
44 | self.model = torch.nn.DataParallel(self.model, device_ids=[0])
45 |
46 | if args.pretrain != '':
47 | self._load_pretrain(args.pretrain)
48 |
49 | def train(self):
50 | self.train_hist = {
51 | 'loss': [],
52 | 'per_epoch_time': [],
53 | 'total_time': []
54 | }
55 | best_loss = 1000000000
56 | print('training start!!')
57 | start_time = time.time()
58 |
59 | self.model.train()
60 | freeze_sign = 1
61 | for epoch in range(self.epoch):
62 |
63 | self.train_epoch(epoch)
64 |
65 | if epoch % self.evaluate_interval == 0 or epoch == 0:
66 | res = self.evaluate(epoch + 1)
67 | print(f'Evaluation: Epoch {epoch}: Loss {res["loss"]}')
68 |
69 | if res['loss'] < best_loss:
70 | best_loss = res['loss']
71 | self._snapshot('best')
72 | if self.writer:
73 | self.writer.add_scalar('Loss', res['loss'], epoch)
74 |
75 | if epoch % self.scheduler_interval == 0:
76 | old_lr = self.optimizer.param_groups[0]['lr']
77 | self.scheduler.step()
78 | new_lr = self.optimizer.param_groups[0]['lr']
79 | print('update detector learning rate: %f -> %f' % (old_lr, new_lr))
80 |
81 | if self.writer:
82 | self.writer.add_scalar('Learning Rate', self._get_lr(), epoch)
83 | self.writer.add_scalar('Train Loss', self.train_hist['loss'][-1], epoch)
84 |
85 | # finish all epoch
86 | self.train_hist['total_time'].append(time.time() - start_time)
87 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
88 | self.epoch, self.train_hist['total_time'][0]))
89 | print("Training finish!... save training results")
90 |
91 | def train_epoch(self, epoch):
92 | epoch_start_time = time.time()
93 | loss_buf = []
94 | num_batch = int(len(self.train_loader.dataset) / self.batch_size)
95 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.train_loader):
96 |
97 | B = anc_local_patch.shape[0]
98 | anc_local_patch = anc_local_patch.float()
99 | pos_local_patch = pos_local_patch.float()
100 | rotate = rotate.float()
101 | shift = shift.float()
102 |
103 | if self.gpu_mode:
104 | anc_local_patch = anc_local_patch.cuda()
105 | pos_local_patch = pos_local_patch.cuda()
106 |
107 | # forward
108 | self.optimizer.zero_grad()
109 | a_desc = self.model(anc_local_patch)
110 | p_desc = self.model(pos_local_patch)
111 | anc_desc = F.normalize(a_desc.view(B, -1), p=2, dim=1)
112 | pos_desc = F.normalize(p_desc.view(B, -1), p=2, dim=1)
113 |
114 | # calculate the contrastive loss
115 | des_loss, accuracy = self.desc_loss(anc_desc, pos_desc)
116 | loss = des_loss
117 |
118 | # backward
119 | loss.backward()
120 | self.optimizer.step()
121 | loss_buf.append(float(loss))
122 |
123 | if iter % self.snapshot_interval == 0:
124 | self._snapshot(f'{epoch}_{iter + 1}')
125 |
126 | if iter % 200 == 0 and self.verbose:
127 | iter_time = time.time() - epoch_start_time
128 | print(f"Epoch: {epoch} [{iter:4d}/{num_batch}] loss: {loss:.2f} time: {iter_time:.2f}s")
129 | print(f"Epoch: {epoch} [{iter:4d}/{num_batch}] des loss: {des_loss:.2f} time: {iter_time:.2f}s")
130 | print(f"Accuracy: {accuracy.item():.4f}\n")
131 | del loss
132 | del anc_local_patch
133 | del pos_local_patch
134 | # finish one epoch
135 | epoch_time = time.time() - epoch_start_time
136 | self.train_hist['per_epoch_time'].append(epoch_time)
137 | self.train_hist['loss'].append(np.mean(loss_buf))
138 | print(f'Epoch {epoch}: Loss {np.mean(loss_buf)}, time {epoch_time:.4f}s')
139 |
140 | del loss_buf
141 |
142 | def evaluate(self):
143 | self.model.eval()
144 | loss_buf = []
145 | with torch.no_grad():
146 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.val_loader):
147 |
148 | B = anc_local_patch.shape[0]
149 | anc_local_patch = anc_local_patch.float()
150 | pos_local_patch = pos_local_patch.float()
151 | rotate = rotate.float()
152 | shift = shift.float()
153 |
154 | if self.gpu_mode:
155 | anc_local_patch = anc_local_patch.cuda()
156 | pos_local_patch = pos_local_patch.cuda()
157 |
158 | # forward
159 | a_des = self.model(anc_local_patch)
160 | p_des = self.model(pos_local_patch)
161 | anc_des = F.normalize(a_des.view(B, -1), p=2, dim=1)
162 | pos_des = F.normalize(p_des.view(B, -1), p=2, dim=1)
163 |
164 | # calculate the contrastive loss
165 | des_loss, accuracy = self.desc_loss(anc_des, pos_des)
166 | loss = des_loss
167 | loss_buf.append(float(loss))
168 |
169 | del loss
170 | del anc_local_patch
171 | del pos_local_patch
172 |
173 | self.model.train()
174 |
175 | res = {
176 | 'loss': np.mean(loss_buf)
177 | }
178 | del loss_buf
179 | return res
180 |
181 | def _snapshot(self, epoch):
182 | save_dir = os.path.join(self.save_dir, self.dataset)
183 | torch.save(self.model.state_dict(), save_dir + "_" + str(epoch) + '.pkl')
184 | print(f"Save model to {save_dir}_{str(epoch)}.pkl")
185 |
186 | def _load_pretrain(self, pretrain):
187 | state_dict = torch.load(pretrain, map_location='cpu')
188 | self.model.load_state_dict(state_dict)
189 | print(f"Load model from {pretrain}.pkl")
190 |
191 | def _get_lr(self, group=0):
192 | return self.optimizer.param_groups[group]['lr']
193 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Qingyong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/point-cloud-registration-on-3dmatch-benchmark?p=spinnet-learning-a-general-surface-descriptor)
2 | [](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode)
3 | [](https://arxiv.org/abs/2011.12149)
4 | # SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration (CVPR 2021)
5 |
6 | This is the official repository of **SpinNet**, a conceptually simple neural architecture to extract local
7 | features which are rotationally invariant whilst sufficiently informative to enable accurate registration. For technical details, please refer to:
8 |
9 | **[SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration](https://arxiv.org/abs/2011.12149)**
10 | [Sheng Ao*](http://scholar.google.com/citations?user=cvS1yuMAAAAJ&hl=zh-CN), [Qingyong Hu*](https://www.cs.ox.ac.uk/people/qingyong.hu/), [Bo Yang](https://yang7879.github.io/), [Andrew Markham](https://www.cs.ox.ac.uk/people/andrew.markham/), [Yulan Guo](http://yulanguo.me/).
11 | (* *indicates equal contribution*)
12 |
13 | **[[Paper](https://arxiv.org/abs/2011.12149)] [Video] [Project page]**
14 |
15 |
16 | ### (1) Overview
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ### (2) Setup
25 | This code has been tested with Python 3.6, Pytorch 1.6.0, CUDA 10.2 on Ubuntu 18.04.
26 |
27 | - Clone the repository
28 | ```
29 | git clone https://github.com/QingyongHu/SpinNet && cd SpinNet
30 | ```
31 | - Setup conda virtual environment
32 | ```
33 | conda create -n spinnet python=3.6
34 | source activate spinnet
35 | conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.2 -c pytorch
36 | conda install -c open3d-admin open3d==0.11.1
37 | pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
38 | ```
39 |
40 | ### (3) 3DMatch
41 | Download the processed dataset from [Google Drive](https://drive.google.com/file/d/1PrkSE0nY79gOF_VJcKv2VpxQ8s7DOITg/view?usp=sharing), [Baidu Yun](https://pan.baidu.com/s/1FB7IUbKAAlk7RVnB_AgwcQ) (Verification code:d1vn) and put the folder into `data`.
42 | Then the structure should be as follows:
43 | ```
44 | --data--3DMatch--fragments
45 | |--intermediate-files-real
46 | |--patches
47 |
48 | ```
49 |
50 | **Training**
51 |
52 | Training SpinNet on the 3DMatch dataset:
53 | ```
54 | cd ./ThreeDMatch/Train
55 | python train.py
56 | ```
57 | **Testing**
58 |
59 | Evaluate the performance of the trained models on the 3DMatch dataset:
60 |
61 | ```
62 | cd ./ThreeDMatch/Test
63 | python preparation.py
64 | ```
65 | The learned descriptors for each point will be saved in `ThreeDMatch/Test/SpinNet_{timestr}/` folder.
66 | Then the `Feature Matching Recall(FMR)` and `Inlier Ratio(IR)` can be calculated by running:
67 | ```
68 | python evaluate.py [timestr]
69 | ```
70 | The ground truth poses have been put in the `ThreeDMatch/Test/gt_result` folder.
71 | The `Registration Recall` can be calculated by running the `evaluate.m` in `ThreeDMatch/Test/3dmatch` which are provided by [3DMatch](https://github.com/andyzeng/3dmatch-toolbox/tree/master/evaluation/geometric-registration).
72 | Note that, you need to modify the `descriptorName` to `SpinNet_{timestr}` in the `ThreeDMatch/Test/3dmatch/evaluate.m` file.
73 |
74 |
75 | ### (4) KITTI
76 | Download the processed dataset from [Google Drive](https://drive.google.com/file/d/1fuJiQwAay23BUKtxBG3__MwStyMuvrMQ/view?usp=sharing), [Baidu Yun](https://pan.baidu.com/s/1FB7IUbKAAlk7RVnB_AgwcQ) (Verification code:d1vn), and put the folder into `data`.
77 | Then the structure is as follows:
78 | ```
79 | --data--KITTI--dataset
80 | |--icp
81 | |--patches
82 |
83 | ```
84 |
85 | **Training**
86 |
87 | Training SpinNet on the KITTI dataset:
88 |
89 | ```
90 | cd ./KITTI/Train/
91 | python train.py
92 | ```
93 |
94 | **Testing**
95 |
96 | Evaluate the performance of the trained models on the KITTI dataset:
97 |
98 | ```
99 | cd ./KITTI/Test/
100 | python test_kitti.py
101 | ```
102 |
103 |
104 | ### (5) ETH
105 |
106 | The test set can be downloaded from [here](https://share.phys.ethz.ch/~gsg/3DSmoothNet/data/ETH.rar), and put the folder into `data`, then the structure is as follows:
107 | ```
108 | --data--ETH--gazebo_summer
109 | |--gazebo_winter
110 | |--wood_autmn
111 | |--wood_summer
112 | ```
113 |
114 | ### (6) Generalization across Unseen Datasets
115 |
116 | **3DMatch to ETH**
117 |
118 | Generalization from 3DMatch dataset to ETH dataset:
119 | ```
120 | cd ./generalization/ThreeDMatch-to-ETH
121 | python preparation.py
122 | ```
123 | The descriptors for each point will be generated and saved in the `generalization/ThreeDMatch-to-ETH/SpinNet_{timestr}/` folder.
124 | Then the `Feature Matching Recall` and `inlier ratio` can be caluclated by running
125 | ```
126 | python evaluate.py [timestr]
127 | ```
128 |
129 | **3DMatch to KITTI**
130 |
131 | Generalization from 3DMatch dataset to KITTI dataset:
132 |
133 | ```
134 | cd ./generalization/ThreeDMatch-to-KITTI
135 | python test.py
136 | ```
137 |
138 | **KITTI to 3DMatch**
139 |
140 | Generalization from KITTI dataset to 3DMatch dataset:
141 | ```
142 | cd ./generalization/KITTI-to-ThreeDMatch
143 | python preparation.py
144 | ```
145 | The descriptors for each point will be generated and saved in `generalization/KITTI-to-3DMatch/SpinNet_{timestr}/` folder.
146 | Then the `Feature Matching Recall` and `inlier ratio` can be caluclated by running
147 | ```
148 | python evaluate.py [timestr]
149 | ```
150 |
151 | ## Acknowledgement
152 |
153 | In this project, we use (parts of) the implementations of the following works:
154 |
155 | * [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch)
156 | * [PPF-FoldNet](https://github.com/XuyangBai/PPF-FoldNet)
157 | * [Spherical CNNs](https://github.com/jonas-koehler/s2cnn)
158 | * [FCGF](https://github.com/chrischoy/FCGF)
159 | * [r2d2](https://github.com/naver/r2d2)
160 | * [D3Feat](https://github.com/XuyangBai/D3Feat)
161 | * [D3Feat.pytorch](https://github.com/XuyangBai/D3Feat.pytorch)
162 |
163 |
164 | ### Citation
165 | If you find our work useful in your research, please consider citing:
166 |
167 | @inproceedings{ao2020SpinNet,
168 | title={SpinNet: Learning a General Surface Descriptor for 3D Point Cloud Registration},
169 | author={Ao, Sheng and Hu, Qingyong and Yang, Bo and Markham, Andrew and Guo, Yulan},
170 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
171 | year={2021}
172 | }
173 |
174 | ### References
175 |
176 |
177 | [1] 3DMatch: Learning Local Geometric Descriptors from RGB-D Reconstructions, Andy Zeng, Shuran Song, Matthias Nießner, Matthew Fisher, Jianxiong Xiao, and Thomas Funkhouser, CVPR 2017.
178 |
179 |
180 |
181 | ### Updates
182 | * 03/04/2021: The code is released!
183 | * 01/03/2021: This paper has been accepted by CVPR 2021!
184 | * 25/11/2020: Initial release!
185 |
186 | ## Related Repos
187 | 1. [RandLA-Net: Efficient Semantic Segmentation of Large-Scale Point Clouds](https://github.com/QingyongHu/RandLA-Net) 
188 | 2. [SoTA-Point-Cloud: Deep Learning for 3D Point Clouds: A Survey](https://github.com/QingyongHu/SoTA-Point-Cloud) 
189 | 3. [3D-BoNet: Learning Object Bounding Boxes for 3D Instance Segmentation on Point Clouds](https://github.com/Yang7879/3D-BoNet) 
190 | 4. [SensatUrban: Learning Semantics from Urban-Scale Photogrammetric Point Clouds](https://github.com/QingyongHu/SensatUrban) 
191 | 5. [SQN: Weakly-Supervised Semantic Segmentation of Large-Scale 3D Point Clouds with 1000x Fewer Labels](https://github.com/QingyongHu/SQN) 
192 |
193 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/evaluate.m:
--------------------------------------------------------------------------------
1 | % Script to evaluate .log files for the geometric registration benchmarks,
2 | % in the same spirit as Choi et al 2015. Please see:
3 | %
4 | % http://redwood-data.org/indoor/regbasic.html
5 | % https://github.com/qianyizh/ElasticReconstruction/tree/master/Matlab_Toolbox
6 |
7 |
8 | descriptorName = 'SpinNet_10051828'; %
9 |
10 | % Locations of evaluation files
11 | dataPath = '../log_result';
12 |
13 | % Real data benchmark
14 | sceneList = {
15 | '7-scenes-redkitchen-evaluation', ...
16 | 'sun3d-home_at-home_at_scan1_2013_jan_1-evaluation', ...
17 | 'sun3d-home_md-home_md_scan9_2012_sep_30-evaluation', ...
18 | 'sun3d-hotel_uc-scan3-evaluation', ...
19 | 'sun3d-hotel_umd-maryland_hotel1-evaluation', ...
20 | 'sun3d-hotel_umd-maryland_hotel3-evaluation', ...
21 | 'sun3d-mit_76_studyroom-76-1studyroom2-evaluation', ...
22 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika-evaluation'
23 | };
24 |
25 | % Load Elastic Reconstruction toolbox
26 | addpath(genpath('external'));
27 |
28 | % Compute precision and recall
29 | totalRecall = []; totalPrecision = [];
30 | totalGt = 0;
31 | totalTP = 0;
32 | for sceneIdx = 1:length(sceneList)
33 | scenePath = fullfile(dataPath,sceneList{sceneIdx});
34 | gtPath = fullfile('../gt_result',sceneList{sceneIdx});
35 |
36 | % Compute registration error
37 | gt = mrLoadLog(fullfile(gtPath,'gt.log'));
38 | gt_info = mrLoadInfo(fullfile(gtPath,'gt.info'));
39 | result = mrLoadLog(fullfile(scenePath,sprintf('%s.log',descriptorName)));
40 | [recall,precision,gt_num] = mrEvaluateRegistration(result,gt,gt_info);
41 | totalRecall = [totalRecall;recall];
42 | totalPrecision = [totalPrecision;precision];
43 | totalGt = totalGt + gt_num;
44 | totalTP = totalTP + round(gt_num * recall);
45 | end
46 | totalRecall
47 | fprintf('Mean registration recall: %f precision: %f\n',mean(totalRecall),mean(totalPrecision));
48 | fprintf('True average recall: %f (%d/%d)\n',totalTP/totalGt,totalTP, totalGt);
49 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrDrawTrajectory.m:
--------------------------------------------------------------------------------
1 | function mrDrawTraj( traj, c, init_trans )
2 | if ~exist( 'c', 'var' )
3 | c = 'b-';
4 | end
5 |
6 | if ~exist( 'init_trans', 'var' )
7 | init_trans = traj( 1 ).trans;
8 | end
9 |
10 | n = size( traj, 2 );
11 | x = zeros( 2, n );
12 | init_inverse = init_trans ^ -1;
13 |
14 | for k = 1 : n
15 | m = init_inverse * traj( k ).trans;
16 | x( :, k ) = [ m( 1, 4 ); m( 3, 4 ) ];
17 | end
18 |
19 | plot( -x( 1, : ), -x( 2, : ), c, 'LineWidth',2 );
20 | axis equal;
21 | end
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrEvaluateRegistration.m:
--------------------------------------------------------------------------------
1 | function [ recall, precision, gt_num ] = mrEvaluateRegistration( result, gt, gt_info, err2 )
2 | if ~exist( 'err2', 'var' )
3 | err2 = 0.04;
4 | end
5 | num = gt( 1 ).info( 3 );
6 |
7 | mask = zeros( 1, num * num );
8 | gt_num = 0;
9 | for i = 1 : size( gt, 2 )
10 | if ( gt( i ).info( 2 ) - gt( i ).info( 1 ) > 1 )
11 | mask( gt( i ).info( 1 ) + gt( i ).info( 2 ) * num + 1 ) = i;
12 | gt_num = gt_num + 1;
13 | end
14 | end
15 |
16 | rs_num = 0;
17 | good = 0;
18 | bad = 0;
19 | false_pos = 0;
20 | error_dis = [];
21 | for i = 1 : size( result, 2 )
22 | if ( result( i ).info( 2 ) - result( i ).info( 1 ) > 1 )
23 | rs_num = rs_num + 1;
24 | idx = mask( result( i ).info( 1 ) + result( i ).info( 2 ) * num + 1 );
25 | if idx == 0
26 | false_pos = false_pos + 1;
27 | else
28 | p = mrComputeTransformationError( gt( idx ).trans ^ -1 * result( i ).trans, gt_info( idx ).mat );
29 | error_dis = [ error_dis, p ];
30 | if ( p <= err2 )
31 | good = good + 1;
32 | else
33 | bad = bad + 1;
34 | end
35 | end
36 | end
37 | end
38 | recall = good / gt_num;
39 | precision = good / rs_num;
40 | disp( [ 'recall : ' num2str( recall ) ' ( ' num2str( good ) ' / ' num2str( gt_num ) ' )' ] );
41 | %disp( [ 'precision : ' num2str( precision ) ' ( ' num2str( good ) ' / ' num2str( rs_num ) ' )' ] );
42 | end
43 |
44 | function [ p ] = mrComputeTransformationError( trans, info )
45 | te = trans( 1 : 3, 4 );
46 | qt = dcm2quat( trans( 1 : 3, 1 : 3 ) );
47 | er = [ te; - qt( 2 : 4 )' ];
48 | p = er' * info * er / info( 1, 1 );
49 | end
50 |
51 | function [qout] = dcm2quat(DCM)
52 | % this is consistent with the matlab function in
53 | % the Aerospace Toolbox
54 | qout = zeros(1,4);
55 | qout(1) = 0.5 * sqrt(1 + DCM(1,1) + DCM(2,2) + DCM(3,3));
56 | qout(2) = - (DCM(3,2) - DCM(2,3)) / ( 4 * qout(1) );
57 | qout(3) = - (DCM(1,3) - DCM(3,1)) / ( 4 * qout(1) );
58 | qout(4) = - (DCM(2,1) - DCM(1,2)) / ( 4 * qout(1) );
59 | end
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrEvaluateTrajectory.m:
--------------------------------------------------------------------------------
1 | function [ rmse, trans ] = mrEvaluateTraj( traj_et, traj_gt )
2 | gt_n = size( traj_gt, 2 );
3 | et_n = size( traj_et, 2 );
4 | if (gt_n ~= et_n)
5 | fprintf('WARNING: There are Lost Frames!\n');
6 | fprintf('ground truth traj : %d frames\n', gt_n);
7 | fprintf('estimated traj : %d frames\n', et_n);
8 | gt_n = min( [ gt_n, et_n ] );
9 | et_n = gt_n;
10 | end
11 | n = et_n;
12 |
13 | trans = mrAlignTraj( traj_et, traj_gt );
14 | err = zeros( 1, n );
15 |
16 | for i = 1 : n
17 | assert( traj_et( i ).info( 3 ) == traj_gt( i ).info( 3 ),...
18 | 'bad trajectory file format or asynchronized frame.' );
19 | trans_et = trans * traj_et( i ).trans;
20 | trans_gt = traj_gt( i ).trans;
21 | err( i ) = norm( trans_gt( 1 : 3, 4 ) - trans_et( 1 : 3, 4 ) );
22 | end
23 |
24 | rmse = sqrt( err * err' / size( err, 2 ) );
25 | fprintf( 'median absolute translational error %f m\n', median( err ) );
26 | fprintf( 'rmse %f m\n', rmse );
27 | end
28 |
29 | function [ trans ] = mrAlignTraj( traj_et, traj_gt )
30 | n = size( traj_et, 2 );
31 | gt_trans = zeros( 3, n );
32 | et_trans = zeros( 3, n );
33 |
34 | for i = 1 : n
35 | gt_trans( :, i ) = traj_gt( i ).trans( 1 : 3, 4 );
36 | et_trans( :, i ) = traj_et( i ).trans( 1 : 3, 4 );
37 | end
38 |
39 | gt_mean = mean( gt_trans, 2 );
40 | et_mean = mean( et_trans, 2 );
41 | gt_centered = gt_trans - repmat( gt_mean, 1, n );
42 | et_centered = et_trans - repmat( et_mean, 1, n );
43 |
44 | W = zeros( 3, 3 );
45 | for i = 1 : n
46 | W = W + et_centered( :, i ) * gt_centered( :, i )';
47 | end
48 |
49 | [ U, ~, V ] = svd( W' );
50 | Vh = V';
51 | S = eye( 3 );
52 | if ( det( U ) * det( Vh ) < 0 )
53 | S( 3, 3 ) = -1;
54 | end
55 |
56 | r = U * S * Vh;
57 | t = gt_mean - r * et_mean;
58 |
59 | trans = [ r, t; 0, 0, 0, 1 ];
60 | end
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrLoadInfo.m:
--------------------------------------------------------------------------------
1 | function [ info ] = mrLoadInfo( filename )
2 | fid = fopen( filename );
3 | k = 1;
4 | x = fscanf( fid, '%d', [ 1, 3 ] );
5 | while ( size( x, 2 ) == 3 )
6 | m = fscanf( fid, '%f', [ 6, 6 ] );
7 | info( k ) = struct( 'info', x, 'mat', m' );
8 | k = k + 1;
9 | x = fscanf( fid, '%d', [ 1, 3 ] );
10 | end
11 | fclose( fid );
12 | %disp( [ num2str( size( info, 2 ) ), ' matrices have been read.' ] );
13 | end
14 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrLoadLog.m:
--------------------------------------------------------------------------------
1 | function [ traj ] = mrLoadLog( filename )
2 | fid = fopen( filename );
3 | k = 1;
4 | x = fscanf( fid, '%d', [1 3] );
5 | while ( size( x, 2 ) == 3 )
6 | m = fscanf( fid, '%f', [4 4] );
7 | traj( k ) = struct( 'info', x, 'trans', m' );
8 | k = k + 1;
9 | x = fscanf( fid, '%d', [1 3] );
10 | end
11 | fclose( fid );
12 | %disp( [ num2str( size( traj, 2 ) ), ' frames have been read.' ] );
13 | end
14 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrMatchDepthColor.m:
--------------------------------------------------------------------------------
1 | function mrMatchDepthColor( basepath, unique, depthdir, imagedir, matchfile )
2 | if ~exist( 'matchfile', 'var' )
3 | matchfile = 'match';
4 | end
5 | if ~exist( 'imagedir', 'var' )
6 | imagedir = 'rgb';
7 | end
8 | if ~exist( 'depthdir', 'var' )
9 | depthdir = 'depth';
10 | end
11 | if ~exist( 'unique', 'var' )
12 | unique = 1;
13 | end
14 |
15 | depth_file_list = dir( [ basepath, depthdir, '/*.png' ] );
16 | if ( size( depth_file_list, 1 ) <= 1 )
17 | disp( 'Error: path not found' );
18 | return;
19 | end
20 | disp( [ num2str( size( depth_file_list, 1 ) ) ' depth images detected.' ] );
21 | depth_timestamp = parseTimestamp( depth_file_list );
22 |
23 | color_file_list = dir( [ basepath, imagedir, '/*.jpg' ] );
24 | if ( size( color_file_list, 1 ) <= 1 )
25 | disp( 'Error: path not found' );
26 | return;
27 | end
28 | disp( [ num2str( size( color_file_list, 1 ) ) ' color images detected.' ] );
29 | color_timestamp = parseTimestamp( color_file_list );
30 | color_timestamp_mat = cell2mat( color_timestamp( :, 1 ) );
31 |
32 | fid = fopen( [ basepath, matchfile ], 'w' );
33 | used_color = zeros( size( color_timestamp, 1 ), 1 );
34 | k = 0;
35 | for i = 1 : size( depth_timestamp, 1 )
36 | idx = findClosestColor( depth_timestamp{ i, 1 }, color_timestamp_mat );
37 | if ( unique == 0 || used_color( idx ) == 0 )
38 | used_color( idx ) = 1;
39 | fprintf( fid, '%s/%s %s/%s\n', depthdir, depth_timestamp{ i, 2 }, imagedir, color_timestamp{ idx, 2 } );
40 | k = k + 1;
41 | end
42 | end
43 | fclose( fid );
44 | disp( [ num2str( k ) ' pairs have been written.' ] );
45 | end
46 |
47 | function [ i ] = findClosestColor( depth_ts, color_ts_mat )
48 | [ ~, i ] = min( abs( color_ts_mat - depth_ts ) );
49 | end
50 |
51 | function [ timestamp ] = parseTimestamp( filelist )
52 | num = size( filelist, 1 );
53 | timestamp = cell( num, 2 );
54 | for i = 1 : num
55 | x = sscanf( filelist( i ).name, '%f-%f.' )';
56 | timestamp{ i, 1 } = x( 2 );
57 | timestamp{ i, 2 } = filelist( i ).name;
58 | end
59 | sortrows( timestamp );
60 | end
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrWriteInfo.m:
--------------------------------------------------------------------------------
1 | function mrWriteInfo( info, filename )
2 | fid = fopen( filename, 'w' );
3 | for i = 1 : size( info, 2 )
4 | mrWriteInfoStruct( fid, info( i ).info, info( i ).mat );
5 | end
6 | fclose( fid );
7 | %disp( [ num2str( size( info, 2 ) ), ' matrices have been written.' ] );
8 | end
9 |
10 | function mrWriteInfoStruct( fid, x, m )
11 | fprintf( fid, '%d\t%d\t%d\n', x(1), x(2), x(3) );
12 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ...
13 | m(1,1), m(1,2), m(1,3), m(1,4), m(1,5), m(1,6) );
14 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ...
15 | m(2,1), m(2,2), m(2,3), m(2,4), m(2,5), m(2,6) );
16 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ...
17 | m(3,1), m(3,2), m(3,3), m(3,4), m(3,5), m(3,6) );
18 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ...
19 | m(4,1), m(4,2), m(4,3), m(4,4), m(4,5), m(4,6) );
20 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ...
21 | m(5,1), m(5,2), m(5,3), m(5,4), m(5,5), m(5,6) );
22 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\t%.10f\t%.10f\n', ...
23 | m(6,1), m(6,2), m(6,3), m(6,4), m(6,5), m(6,6) );
24 | end
25 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/ElasticReconstruction/mrWriteLog.m:
--------------------------------------------------------------------------------
1 | function mrWriteLog( traj, filename )
2 | fid = fopen( filename, 'w' );
3 | for i = 1 : size( traj, 2 )
4 | mrWriteLogStruct( fid, traj( i ).info, traj( i ).trans );
5 | end
6 | fclose( fid );
7 | %disp( [ num2str( size( traj, 2 ) ), ' frames have been written.' ] );
8 | end
9 |
10 | function mrWriteLogStruct( fid, x, m )
11 | fprintf( fid, '%d\t%d\t%d\n', x(1), x(2), x(3) );
12 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(1,1), m(1,2), m(1,3), m(1,4) );
13 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(2,1), m(2,2), m(2,3), m(2,4) );
14 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(3,1), m(3,2), m(3,3), m(3,4) );
15 | fprintf( fid, '%.10f\t%.10f\t%.10f\t%.10f\n', m(4,1), m(4,2), m(4,3), m(4,4) );
16 | end
17 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/npy-matlab/constructNPYheader.m:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | function header = constructNPYheader(dataType, shape, varargin)
5 |
6 | if ~isempty(varargin)
7 | fortranOrder = varargin{1}; % must be true/false
8 | littleEndian = varargin{2}; % must be true/false
9 | else
10 | fortranOrder = true;
11 | littleEndian = true;
12 | end
13 |
14 | dtypesMatlab = {'uint8','uint16','uint32','uint64','int8','int16','int32','int64','single','double', 'logical'};
15 | dtypesNPY = {'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'f4', 'f8', 'b1'};
16 |
17 | magicString = uint8([147 78 85 77 80 89]); %x93NUMPY
18 |
19 | majorVersion = uint8(1);
20 | minorVersion = uint8(0);
21 |
22 | % build the dict specifying data type, array order, endianness, and
23 | % shape
24 | dictString = '{''descr'': ''';
25 |
26 | if littleEndian
27 | dictString = [dictString '<'];
28 | else
29 | dictString = [dictString '>'];
30 | end
31 |
32 | dictString = [dictString dtypesNPY{strcmp(dtypesMatlab,dataType)} ''', '];
33 |
34 | dictString = [dictString '''fortran_order'': '];
35 |
36 | if fortranOrder
37 | dictString = [dictString 'True, '];
38 | else
39 | dictString = [dictString 'False, '];
40 | end
41 |
42 | dictString = [dictString '''shape'': ('];
43 |
44 | % if length(shape)==1 && shape==1
45 | %
46 | % else
47 | % for s = 1:length(shape)
48 | % if s==length(shape) && shape(s)==1
49 | %
50 | % else
51 | % dictString = [dictString num2str(shape(s))];
52 | % if length(shape)>1 && s+1==length(shape) && shape(s+1)==1
53 | % dictString = [dictString ','];
54 | % elseif length(shape)>1 && s %s', tempFilename, inFilename, outFilename));
38 |
39 | otherwise
40 | fprintf(1, 'I don''t know how to concatenate files for your OS, but you can finish making the NPY youself by concatenating %s with %s.\n', tempFilename, inFilename);
41 | end
42 |
43 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/npy-matlab/readNPY.m:
--------------------------------------------------------------------------------
1 |
2 |
3 | function data = readNPY(filename)
4 | % Function to read NPY files into matlab.
5 | % *** Only reads a subset of all possible NPY files, specifically N-D arrays of certain data types.
6 | % See https://github.com/kwikteam/npy-matlab/blob/master/tests/npy.ipynb for
7 | % more.
8 | %
9 |
10 | [shape, dataType, fortranOrder, littleEndian, totalHeaderLength, ~] = readNPYheader(filename);
11 |
12 | if littleEndian
13 | fid = fopen(filename, 'r', 'l');
14 | else
15 | fid = fopen(filename, 'r', 'b');
16 | end
17 |
18 | try
19 |
20 | [~] = fread(fid, totalHeaderLength, 'uint8');
21 |
22 | % read the data
23 | data = fread(fid, prod(shape), [dataType '=>' dataType]);
24 |
25 | if length(shape)>1 && ~fortranOrder
26 | data = reshape(data, shape(end:-1:1));
27 | data = permute(data, [length(shape):-1:1]);
28 | elseif length(shape)>1
29 | data = reshape(data, shape);
30 | end
31 |
32 | fclose(fid);
33 |
34 | catch me
35 | fclose(fid);
36 | rethrow(me);
37 | end
38 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/npy-matlab/readNPYheader.m:
--------------------------------------------------------------------------------
1 |
2 |
3 | function [arrayShape, dataType, fortranOrder, littleEndian, totalHeaderLength, npyVersion] = readNPYheader(filename)
4 | % function [arrayShape, dataType, fortranOrder, littleEndian, ...
5 | % totalHeaderLength, npyVersion] = readNPYheader(filename)
6 | %
7 | % parse the header of a .npy file and return all the info contained
8 | % therein.
9 | %
10 | % Based on spec at http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html
11 |
12 | fid = fopen(filename);
13 |
14 | % verify that the file exists
15 | if (fid == -1)
16 | if ~isempty(dir(filename))
17 | error('Permission denied: %s', filename);
18 | else
19 | error('File not found: %s', filename);
20 | end
21 | end
22 |
23 | try
24 |
25 | dtypesMatlab = {'uint8','uint16','uint32','uint64','int8','int16','int32','int64','single','double', 'logical'};
26 | dtypesNPY = {'u1', 'u2', 'u4', 'u8', 'i1', 'i2', 'i4', 'i8', 'f4', 'f8', 'b1'};
27 |
28 |
29 | magicString = fread(fid, [1 6], 'uint8=>uint8');
30 |
31 | if ~all(magicString == [147,78,85,77,80,89])
32 | error('readNPY:NotNUMPYFile', 'Error: This file does not appear to be NUMPY format based on the header.');
33 | end
34 |
35 | majorVersion = fread(fid, [1 1], 'uint8=>uint8');
36 | minorVersion = fread(fid, [1 1], 'uint8=>uint8');
37 |
38 | npyVersion = [majorVersion minorVersion];
39 |
40 | headerLength = fread(fid, [1 1], 'uint16=>uint16');
41 |
42 | totalHeaderLength = 10+headerLength;
43 |
44 | arrayFormat = fread(fid, [1 headerLength], 'char=>char');
45 |
46 | % to interpret the array format info, we make some fairly strict
47 | % assumptions about its format...
48 |
49 | r = regexp(arrayFormat, '''descr''\s*:\s*''(.*?)''', 'tokens');
50 | dtNPY = r{1}{1};
51 |
52 | littleEndian = ~strcmp(dtNPY(1), '>');
53 |
54 | dataType = dtypesMatlab{strcmp(dtNPY(2:3), dtypesNPY)};
55 |
56 | r = regexp(arrayFormat, '''fortran_order''\s*:\s*(\w+)', 'tokens');
57 | fortranOrder = strcmp(r{1}{1}, 'True');
58 |
59 | r = regexp(arrayFormat, '''shape''\s*:\s*\((.*?)\)', 'tokens');
60 | shapeStr = r{1}{1};
61 | arrayShape = str2num(shapeStr(shapeStr~='L'));
62 |
63 |
64 | fclose(fid);
65 |
66 | catch me
67 | fclose(fid);
68 | rethrow(me);
69 | end
70 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/3dmatch/external/npy-matlab/writeNPY.m:
--------------------------------------------------------------------------------
1 |
2 |
3 | function writeNPY(var, filename)
4 | % function writeNPY(var, filename)
5 | %
6 | % Only writes little endian, fortran (column-major) ordering; only writes
7 | % with NPY version number 1.0.
8 | %
9 | % Always outputs a shape according to matlab's convention, e.g. (10, 1)
10 | % rather than (10,).
11 |
12 |
13 | shape = size(var);
14 | dataType = class(var);
15 |
16 | header = constructNPYheader(dataType, shape);
17 |
18 | fid = fopen(filename, 'w');
19 | fwrite(fid, header, 'uint8');
20 | fwrite(fid, var, dataType);
21 | fclose(fid);
22 |
23 |
24 | end
25 |
26 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../../')
4 | import open3d
5 | import numpy as np
6 | import time
7 | import os
8 | from ThreeDMatch.Test.tools import get_pcd, get_keypts, get_desc, loadlog
9 | from sklearn.neighbors import KDTree
10 |
11 |
12 | def calculate_M(source_desc, target_desc):
13 | """
14 | Find the mutually closest point pairs in feature space.
15 | source and target are descriptor for 2 point cloud key points. [5000, 512]
16 | """
17 |
18 | kdtree_s = KDTree(target_desc)
19 | sourceNNdis, sourceNNidx = kdtree_s.query(source_desc, 1)
20 | kdtree_t = KDTree(source_desc)
21 | targetNNdis, targetNNidx = kdtree_t.query(target_desc, 1)
22 | result = []
23 | for i in range(len(sourceNNidx)):
24 | if targetNNidx[sourceNNidx[i]] == i:
25 | result.append([i, sourceNNidx[i][0]])
26 | return np.array(result)
27 |
28 |
29 | def register2Fragments(id1, id2, keyptspath, descpath, resultpath, desc_name='SpinNet'):
30 | cloud_bin_s = f'cloud_bin_{id1}'
31 | cloud_bin_t = f'cloud_bin_{id2}'
32 | write_file = f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'
33 | if os.path.exists(os.path.join(resultpath, write_file)):
34 | print(f"{write_file} already exists.")
35 | return 0, 0, 0
36 |
37 | if is_D3Feat_keypts:
38 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_s + '.npy'
39 | source_keypts = np.load(keypts_path)
40 | source_keypts = source_keypts[-num_keypoints:, :]
41 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_t + '.npy'
42 | target_keypts = np.load(keypts_path)
43 | target_keypts = target_keypts[-num_keypoints:, :]
44 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name)
45 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name)
46 | source_desc = np.nan_to_num(source_desc)
47 | target_desc = np.nan_to_num(target_desc)
48 | source_desc = source_desc[-num_keypoints:, :]
49 | target_desc = target_desc[-num_keypoints:, :]
50 | else:
51 | source_keypts = get_keypts(keyptspath, cloud_bin_s)
52 | target_keypts = get_keypts(keyptspath, cloud_bin_t)
53 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name)
54 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name)
55 | source_desc = np.nan_to_num(source_desc)
56 | target_desc = np.nan_to_num(target_desc)
57 | if source_desc.shape[0] > num_keypoints:
58 | rand_ind = np.random.choice(source_desc.shape[0], num_keypoints, replace=False)
59 | source_keypts = source_keypts[rand_ind]
60 | target_keypts = target_keypts[rand_ind]
61 | source_desc = source_desc[rand_ind]
62 | target_desc = target_desc[rand_ind]
63 |
64 | key = f'{cloud_bin_s.split("_")[-1]}_{cloud_bin_t.split("_")[-1]}'
65 | if key not in gtLog.keys():
66 | num_inliers = 0
67 | inlier_ratio = 0
68 | gt_flag = 0
69 | else:
70 | # find mutually cloest point.
71 | corr = calculate_M(source_desc, target_desc)
72 |
73 | gtTrans = gtLog[key]
74 | frag1 = source_keypts[corr[:, 0]]
75 | frag2_pc = open3d.geometry.PointCloud()
76 | frag2_pc.points = open3d.utility.Vector3dVector(target_keypts[corr[:, 1]])
77 | frag2_pc.transform(gtTrans)
78 | frag2 = np.asarray(frag2_pc.points)
79 | distance = np.sqrt(np.sum(np.power(frag1 - frag2, 2), axis=1))
80 | num_inliers = np.sum(distance < 0.10)
81 | inlier_ratio = num_inliers / len(distance)
82 | gt_flag = 1
83 |
84 | # calculate the transformation matrix using RANSAC, this is for Registration Recall.
85 | source_pcd = open3d.geometry.PointCloud()
86 | source_pcd.points = open3d.utility.Vector3dVector(source_keypts)
87 | target_pcd = open3d.geometry.PointCloud()
88 | target_pcd.points = open3d.utility.Vector3dVector(target_keypts)
89 | s_desc = open3d.pipelines.registration.Feature()
90 | s_desc.data = source_desc.T
91 | t_desc = open3d.pipelines.registration.Feature()
92 | t_desc.data = target_desc.T
93 |
94 | # Another registration method
95 | corr_v = open3d.utility.Vector2iVector(corr)
96 | result = open3d.pipelines.registration.registration_ransac_based_on_correspondence(
97 | source_pcd, target_pcd, corr_v,
98 | 0.05,
99 | open3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3,
100 | open3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000))
101 |
102 | # write the transformation matrix into .log file for evaluation.
103 | with open(os.path.join(logpath, f'{desc_name}_{timestr}.log'), 'a+') as f:
104 | trans = result.transformation
105 | trans = np.linalg.inv(trans)
106 | s1 = f'{id1}\t {id2}\t 37\n'
107 | f.write(s1)
108 | f.write(f"{trans[0, 0]}\t {trans[0, 1]}\t {trans[0, 2]}\t {trans[0, 3]}\t \n")
109 | f.write(f"{trans[1, 0]}\t {trans[1, 1]}\t {trans[1, 2]}\t {trans[1, 3]}\t \n")
110 | f.write(f"{trans[2, 0]}\t {trans[2, 1]}\t {trans[2, 2]}\t {trans[2, 3]}\t \n")
111 | f.write(f"{trans[3, 0]}\t {trans[3, 1]}\t {trans[3, 2]}\t {trans[3, 3]}\t \n")
112 |
113 | s = f"{cloud_bin_s}\t{cloud_bin_t}\t{num_inliers}\t{inlier_ratio:.8f}\t{gt_flag}"
114 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'w+') as f:
115 | f.write(s)
116 | return num_inliers, inlier_ratio, gt_flag
117 |
118 |
119 | def read_register_result(id1, id2):
120 | cloud_bin_s = f'cloud_bin_{id1}'
121 | cloud_bin_t = f'cloud_bin_{id2}'
122 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'r') as f:
123 | content = f.readlines()
124 | nums = content[0].replace("\n", "").split("\t")[2:5]
125 | return nums
126 |
127 |
128 | if __name__ == '__main__':
129 | scene_list = [
130 | '7-scenes-redkitchen',
131 | 'sun3d-home_at-home_at_scan1_2013_jan_1',
132 | 'sun3d-home_md-home_md_scan9_2012_sep_30',
133 | 'sun3d-hotel_uc-scan3',
134 | 'sun3d-hotel_umd-maryland_hotel1',
135 | 'sun3d-hotel_umd-maryland_hotel3',
136 | 'sun3d-mit_76_studyroom-76-1studyroom2',
137 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika'
138 | ]
139 | desc_name = 'SpinNet'
140 | timestr = sys.argv[1]
141 | inliers_list = []
142 | recall_list = []
143 | inliers_ratio_list = []
144 | num_keypoints = 5000
145 | is_D3Feat_keypts = False
146 | for scene in scene_list:
147 | pcdpath = f"../../data/3DMatch/fragments/{scene}/"
148 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/"
149 | gtpath = f'../../data/3DMatch/fragments/{scene}-evaluation/'
150 | keyptspath = interpath # os.path.join(interpath, "keypoints/")
151 | descpath = os.path.join(".", f"{desc_name}_desc_{timestr}/{scene}")
152 | gtLog = loadlog(gtpath)
153 | logpath = f"log_result/{scene}-evaluation"
154 | resultpath = os.path.join(".", f"pred_result/{scene}/{desc_name}_result_{timestr}")
155 | if not os.path.exists(resultpath):
156 | os.makedirs(resultpath)
157 | if not os.path.exists(logpath):
158 | os.makedirs(logpath)
159 |
160 | # register each pair
161 | num_frag = len(os.listdir(pcdpath))
162 | print(f"Start Evaluate Descriptor {desc_name} for {scene}")
163 | start_time = time.time()
164 | for id1 in range(num_frag):
165 | for id2 in range(id1 + 1, num_frag):
166 | num_inliers, inlier_ratio, gt_flag = register2Fragments(id1, id2, keyptspath, descpath, resultpath,
167 | desc_name)
168 | print(f"Finish Evaluation, time: {time.time() - start_time:.2f}s")
169 |
170 | # evaluate
171 | result = []
172 | for id1 in range(num_frag):
173 | for id2 in range(id1 + 1, num_frag):
174 | line = read_register_result(id1, id2)
175 | result.append([int(line[0]), float(line[1]), int(line[2])])
176 | result = np.array(result)
177 | indices_results = np.sum(result[:, 2] == 1)
178 | correct_match = np.sum(result[:, 1] > 0.05)
179 | recall = float(correct_match / indices_results) * 100
180 | print(f"Correct Match {correct_match}, ground truth Match {indices_results}")
181 | print(f"Recall {recall}%")
182 | ave_num_inliers = np.sum(np.where(result[:, 1] > 0.05, result[:, 0], np.zeros(result.shape[0]))) / correct_match
183 | print(f"Average Num Inliners: {ave_num_inliers}")
184 | ave_inlier_ratio = np.sum(
185 | np.where(result[:, 1] > 0.05, result[:, 1], np.zeros(result.shape[0]))) / correct_match
186 | print(f"Average Num Inliner Ratio: {ave_inlier_ratio}")
187 | recall_list.append(recall)
188 | inliers_list.append(ave_num_inliers)
189 | inliers_ratio_list.append(ave_inlier_ratio)
190 | print(recall_list)
191 | average_recall = sum(recall_list) / len(recall_list)
192 | print(f"All 8 scene, average recall: {average_recall}%")
193 | average_inliers = sum(inliers_list) / len(inliers_list)
194 | print(f"All 8 scene, average num inliers: {average_inliers}")
195 | average_inliers_ratio = sum(inliers_ratio_list) / len(inliers_list)
196 | print(f"All 8 scene, average num inliers ratio: {average_inliers_ratio}")
197 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/preparation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4 | import time
5 | import numpy as np
6 | import torch
7 | import shutil
8 | import torch.nn as nn
9 | import sys
10 |
11 | sys.path.append('../../')
12 | import script.common as cm
13 | from ThreeDMatch.Test.tools import get_pcd, get_keypts
14 | from sklearn.neighbors import KDTree
15 | import importlib
16 | import open3d
17 |
18 |
19 | def make_open3d_point_cloud(xyz, color=None):
20 | pcd = open3d.geometry.PointCloud()
21 | pcd.points = open3d.utility.Vector3dVector(xyz)
22 | if color is not None:
23 | pcd.paint_uniform_color(color)
24 | return pcd
25 |
26 |
27 | def build_patch_input(pcd, keypts, vicinity=0.3, num_points_per_patch=2048):
28 | refer_pts = keypts.astype(np.float32)
29 | pts = np.array(pcd.points).astype(np.float32)
30 | num_patches = refer_pts.shape[0]
31 | tree = KDTree(pts[:, 0:3])
32 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity)
33 | local_patches = np.zeros([num_patches, num_points_per_patch, 3], dtype=float)
34 | for i in range(num_patches):
35 | local_neighbors = pts[ind_local[i], :]
36 | if local_neighbors.shape[0] >= num_points_per_patch:
37 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False)
38 | local_neighbors = local_neighbors[temp]
39 | local_neighbors[-1, :] = refer_pts[i, :]
40 | else:
41 | fix_idx = np.asarray(range(local_neighbors.shape[0]))
42 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch:
43 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0)
44 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0],
45 | replace=False)
46 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0)
47 | local_neighbors = local_neighbors[choice_idx]
48 | local_neighbors[-1, :] = refer_pts[i, :]
49 | local_patches[i] = local_neighbors
50 |
51 | return local_patches
52 |
53 |
54 | def prepare_patch(pcdpath, filename, keyptspath, trans_matrix):
55 | pcd = get_pcd(pcdpath, filename)
56 | keypts = get_keypts(keyptspath, filename)
57 | # load D3Feat keypts
58 | if is_D3Feat_keypts:
59 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + pcdpath.split('/')[-2] + '/' + filename + '.npy'
60 | keypts = np.load(keypts_path)
61 | keypts = keypts[-5000:, :]
62 | if is_rotate_dataset:
63 | # Add arbitrary rotation
64 | # rotate terminal frament with an arbitrary angle around the z-axis
65 | angles_3d = np.random.rand(3) * np.pi * 2
66 | R = cm.angles2rotation_matrix(angles_3d)
67 | T = np.identity(4)
68 | T[:3, :3] = R
69 | pcd.transform(T)
70 | keypts_pcd = make_open3d_point_cloud(keypts)
71 | keypts_pcd.transform(T)
72 | keypts = np.array(keypts_pcd.points)
73 | trans_matrix.append(T)
74 |
75 | local_patches = build_patch_input(pcd, keypts) # [num_keypts, 1024, 4]
76 | return local_patches
77 |
78 |
79 | def generate_descriptor(model, desc_name, pcdpath, keyptspath, descpath):
80 | model.eval()
81 | num_frag = len(os.listdir(pcdpath))
82 | num_desc = len(os.listdir(descpath))
83 | trans_matrix = []
84 | if num_frag == num_desc:
85 | print("Descriptor already prepared.")
86 | return
87 | for j in range(num_frag):
88 | local_patches = prepare_patch(pcdpath, 'cloud_bin_' + str(j), keyptspath, trans_matrix)
89 | input_ = torch.tensor(local_patches.astype(np.float32))
90 | B = input_.shape[0]
91 | input_ = input_.cuda()
92 | model = model.cuda()
93 | # calculate descriptors
94 | desc_list = []
95 | start_time = time.time()
96 | desc_len = 32
97 | step_size = 100
98 | iter_num = np.int(np.ceil(B / step_size))
99 | for k in range(iter_num):
100 | if k == iter_num - 1:
101 | desc = model(input_[k * step_size:, :, :])
102 | else:
103 | desc = model(input_[k * step_size: (k + 1) * step_size, :, :])
104 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
105 | del desc
106 | step_time = time.time() - start_time
107 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
108 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
109 | np.save(descpath + 'cloud_bin_' + str(j) + f".desc.{desc_name}.bin", desc.astype(np.float32))
110 | if is_rotate_dataset:
111 | scene_name = pcdpath.split('/')[-2]
112 | all_trans_matrix[scene_name] = trans_matrix
113 |
114 |
115 | if __name__ == '__main__':
116 | scene_list = [
117 | '7-scenes-redkitchen',
118 | 'sun3d-home_at-home_at_scan1_2013_jan_1',
119 | 'sun3d-home_md-home_md_scan9_2012_sep_30',
120 | 'sun3d-hotel_uc-scan3',
121 | 'sun3d-hotel_umd-maryland_hotel1',
122 | 'sun3d-hotel_umd-maryland_hotel3',
123 | 'sun3d-mit_76_studyroom-76-1studyroom2',
124 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika'
125 | ]
126 |
127 | experiment_id = time.strftime('%m%d%H%M')
128 | model_str = experiment_id # sys.argv[1]
129 | if not os.path.exists(f"SpinNet_desc_{model_str}/"):
130 | os.mkdir(f"SpinNet_desc_{model_str}")
131 |
132 | # dynamically load the model
133 | module_file_path = '../model.py'
134 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), module_file_path)
135 | module_name = ''
136 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
137 | module = importlib.util.module_from_spec(module_spec)
138 | module_spec.loader.exec_module(module)
139 | model = module.Descriptor_Net(0.30, 9, 80, 40, 0.04, 30, '3DMatch')
140 | model = nn.DataParallel(model, device_ids=[0])
141 | model.load_state_dict(torch.load('../../pre-trained_models/3DMatch_best.pkl'))
142 |
143 | all_trans_matrix = {}
144 | is_rotate_dataset = False
145 | is_D3Feat_keypts = False
146 | for scene in scene_list:
147 | pcdpath = f"../../data/3DMatch/fragments/{scene}/"
148 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/"
149 | keyptspath = interpath
150 | descpath = os.path.join('.', f"SpinNet_desc_{model_str}/{scene}/")
151 | if not os.path.exists(descpath):
152 | os.makedirs(descpath)
153 | start_time = time.time()
154 | print(f"Begin Processing {scene}")
155 | generate_descriptor(model, desc_name='SpinNet', pcdpath=pcdpath, keyptspath=keyptspath, descpath=descpath)
156 | print(f"Finish in {time.time() - start_time}s")
157 | if is_rotate_dataset:
158 | np.save(f"trans_matrix", all_trans_matrix)
159 |
--------------------------------------------------------------------------------
/ThreeDMatch/Test/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import open3d
3 | import numpy as np
4 |
5 |
6 | def get_pcd(pcdpath, filename):
7 | return open3d.io.read_point_cloud(os.path.join(pcdpath, filename + '.ply'))
8 |
9 |
10 | def get_keypts(keyptspath, filename):
11 | keypts = np.fromfile(os.path.join(keyptspath, filename + '.keypts.bin'), dtype=np.float32)
12 | num_keypts = int(keypts[0])
13 | keypts = keypts[1:].reshape([num_keypts, 3])
14 | return keypts
15 |
16 |
17 | def get_ETH_keypts(pcd, keyptspath, filename):
18 | pts = np.array(pcd.points)
19 | key_ind = np.loadtxt(os.path.join(keyptspath, filename + '_Keypoints.txt'), dtype=np.int)
20 | keypts = pts[key_ind]
21 | return keypts
22 |
23 |
24 | def get_keypts_(keyptspath, filename):
25 | keypts = np.load(os.path.join(keyptspath, filename + f'.keypts.bin.npy'))
26 | return keypts
27 |
28 |
29 | def get_desc(descpath, filename, desc_name):
30 | if desc_name == '3dmatch':
31 | desc = np.fromfile(os.path.join(descpath, filename + '.desc.3dmatch.bin'), dtype=np.float32)
32 | num_desc = int(desc[0])
33 | desc_size = int(desc[1])
34 | desc = desc[2:].reshape([num_desc, desc_size])
35 | elif desc_name == 'SpinNet':
36 | desc = np.load(os.path.join(descpath, filename + '.desc.SpinNet.bin.npy'))
37 | else:
38 | print("No such descriptor")
39 | exit(-1)
40 | return desc
41 |
42 |
43 | def loadlog(gtpath):
44 | with open(os.path.join(gtpath, 'gt.log')) as f:
45 | content = f.readlines()
46 | result = {}
47 | i = 0
48 | while i < len(content):
49 | line = content[i].replace("\n", "").split("\t")[0:3]
50 | trans = np.zeros([4, 4])
51 | trans[0] = [float(x) for x in content[i + 1].replace("\n", "").split("\t")[0:4]]
52 | trans[1] = [float(x) for x in content[i + 2].replace("\n", "").split("\t")[0:4]]
53 | trans[2] = [float(x) for x in content[i + 3].replace("\n", "").split("\t")[0:4]]
54 | trans[3] = [float(x) for x in content[i + 4].replace("\n", "").split("\t")[0:4]]
55 | i = i + 5
56 | result[f'{int(line[0])}_{int(line[1])}'] = trans
57 |
58 | return result
59 |
--------------------------------------------------------------------------------
/ThreeDMatch/Train/dataloader.py:
--------------------------------------------------------------------------------
1 | import time
2 | from ThreeDMatch.Train.dataset import ThreeDMatchDataset
3 | import torch
4 |
5 |
6 | def get_dataloader(root, split, batch_size=1, num_workers=4, shuffle=True, drop_last=True):
7 | dataset = ThreeDMatchDataset(
8 | root=root,
9 | split=split,
10 | batch_size=batch_size,
11 | shuffle=shuffle,
12 | drop_last=drop_last
13 | )
14 | dataset.initial()
15 | dataloader = torch.utils.data.DataLoader(
16 | dataset=dataset,
17 | batch_size=batch_size,
18 | num_workers=num_workers,
19 | drop_last=drop_last
20 | )
21 |
22 | return dataloader
23 |
24 |
25 | if __name__ == '__main__':
26 | dataset = 'sun3d'
27 | dataroot = "/data/3DMatch/whole"
28 | trainloader = get_dataloader(dataroot, split='test', batch_size=32)
29 | start_time = time.time()
30 | print(f"Totally {len(trainloader)} iter.")
31 | for iter, (patches, ids) in enumerate(trainloader):
32 | if iter % 100 == 0:
33 | print(f"Iter {iter}: {time.time() - start_time} s")
34 | print(f"On the fly: {time.time() - start_time}")
35 |
--------------------------------------------------------------------------------
/ThreeDMatch/Train/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as Data
2 | import os
3 | import random
4 | import glob
5 | import pickle
6 | import open3d as o3d
7 | import numpy as np
8 |
9 |
10 | class ThreeDMatchDataset(Data.Dataset):
11 | def __init__(self, root, split, batch_size, shuffle, drop_last):
12 | """
13 | Create ThreeDMatchDataset to read multiple training files
14 | Args:
15 | root: the path to the dataset file
16 | shuffle: whether the data need to shuffle
17 | """
18 | self.patches_path = os.path.join(root, split)
19 | self.split = split
20 | # Get name of all training pkl files
21 | training_data_files = glob.glob(self.patches_path + '/*.pkl')
22 | ids = [file.split("/")[-1] for file in training_data_files]
23 | ids = sorted(ids, key=lambda x: int(x.split("_")[-1].split(".")[0]))
24 | ids = [file for file in ids if file.split("_")[1] == 'anc&pos']
25 | self.training_data_files = ids
26 | # Get info of training files
27 | self.per_num_patch = int(training_data_files[0].split("/")[-1].split("_")[2])
28 | self.dataset_len = int(ids[-1].split("_")[-1].split(".")[0]) * self.per_num_patch
29 | self.batch_size = batch_size
30 | self.shuffle = shuffle
31 | self.drop_last = drop_last
32 | # Record the loaded i-th training file
33 | self.num_file = 0
34 | # load poses for each type of patches
35 | self.per_patch_points = int(self.training_data_files[-1].split("_")[3])
36 | self.num_framents = int(self.training_data_files[-1].split("_")[4].split(".")[0])
37 | with open(os.path.join(root,
38 | f'{self.split}/{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl'),
39 | 'rb') as file:
40 | self.poses = pickle.load(file)
41 | print(
42 | f"load training poses {os.path.join(root, f'{self.split}_poses_{self.per_num_patch}_{self.per_patch_points}_{self.num_framents}.pkl')}")
43 | self.cur_pose_ind = 0
44 |
45 | def initial(self):
46 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file:
47 | self.patches = pickle.load(file)
48 | print(f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}")
49 |
50 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1])
51 | poses = self.poses[self.cur_pose_ind:next_pose_ind]
52 | for i in range(len(self.patches)):
53 | ind = int(np.floor(i / self.per_num_patch))
54 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6)
55 | self.patches[i] = np.concatenate([pose, self.patches[i]])
56 | self.cur_pose_ind = next_pose_ind
57 |
58 | self.current_patches_num = len(self.patches)
59 | self.index = list(range(self.current_patches_num))
60 | if self.shuffle:
61 | random.shuffle(self.patches)
62 |
63 | def __len__(self):
64 | return self.dataset_len
65 |
66 | def __getitem__(self, item):
67 | idx = self.index[0]
68 | patches = self.patches[idx]
69 | self.index = self.index[1:]
70 | self.current_patches_num -= 1
71 |
72 | if self.drop_last:
73 | if self.current_patches_num <= (len(self.patches) % self.batch_size): # reach the end of training file
74 | self.num_file = self.num_file + 1
75 | if self.num_file < len(self.training_data_files):
76 | remain_patches = [self.patches[i] for i in self.index] # the remained training patches
77 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file:
78 | self.patches = pickle.load(file)
79 | print(
80 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}")
81 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1])
82 | poses = self.poses[self.cur_pose_ind:next_pose_ind]
83 | for i in range(len(self.patches)):
84 | ind = int(np.floor(i / self.per_num_patch))
85 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6)
86 | self.patches[i] = np.concatenate([pose, self.patches[i]])
87 | self.cur_pose_ind = next_pose_ind
88 | self.patches += remain_patches # add the remained patches to compose a set of new patches
89 | self.current_patches_num = len(self.patches)
90 | self.index = list(range(self.current_patches_num))
91 | if self.shuffle:
92 | random.shuffle(self.patches)
93 | else:
94 | self.num_file = 0
95 | self.cur_pose_ind = 0
96 | self.initial()
97 | else:
98 | if self.current_patches_num <= 0:
99 | self.num_file = self.num_file + 1
100 | if self.num_file < len(self.training_data_files):
101 | with open(os.path.join(self.patches_path, self.training_data_files[self.num_file]), 'rb') as file:
102 | self.patches = pickle.load(file)
103 | print(
104 | f"load training files {os.path.join(self.patches_path, self.training_data_files[self.num_file])}")
105 | next_pose_ind = int(self.training_data_files[self.num_file].split(".")[0].split("_")[-1])
106 | poses = self.poses[self.cur_pose_ind:next_pose_ind]
107 | for i in range(len(self.patches)):
108 | ind = int(np.floor(i / self.per_num_patch))
109 | pose = np.concatenate([poses[ind][:3, :3].reshape(9), poses[ind][:3, 3]]).reshape(2, 6)
110 | self.patches[i] = np.concatenate([pose, self.patches[i]])
111 | self.cur_pose_ind = next_pose_ind
112 | self.current_patches_num = len(self.patches)
113 | self.index = list(range(self.current_patches_num))
114 | if self.shuffle:
115 | random.shuffle(self.patches)
116 | else:
117 | self.num_file = 0
118 | self.cur_pose_ind = 0
119 | self.initial()
120 |
121 | anc_local_patch = patches[2:, :3]
122 | pos_local_patch = patches[2:, 3:]
123 | rotate = patches[:2, :].reshape(12)[:9].reshape(3, 3)
124 | shift = patches[:2, :].reshape(12)[9:]
125 |
126 | # np.random.shuffle(anc_local_patch)
127 | # np.random.shuffle(pos_local_patch)
128 |
129 | return anc_local_patch, pos_local_patch, rotate, shift
130 |
131 |
132 | if __name__ == "__main__":
133 | data_root = "../data/3DMatch_patches/"
134 | batch_size = 48
135 | epoch = 1
136 | train_dataset = ThreeDMatchDataset(root=data_root, split='train', batch_size=batch_size, shuffle=True,
137 | drop_last=True)
138 | train_dataset.initial()
139 | for _ in range(epoch):
140 | train_iter = Data.DataLoader(dataset=train_dataset, batch_size=batch_size, drop_last=True)
141 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(train_iter):
142 | B = anc_local_patch.shape[0]
143 |
--------------------------------------------------------------------------------
/ThreeDMatch/Train/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4 | import time
5 | import shutil
6 | import sys
7 |
8 | sys.path.append('../../')
9 | from ThreeDMatch.Train.dataloader import get_dataloader
10 | from ThreeDMatch.Train.trainer import Trainer
11 | from network.SpinNet import Descriptor_Net
12 | from torch import optim
13 |
14 |
15 | class Args(object):
16 | def __init__(self):
17 | self.experiment_id = "Proposal" + time.strftime('%m%d%H%M')
18 | snapshot_root = 'snapshot/%s' % self.experiment_id
19 | tensorboard_root = 'tensorboard/%s' % self.experiment_id
20 | os.makedirs(snapshot_root, exist_ok=True)
21 | os.makedirs(tensorboard_root, exist_ok=True)
22 | shutil.copy2(os.path.join('', 'train.py'), os.path.join(snapshot_root, 'train.py'))
23 | shutil.copy2(os.path.join('', 'trainer.py'), os.path.join(snapshot_root, 'trainer.py'))
24 | shutil.copy2(os.path.join('', '../../network/SpinNet.py'), os.path.join(snapshot_root, 'SpinNet.py'))
25 | shutil.copy2(os.path.join('', '../../network/ThreeDCCN.py'), os.path.join(snapshot_root, 'ThreeDCCN.py'))
26 | shutil.copy2(os.path.join('', '../../loss/desc_loss.py'), os.path.join(snapshot_root, 'loss.py'))
27 | self.epoch = 20
28 | self.batch_size = 76
29 | self.rad_n = 9
30 | self.azi_n = 80
31 | self.ele_n = 40
32 | self.des_r = 0.30
33 | self.voxel_r = 0.04
34 | self.voxel_sample = 30
35 |
36 | self.dataset = '3DMatch'
37 | self.data_train_dir = '../../data/3DMatch/patches'
38 | self.data_val_dir = '../../data/3DMatch/patches'
39 |
40 | self.gpu_mode = True
41 | self.verbose = True
42 | self.freeze_epoch = 5
43 |
44 | # model & optimizer
45 | self.model = Descriptor_Net(self.des_r, self.rad_n, self.azi_n, self.ele_n,
46 | self.voxel_r, self.voxel_sample, self.dataset)
47 | self.pretrain = ''
48 | self.parameter = self.model.get_parameter()
49 | self.optimizer = optim.Adam(self.parameter, lr=0.001, betas=(0.9, 0.999), weight_decay=1e-6)
50 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.5)
51 | self.scheduler_interval = 5
52 |
53 | # dataloader
54 | self.train_loader = get_dataloader(root=self.data_train_dir,
55 | batch_size=self.batch_size,
56 | split='train',
57 | shuffle=True,
58 | num_workers=0, # if the dataset is offline generated, must 0
59 | )
60 | self.val_loader = get_dataloader(root=self.data_val_dir,
61 | batch_size=self.batch_size,
62 | split='val',
63 | shuffle=False,
64 | num_workers=0, # if the dataset is offline generated, must 0
65 | )
66 |
67 | print("Training set size:", self.train_loader.dataset.__len__())
68 | print("Validate set size:", self.val_loader.dataset.__len__())
69 |
70 | # snapshot
71 | self.snapshot_interval = int(self.train_loader.dataset.__len__() / self.batch_size / 2)
72 | self.save_dir = os.path.join(snapshot_root, 'models/')
73 | self.result_dir = os.path.join(snapshot_root, 'results/')
74 | self.tboard_dir = tensorboard_root
75 |
76 | # evaluate
77 | self.evaluate_interval = 1
78 |
79 | self.check_args()
80 |
81 | def check_args(self):
82 | """checking arguments"""
83 | if not os.path.exists(self.save_dir):
84 | os.makedirs(self.save_dir)
85 | if not os.path.exists(self.result_dir):
86 | os.makedirs(self.result_dir)
87 | if not os.path.exists(self.tboard_dir):
88 | os.makedirs(self.tboard_dir)
89 | return self
90 |
91 |
92 | if __name__ == '__main__':
93 | args = Args()
94 | trainer = Trainer(args)
95 | trainer.train()
96 |
--------------------------------------------------------------------------------
/ThreeDMatch/Train/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import time, os
4 | import numpy as np
5 | from loss.desc_loss import ContrastiveLoss
6 | from tensorboardX import SummaryWriter
7 |
8 |
9 | class Trainer(object):
10 | def __init__(self, args):
11 | # parameters
12 | self.epoch = args.epoch
13 | self.batch_size = args.batch_size
14 | self.dataset = args.dataset
15 | self.save_dir = args.save_dir
16 | self.result_dir = args.result_dir
17 | self.gpu_mode = args.gpu_mode
18 | self.verbose = args.verbose
19 | self.model = args.model
20 | self.optimizer = args.optimizer
21 | self.scheduler = args.scheduler
22 | self.scheduler_interval = args.scheduler_interval
23 | self.snapshot_interval = args.snapshot_interval
24 | self.evaluate_interval = args.evaluate_interval
25 | self.writer = SummaryWriter(log_dir=args.tboard_dir)
26 |
27 | self.train_loader = args.train_loader
28 | self.val_loader = args.val_loader
29 |
30 | self.desc_loss = ContrastiveLoss()
31 |
32 | if self.gpu_mode:
33 | self.model = self.model.cuda()
34 | self.model = torch.nn.DataParallel(self.model, device_ids=[0])
35 |
36 | if args.pretrain != '':
37 | self._load_pretrain(args.pretrain)
38 |
39 | def train(self):
40 | self.train_hist = {
41 | 'loss': [],
42 | 'per_epoch_time': [],
43 | 'total_time': []
44 | }
45 | best_loss = 1000000000
46 | print('training start!!')
47 | start_time = time.time()
48 |
49 | self.model.train()
50 | for epoch in range(self.epoch):
51 |
52 | self.train_epoch(epoch)
53 |
54 | if epoch % self.evaluate_interval == 0 or epoch == 0:
55 | res = self.evaluate()
56 | print(f'Evaluation: Epoch {epoch}: Loss {res["loss"]}')
57 |
58 | if res['loss'] < best_loss:
59 | best_loss = res['loss']
60 | self._snapshot('best')
61 | if self.writer:
62 | self.writer.add_scalar('Loss', res['loss'], epoch)
63 |
64 | if epoch % self.scheduler_interval == 0:
65 | old_lr = self.optimizer.param_groups[0]['lr']
66 | self.scheduler.step()
67 | new_lr = self.optimizer.param_groups[0]['lr']
68 | print('update detector learning rate: %f -> %f' % (old_lr, new_lr))
69 |
70 | if self.writer:
71 | self.writer.add_scalar('Learning Rate', self._get_lr(), epoch)
72 | self.writer.add_scalar('Train Loss', self.train_hist['loss'][-1], epoch)
73 |
74 | # finish all epoch
75 | self.train_hist['total_time'].append(time.time() - start_time)
76 | print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
77 | self.epoch, self.train_hist['total_time'][0]))
78 | print("Training finish!... save training results")
79 |
80 | def train_epoch(self, epoch):
81 | epoch_start_time = time.time()
82 | loss_buf = []
83 | num_batch = int(len(self.train_loader.dataset) / self.batch_size)
84 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.train_loader):
85 |
86 | B = anc_local_patch.shape[0]
87 | anc_local_patch = anc_local_patch.float()
88 | pos_local_patch = pos_local_patch.float()
89 | rotate = rotate.float()
90 | shift = shift.float()
91 |
92 | if self.gpu_mode:
93 | anc_local_patch = anc_local_patch.cuda()
94 | pos_local_patch = pos_local_patch.cuda()
95 |
96 | # forward
97 | self.optimizer.zero_grad()
98 | a_des = self.model(anc_local_patch)
99 | p_des = self.model(pos_local_patch)
100 | anc_des = F.normalize(a_des.view(B, -1), p=2, dim=1)
101 | pos_des = F.normalize(p_des.view(B, -1), p=2, dim=1)
102 |
103 | # calculate the contrastive loss
104 | des_loss, accuracy = self.desc_loss(anc_des, pos_des)
105 | loss = des_loss
106 |
107 | # backward
108 | loss.backward()
109 | self.optimizer.step()
110 | loss_buf.append(float(loss))
111 |
112 | if iter % self.snapshot_interval == 0:
113 | self._snapshot(f'{epoch}_{iter + 1}')
114 |
115 | if iter % 200 == 0 and self.verbose:
116 | iter_time = time.time() - epoch_start_time
117 | print(f"Epoch: {epoch} [{iter:4d}/{num_batch}] loss: {loss:.2f} time: {iter_time:.2f}s")
118 | print(f"Accuracy: {accuracy.item():.4f}\n")
119 | del loss
120 | del anc_local_patch
121 | del pos_local_patch
122 | # finish one epoch
123 | epoch_time = time.time() - epoch_start_time
124 | self.train_hist['per_epoch_time'].append(epoch_time)
125 | self.train_hist['loss'].append(np.mean(loss_buf))
126 | print(f'Epoch {epoch}: Loss {np.mean(loss_buf)}, time {epoch_time:.4f}s')
127 |
128 | del loss_buf
129 |
130 | def evaluate(self):
131 | self.model.eval()
132 | loss_buf = []
133 | with torch.no_grad():
134 | for iter, (anc_local_patch, pos_local_patch, rotate, shift) in enumerate(self.val_loader):
135 |
136 | B = anc_local_patch.shape[0]
137 | anc_local_patch = anc_local_patch.float()
138 | pos_local_patch = pos_local_patch.float()
139 | rotate = rotate.float()
140 | shift = shift.float()
141 |
142 | if self.gpu_mode:
143 | anc_local_patch = anc_local_patch.cuda()
144 | pos_local_patch = pos_local_patch.cuda()
145 |
146 | # forward
147 | a_des = self.model(anc_local_patch)
148 | p_des = self.model(pos_local_patch)
149 |
150 | # descriptor loss
151 | anc_des = F.normalize(a_des.view(B, -1), p=2, dim=1)
152 | pos_des = F.normalize(p_des.view(B, -1), p=2, dim=1)
153 |
154 | # calculate the contrastive loss
155 | des_loss, accuracy = self.desc_loss(anc_des, pos_des)
156 | loss = des_loss
157 | loss_buf.append(float(loss))
158 |
159 | del loss
160 | del anc_local_patch
161 | del pos_local_patch
162 |
163 | self.model.train()
164 |
165 | res = {
166 | 'loss': np.mean(loss_buf),
167 | }
168 | del loss_buf
169 | return res
170 |
171 | def _snapshot(self, epoch):
172 | save_dir = os.path.join(self.save_dir, self.dataset)
173 | torch.save(self.model.state_dict(), save_dir + "_" + str(epoch) + '.pkl')
174 | print(f"Save model to {save_dir}_{str(epoch)}.pkl")
175 |
176 | def _load_pretrain(self, pretrain):
177 | state_dict = torch.load(pretrain)
178 | self.model.load_state_dict(state_dict)
179 | print(f"Load model from {pretrain}.pkl")
180 |
181 | def _get_lr(self, group=0):
182 | return self.optimizer.param_groups[group]['lr']
183 |
--------------------------------------------------------------------------------
/figs/Fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig1.png
--------------------------------------------------------------------------------
/figs/Fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig2.png
--------------------------------------------------------------------------------
/figs/Fig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig3.png
--------------------------------------------------------------------------------
/figs/Fig4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig4.png
--------------------------------------------------------------------------------
/figs/Fig5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Fig5.png
--------------------------------------------------------------------------------
/figs/Table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table1.png
--------------------------------------------------------------------------------
/figs/Table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table2.png
--------------------------------------------------------------------------------
/figs/Table3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table3.png
--------------------------------------------------------------------------------
/figs/Table4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table4.png
--------------------------------------------------------------------------------
/figs/Table5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table5.png
--------------------------------------------------------------------------------
/figs/Table6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table6.png
--------------------------------------------------------------------------------
/figs/Table7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/figs/Table7.png
--------------------------------------------------------------------------------
/generalization/KITTI-to-ThreeDMatch/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../../')
4 | import open3d
5 | import numpy as np
6 | import time
7 | import os
8 | from ThreeDMatch.Test.tools import get_pcd, get_keypts, get_desc, loadlog
9 | from sklearn.neighbors import KDTree
10 |
11 |
12 | def calculate_M(source_desc, target_desc):
13 | """
14 | Find the mutually closest point pairs in feature space.
15 | source and target are descriptor for 2 point cloud key points. [5000, 512]
16 | """
17 |
18 | kdtree_s = KDTree(target_desc)
19 | sourceNNdis, sourceNNidx = kdtree_s.query(source_desc, 1)
20 | kdtree_t = KDTree(source_desc)
21 | targetNNdis, targetNNidx = kdtree_t.query(target_desc, 1)
22 | result = []
23 | for i in range(len(sourceNNidx)):
24 | if targetNNidx[sourceNNidx[i]] == i:
25 | result.append([i, sourceNNidx[i][0]])
26 | return np.array(result)
27 |
28 |
29 | def register2Fragments(id1, id2, keyptspath, descpath, resultpath, desc_name='SpinNet'):
30 | cloud_bin_s = f'cloud_bin_{id1}'
31 | cloud_bin_t = f'cloud_bin_{id2}'
32 | write_file = f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'
33 | if os.path.exists(os.path.join(resultpath, write_file)):
34 | print(f"{write_file} already exists.")
35 | return 0, 0, 0
36 |
37 | if is_D3Feat_keypts:
38 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_s + '.npy'
39 | source_keypts = np.load(keypts_path)
40 | source_keypts = source_keypts[-num_keypoints:, :]
41 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + keyptspath.split('/')[-2] + '/' + cloud_bin_t + '.npy'
42 | target_keypts = np.load(keypts_path)
43 | target_keypts = target_keypts[-num_keypoints:, :]
44 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name)
45 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name)
46 | source_desc = np.nan_to_num(source_desc)
47 | target_desc = np.nan_to_num(target_desc)
48 | source_desc = source_desc[-num_keypoints:, :]
49 | target_desc = target_desc[-num_keypoints:, :]
50 | else:
51 | source_keypts = get_keypts(keyptspath, cloud_bin_s)
52 | target_keypts = get_keypts(keyptspath, cloud_bin_t)
53 | # print(source_keypts.shape)
54 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name)
55 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name)
56 | source_desc = np.nan_to_num(source_desc)
57 | target_desc = np.nan_to_num(target_desc)
58 | if source_desc.shape[0] > num_keypoints:
59 | rand_ind = np.random.choice(source_desc.shape[0], num_keypoints, replace=False)
60 | source_keypts = source_keypts[rand_ind]
61 | target_keypts = target_keypts[rand_ind]
62 | source_desc = source_desc[rand_ind]
63 | target_desc = target_desc[rand_ind]
64 |
65 | key = f'{cloud_bin_s.split("_")[-1]}_{cloud_bin_t.split("_")[-1]}'
66 | if key not in gtLog.keys():
67 | num_inliers = 0
68 | inlier_ratio = 0
69 | gt_flag = 0
70 | else:
71 | # find mutually cloest point.
72 | corr = calculate_M(source_desc, target_desc)
73 |
74 | gtTrans = gtLog[key]
75 | frag1 = source_keypts[corr[:, 0]]
76 | frag2_pc = open3d.geometry.PointCloud()
77 | frag2_pc.points = open3d.utility.Vector3dVector(target_keypts[corr[:, 1]])
78 | frag2_pc.transform(gtTrans)
79 | frag2 = np.asarray(frag2_pc.points)
80 | distance = np.sqrt(np.sum(np.power(frag1 - frag2, 2), axis=1))
81 | num_inliers = np.sum(distance < 0.10)
82 | inlier_ratio = num_inliers / len(distance)
83 | gt_flag = 1
84 |
85 | # calculate the transformation matrix using RANSAC, this is for Registration Recall.
86 | source_pcd = open3d.geometry.PointCloud()
87 | source_pcd.points = open3d.utility.Vector3dVector(source_keypts)
88 | target_pcd = open3d.geometry.PointCloud()
89 | target_pcd.points = open3d.utility.Vector3dVector(target_keypts)
90 | s_desc = open3d.pipelines.registration.Feature()
91 | s_desc.data = source_desc.T
92 | t_desc = open3d.pipelines.registration.Feature()
93 | t_desc.data = target_desc.T
94 |
95 | # Another registration method
96 | corr_v = open3d.utility.Vector2iVector(corr)
97 | result = open3d.pipelines.registration.registration_ransac_based_on_correspondence(
98 | source_pcd, target_pcd, corr_v,
99 | 0.05,
100 | open3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3,
101 | open3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000))
102 |
103 | # write the transformation matrix into .log file for evaluation.
104 | with open(os.path.join(logpath, f'{desc_name}_{timestr}.log'), 'a+') as f:
105 | trans = result.transformation
106 | trans = np.linalg.inv(trans)
107 | s1 = f'{id1}\t {id2}\t 37\n'
108 | f.write(s1)
109 | f.write(f"{trans[0, 0]}\t {trans[0, 1]}\t {trans[0, 2]}\t {trans[0, 3]}\t \n")
110 | f.write(f"{trans[1, 0]}\t {trans[1, 1]}\t {trans[1, 2]}\t {trans[1, 3]}\t \n")
111 | f.write(f"{trans[2, 0]}\t {trans[2, 1]}\t {trans[2, 2]}\t {trans[2, 3]}\t \n")
112 | f.write(f"{trans[3, 0]}\t {trans[3, 1]}\t {trans[3, 2]}\t {trans[3, 3]}\t \n")
113 |
114 | s = f"{cloud_bin_s}\t{cloud_bin_t}\t{num_inliers}\t{inlier_ratio:.8f}\t{gt_flag}"
115 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'w+') as f:
116 | f.write(s)
117 | return num_inliers, inlier_ratio, gt_flag
118 |
119 |
120 | def read_register_result(id1, id2):
121 | cloud_bin_s = f'cloud_bin_{id1}'
122 | cloud_bin_t = f'cloud_bin_{id2}'
123 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'r') as f:
124 | content = f.readlines()
125 | nums = content[0].replace("\n", "").split("\t")[2:5]
126 | return nums
127 |
128 |
129 | if __name__ == '__main__':
130 | scene_list = [
131 | '7-scenes-redkitchen',
132 | 'sun3d-home_at-home_at_scan1_2013_jan_1',
133 | 'sun3d-home_md-home_md_scan9_2012_sep_30',
134 | 'sun3d-hotel_uc-scan3',
135 | 'sun3d-hotel_umd-maryland_hotel1',
136 | 'sun3d-hotel_umd-maryland_hotel3',
137 | 'sun3d-mit_76_studyroom-76-1studyroom2',
138 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika'
139 | ]
140 | desc_name = 'SpinNet'
141 | timestr = sys.argv[1]
142 | inliers_list = []
143 | recall_list = []
144 | inliers_ratio_list = []
145 | num_keypoints = 5000
146 | is_D3Feat_keypts = False
147 | for scene in scene_list:
148 | pcdpath = f"../../data/3DMatch/fragments/{scene}/"
149 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/"
150 | gtpath = f'../../data/3DMatch/fragments/{scene}-evaluation/'
151 | keyptspath = interpath # os.path.join(interpath, "keypoints/")
152 | descpath = os.path.join(".", f"{desc_name}_desc_{timestr}/{scene}")
153 | gtLog = loadlog(gtpath)
154 | logpath = f"log_result/{scene}-evaluation"
155 | resultpath = os.path.join(".", f"pred_result/{scene}/{desc_name}_result_{timestr}")
156 | if not os.path.exists(resultpath):
157 | os.makedirs(resultpath)
158 | if not os.path.exists(logpath):
159 | os.makedirs(logpath)
160 |
161 | # register each pair
162 | num_frag = len(os.listdir(pcdpath))
163 | print(f"Start Evaluate Descriptor {desc_name} for {scene}")
164 | start_time = time.time()
165 | for id1 in range(num_frag):
166 | for id2 in range(id1 + 1, num_frag):
167 | num_inliers, inlier_ratio, gt_flag = register2Fragments(id1, id2, keyptspath, descpath, resultpath,
168 | desc_name)
169 | print(f"Finish Evaluation, time: {time.time() - start_time:.2f}s")
170 |
171 | # evaluate
172 | result = []
173 | for id1 in range(num_frag):
174 | for id2 in range(id1 + 1, num_frag):
175 | line = read_register_result(id1, id2)
176 | result.append([int(line[0]), float(line[1]), int(line[2])])
177 | result = np.array(result)
178 | indices_results = np.sum(result[:, 2] == 1)
179 | correct_match = np.sum(result[:, 1] > 0.05)
180 | recall = float(correct_match / indices_results) * 100
181 | print(f"Correct Match {correct_match}, ground truth Match {indices_results}")
182 | print(f"Recall {recall}%")
183 | ave_num_inliers = np.sum(np.where(result[:, 1] > 0.05, result[:, 0], np.zeros(result.shape[0]))) / correct_match
184 | print(f"Average Num Inliners: {ave_num_inliers}")
185 | ave_inlier_ratio = np.sum(
186 | np.where(result[:, 1] > 0.05, result[:, 1], np.zeros(result.shape[0]))) / correct_match
187 | print(f"Average Num Inliner Ratio: {ave_inlier_ratio}")
188 | recall_list.append(recall)
189 | inliers_list.append(ave_num_inliers)
190 | inliers_ratio_list.append(ave_inlier_ratio)
191 | print(recall_list)
192 | average_recall = sum(recall_list) / len(recall_list)
193 | print(f"All 8 scene, average recall: {average_recall}%")
194 | average_inliers = sum(inliers_list) / len(inliers_list)
195 | print(f"All 8 scene, average num inliers: {average_inliers}")
196 | average_inliers_ratio = sum(inliers_ratio_list) / len(inliers_list)
197 | print(f"All 8 scene, average num inliers ratio: {average_inliers_ratio}")
198 |
--------------------------------------------------------------------------------
/generalization/KITTI-to-ThreeDMatch/preparation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4 | import time
5 | import numpy as np
6 | import torch
7 | import shutil
8 | import torch.nn as nn
9 | import sys
10 |
11 | sys.path.append('../../')
12 | from ThreeDMatch.Test.tools import get_pcd, get_keypts
13 | from sklearn.neighbors import KDTree
14 | import importlib
15 | import script.common as cm
16 | import open3d
17 |
18 |
19 | def make_open3d_point_cloud(xyz, color=None):
20 | pcd = open3d.geometry.PointCloud()
21 | pcd.points = open3d.utility.Vector3dVector(xyz)
22 | if color is not None:
23 | pcd.paint_uniform_color(color)
24 | return pcd
25 |
26 |
27 | def build_patch_input(pcd, keypts, vicinity=0.3, num_points_per_patch=2048):
28 | refer_pts = keypts.astype(np.float32)
29 | pts = np.array(pcd.points).astype(np.float32)
30 | num_patches = refer_pts.shape[0]
31 | tree = KDTree(pts[:, 0:3])
32 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity)
33 | local_patches = np.zeros([num_patches, num_points_per_patch, 3], dtype=float)
34 | for i in range(num_patches):
35 | local_neighbors = pts[ind_local[i], :]
36 | if local_neighbors.shape[0] >= num_points_per_patch:
37 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False)
38 | local_neighbors = local_neighbors[temp]
39 | local_neighbors[-1, :] = refer_pts[i, :]
40 | else:
41 | fix_idx = np.asarray(range(local_neighbors.shape[0]))
42 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch:
43 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0)
44 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0],
45 | replace=False)
46 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0)
47 | local_neighbors = local_neighbors[choice_idx]
48 | local_neighbors[-1, :] = refer_pts[i, :]
49 | local_patches[i] = local_neighbors
50 |
51 | return local_patches
52 |
53 |
54 | def prepare_patch(pcdpath, filename, keyptspath, trans_matrix):
55 | pcd = get_pcd(pcdpath, filename)
56 | keypts = get_keypts(keyptspath, filename)
57 | # load D3Feat keypts
58 | if is_D3Feat_keypts:
59 | keypts_path = './D3Feat_contralo-54-pred/keypoints/' + pcdpath.split('/')[-2] + '/' + filename + '.npy'
60 | keypts = np.load(keypts_path)
61 | keypts = keypts[-5000:, :]
62 | if is_rotate_dataset:
63 | # Add arbitrary rotation
64 | # rotate terminal frament with an arbitrary angle
65 | angles_3d = np.random.rand(3) * np.pi * 2
66 | R = cm.angles2rotation_matrix(angles_3d)
67 | T = np.identity(4)
68 | T[:3, :3] = R
69 | pcd.transform(T)
70 | keypts_pcd = make_open3d_point_cloud(keypts)
71 | keypts_pcd.transform(T)
72 | keypts = np.array(keypts_pcd.points)
73 | trans_matrix.append(T)
74 |
75 | local_patches = build_patch_input(pcd, keypts, des_r)
76 | return local_patches
77 |
78 |
79 | def generate_descriptor(model, desc_name, pcdpath, keyptspath, descpath):
80 | model.eval()
81 | num_frag = len(os.listdir(pcdpath))
82 | num_desc = len(os.listdir(descpath))
83 | trans_matrix = []
84 | if num_frag == num_desc:
85 | print("Descriptor already prepared.")
86 | return
87 | for j in range(num_frag):
88 | local_patches = prepare_patch(pcdpath, 'cloud_bin_' + str(j), keyptspath, trans_matrix)
89 | input_ = torch.tensor(local_patches.astype(np.float32))
90 | B = input_.shape[0]
91 | input_ = input_.cuda()
92 | model = model.cuda()
93 | # calculate descriptors
94 | desc_list = []
95 | start_time = time.time()
96 | desc_len = 32
97 | step_size = 100
98 | iter_num = np.int(np.ceil(B / step_size))
99 | for k in range(iter_num):
100 | if k == iter_num - 1:
101 | desc = model(input_[k * step_size:, :, :])
102 | else:
103 | desc = model(input_[k * step_size: (k + 1) * step_size, :, :])
104 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
105 | del desc
106 | step_time = time.time() - start_time
107 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
108 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
109 | np.save(descpath + 'cloud_bin_' + str(j) + f".desc.{desc_name}.bin", desc.astype(np.float32))
110 | if is_rotate_dataset:
111 | scene_name = pcdpath.split('/')[-2]
112 | all_trans_matrix[scene_name] = trans_matrix
113 |
114 |
115 | if __name__ == '__main__':
116 | scene_list = [
117 | '7-scenes-redkitchen',
118 | 'sun3d-home_at-home_at_scan1_2013_jan_1',
119 | 'sun3d-home_md-home_md_scan9_2012_sep_30',
120 | 'sun3d-hotel_uc-scan3',
121 | 'sun3d-hotel_umd-maryland_hotel1',
122 | 'sun3d-hotel_umd-maryland_hotel3',
123 | 'sun3d-mit_76_studyroom-76-1studyroom2',
124 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika'
125 | ]
126 |
127 | experiment_id = time.strftime('%m%d%H%M')
128 | model_str = experiment_id # sys.argv[1]
129 | if not os.path.exists(f"SpinNet_desc_{model_str}/"):
130 | os.mkdir(f"SpinNet_desc_{model_str}")
131 |
132 | # dynamically load the model
133 | module_file_path = '../model.py'
134 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), '../model.py')
135 | module_name = ''
136 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
137 | module = importlib.util.module_from_spec(module_spec)
138 | module_spec.loader.exec_module(module)
139 |
140 | des_r = 0.45
141 | model = module.Descriptor_Net(des_r, 9, 60, 30, 0.15, 30, '3DMatch')
142 | model = nn.DataParallel(model, device_ids=[0])
143 | model.load_state_dict(torch.load('../../pre-trained_models/KITTI_best.pkl'))
144 | all_trans_matrix = {}
145 | is_rotate_dataset = False
146 | is_D3Feat_keypts = False
147 | for scene in scene_list:
148 | pcdpath = f"../../data/3DMatch/fragments/{scene}/"
149 | interpath = f"../../data/3DMatch/intermediate-files-real/{scene}/"
150 | keyptspath = interpath
151 | descpath = os.path.join('.', f"SpinNet_desc_{model_str}/{scene}/")
152 | if not os.path.exists(descpath):
153 | os.makedirs(descpath)
154 | start_time = time.time()
155 | print(f"Begin Processing {scene}")
156 | generate_descriptor(model, desc_name='SpinNet', pcdpath=pcdpath, keyptspath=keyptspath, descpath=descpath)
157 | print(f"Finish in {time.time() - start_time}s")
158 | if is_rotate_dataset:
159 | np.save(f"trans_matrix", all_trans_matrix)
160 |
--------------------------------------------------------------------------------
/generalization/ThreeDMatch-to-ETH/evaluate.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../../')
4 | import open3d
5 | import numpy as np
6 | import time
7 | import os
8 | from ThreeDMatch.Test.tools import get_pcd, get_ETH_keypts, get_desc, loadlog
9 | from sklearn.neighbors import KDTree
10 | import glob
11 |
12 |
13 | def calculate_M(source_desc, target_desc):
14 | """
15 | Find the mutually closest point pairs in feature space.
16 | source and target are descriptor for 2 point cloud key points. [5000, 512]
17 | """
18 |
19 | kdtree_s = KDTree(target_desc)
20 | sourceNNdis, sourceNNidx = kdtree_s.query(source_desc, 1)
21 | kdtree_t = KDTree(source_desc)
22 | targetNNdis, targetNNidx = kdtree_t.query(target_desc, 1)
23 | result = []
24 | for i in range(len(sourceNNidx)):
25 | if targetNNidx[sourceNNidx[i]] == i:
26 | result.append([i, sourceNNidx[i][0]])
27 | return np.array(result)
28 |
29 |
30 | def register2Fragments(id1, id2, keyptspath, descpath, resultpath, desc_name='ppf'):
31 | cloud_bin_s = f'Hokuyo_{id1}'
32 | cloud_bin_t = f'Hokuyo_{id2}'
33 | write_file = f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'
34 | if os.path.exists(os.path.join(resultpath, write_file)):
35 | # print(f"{write_file} already exists.")
36 | return 0, 0, 0
37 | pcd_s = get_pcd(pcdpath, cloud_bin_s)
38 | source_keypts = get_ETH_keypts(pcd_s, keyptspath, cloud_bin_s)
39 | pcd_t = get_pcd(pcdpath, cloud_bin_t)
40 | target_keypts = get_ETH_keypts(pcd_t, keyptspath, cloud_bin_t)
41 | # print(source_keypts.shape)
42 | source_desc = get_desc(descpath, cloud_bin_s, desc_name=desc_name)
43 | target_desc = get_desc(descpath, cloud_bin_t, desc_name=desc_name)
44 | source_desc = np.nan_to_num(source_desc)
45 | target_desc = np.nan_to_num(target_desc)
46 |
47 | key = f'{cloud_bin_s.split("_")[-1]}_{cloud_bin_t.split("_")[-1]}'
48 | if key not in gtLog.keys():
49 | num_inliers = 0
50 | inlier_ratio = 0
51 | gt_flag = 0
52 | else:
53 | # find mutually cloest point.
54 | corr = calculate_M(source_desc, target_desc)
55 |
56 | gtTrans = gtLog[key]
57 | frag1 = source_keypts[corr[:, 0]]
58 | frag2_pc = open3d.geometry.PointCloud()
59 | frag2_pc.points = open3d.utility.Vector3dVector(target_keypts[corr[:, 1]])
60 | frag2_pc.transform(gtTrans)
61 | frag2 = np.asarray(frag2_pc.points)
62 | distance = np.sqrt(np.sum(np.power(frag1 - frag2, 2), axis=1))
63 | num_inliers = np.sum(distance < 0.1)
64 | inlier_ratio = num_inliers / len(distance)
65 | gt_flag = 1
66 |
67 | # calculate the transformation matrix using RANSAC, this is for Registration Recall.
68 | source_pcd = open3d.geometry.PointCloud()
69 | source_pcd.points = open3d.utility.Vector3dVector(source_keypts)
70 | target_pcd = open3d.geometry.PointCloud()
71 | target_pcd.points = open3d.utility.Vector3dVector(target_keypts)
72 | s_desc = open3d.pipelines.registration.Feature()
73 | s_desc.data = source_desc.T
74 | t_desc = open3d.pipelines.registration.Feature()
75 | t_desc.data = target_desc.T
76 | result = open3d.pipelines.registration.registration_ransac_based_on_feature_matching(
77 | source_pcd, target_pcd, s_desc, t_desc,
78 | 0.05,
79 | open3d.pipelines.registration.TransformationEstimationPointToPoint(False), 3,
80 | [open3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
81 | open3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(0.05)],
82 | open3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000))
83 | # write the transformation matrix into .log file for evaluation.
84 | with open(os.path.join(logpath, f'{desc_name}_{timestr}.log'), 'a+') as f:
85 | trans = result.transformation
86 | trans = np.linalg.inv(trans)
87 | s1 = f'{id1}\t {id2}\t 37\n'
88 | f.write(s1)
89 | f.write(f"{trans[0, 0]}\t {trans[0, 1]}\t {trans[0, 2]}\t {trans[0, 3]}\t \n")
90 | f.write(f"{trans[1, 0]}\t {trans[1, 1]}\t {trans[1, 2]}\t {trans[1, 3]}\t \n")
91 | f.write(f"{trans[2, 0]}\t {trans[2, 1]}\t {trans[2, 2]}\t {trans[2, 3]}\t \n")
92 | f.write(f"{trans[3, 0]}\t {trans[3, 1]}\t {trans[3, 2]}\t {trans[3, 3]}\t \n")
93 |
94 | s = f"{cloud_bin_s}\t{cloud_bin_t}\t{num_inliers}\t{inlier_ratio:.8f}\t{gt_flag}"
95 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'w+') as f:
96 | f.write(s)
97 | return num_inliers, inlier_ratio, gt_flag
98 |
99 |
100 | def read_register_result(id1, id2):
101 | cloud_bin_s = f'Hokuyo_{id1}'
102 | cloud_bin_t = f'Hokuyo_{id2}'
103 | with open(os.path.join(resultpath, f'{cloud_bin_s}_{cloud_bin_t}.rt.txt'), 'r') as f:
104 | content = f.readlines()
105 | nums = content[0].replace("\n", "").split("\t")[2:5]
106 | return nums
107 |
108 |
109 | if __name__ == '__main__':
110 | scene_list = [
111 | 'gazebo_summer',
112 | 'gazebo_winter',
113 | 'wood_autmn',
114 | 'wood_summer',
115 | ]
116 | desc_name = 'SpinNet'
117 | timestr = sys.argv[1]
118 | inliers_list = []
119 | recall_list = []
120 | for scene in scene_list:
121 | pcdpath = f"../../data/ETH/{scene}/"
122 | interpath = f"../../data/ETH/{scene}/01_Keypoints/"
123 | gtpath = f'../../data/ETH/{scene}/'
124 | keyptspath = interpath # os.path.join(interpath, "keypoints/")
125 | descpath = os.path.join(".", f"{desc_name}_desc_{timestr}/{scene}")
126 | logpath = f"log_result/{scene}-evaluation"
127 | gtLog = loadlog(gtpath)
128 | resultpath = os.path.join(".", f"pred_result/{scene}/{desc_name}_result_{timestr}")
129 | if not os.path.exists(resultpath):
130 | os.makedirs(resultpath)
131 | if not os.path.exists(logpath):
132 | os.makedirs(logpath)
133 |
134 | # register each pair
135 | fragments = glob.glob(pcdpath + '*.ply')
136 | num_frag = len(fragments)
137 | print(f"Start Evaluate Descriptor {desc_name} for {scene}")
138 | start_time = time.time()
139 | for id1 in range(num_frag):
140 | for id2 in range(id1 + 1, num_frag):
141 | num_inliers, inlier_ratio, gt_flag = register2Fragments(id1, id2, keyptspath, descpath, resultpath,
142 | desc_name)
143 | print(f"Finish Evaluation, time: {time.time() - start_time:.2f}s")
144 |
145 | # evaluate
146 | result = []
147 | for id1 in range(num_frag):
148 | for id2 in range(id1 + 1, num_frag):
149 | line = read_register_result(id1, id2)
150 | result.append([int(line[0]), float(line[1]), int(line[2])])
151 | result = np.array(result)
152 | indices_results = np.sum(result[:, 2] == 1)
153 | correct_match = np.sum(result[:, 1] > 0.05)
154 | recall = float(correct_match / indices_results) * 100
155 | print(f"Correct Match {correct_match}, ground truth Match {indices_results}")
156 | print(f"Recall {recall}%")
157 | ave_num_inliers = np.sum(np.where(result[:, 1] > 0.05, result[:, 0], np.zeros(result.shape[0]))) / correct_match
158 | print(f"Average Num Inliners: {ave_num_inliers}")
159 | recall_list.append(recall)
160 | inliers_list.append(ave_num_inliers)
161 | print(recall_list)
162 | average_recall = sum(recall_list) / len(recall_list)
163 | print(f"All 8 scene, average recall: {average_recall}%")
164 | average_inliers = sum(inliers_list) / len(inliers_list)
165 | print(f"All 8 scene, average num inliers: {average_inliers}")
166 |
--------------------------------------------------------------------------------
/generalization/ThreeDMatch-to-ETH/preparation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4 | import time
5 | import numpy as np
6 | import torch
7 | import shutil
8 | import torch.nn as nn
9 | import glob
10 | import sys
11 |
12 | sys.path.append('../../')
13 | import script.common as cm
14 | import open3d
15 | from ThreeDMatch.Test.tools import get_pcd, get_ETH_keypts, get_desc, loadlog
16 | from sklearn.neighbors import KDTree
17 | import importlib
18 |
19 |
20 | def make_open3d_point_cloud(xyz, color=None):
21 | pcd = open3d.geometry.PointCloud()
22 | pcd.points = open3d.utility.Vector3dVector(xyz)
23 | if color is not None:
24 | pcd.paint_uniform_color(color)
25 | return pcd
26 |
27 |
28 | def build_patch_input(pcd, keypts, vicinity=0.3, num_points_per_patch=2048):
29 | refer_pts = keypts.astype(np.float32)
30 | pts = np.array(pcd.points).astype(np.float32)
31 | num_patches = refer_pts.shape[0]
32 | tree = KDTree(pts[:, 0:3])
33 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity)
34 | local_patches = np.zeros([num_patches, num_points_per_patch, 3], dtype=float)
35 | for i in range(num_patches):
36 | local_neighbors = pts[ind_local[i], :]
37 | if local_neighbors.shape[0] >= num_points_per_patch:
38 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False)
39 | local_neighbors = local_neighbors[temp]
40 | local_neighbors[-1, :] = refer_pts[i, :]
41 | else:
42 | fix_idx = np.asarray(range(local_neighbors.shape[0]))
43 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch:
44 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0)
45 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0],
46 | replace=False)
47 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0)
48 | local_neighbors = local_neighbors[choice_idx]
49 | local_neighbors[-1, :] = refer_pts[i, :]
50 | local_patches[i] = local_neighbors
51 |
52 | return local_patches
53 |
54 |
55 | def prepare_patch(pcdpath, filename, keyptspath, trans_matrix):
56 | pcd = get_pcd(pcdpath, filename)
57 | keypts = get_ETH_keypts(pcd, keyptspath, filename)
58 | if is_rotate_dataset:
59 | # Add arbitrary rotation
60 | # rotate terminal frament with an arbitrary angle around the z-axis
61 | angles_3d = np.random.rand(3) * np.pi * 2
62 | R = cm.angles2rotation_matrix(angles_3d)
63 | T = np.identity(4)
64 | T[:3, :3] = R
65 | pcd.transform(T)
66 | keypts_pcd = make_open3d_point_cloud(keypts)
67 | keypts_pcd.transform(T)
68 | keypts = np.array(keypts_pcd.points)
69 | trans_matrix.append(T)
70 | local_patches = build_patch_input(pcd, keypts, des_r) # [num_keypts, 1024, 4]
71 | return local_patches
72 |
73 |
74 | def generate_descriptor(model, desc_name, pcdpath, keyptspath, descpath):
75 | model.eval()
76 | fragments = glob.glob(pcdpath + '*.ply')
77 | num_frag = len(fragments)
78 | num_desc = len(os.listdir(descpath))
79 | trans_matrix = []
80 | if num_frag == num_desc:
81 | print("Descriptor already prepared.")
82 | return
83 | for j in range(num_frag):
84 | local_patches = prepare_patch(pcdpath, 'Hokuyo_' + str(j), keyptspath, trans_matrix)
85 | input_ = torch.tensor(local_patches.astype(np.float32))
86 | B = input_.shape[0]
87 | input_ = input_.cuda()
88 | model = model.cuda()
89 | # calculate descriptors
90 | desc_list = []
91 | start_time = time.time()
92 | desc_len = 32
93 | step_size = 100
94 | iter_num = np.int(np.ceil(B / step_size))
95 | for k in range(iter_num):
96 | if k == iter_num - 1:
97 | desc = model(input_[k * step_size:, :, :])
98 | else:
99 | desc = model(input_[k * step_size: (k + 1) * step_size, :, :])
100 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
101 | del desc
102 | step_time = time.time() - start_time
103 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
104 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
105 | np.save(descpath + 'Hokuyo_' + str(j) + f".desc.{desc_name}.bin", desc.astype(np.float32))
106 | if is_rotate_dataset:
107 | scene_name = pcdpath.split('/')[-2]
108 | all_trans_matrix[scene_name] = trans_matrix
109 |
110 |
111 | if __name__ == '__main__':
112 | scene_list = [
113 | 'gazebo_summer',
114 | 'gazebo_winter',
115 | 'wood_autmn',
116 | 'wood_summer',
117 | ]
118 | experiment_id = time.strftime('%m%d%H%M')
119 | model_str = experiment_id # sys.argv[1]
120 | if not os.path.exists(f"SpinNet_desc_{model_str}/"):
121 | os.mkdir(f"SpinNet_desc_{model_str}")
122 |
123 | # dynamically load the model
124 | module_file_path = '../model.py'
125 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), '../model.py')
126 | module_name = ''
127 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
128 | module = importlib.util.module_from_spec(module_spec)
129 | module_spec.loader.exec_module(module)
130 |
131 | des_r = 0.8
132 | model = module.Descriptor_Net(des_r, 9, 80, 40, 0.10, 30, '3DMatch')
133 | model = nn.DataParallel(model, device_ids=[0])
134 | model.load_state_dict(torch.load('../../pre-trained_models/3DMatch_best.pkl'))
135 | all_trans_matrix = {}
136 | is_rotate_dataset = False
137 |
138 | for scene in scene_list:
139 | pcdpath = f"../../data/ETH/{scene}/"
140 | interpath = f"../../data/ETH/{scene}/01_Keypoints/"
141 | keyptspath = interpath
142 | descpath = os.path.join('.', f"SpinNet_desc_{model_str}/{scene}/")
143 | if not os.path.exists(descpath):
144 | os.makedirs(descpath)
145 | start_time = time.time()
146 | print(f"Begin Processing {scene}")
147 | generate_descriptor(model, desc_name='SpinNet', pcdpath=pcdpath, keyptspath=keyptspath, descpath=descpath)
148 | print(f"Finish in {time.time() - start_time}s")
149 | if is_rotate_dataset:
150 | np.save(f"trans_matrix", all_trans_matrix)
151 |
--------------------------------------------------------------------------------
/generalization/ThreeDMatch-to-KITTI/test.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
4 | import logging
5 | import numpy as np
6 | import open3d as o3d
7 | import torch
8 | import torch.nn as nn
9 | import glob
10 | import time
11 | import gc
12 | import shutil
13 | import pointnet2_ops.pointnet2_utils as pnt2
14 | import copy
15 | import importlib
16 | import sys
17 |
18 | sys.path.append('../../')
19 | import script.common as cm
20 |
21 | kitti_icp_cache = {}
22 | kitti_cache = {}
23 |
24 |
25 | class Timer(object):
26 | """A simple timer."""
27 |
28 | def __init__(self, binary_fn=None, init_val=0):
29 | self.total_time = 0.
30 | self.calls = 0
31 | self.start_time = 0.
32 | self.diff = 0.
33 | self.binary_fn = binary_fn
34 | self.tmp = init_val
35 |
36 | def reset(self):
37 | self.total_time = 0
38 | self.calls = 0
39 | self.start_time = 0
40 | self.diff = 0
41 |
42 | @property
43 | def avg(self):
44 | return self.total_time / self.calls
45 |
46 | def tic(self):
47 | # using time.time instead of time.clock because time time.clock
48 | # does not normalize for multithreading
49 | self.start_time = time.time()
50 |
51 | def toc(self, average=True):
52 | self.diff = time.time() - self.start_time
53 | self.total_time += self.diff
54 | self.calls += 1
55 | if self.binary_fn:
56 | self.tmp = self.binary_fn(self.tmp, self.diff)
57 | if average:
58 | return self.avg
59 | else:
60 | return self.diff
61 |
62 |
63 | class AverageMeter(object):
64 | """Computes and stores the average and current value"""
65 |
66 | def __init__(self):
67 | self.reset()
68 |
69 | def reset(self):
70 | self.val = 0
71 | self.avg = 0
72 | self.sum = 0.0
73 | self.sq_sum = 0.0
74 | self.count = 0
75 |
76 | def update(self, val, n=1):
77 | self.val = val
78 | self.sum += val * n
79 | self.count += n
80 | self.avg = self.sum / self.count
81 | self.sq_sum += val ** 2 * n
82 | self.var = self.sq_sum / self.count - self.avg ** 2
83 |
84 |
85 | def get_desc(descpath, filename):
86 | desc = np.load(os.path.join(descpath, filename + '.npy'))
87 | return desc
88 |
89 |
90 | def get_keypts(keypts_path, filename):
91 | keypts = np.load(os.path.join(keypts_path, filename + '.npy'))
92 | return keypts
93 |
94 |
95 | def make_open3d_feature(data, dim, npts):
96 | feature = o3d.pipelines.registration.Feature()
97 | feature.resize(dim, npts)
98 | feature.data = data.astype('d').transpose()
99 | return feature
100 |
101 |
102 | def make_open3d_point_cloud(xyz, color=None):
103 | pcd = o3d.geometry.PointCloud()
104 | pcd.points = o3d.utility.Vector3dVector(xyz)
105 | if color is not None:
106 | pcd.paint_uniform_color(color)
107 | return pcd
108 |
109 |
110 | def get_matching_indices(source, target, trans, search_voxel_size, K=None):
111 | source_copy = copy.deepcopy(source)
112 | target_copy = copy.deepcopy(target)
113 | source_copy.transform(trans)
114 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy)
115 |
116 | match_inds = []
117 | for i, point in enumerate(source_copy.points):
118 | [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size)
119 | if K is not None:
120 | idx = idx[:K]
121 | for j in idx:
122 | match_inds.append((i, j))
123 | return match_inds
124 |
125 |
126 | class KITTI(object):
127 | DATA_FILES = {
128 | 'train': 'train_kitti.txt',
129 | 'val': 'val_kitti.txt',
130 | 'test': 'test_kitti.txt'
131 | }
132 | """
133 | Given point cloud fragments and corresponding pose in '{root}'.
134 | 1. Save the aligned point cloud pts in '{savepath}/3DMatch_{downsample}_points.pkl'
135 | 2. Calculate the overlap ratio and save in '{savepath}/3DMatch_{downsample}_overlap.pkl'
136 | 3. Save the ids of anchor keypoints and positive keypoints in '{savepath}/3DMatch_{downsample}_keypts.pkl'
137 | """
138 |
139 | def __init__(self, root, descpath, icp_path, split, model, num_points_per_patch, use_random_points):
140 | self.root = root
141 | self.descpath = descpath
142 | self.split = split
143 | self.num_points_per_patch = num_points_per_patch
144 | self.icp_path = icp_path
145 | self.use_random_points = use_random_points
146 | self.model = model
147 | if not os.path.exists(self.icp_path):
148 | os.makedirs(self.icp_path)
149 |
150 | # list: anc & pos
151 | self.patches = []
152 | self.pose = []
153 | # Initiate containers
154 | self.files = {'train': [], 'val': [], 'test': []}
155 |
156 | self.prepare_kitti_ply(split=self.split)
157 |
158 | def prepare_kitti_ply(self, split='train'):
159 | subset_names = open(self.DATA_FILES[split]).read().split()
160 | for dirname in subset_names:
161 | drive_id = int(dirname)
162 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id)
163 | assert len(fnames) > 0, f"Make sure that the path {self.root} has data {dirname}"
164 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames])
165 |
166 | all_odo = self.get_video_odometry(drive_id, return_all=True)
167 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo])
168 | Ts = all_pos[:, :3, 3]
169 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3)) ** 2
170 | pdist = np.sqrt(pdist.sum(-1))
171 | more_than_10 = pdist > 10
172 | curr_time = inames[0]
173 | while curr_time in inames:
174 | next_time = np.where(more_than_10[curr_time][curr_time:curr_time + 100])[0]
175 | if len(next_time) == 0:
176 | curr_time += 1
177 | else:
178 | next_time = next_time[0] + curr_time - 1
179 |
180 | if next_time in inames:
181 | self.files[split].append((drive_id, curr_time, next_time))
182 | curr_time = next_time + 1
183 | # Remove problematic sequence
184 | for item in [
185 | (8, 15, 58),
186 | ]:
187 | if item in self.files[split]:
188 | self.files[split].pop(self.files[split].index(item))
189 |
190 | if split == 'train':
191 | self.num_train = len(self.files[split])
192 | print("Num_train", self.num_train)
193 | elif split == 'val':
194 | self.num_val = len(self.files[split])
195 | print("Num_val", self.num_val)
196 | elif split == 'test':
197 | self.num_test = len(self.files[split])
198 | print("Num_test", self.num_test)
199 |
200 | for idx in range(len(self.files[split])):
201 | drive = self.files[split][idx][0]
202 | t0, t1 = self.files[split][idx][1], self.files[split][idx][2]
203 | all_odometry = self.get_video_odometry(drive, [t0, t1])
204 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry]
205 | fname0 = self._get_velodyne_fn(drive, t0)
206 | fname1 = self._get_velodyne_fn(drive, t1)
207 |
208 | # XYZ and reflectance
209 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4)
210 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4)
211 |
212 | xyz0 = xyzr0[:, :3]
213 | xyz1 = xyzr1[:, :3]
214 |
215 | key = '%d_%d_%d' % (drive, t0, t1)
216 | filename = self.icp_path + '/' + key + '.npy'
217 | if key not in kitti_icp_cache:
218 | if not os.path.exists(filename):
219 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T)
220 | @ np.linalg.inv(self.velo2cam)).T
221 | xyz0_t = self.apply_transform(xyz0, M)
222 | pcd0 = make_open3d_point_cloud(xyz0_t, [0.5, 0.5, 0.5])
223 | pcd1 = make_open3d_point_cloud(xyz1, [0, 1, 0])
224 | reg = o3d.pipelines.registration.registration_icp(pcd0, pcd1, 0.10, np.eye(4),
225 | o3d.pipelines.registration.TransformationEstimationPointToPoint(),
226 | o3d.pipelines.registration.ICPConvergenceCriteria(
227 | max_iteration=400))
228 | pcd0.transform(reg.transformation)
229 | M2 = M @ reg.transformation
230 | # write to a file
231 | np.save(filename, M2)
232 | else:
233 | M2 = np.load(filename)
234 | kitti_icp_cache[key] = M2
235 | else:
236 | M2 = kitti_icp_cache[key]
237 | trans = M2
238 | # extract patches for anc&pos
239 | np.random.shuffle(xyz0)
240 | np.random.shuffle(xyz1)
241 |
242 | if is_rotate_dataset:
243 | # Add arbitrary rotation
244 | # rotate terminal frament with an arbitrary angle
245 | angles_3d = np.random.rand(3) * np.pi * 2
246 | R = cm.angles2rotation_matrix(angles_3d)
247 | T = np.identity(4)
248 | T[:3, :3] = R
249 | pcd1 = make_open3d_point_cloud(xyz1)
250 | pcd1.transform(T)
251 | xyz1 = np.array(pcd1.points)
252 | all_trans_matrix[key] = T
253 |
254 | if not os.path.exists(self.descpath + str(drive)):
255 | os.makedirs(self.descpath + str(drive))
256 | if self.use_random_points:
257 | num_keypts = 5000
258 | step_size = 100
259 | desc_len = 32
260 | model = self.model.cuda()
261 | # calc t0 descriptors
262 | desc_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".desc.bin.npy")
263 | keypts_t0_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t0) + f".keypts.npy")
264 | if not os.path.exists(desc_t0_path):
265 | keypoints_id = np.random.choice(xyz0.shape[0], num_keypts)
266 | keypts = xyz0[keypoints_id]
267 | np.save(keypts_t0_path, keypts.astype(np.float32))
268 | local_patches = self.select_patches(xyz0, keypts, vicinity=vicinity,
269 | num_points_per_patch=self.num_points_per_patch)
270 | B = local_patches.shape[0]
271 | # cuda out of memry
272 | desc_list = []
273 | start_time = time.time()
274 | iter_num = np.int(np.ceil(B / step_size))
275 | for k in range(iter_num):
276 | if k == iter_num - 1:
277 | desc = model(local_patches[k * step_size:, :, :])
278 | else:
279 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :])
280 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
281 | del desc
282 | step_time = time.time() - start_time
283 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
284 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
285 | np.save(desc_t0_path, desc.astype(np.float32))
286 | else:
287 | print(f"{desc_t0_path} already exists.")
288 |
289 | # calc t1 descriptors
290 | desc_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".desc.bin.npy")
291 | keypts_t1_path = os.path.join(self.descpath + str(drive), f"cloud_bin_" + str(t1) + f".keypts.npy")
292 | if not os.path.exists(desc_t1_path):
293 | keypoints_id = np.random.choice(xyz1.shape[0], num_keypts)
294 | keypts = xyz1[keypoints_id]
295 | np.save(keypts_t1_path, keypts.astype(np.float32))
296 | local_patches = self.select_patches(xyz1, keypts, vicinity=vicinity,
297 | num_points_per_patch=self.num_points_per_patch)
298 | B = local_patches.shape[0]
299 | # cuda out of memry
300 | desc_list = []
301 | start_time = time.time()
302 | iter_num = np.int(np.ceil(B / step_size))
303 | for k in range(iter_num):
304 | if k == iter_num - 1:
305 | desc = model(local_patches[k * step_size:, :, :])
306 | else:
307 | desc = model(local_patches[k * step_size: (k + 1) * step_size, :, :])
308 | desc_list.append(desc.view(desc.shape[0], desc_len).detach().cpu().numpy())
309 | del desc
310 | step_time = time.time() - start_time
311 | print(f'Finish {B} descriptors spend {step_time:.4f}s')
312 | desc = np.concatenate(desc_list, 0).reshape([B, desc_len])
313 | np.save(desc_t1_path, desc.astype(np.float32))
314 | else:
315 | print(f"{desc_t1_path} already exists.")
316 | else:
317 | num_keypts = 512
318 |
319 | def select_patches(self, pts, refer_pts, vicinity, num_points_per_patch=1024):
320 | gc.collect()
321 | pts = torch.FloatTensor(pts).cuda().unsqueeze(0)
322 | refer_pts = torch.FloatTensor(refer_pts).cuda().unsqueeze(0)
323 | group_idx = pnt2.ball_query(vicinity, num_points_per_patch, pts, refer_pts)
324 | pts_trans = pts.transpose(1, 2).contiguous()
325 | new_points = pnt2.grouping_operation(
326 | pts_trans, group_idx
327 | )
328 | new_points = new_points.permute([0, 2, 3, 1])
329 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, num_points_per_patch)
330 | mask = (group_idx == mask).float()
331 | mask[:, :, 0] = 0
332 | mask[:, :, num_points_per_patch - 1] = 1
333 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3])
334 | new_pts = refer_pts.unsqueeze(2).repeat([1, 1, num_points_per_patch, 1])
335 | local_patches = new_points * (1 - mask).float() + new_pts * mask.float()
336 | # local_patches = list(local_patches.squeeze(0).detach().cpu().numpy())
337 | local_patches = local_patches.squeeze(0)
338 | del mask
339 | del new_points
340 | del group_idx
341 | del new_pts
342 | del pts
343 | del pts_trans
344 |
345 | return local_patches
346 |
347 | def apply_transform(self, pts, trans):
348 | R = trans[:3, :3]
349 | T = trans[:3, 3]
350 | pts = pts @ R.T + T
351 | return pts
352 |
353 | @property
354 | def velo2cam(self):
355 | try:
356 | velo2cam = self._velo2cam
357 | except AttributeError:
358 | R = np.array([
359 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04,
360 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02
361 | ]).reshape(3, 3)
362 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1)
363 | velo2cam = np.hstack([R, T])
364 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T
365 | return self._velo2cam
366 |
367 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False):
368 | data_path = self.root + '/poses/%02d.txt' % drive
369 | if data_path not in kitti_cache:
370 | kitti_cache[data_path] = np.genfromtxt(data_path)
371 | if return_all:
372 | return kitti_cache[data_path]
373 | else:
374 | return kitti_cache[data_path][indices]
375 |
376 | def odometry_to_positions(self, odometry):
377 | T_w_cam0 = odometry.reshape(3, 4)
378 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1]))
379 | return T_w_cam0
380 |
381 | def _get_velodyne_fn(self, drive, t):
382 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t)
383 | return fname
384 |
385 |
386 | if __name__ == '__main__':
387 | is_rotate_dataset = False
388 | all_trans_matrix = {}
389 | experiment_id = time.strftime('%m%d%H%M') # '11210201'#
390 | model_str = experiment_id
391 | reg_timer = Timer()
392 | success_meter, rte_meter, rre_meter = AverageMeter(), AverageMeter(), AverageMeter()
393 | ch = logging.StreamHandler(sys.stdout)
394 | logging.getLogger().setLevel(logging.INFO)
395 | logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch])
396 |
397 | # dynamically load the model from snapshot
398 | module_file_path = '../model.py'
399 | shutil.copy2(os.path.join('.', '../../network/SpinNet.py'), module_file_path)
400 | module_name = ''
401 | module_spec = importlib.util.spec_from_file_location(module_name, module_file_path)
402 | module = importlib.util.module_from_spec(module_spec)
403 | module_spec.loader.exec_module(module)
404 |
405 | vicinity = 3.0
406 | model = module.Descriptor_Net(vicinity, 9, 80, 40, 0.5, 30, 'KITTI')
407 | model = nn.DataParallel(model, device_ids=[0])
408 | model.load_state_dict(torch.load('../../pre-trained_models/3DMatch_best.pkl'))
409 |
410 | test_data = KITTI(root='../../data/KITTI/dataset',
411 | descpath=f'SpinNet_desc_{model_str}/',
412 | icp_path='../../data/KITTI/icp',
413 | split='test',
414 | model=model,
415 | num_points_per_patch=2048,
416 | use_random_points=True
417 | )
418 |
419 | files = test_data.files[test_data.split]
420 | for idx in range(len(files)):
421 | drive = files[idx][0]
422 | t0, t1 = files[idx][1], files[idx][2]
423 | key = '%d_%d_%d' % (drive, t0, t1)
424 | filename = test_data.icp_path + '/' + key + '.npy'
425 | T_gth = kitti_icp_cache[key]
426 | if is_rotate_dataset:
427 | T_gth = np.matmul(all_trans_matrix[key], T_gth)
428 |
429 | descpath = os.path.join(test_data.descpath, str(drive))
430 | fname0 = test_data._get_velodyne_fn(drive, t0)
431 | fname1 = test_data._get_velodyne_fn(drive, t1)
432 | # XYZ and reflectance
433 | xyz0 = get_keypts(descpath, f"cloud_bin_" + str(t0) + f".keypts")
434 | xyz1 = get_keypts(descpath, f"cloud_bin_" + str(t1) + f".keypts")
435 | pcd0 = make_open3d_point_cloud(xyz0)
436 | pcd1 = make_open3d_point_cloud(xyz1)
437 |
438 | source_desc = get_desc(descpath, f"cloud_bin_" + str(t0) + f".desc.bin")
439 | target_desc = get_desc(descpath, f"cloud_bin_" + str(t1) + f".desc.bin")
440 | feat0 = make_open3d_feature(source_desc, 32, source_desc.shape[0])
441 | feat1 = make_open3d_feature(target_desc, 32, target_desc.shape[0])
442 |
443 | reg_timer.tic()
444 | distance_threshold = 0.3
445 | ransac_result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
446 | pcd0, pcd1, feat0, feat1, distance_threshold,
447 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 4, [
448 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9),
449 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold)
450 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(50000, 1000))
451 | T_ransac = torch.from_numpy(ransac_result.transformation.astype(np.float32))
452 | reg_timer.toc()
453 |
454 | # Translation error
455 | rte = np.linalg.norm(T_ransac[:3, 3] - T_gth[:3, 3])
456 | rre = np.arccos((np.trace(T_ransac[:3, :3].t() @ T_gth[:3, :3]) - 1) / 2)
457 |
458 | if rte < 2:
459 | rte_meter.update(rte)
460 |
461 | if not np.isnan(rre) and rre < np.pi / 180 * 5:
462 | rre_meter.update(rre * 180 / np.pi)
463 |
464 | if rte < 2 and not np.isnan(rre) and rre < np.pi / 180 * 5:
465 | success_meter.update(1)
466 | else:
467 | success_meter.update(0)
468 | logging.info(f"Failed with RTE: {rte}, RRE: {rre}")
469 |
470 | if (idx + 1) % 10 == 0:
471 | logging.info(
472 | f" RRE: {rre_meter.avg}, Success: {success_meter.sum} / {success_meter.count}" +
473 | f" ({success_meter.avg * 100} %)"
474 | )
475 | reg_timer.reset()
476 |
477 | logging.info(
478 | f"RTE: {rte_meter.avg}, var: {rte_meter.var}," +
479 | f" RRE: {rre_meter.avg}, var: {rre_meter.var}, Success: {success_meter.sum} " +
480 | f"/ {success_meter.count} ({success_meter.avg * 100} %)"
481 | )
482 |
--------------------------------------------------------------------------------
/generalization/ThreeDMatch-to-KITTI/test_kitti.txt:
--------------------------------------------------------------------------------
1 | 8
2 | 9
3 | 10
4 |
--------------------------------------------------------------------------------
/loss/desc_loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | def all_diffs(a, b):
7 | """ Returns a tensor of all combinations of a - b.
8 |
9 | Args:
10 | a (2D tensor): A batch of vectors shaped (B1, F).
11 | b (2D tensor): A batch of vectors shaped (B2, F).
12 |
13 | Returns:
14 | The matrix of all pairwise differences between all vectors in `a` and in
15 | `b`, will be of shape (B1, B2).
16 |
17 | """
18 | return torch.unsqueeze(a, dim=1) - torch.unsqueeze(b, dim=0)
19 |
20 |
21 | def cdist(a, b, metric='euclidean'):
22 | """Similar to scipy.spatial's cdist, but symbolic.
23 |
24 | The currently supported metrics can be listed as `cdist.supported_metrics` and are:
25 | - 'euclidean', although with a fudge-factor epsilon.
26 | - 'sqeuclidean', the squared euclidean.
27 | - 'cityblock', the manhattan or L1 distance.
28 |
29 | Args:
30 | a (2D tensor): The left-hand side, shaped (B1, F).
31 | b (2D tensor): The right-hand side, shaped (B2, F).
32 | metric (string): Which distance metric to use, see notes.
33 |
34 | Returns:
35 | The matrix of all pairwise distances between all vectors in `a` and in
36 | `b`, will be of shape (B1, B2).
37 |
38 | Note:
39 | When a square root is taken (such as in the Euclidean case), a small
40 | epsilon is added because the gradient of the square-root at zero is
41 | undefined. Thus, it will never return exact zero in these cases.
42 | """
43 |
44 | diffs = all_diffs(a, b)
45 | if metric == 'sqeuclidean':
46 | return torch.sum(diffs ** 2, dim=-1)
47 | elif metric == 'euclidean':
48 | return torch.sqrt(torch.sum(diffs ** 2, dim=-1) + 1e-12)
49 | elif metric == 'cityblock':
50 | return torch.sum(torch.abs(diffs), dim=-1)
51 | else:
52 | raise NotImplementedError(
53 | 'The following metric is not implemented by `cdist` yet: {}'.format(metric))
54 |
55 |
56 | class ContrastiveLoss(nn.Module):
57 | def __init__(self, pos_margin=0.1, neg_margin=1.4, metric='euclidean', safe_radius=0.25):
58 | super(ContrastiveLoss, self).__init__()
59 | self.pos_margin = pos_margin
60 | self.neg_margin = neg_margin
61 | self.metric = metric
62 | self.safe_radius = safe_radius
63 |
64 | def forward(self, anchor, positive):
65 | pids = torch.FloatTensor(np.arange(len(anchor))).to(anchor.device)
66 | dist = cdist(anchor, positive, metric=self.metric)
67 | return self.calculate_loss(dist, pids)
68 |
69 | def calculate_loss(self, dists, pids):
70 | """Computes the batch-hard loss from arxiv.org/abs/1703.07737.
71 |
72 | Args:
73 | dists (2D tensor): A square all-to-all distance matrix as given by cdist.
74 | pids (1D tensor): The identities of the entries in `batch`, shape (B,).
75 | This can be of any type that can be compared, thus also a string.
76 | margin: The value of the margin if a number, alternatively the string
77 | 'soft' for using the soft-margin formulation, or `None` for not
78 | using a margin at all.
79 |
80 | Returns:
81 | A 1D tensor of shape (B,) containing the loss value for each sample.
82 | """
83 | # generate the mask that mask[i, j] reprensent whether i th and j th are from the same identity.
84 | # torch.equal is to check whether two tensors have the same size and elements
85 | # torch.eq is to computes element-wise equality
86 | same_identity_mask = torch.eq(torch.unsqueeze(pids, dim=1), torch.unsqueeze(pids, dim=0))
87 |
88 | # dists * same_identity_mask get the distance of each valid anchor-positive pair.
89 | furthest_positive, _ = torch.max(dists * same_identity_mask.float(), dim=1)
90 | closest_negative, _ = torch.min(dists + 1e5 * same_identity_mask.float(), dim=1)
91 | diff = furthest_positive - closest_negative
92 | accuracy = (diff < 0).sum() * 100.0 / diff.shape[0]
93 | loss = torch.max(furthest_positive - self.pos_margin, torch.zeros_like(diff)) + torch.max(
94 | self.neg_margin - closest_negative, torch.zeros_like(diff))
95 |
96 | return torch.mean(loss), accuracy
97 |
--------------------------------------------------------------------------------
/network/SpinNet.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../')
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import network.ThreeDCCN as pn
8 | import script.common as cm
9 | from script.common import switch
10 |
11 |
12 | class Descriptor_Net(nn.Module):
13 | def __init__(self, des_r, rad_n, azi_n, ele_n, voxel_r, voxel_sample, dataset):
14 | super(Descriptor_Net, self).__init__()
15 | self.des_r = des_r
16 | self.rad_n = rad_n
17 | self.azi_n = azi_n
18 | self.ele_n = ele_n
19 | self.voxel_r = voxel_r
20 | self.voxel_sample = voxel_sample
21 | self.dataset = dataset
22 |
23 | self.bn_xyz_raising = nn.BatchNorm2d(16)
24 | self.bn_mapping = nn.BatchNorm2d(16)
25 | self.activation = nn.ReLU()
26 | self.xyz_raising = nn.Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))
27 | self.conv_net = pn.Cylindrical_Net(inchan=16, dim=32)
28 |
29 | def forward(self, input):
30 | center = input[:, -1, :].unsqueeze(1)
31 | delta_x = input[:, :, 0:3] - center[:, :, 0:3] # (B, npoint, 3), normalized coordinates
32 | for case in switch(self.dataset):
33 | if case('3DMatch'):
34 | z_axis = cm.cal_Z_axis(delta_x, ref_point=input[:, -1, :3])
35 | z_axis = cm.l2_norm(z_axis, axis=1)
36 | R = cm.RodsRotatFormula(z_axis, torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(z_axis.shape[0], 1))
37 | delta_x = torch.matmul(delta_x, R)
38 | break
39 | if case('KITTI'):
40 | break
41 |
42 | # partition the local surface along elevator, azimuth, radial dimensions
43 | S2_xyz = torch.FloatTensor(cm.get_voxel_coordinate(radius=self.des_r,
44 | rad_n=self.rad_n,
45 | azi_n=self.azi_n,
46 | ele_n=self.ele_n))
47 |
48 | pts_xyz = S2_xyz.view(1, -1, 3).repeat([delta_x.shape[0], 1, 1]).cuda()
49 | # query points in sphere
50 | new_points = cm.sphere_query(delta_x, pts_xyz, radius=self.voxel_r,
51 | nsample=self.voxel_sample)
52 | # transform rotation-variant coords into rotation-invariant coords
53 | new_points = new_points - pts_xyz.unsqueeze(2).repeat([1, 1, self.voxel_sample, 1])
54 | new_points = cm.var_to_invar(new_points, self.rad_n, self.azi_n, self.ele_n)
55 |
56 | new_points = new_points.permute(0, 3, 1, 2) # (B, C_in, npoint, nsample), input features
57 | C_in = new_points.size()[1]
58 | nsample = new_points.size()[3]
59 | x = self.activation(self.bn_xyz_raising(self.xyz_raising(new_points)))
60 | x = F.max_pool2d(x, kernel_size=(1, nsample)).squeeze(3) # (B, C_in, npoint)
61 | del new_points
62 | del pts_xyz
63 | x = x.view(x.shape[0], x.shape[1], self.rad_n, self.ele_n, self.azi_n)
64 |
65 | x = self.conv_net(x)
66 | x = F.max_pool2d(x, kernel_size=(x.shape[2], x.shape[3]))
67 |
68 | return x
69 |
70 | def get_parameter(self):
71 | return list(self.parameters())
72 |
--------------------------------------------------------------------------------
/network/ThreeDCCN.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import script.common as cm
6 |
7 |
8 | class BaseNet(nn.Module):
9 | """ Takes a list of images as input, and returns for each image:
10 | - a pixelwise descriptor
11 | - a pixelwise confidence
12 | """
13 |
14 | def forward_one(self, x):
15 | raise NotImplementedError()
16 |
17 | def forward(self, imgs):
18 | res = self.forward_one(imgs)
19 | return res
20 |
21 |
22 | class Cyclindrical_ConvNet(BaseNet):
23 | def __init__(self, inchan=3, dilated=True, dilation=1, bn=True, bn_affine=False):
24 | BaseNet.__init__(self)
25 | self.inchan = inchan
26 | self.curchan = inchan
27 | self.dilated = dilated
28 | self.dilation = dilation
29 | self.bn = bn
30 | self.bn_affine = bn_affine
31 | self.ops = nn.ModuleList([])
32 |
33 | def _make_bn_2d(self, outd):
34 | return nn.BatchNorm2d(outd, affine=self.bn_affine)
35 |
36 | def _make_bn_3d(self, outd):
37 | return nn.BatchNorm3d(outd, affine=self.bn_affine)
38 |
39 | def _add_conv_2d(self, outd, k=3, stride=1, dilation=1, bn=True, relu=True):
40 | d = self.dilation * dilation
41 | self.dilation *= stride
42 | self.ops.append(nn.Conv2d(self.curchan, outd, kernel_size=(k, k), dilation=d))
43 | if bn and self.bn: self.ops.append(self._make_bn_2d(outd))
44 | if relu: self.ops.append(nn.ReLU(inplace=True))
45 | self.curchan = outd
46 |
47 | def _add_conv_3d(self, outd, k, stride=1, dilation=1, bn=True, relu=True):
48 | d = self.dilation * dilation
49 | self.dilation *= stride
50 | self.ops.append(nn.Conv3d(self.curchan, outd, kernel_size=(k[0], k[1], k[2]), dilation=d))
51 | if bn and self.bn: self.ops.append(self._make_bn_3d(outd))
52 | if relu: self.ops.append(nn.ReLU(inplace=True))
53 | self.curchan = outd
54 |
55 | def forward_one(self, x):
56 | assert self.ops, "You need to add convolutions first"
57 | for n, op in enumerate(self.ops):
58 | k_exist = hasattr(op, 'kernel_size')
59 | if k_exist:
60 | if len(op.kernel_size) == 3:
61 | x = cm.pad_image_3d(x, op.kernel_size[1] + (op.kernel_size[1] - 1) * (op.dilation[0] - 1))
62 | else:
63 | if len(x.shape) == 5:
64 | x = x.squeeze(2)
65 | x = cm.pad_image(x, op.kernel_size[0] + (op.kernel_size[0] - 1) * (op.dilation[0] - 1))
66 | x = op(x)
67 | return x
68 |
69 |
70 | class Cylindrical_Net(Cyclindrical_ConvNet):
71 | """ Compute a descriptor for all overlapping patches.
72 | From the L2Net paper (CVPR'17).
73 | """
74 |
75 | def __init__(self, inchan=16, dim=32, **kw):
76 | Cyclindrical_ConvNet.__init__(self, inchan=inchan, **kw)
77 | add_conv_2d = lambda n, **kw: self._add_conv_2d(n, **kw)
78 | add_conv_3d = lambda n, **kw: self._add_conv_3d(n, **kw)
79 | add_conv_3d(32, k=[3, 3, 3])
80 | add_conv_3d(32, k=[3, 3, 3])
81 | add_conv_3d(64, k=[3, 3, 3])
82 | add_conv_3d(64, k=[3, 3, 3])
83 | add_conv_2d(128, stride=2)
84 | add_conv_2d(128)
85 | add_conv_2d(64, stride=2)
86 | add_conv_2d(64)
87 | add_conv_2d(32, k=2, stride=2, relu=False)
88 | add_conv_2d(32, k=2, stride=2, relu=False)
89 | add_conv_2d(dim, k=2, stride=2, bn=False, relu=False)
90 | self.out_dim = dim
91 |
--------------------------------------------------------------------------------
/pre-trained_models/3DMatch_best.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/pre-trained_models/3DMatch_best.pkl
--------------------------------------------------------------------------------
/pre-trained_models/KITTI_best.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/QingyongHu/SpinNet/5581e7d184bc3b4d525d5b5e58777ea04dfdc9ab/pre-trained_models/KITTI_best.pkl
--------------------------------------------------------------------------------
/script/cal_overlap.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import exists, join
3 | import pickle
4 | import numpy as np
5 | import open3d
6 | import cv2
7 | import time
8 |
9 |
10 | class ThreeDMatch(object):
11 | """
12 | Given point cloud fragments and corresponding pose in '{root}'.
13 | 1. Save the aligned point cloud pts in '{savepath}/3DMatch_{downsample}_points.pkl'
14 | 2. Calculate the overlap ratio and save in '{savepath}/3DMatch_{downsample}_overlap.pkl'
15 | 3. Save the ids of anchor keypoints and positive keypoints in '{savepath}/3DMatch_{downsample}_keypts.pkl'
16 | """
17 |
18 | def __init__(self, root, savepath, split, downsample):
19 | self.root = root
20 | self.savepath = savepath
21 | self.split = split
22 | self.downsample = downsample
23 |
24 | # dict: from id to pts.
25 | self.pts = {}
26 |
27 | # dict: from id_id to overlap_ratio
28 | self.overlap_ratio = {}
29 | # dict: from id_id to anc_keypts id & pos_keypts id
30 | self.keypts_pairs = {}
31 |
32 | with open(os.path.join(root, f'scene_list_{split}.txt')) as f:
33 | scene_list = f.readlines()
34 | self.ids_list = []
35 | self.scene_to_ids = {}
36 | for scene in scene_list:
37 | scene = scene.replace("\n", "")
38 | self.scene_to_ids[scene] = []
39 | for seq in sorted(os.listdir(os.path.join(self.root, scene))):
40 | if not seq.startswith('seq'):
41 | continue
42 | scene_path = os.path.join(self.root, scene + f'/{seq}')
43 | ids = [scene + f"/{seq}/" + str(filename.split(".")[0]) for filename in os.listdir(scene_path) if
44 | filename.endswith('ply')]
45 | ids = sorted(ids, key=lambda x: int(x.split("_")[-1]))
46 | self.ids_list += ids
47 | self.scene_to_ids[scene] += ids
48 | print(f"Scene {scene}, seq {seq}: num ply: {len(ids)}")
49 | print(f"Total {len(scene_list)} scenes, {len(self.ids_list)} point cloud fragments.")
50 | self.idpair_list = []
51 | self.load_all_ply(downsample)
52 | self.cal_overlap(downsample)
53 |
54 | def load_ply(self, data_dir, ind, downsample, aligned=True):
55 | pcd = open3d.io.read_point_cloud(join(data_dir, f'{ind}.ply'))
56 | pcd = open3d.geometry.PointCloud.voxel_down_sample(pcd, voxel_size=downsample)
57 | if aligned is True:
58 | matrix = np.load(join(data_dir, f'{ind}.pose.npy'))
59 | pcd.transform(matrix)
60 | return pcd
61 |
62 | def load_all_ply(self, downsample):
63 | pts_filename = join(self.savepath, f'3DMatch_{self.split}_{downsample:.3f}_points.pkl')
64 | if exists(pts_filename):
65 | with open(pts_filename, 'rb') as file:
66 | self.pts = pickle.load(file)
67 | print(f"Load pts file from {self.savepath}")
68 | return
69 | self.pts = {}
70 | for i, anc_id in enumerate(self.ids_list):
71 | anc_pcd = self.load_ply(self.root, anc_id, downsample=downsample, aligned=True)
72 | points = np.array(anc_pcd.points)
73 | print(len(points))
74 | self.pts[anc_id] = points
75 | print('processing ply: {:.1f}%'.format(100 * i / len(self.ids_list)))
76 | with open(pts_filename, 'wb') as file:
77 | pickle.dump(self.pts, file)
78 |
79 | def get_matching_indices(self, anc_pts, pos_pts, search_voxel_size, K=None):
80 | match_inds = []
81 | bf_matcher = cv2.BFMatcher(cv2.NORM_L2)
82 | match = bf_matcher.match(anc_pts, pos_pts)
83 | for match_val in match:
84 | if match_val.distance < search_voxel_size:
85 | match_inds.append([match_val.queryIdx, match_val.trainIdx])
86 | return np.array(match_inds)
87 |
88 | def cal_overlap(self, downsample):
89 | overlap_filename = join(self.savepath, f'3DMatch_{self.split}_{downsample:.3f}_overlap.pkl')
90 | keypts_filename = join(self.savepath, f'3DMatch_{self.split}_{downsample:.3f}_keypts.pkl')
91 | if exists(overlap_filename) and exists(keypts_filename):
92 | with open(overlap_filename, 'rb') as file:
93 | self.overlap_ratio = pickle.load(file)
94 | print(f"Reload overlap info from {overlap_filename}")
95 | with open(keypts_filename, 'rb') as file:
96 | self.keypts_pairs = pickle.load(file)
97 | print(f"Reload keypts info from {keypts_filename}")
98 | import pdb
99 | pdb.set_trace()
100 | return
101 | t0 = time.time()
102 | for scene, scene_ids in self.scene_to_ids.items():
103 | scene_overlap = {}
104 | print(f"Begin processing scene {scene}")
105 | for i in range(0, len(scene_ids)):
106 | anc_id = scene_ids[i]
107 | for j in range(i + 1, len(scene_ids)):
108 | pos_id = scene_ids[j]
109 | anc_pts = self.pts[anc_id].astype(np.float32)
110 | pos_pts = self.pts[pos_id].astype(np.float32)
111 |
112 | try:
113 | matching_01 = self.get_matching_indices(anc_pts, pos_pts, self.downsample)
114 | except BaseException as e:
115 | print(f"Something wrong with get_matching_indices {e} for {anc_id}, {pos_id}")
116 | matching_01 = np.array([])
117 | overlap_ratio = len(matching_01) / len(anc_pts)
118 | scene_overlap[f'{anc_id}@{pos_id}'] = overlap_ratio
119 | if overlap_ratio > 0.30:
120 | self.keypts_pairs[f'{anc_id}@{pos_id}'] = matching_01.astype(np.int32)
121 | self.overlap_ratio[f'{anc_id}@{pos_id}'] = overlap_ratio
122 | print(f'\t {anc_id}, {pos_id} overlap ratio: {overlap_ratio}')
123 | print('processing {:s} ply: {:.1f}%'.format(scene, 100 * i / len(scene_ids)))
124 | print('Finish {:s}, Done in {:.1f}s'.format(scene, time.time() - t0))
125 |
126 | with open(overlap_filename, 'wb') as file:
127 | pickle.dump(self.overlap_ratio, file)
128 | with open(keypts_filename, 'wb') as file:
129 | pickle.dump(self.keypts_pairs, file)
130 |
131 |
132 | if __name__ == '__main__':
133 | ThreeDMatch(root='path to your ply file.',
134 | savepath='data/3DMatch',
135 | split='train',
136 | downsample=0.030
137 | )
138 |
--------------------------------------------------------------------------------
/script/common.py:
--------------------------------------------------------------------------------
1 | import open3d
2 | import numpy as np
3 | import os
4 | import time
5 | import torch
6 | from sklearn.neighbors import KDTree
7 | import pointnet2_ops.pointnet2_utils as pnt2
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 |
11 |
12 | class switch(object):
13 | def __init__(self, value):
14 | self.value = value
15 | self.fall = False
16 |
17 | def __iter__(self):
18 | """Return the match method once, then stop"""
19 | yield self.match
20 | raise StopIteration
21 |
22 | def match(self, *args):
23 | """Indicate whether or not to enter a case suite"""
24 | if self.fall or not args:
25 | return True
26 | elif self.value in args: # changed for v1.5, see below
27 | self.fall = True
28 | return True
29 | else:
30 | return False
31 |
32 |
33 | def select_patches(pts, ind, num_patches=1024, vicinity=0.15, num_points_per_patch=1024, is_rand=True):
34 | # A point sampling algorithm for 3d matching of irregular geometries.
35 | tree = KDTree(pts[:, 0:3])
36 | num_points = pts.shape[0]
37 | if is_rand:
38 | out_inds = np.random.choice(range(ind.shape[0]), num_patches, replace=False)
39 | inds = ind[out_inds]
40 | else:
41 | inds = ind
42 | refer_pts = pts[inds]
43 |
44 | ind_local = tree.query_radius(refer_pts[:, 0:3], r=vicinity)
45 | local_patches = []
46 | for i in range(np.size(ind_local)):
47 | local_neighbors = pts[ind_local[i], :]
48 | if local_neighbors.shape[0] >= num_points_per_patch:
49 | temp = np.random.choice(range(local_neighbors.shape[0]), num_points_per_patch, replace=False)
50 | local_neighbors = local_neighbors[temp]
51 | local_neighbors[-1, :] = refer_pts[i, :]
52 | else:
53 | fix_idx = np.asarray(range(local_neighbors.shape[0]))
54 | while local_neighbors.shape[0] + fix_idx.shape[0] < num_points_per_patch:
55 | fix_idx = np.concatenate((fix_idx, np.asarray(range(local_neighbors.shape[0]))), axis=0)
56 | random_idx = np.random.choice(local_neighbors.shape[0], num_points_per_patch - fix_idx.shape[0],
57 | replace=False)
58 | choice_idx = np.concatenate((fix_idx, random_idx), axis=0)
59 | local_neighbors = local_neighbors[choice_idx]
60 | local_neighbors[-1, :] = refer_pts[i, :]
61 |
62 | # fill_num = num_points_per_patch-local_neighbors.shape[0]
63 | # local_neighbors = np.concatenate((local_neighbors, np.tile(refer_pts[i,:],(fill_num,1))), axis=0)
64 | local_patches.append(local_neighbors)
65 | if is_rand:
66 | return local_patches, out_inds
67 | else:
68 | return local_patches
69 |
70 |
71 | def transform_pc_pytorch(pc, sn):
72 | '''
73 |
74 | :param pc: 3xN tensor
75 | :param sn: 5xN tensor / 4xN tensor
76 | :param node: 3xM tensor
77 | :return: pc, sn, node of the same shape, detach
78 | '''
79 | angles_3d = np.random.rand(3) * np.pi * 2
80 | shift = np.random.uniform(-1, 1, (1, 3))
81 |
82 | sigma, clip = 0.010, 0.02
83 | N, C = pc.shape
84 | jitter_pc = np.clip(sigma * np.random.randn(N, 3), -1 * clip, clip)
85 | sigma, clip = 0.010, 0.02
86 | jitter_sn = np.clip(sigma * np.random.randn(N, 4), -1 * clip, clip)
87 | pc += jitter_pc
88 | sn += jitter_sn
89 |
90 | pc = pc_rotate_translate(pc, angles_3d, shift)
91 | sn[:, 0:3] = vec_rotate(sn[:, 0:3], angles_3d) # 3x3 * 3xN -> 3xN
92 |
93 | return pc, sn, \
94 | angles_3d, shift
95 |
96 |
97 | def l2_norm(input, axis=1):
98 | norm = torch.norm(input, p=2, dim=axis, keepdim=True)
99 | output = torch.div(input, norm)
100 | return output
101 |
102 |
103 | def angles2rotation_matrix(angles):
104 | Rx = np.array([[1, 0, 0],
105 | [0, np.cos(angles[0]), -np.sin(angles[0])],
106 | [0, np.sin(angles[0]), np.cos(angles[0])]])
107 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
108 | [0, 1, 0],
109 | [-np.sin(angles[1]), 0, np.cos(angles[1])]])
110 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
111 | [np.sin(angles[2]), np.cos(angles[2]), 0],
112 | [0, 0, 1]])
113 | R = np.dot(Rz, np.dot(Ry, Rx))
114 | return R
115 |
116 |
117 | def pc_rotate_translate(data, angles, translates):
118 | '''
119 | :param data: numpy array of Nx3 array
120 | :param angles: numpy array / list of 3
121 | :param translates: numpy array / list of 3
122 | :return: rotated_data: numpy array of Nx3
123 | '''
124 | R = angles2rotation_matrix(angles)
125 | rotated_data = np.dot(data, np.transpose(R)) + translates
126 |
127 | return rotated_data
128 |
129 |
130 | def pc_rotate_translate_torch(data, angles, translates):
131 | '''
132 | :param data: Tensor of BxNx3 array
133 | :param angles: Tensor of Bx3
134 | :param translates: Tensor of Bx3
135 | :return: rotated_data: Tensor of Nx3
136 | '''
137 | device = data.device
138 | B, N, _ = data.shape
139 |
140 | R = np.zeros([B, 3, 3])
141 | for i in range(B):
142 | R[i] = angles2rotation_matrix(angles[i]) # 3x3
143 | R = torch.FloatTensor(R).to(device)
144 |
145 | rotated_data = torch.matmul(data, R.transpose(-1, -2)) + torch.FloatTensor(translates).unsqueeze(1).to(device)
146 |
147 | return rotated_data
148 |
149 |
150 | def _pc_rotate_translate_torch(data, R, translates):
151 | '''
152 | :param data: Tensor of BxNx3 array
153 | :param angles: Tensor of Bx3
154 | :param translates: Tensor of Bx3
155 | :return: rotated_data: Tensor of Nx3
156 | '''
157 | device = data.device
158 | B, N, _ = data.shape
159 |
160 | rotated_data = torch.matmul(data, R.to(device).transpose(-1, -2)) + torch.FloatTensor(translates).unsqueeze(1).to(
161 | device)
162 |
163 | return rotated_data
164 |
165 |
166 | def max_ind(data):
167 | B, C, row, col = data.shape
168 | inds = np.zeros([B, 2])
169 | for i in range(B):
170 | ind = torch.argmax(data[i])
171 | r = int(ind // col)
172 | c = ind % col
173 | inds[i, 0] = r
174 | inds[i, 1] = c
175 | return inds
176 |
177 |
178 | def vec_rotate(data, angles):
179 | '''
180 | :param data: numpy array of Nx3 array
181 | :param angles: numpy array / list of 3
182 | :return: rotated_data: numpy array of Nx3
183 | '''
184 | R = angles2rotation_matrix(angles)
185 | rotated_data = np.dot(data, R)
186 |
187 | return rotated_data
188 |
189 |
190 | def vec_rotate_torch(data, angles):
191 | '''
192 | :param data: BxNx3 tensor
193 | :param angles: Bx3 numpy array
194 | :return:
195 | '''
196 | device = data.device
197 | B, N, _ = data.shape
198 |
199 | R = np.zeros([B, 3, 3])
200 | for i in range(B):
201 | R[i] = angles2rotation_matrix(angles[i]) # 3x3
202 | R = torch.FloatTensor(R).to(device)
203 |
204 | rotated_data = torch.matmul(data, R.transpose(-1, -2)) # BxNx3 * Bx3x3 -> BxNx3
205 | return rotated_data
206 |
207 |
208 | def rotate_perturbation_point_cloud(data, angle_sigma=0.01, angle_clip=0.05):
209 | """ Randomly perturb the point clouds by small rotations
210 | Input:
211 | Nx3 array, original point clouds
212 | Return:
213 | Nx3 array, rotated point clouds
214 | """
215 | # truncated Gaussian sampling
216 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
217 | rotated_data = vec_rotate(data, angles)
218 |
219 | return rotated_data
220 |
221 |
222 | def jitter_point_cloud(data, sigma=0.01, clip=0.05):
223 | """ Randomly jitter points. jittering is per point.
224 | Input:
225 | BxNx3 array, original point clouds
226 | Return:
227 | BxNx3 array, jittered point clouds
228 | """
229 | B, N, C = data.shape
230 | assert (clip > 0)
231 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1 * clip, clip)
232 | jittered_data += data
233 | return jittered_data
234 |
235 |
236 | def square_distance(src, dst):
237 | """
238 | Calculate Euclid distance between each two points.
239 | src^T * dst = xn * xm + yn * ym + zn * zm;
240 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
241 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
242 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
243 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
244 | Input:
245 | src: source points, [B, N, C]
246 | dst: target points, [B, M, C]
247 | Output:
248 | dist: per-point square distance, [B, N, M]
249 | """
250 | B, N, _ = src.shape
251 | _, M, _ = dst.shape
252 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
253 | dist += torch.sum(src ** 2, -1).view(B, N, 1)
254 | dist += torch.sum(dst ** 2, -1).view(B, 1, M)
255 | return dist
256 |
257 |
258 | def cdist(a, b):
259 | '''
260 | :param a:
261 | :param b:
262 | :return:
263 | '''
264 | diff = a.unsqueeze(0) - b.unsqueeze(1)
265 | dis_matrix = torch.sqrt(torch.sum(diff * diff, dim=-1) + 1e-12)
266 | return dis_matrix
267 |
268 |
269 | def s2_grid(n_alpha, n_beta):
270 | '''
271 | :return: rings around the equator
272 | size of the kernel = n_alpha * n_beta
273 | '''
274 | beta = np.linspace(start=0, stop=np.pi, num=n_beta, endpoint=False) + np.pi / n_beta / 2
275 | # ele = np.arcsin(np.linspace(start=0, stop=1, num=n_beta / 2, endpoint=False) + 1 / n_beta / 4)
276 | # beta = np.concatenate([np.sort(-ele), ele])
277 | alpha = np.linspace(start=0, stop=2 * np.pi, num=n_alpha, endpoint=False) + np.pi / n_alpha
278 | B, A = np.meshgrid(beta, alpha, indexing='ij')
279 | B = B.flatten()
280 | A = A.flatten()
281 | grid = np.stack((B, A), axis=1)
282 | return grid
283 |
284 |
285 | def pad_image(input, kernel_size):
286 | """
287 | Circularly padding image for convolution
288 | :param input: [B, C, H, W]
289 | :param kernel_size:
290 | :return:
291 | """
292 | device = input.device
293 | if kernel_size % 2 == 0:
294 | pad_size = kernel_size // 2
295 | output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3)
296 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], pad_size, output.shape[3]]).to(device)
297 | output = torch.cat([output, zeros_pad], dim=2)
298 | else:
299 | pad_size = (kernel_size - 1) // 2
300 | output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3)
301 | output = torch.cat([input[:, :, :, -pad_size:], output], dim=3)
302 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], pad_size, output.shape[3]]).to(device)
303 | output = torch.cat([output, zeros_pad], dim=2)
304 | output = torch.cat([zeros_pad, output], dim=2)
305 | return output
306 |
307 |
308 | def pad_image_3d(input, kernel_size):
309 | """
310 | Circularly padding image for convolution
311 | :param input: [B, C, D, H, W]
312 | :param kernel_size:
313 | :return:
314 | """
315 | device = input.device
316 | if kernel_size % 2 == 0:
317 | pad_size = kernel_size // 2
318 | output = torch.cat([input, input[:, :, :, :, 0:pad_size]], dim=4)
319 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], output.shape[2], pad_size, output.shape[4]]).to(
320 | device)
321 | output = torch.cat([output, zeros_pad], dim=3)
322 | else:
323 | pad_size = (kernel_size - 1) // 2
324 | output = torch.cat([input, input[:, :, :, :, 0:pad_size]], dim=4)
325 | output = torch.cat([input[:, :, :, :, -pad_size:], output], dim=4)
326 | zeros_pad = torch.zeros([output.shape[0], output.shape[1], output.shape[2], pad_size, output.shape[4]]).to(
327 | device)
328 | output = torch.cat([output, zeros_pad], dim=3)
329 | output = torch.cat([zeros_pad, output], dim=3)
330 | return output
331 |
332 |
333 | def pad_image_on_azi(input, kernel_size):
334 | """
335 | Circularly padding image for convolution
336 | :param input: [B, C, H, W]
337 | :param kernel_size:
338 | :return:
339 | """
340 | device = input.device
341 | pad_size = (kernel_size - 1) // 2
342 | output = torch.cat([input, input[:, :, :, 0:pad_size]], dim=3)
343 | output = torch.cat([input[:, :, :, -pad_size:], output], dim=3)
344 | return output
345 |
346 |
347 | def kmax_pooling(x, dim, k):
348 | kmax = x.topk(k, dim=dim)[0]
349 | return kmax
350 |
351 |
352 | def change_coordinates(coords, radius, p_from='C', p_to='S'):
353 | """
354 | Change Spherical to Cartesian coordinates and vice versa, for points x in S^2.
355 |
356 | In the spherical system, we have coordinates beta and alpha,
357 | where beta in [0, pi] and alpha in [0, 2pi]
358 |
359 | We use the names beta and alpha for compatibility with the SO(3) code (S^2 being a quotient SO(3)/SO(2)).
360 | Many sources, like wikipedia use theta=beta and phi=alpha.
361 |
362 | :param coords: coordinate array
363 | :param p_from: 'C' for Cartesian or 'S' for spherical coordinates
364 | :param p_to: 'C' for Cartesian or 'S' for spherical coordinates
365 | :return: new coordinates
366 | """
367 | if p_from == p_to:
368 | return coords
369 | elif p_from == 'S' and p_to == 'C':
370 |
371 | beta = coords[..., 0]
372 | alpha = coords[..., 1]
373 | r = radius
374 |
375 | out = np.empty(beta.shape + (3,))
376 |
377 | ct = np.cos(beta)
378 | cp = np.cos(alpha)
379 | st = np.sin(beta)
380 | sp = np.sin(alpha)
381 | out[..., 0] = r * st * cp # x
382 | out[..., 1] = r * st * sp # y
383 | out[..., 2] = r * ct # z
384 | return out
385 |
386 | elif p_from == 'C' and p_to == 'S':
387 |
388 | x = coords[..., 0]
389 | y = coords[..., 1]
390 | z = coords[..., 2]
391 |
392 | out = np.empty(x.shape + (2,))
393 | out[..., 0] = np.arccos(z) # beta
394 | out[..., 1] = np.arctan2(y, x) # alpha
395 | return out
396 |
397 | else:
398 | raise ValueError('Unknown conversion:' + str(p_from) + ' to ' + str(p_to))
399 |
400 |
401 | def get_voxel_coordinate(radius, rad_n, azi_n, ele_n):
402 | grid = s2_grid(n_alpha=azi_n, n_beta=ele_n)
403 | pts_xyz_on_S2 = change_coordinates(grid, radius, 'S', 'C')
404 | pts_xyz_on_S2 = np.expand_dims(pts_xyz_on_S2, axis=0).repeat(rad_n, axis=0)
405 | scale = np.reshape(np.arange(rad_n) / rad_n + 1 / (2 * rad_n), [rad_n, 1, 1])
406 | pts_xyz = scale * pts_xyz_on_S2
407 | return pts_xyz
408 |
409 |
410 | def knn_query(pts, new_pts, knn):
411 | """
412 | :param pts: all points, [B. N. 3]
413 | :param new_pts: query points, [B, S. 3]
414 | :param knn: the number of queried points
415 | :return:
416 | """
417 | device = pts.device
418 | B, N, C = pts.shape
419 | _, S, _ = new_pts.shape
420 | group_idx = torch.arange(N).to(device).view(1, 1, N).repeat([B, S, 1])
421 | sqrdists = square_distance(new_pts, pts)
422 |
423 |
424 | def sphere_query(pts, new_pts, radius, nsample):
425 | """
426 | :param pts: all points, [B. N. 3]
427 | :param new_pts: query points, [B, S. 3]
428 | :param radius: local sperical radius
429 | :param nsample: max sample number in local sphere
430 | :return:
431 | """
432 |
433 | device = pts.device
434 | B, N, C = pts.shape
435 | _, S, _ = new_pts.shape
436 |
437 | pts = pts.contiguous()
438 | new_pts = new_pts.contiguous()
439 | group_idx = pnt2.ball_query(radius, nsample, pts, new_pts)
440 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, nsample)
441 | mask = (group_idx == mask).float()
442 | mask[:, :, 0] = 0
443 |
444 | # C implementation
445 | pts_trans = pts.transpose(1, 2).contiguous()
446 | new_points = pnt2.grouping_operation(
447 | pts_trans, group_idx
448 | ) # (B, 3, npoint, nsample)
449 | new_points = new_points.permute([0, 2, 3, 1])
450 |
451 | # replace the wrong points using new_pts
452 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3])
453 | # new_pts = new_pts.unsqueeze(2).repeat([1, 1, nsample + 1, 1])
454 | new_pts = new_pts.unsqueeze(2).repeat([1, 1, nsample, 1])
455 | n_points = new_points * (1 - mask).float() + new_pts * mask.float()
456 |
457 | del mask
458 | del new_points
459 | del group_idx
460 | del new_pts
461 | del pts
462 | del pts_trans
463 |
464 | return n_points
465 |
466 |
467 | def sphere_query_new(pts, new_pts, radius, nsample):
468 | """
469 | :param pts: all points, [B. N. 3]
470 | :param new_pts: query points, [B, S. 3]
471 | :param radius: local sperical radius
472 | :param nsample: max sample number in local sphere
473 | :return:
474 | """
475 |
476 | device = pts.device
477 | B, N, C = pts.shape
478 | _, S, _ = new_pts.shape
479 |
480 | pts = pts.contiguous()
481 | new_pts = new_pts.contiguous()
482 | group_idx = pnt2.ball_query(radius, nsample, pts, new_pts)
483 | mask = group_idx[:, :, 0].unsqueeze(2).repeat(1, 1, nsample)
484 | mask = (group_idx == mask).float()
485 | mask[:, :, 0] = 0
486 |
487 | mask1 = (group_idx[:, :, 0] == 0).unsqueeze(2).float()
488 | mask1 = torch.cat([mask1, torch.zeros_like(mask)[:, :, :-1]], dim=2)
489 | mask = mask + mask1
490 |
491 | # C implementation
492 | pts_trans = pts.transpose(1, 2).contiguous()
493 | new_points = pnt2.grouping_operation(
494 | pts_trans, group_idx
495 | ) # (B, 3, npoint, nsample)
496 | new_points = new_points.permute([0, 2, 3, 1])
497 |
498 | # replace the wrong points using new_pts
499 | mask = mask.unsqueeze(3).repeat([1, 1, 1, 3])
500 | n_points = new_points * (1 - mask).float()
501 |
502 | del mask
503 | del new_points
504 | del group_idx
505 | del new_pts
506 | del pts
507 | del pts_trans
508 |
509 | return n_points
510 |
511 |
512 | def var_to_invar(pts, rad_n, azi_n, ele_n):
513 | """
514 | :param pts: input points data, [B, N, nsample, 3]
515 | :param rad_n: radial number
516 | :param azi_n: azimuth number
517 | :param ele_n: elevator number
518 | :return:
519 | """
520 | device = pts.device
521 | B, N, nsample, C = pts.shape
522 | assert N == rad_n * azi_n * ele_n
523 | angle_step = np.array([0, 0, 2 * np.pi / azi_n])
524 | pts = pts.view(B, rad_n, ele_n, azi_n, nsample, C)
525 |
526 | R = np.zeros([azi_n, 3, 3])
527 | for i in range(azi_n):
528 | angle = -1 * i * angle_step
529 | r = angles2rotation_matrix(angle)
530 | R[i] = r
531 | R = torch.FloatTensor(R).to(device)
532 | R = R.view(1, 1, 1, azi_n, 3, 3).repeat(B, rad_n, ele_n, 1, 1, 1)
533 | new_pts = torch.matmul(pts, R.transpose(-1, -2))
534 |
535 | del R
536 | del pts
537 |
538 | return new_pts.view(B, -1, nsample, C)
539 |
540 |
541 | def cal_Z_axis(local_cor, local_weight=None, ref_point=None):
542 | device = local_cor.device
543 | B, N, _ = local_cor.shape
544 | cov_matrix = torch.matmul(local_cor.transpose(-1, -2), local_cor) if local_weight is None \
545 | else Variable(torch.matmul(local_cor.transpose(-1, -2), local_cor * local_weight), requires_grad=True)
546 | Z_axis = torch.symeig(cov_matrix, eigenvectors=True)[1][:, :, 0]
547 | mask = (torch.sum(-Z_axis * ref_point, dim=1) < 0).float().unsqueeze(1)
548 | Z_axis = Z_axis * (1 - mask) - Z_axis * mask
549 |
550 | return Z_axis
551 |
552 |
553 | def RodsRotatFormula(a, b):
554 | B, _ = a.shape
555 | device = a.device
556 | b = b.to(device)
557 | c = torch.cross(a, b)
558 | theta = torch.acos(F.cosine_similarity(a, b)).unsqueeze(1).unsqueeze(2)
559 |
560 | c = F.normalize(c, p=2, dim=1)
561 | one = torch.ones(B, 1, 1).to(device)
562 | zero = torch.zeros(B, 1, 1).to(device)
563 | a11 = zero
564 | a12 = -c[:, 2].unsqueeze(1).unsqueeze(2)
565 | a13 = c[:, 1].unsqueeze(1).unsqueeze(2)
566 | a21 = c[:, 2].unsqueeze(1).unsqueeze(2)
567 | a22 = zero
568 | a23 = -c[:, 0].unsqueeze(1).unsqueeze(2)
569 | a31 = -c[:, 1].unsqueeze(1).unsqueeze(2)
570 | a32 = c[:, 0].unsqueeze(1).unsqueeze(2)
571 | a33 = zero
572 | Rx = torch.cat(
573 | (torch.cat((a11, a12, a13), dim=2), torch.cat((a21, a22, a23), dim=2), torch.cat((a31, a32, a33), dim=2)),
574 | dim=1)
575 | I = torch.eye(3).to(device)
576 | R = I.unsqueeze(0).repeat(B, 1, 1) + torch.sin(theta) * Rx + (1 - torch.cos(theta)) * torch.matmul(Rx, Rx)
577 | return R.transpose(-1, -2)
578 |
579 |
580 | def rgbd_to_point_cloud(data_dir, ind, downsample=0.03, aligned=True):
581 | pcd = open3d.read_point_cloud(os.path.join(data_dir, f'{ind}.ply'))
582 | # downsample the point cloud
583 | if downsample != 0:
584 | pcd = open3d.voxel_down_sample(pcd, voxel_size=downsample)
585 | # align the point cloud
586 | if aligned is True:
587 | matrix = np.load(os.path.join(data_dir, f'{ind}.pose.npy'))
588 | pcd.transform(matrix)
589 |
590 | return pcd
591 |
592 |
593 | def cal_local_normal(pcd):
594 | if open3d.geometry.estimate_normals(pcd, open3d.KDTreeSearchParamKNN(knn=17)):
595 | return True
596 | else:
597 | print("Calculate Normal Error")
598 | return False
599 |
600 |
601 | def select_referenced_point(pcd, num_patches=2048):
602 | # A point sampling algorithm for 3d matching of irregular geometries.
603 | pts = np.asarray(pcd.points)
604 | num_points = pts.shape[0]
605 | inds = np.random.choice(range(num_points), num_patches, replace=False)
606 | return open3d.geometry.select_down_sample(pcd, inds)
607 |
608 |
609 | def collect_local_neighbor(ref_pcd, pcd, vicinity=0.3, num_points_per_patch=1024, random_state=None):
610 | # collect local neighbor within vicinity for each interest point.
611 | # each local patch is downsampled to 1024 (setting of PPFNet p5.)
612 | kdtree = open3d.geometry.KDTreeFlann(pcd)
613 | dict = []
614 | for point in ref_pcd.points:
615 | # Bug fix: here the first returned result will be itself. So the calculated ppf will be nan.
616 | [k, idx, variant] = kdtree.search_radius_vector_3d(point, vicinity)
617 | # random select fix number [num_points] of points to form the local patch.
618 | if random_state is not None:
619 | if k > num_points_per_patch:
620 | idx = random_state.choice(idx[1:], num_points_per_patch, replace=False)
621 | else:
622 | idx = random_state.choice(idx[1:], num_points_per_patch)
623 | else:
624 | if k > num_points_per_patch:
625 | idx = np.random.choice(idx[1:], num_points_per_patch, replace=False)
626 | else:
627 | idx = np.random.choice(idx[1:], num_points_per_patch)
628 | dict.append(idx)
629 | return dict
630 |
--------------------------------------------------------------------------------
/script/download.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd ../data
3 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_bm_1-brown_bm_1.zip
4 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_bm_4-brown_bm_4.zip
5 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_cogsci_1-brown_cogsci_1.zip
6 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_cs_2-brown_cs2.zip
7 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-brown_cs_3-brown_cs3.zip
8 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c3-hv_c3_1.zip
9 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c5-hv_c5_1.zip
10 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c6-hv_c6_1.zip
11 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c8-hv_c8_3.zip
12 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-harvard_c11-hv_c11_2.zip
13 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-home_at-home_at_scan1_2013_jan_1.zip
14 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika.zip
15 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-home_md-home_md_scan9_2012_sep_30.zip
16 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_nips2012-nips_4.zip
17 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_sf-scan1.zip
18 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_uc-scan3.zip
19 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_umd-maryland_hotel1.zip
20 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-hotel_umd-maryland_hotel3.zip
21 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_32_d507-d507_2.zip
22 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_46_ted_lab1-ted_lab_2.zip
23 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_76_417-76-417b.zip
24 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_76_studyroom-76-1studyroom2.zip
25 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika.zip
26 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika.zip
27 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika.zip
28 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-chess.zip
29 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-fire.zip
30 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-heads.zip
31 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-office.zip
32 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-pumpkin.zip
33 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-redkitchen.zip
34 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/7-scenes-stairs.zip
35 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_01.zip
36 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_02.zip
37 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_03.zip
38 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_04.zip
39 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_05.zip
40 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_06.zip
41 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_07.zip
42 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_08.zip
43 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_09.zip
44 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_10.zip
45 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_11.zip
46 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_12.zip
47 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_13.zip
48 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/rgbd-scenes-v2-scene_14.zip
49 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-apt0.zip
50 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-apt1.zip
51 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-apt2.zip
52 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-copyroom.zip
53 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office0.zip
54 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office1.zip
55 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office2.zip
56 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/bundlefusion-office3.zip
57 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt1-kitchen.zip
58 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt1-living.zip
59 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-bed.zip
60 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-kitchen.zip
61 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-living.zip
62 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-apt2-luke.zip
63 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-office2-5a.zip
64 | wget http://vision.princeton.edu/projects/2016/3DMatch/downloads/rgbd-datasets/analysis-by-synthesis-office2-5b.zip
65 |
--------------------------------------------------------------------------------
/script/fuse_fragments_3DMatch.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 |
4 | from pathlib import Path
5 | import argparse
6 | import math
7 | import numpy as np
8 | import os.path as osp
9 | import os
10 | import sys
11 |
12 | ROOT_DIR = osp.abspath('../')
13 | if ROOT_DIR not in sys.path:
14 | sys.path.append(ROOT_DIR)
15 |
16 | from script import io as uio
17 |
18 |
19 | # ---------------------------------------------------------------------------- #
20 | # Fuse rgbd frames into fragments in 3DMatch
21 | # - Use existing camera poses
22 | # - Save colors & normals
23 | # ---------------------------------------------------------------------------- #
24 | def read_intrinsic(filepath, width, height):
25 | import open3d as o3d
26 |
27 | m = np.loadtxt(filepath, dtype=np.float32)
28 | intrinsic = o3d.camera.PinholeCameraIntrinsic(width, height, m[0, 0], m[1, 1], m[0, 2], m[1, 2])
29 | return intrinsic
30 |
31 |
32 | def read_extrinsic(filepath):
33 | m = np.loadtxt(filepath, dtype=np.float32)
34 | if np.isnan(m).any():
35 | return None
36 | return m # (4, 4)
37 |
38 |
39 | def read_rgbd_image(cfg, color_file, depth_file, convert_rgb_to_intensity):
40 | import open3d as o3d
41 | if color_file is None:
42 | color_file = depth_file # to avoid "Unsupported image format."
43 | # rgbd_image = o3d.RGBDImage()
44 | # rgbd_image.depth = o3d.io.read_image(depth_file)
45 | # return rgbd_image
46 | color = o3d.io.read_image(color_file)
47 | depth = o3d.io.read_image(depth_file)
48 | rgbd_image = o3d.geometry.create_rgbd_image_from_color_and_depth(color, depth, cfg.depth_scale, cfg.depth_trunc,
49 | convert_rgb_to_intensity)
50 | return rgbd_image
51 |
52 |
53 | def process_single_fragment(cfg, color_files, depth_files, frag_id, n_frags, intrinsic_path, out_folder):
54 | import open3d as o3d
55 |
56 | depth_only_flag = (len(color_files) == 0)
57 | n_frames = len(depth_files)
58 | intrinsic = read_intrinsic(intrinsic_path, cfg.width, cfg.height)
59 | if depth_only_flag:
60 | color_type = o3d.integration.TSDFVolumeColorType.__dict__['None']
61 | else:
62 | color_type = o3d.integration.TSDFVolumeColorType.__dict__['RGB8']
63 |
64 | volume = o3d.integration.ScalableTSDFVolume(voxel_length=cfg.tsdf_cubic_size / 512.0,
65 | sdf_trunc=0.04,
66 | color_type=color_type)
67 |
68 | sid = frag_id * cfg.frames_per_frag
69 | eid = min(sid + cfg.frames_per_frag, n_frames)
70 | pose_base2world = None
71 | pose_base2world_inv = None
72 | for fid in range(sid, eid):
73 | if not depth_only_flag:
74 | color_path = color_files[fid]
75 | else:
76 | color_path = None
77 | depth_path = depth_files[fid]
78 | pose_path = depth_path[:-10] + '.pose.txt'
79 |
80 | pose_cam2world = read_extrinsic(pose_path)
81 | if pose_cam2world is None:
82 | continue
83 | if fid == sid: # Use as base frame
84 | pose_base2world = pose_cam2world
85 | pose_base2world_inv = np.linalg.inv(pose_base2world)
86 | if pose_base2world_inv is None:
87 | break
88 | # Relative camera pose
89 | pose_cam2world = np.matmul(pose_base2world_inv, pose_cam2world)
90 |
91 | rgbd = read_rgbd_image(cfg, color_path, depth_path, False)
92 | volume.integrate(rgbd, intrinsic, np.linalg.inv(pose_cam2world))
93 | if pose_base2world_inv is None:
94 | return
95 |
96 | pcloud = volume.extract_point_cloud()
97 | o3d.geometry.estimate_normals(pcloud)
98 | o3d.write_point_cloud(osp.join(out_folder, 'cloud_bin_{}.ply'.format(frag_id)), pcloud)
99 |
100 | np.save(osp.join(out_folder, 'cloud_bin_{}.pose.npy'.format(frag_id)), pose_base2world)
101 |
102 |
103 | # ---------------------------------------------------------------------------- #
104 | # Iterate Folders
105 | # ---------------------------------------------------------------------------- #
106 | def run_seq(cfg, scene, seq):
107 | print(" Start {}".format(seq))
108 |
109 | seq_folder = osp.join(cfg.dataset_root, scene, seq)
110 | color_names = uio.list_files(seq_folder, '*.color.png')
111 | color_paths = [osp.join(seq_folder, cf) for cf in color_names]
112 | depth_names = uio.list_files(seq_folder, '*.depth.png')
113 | depth_paths = [osp.join(seq_folder, df) for df in depth_names]
114 | # depth_paths = [osp.join(seq_folder, cf[:-10] + '.depth.png') for cf in depth_names]
115 |
116 | # n_frames = len(color_paths)
117 | n_frames = len(depth_paths)
118 | n_frags = int(math.ceil(float(n_frames) / cfg.frames_per_frag))
119 |
120 | out_folder = osp.join(cfg.out_root, scene, seq)
121 | uio.may_create_folder(out_folder)
122 |
123 | intrinsic_path = osp.join(cfg.dataset_root, scene, 'camera-intrinsics.txt')
124 |
125 | if cfg.threads > 1:
126 | from joblib import Parallel, delayed
127 | import multiprocessing
128 |
129 | Parallel(n_jobs=cfg.threads)(
130 | delayed(process_single_fragment)(cfg, color_paths, depth_paths, frag_id, n_frags, intrinsic_path,
131 | out_folder)
132 | for frag_id in range(n_frags))
133 |
134 | else:
135 | for frag_id in range(n_frags):
136 | process_single_fragment(cfg, color_paths, depth_paths, frag_id, n_frags, intrinsic_path, out_folder)
137 |
138 | print(" Finished {}".format(seq))
139 |
140 |
141 | def run_scene(cfg, scene):
142 | print(" Start scene {} ".format(scene))
143 |
144 | scene_folder = osp.join(cfg.dataset_root, scene)
145 | seqs = uio.list_folders(scene_folder)
146 | print(" {} sequences".format(len(seqs)))
147 | for seq in seqs:
148 | run_seq(cfg, scene, seq)
149 |
150 | print(" Finished scene {} ".format(scene))
151 |
152 |
153 | def run(cfg):
154 | print("Start making fragments")
155 |
156 | uio.may_create_folder(cfg.out_root)
157 |
158 | scenes = uio.list_folders(cfg.dataset_root, sort=False)
159 | print("{} scenes".format(len(scenes)))
160 | for scene in scenes:
161 | # if not scene.startswith('analysis'):
162 | # continue
163 | run_scene(cfg, scene)
164 |
165 | print("Finished making fragments")
166 |
167 |
168 | # ---------------------------------------------------------------------------- #
169 | # Arguments
170 | # ---------------------------------------------------------------------------- #
171 | def parse_args():
172 | parser = argparse.ArgumentParser()
173 | parser.add_argument('--dataset_root', default='../../data/3DMatch_raw/')
174 | parser.add_argument('--out_root', default='../../data/3DMatch_fragments/')
175 | parser.add_argument('--depth_scale', type=float, default=1000.0)
176 | parser.add_argument('--depth_trunc', type=float, default=6.0)
177 | parser.add_argument('--frames_per_frag', type=int, default=50)
178 | parser.add_argument('--height', type=int, default=480)
179 | parser.add_argument('--threads', type=int, default=1)
180 | parser.add_argument('--tsdf_cubic_size', type=float, default=3.0)
181 | parser.add_argument('--width', type=int, default=640)
182 |
183 | return parser.parse_args()
184 |
185 |
186 | if __name__ == '__main__':
187 | cfg = parse_args()
188 | run(cfg)
189 |
--------------------------------------------------------------------------------
/script/io.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | from collections import defaultdict
5 | from pathlib import Path
6 | import cv2
7 | import json
8 | import numpy as np
9 | import os
10 | import os.path as osp
11 | import re
12 | import shutil
13 |
14 |
15 | def is_number(s):
16 | try:
17 | float(s)
18 | return True
19 | except ValueError:
20 | return False
21 |
22 |
23 | # ---------------------------------------------------------------------------- #
24 | # Common IO
25 | # ---------------------------------------------------------------------------- #
26 | def may_create_folder(folder_path):
27 | if not osp.exists(folder_path):
28 | oldmask = os.umask(000)
29 | os.makedirs(folder_path, mode=0o777)
30 | os.umask(oldmask)
31 | return True
32 | return False
33 |
34 |
35 | def make_clean_folder(folder_path):
36 | success = may_create_folder(folder_path)
37 | if not success:
38 | shutil.rmtree(folder_path)
39 | may_create_folder(folder_path)
40 |
41 |
42 | def sorted_alphanum(file_list_ordered):
43 | convert = lambda text: int(text) if text.isdigit() else text
44 | alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key) if len(c) > 0]
45 | return sorted(file_list_ordered, key=alphanum_key)
46 |
47 |
48 | def list_files(folder_path, name_filter, sort=True):
49 | file_list = [p.name for p in list(Path(folder_path).glob(name_filter))]
50 | if sort:
51 | return sorted_alphanum(file_list)
52 | else:
53 | return file_list
54 |
55 |
56 | def list_folders(folder_path, name_filter=None, sort=True):
57 | folders = list()
58 | for subfolder in Path(folder_path).iterdir():
59 | if subfolder.is_dir() and not subfolder.name.startswith('.'):
60 | folder_name = subfolder.name
61 | if name_filter is not None:
62 | if name_filter in folder_name:
63 | folders.append(folder_name)
64 | else:
65 | folders.append(folder_name)
66 | if sort:
67 | return sorted_alphanum(folders)
68 | else:
69 | return folders
70 |
71 |
72 | def read_lines(file_path):
73 | """
74 | :param file_path:
75 | :return:
76 | """
77 | with open(file_path, 'r') as fin:
78 | lines = [line.strip() for line in fin.readlines() if len(line.strip()) > 0]
79 | return lines
80 |
81 |
82 | def read_json(filepath):
83 | with open(filepath, 'r') as fh:
84 | ret = json.load(fh)
85 | return ret
86 |
87 |
88 | # ---------------------------------------------------------------------------- #
89 | # Image IO
90 | # ---------------------------------------------------------------------------- #
91 | def read_color_image(file_path):
92 | """
93 | Args:
94 | file_path (str):
95 |
96 | Returns:
97 | np.array: RGB.
98 | """
99 | img = cv2.imread(file_path)
100 | return img[..., ::-1]
101 |
102 |
103 | def read_gray_image(file_path):
104 | """Load a gray image
105 |
106 | Args:
107 | file_path (str):
108 |
109 | Returns:
110 | np.array: np.uint8, max 255.
111 | """
112 | img = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
113 | return img
114 |
115 |
116 | def read_16bit_image(file_path):
117 | """Load a 16bit image
118 |
119 | Args:
120 | file_path (str):
121 |
122 | Returns:
123 | np.array: np.uint16, max 65535.
124 | """
125 | img = cv2.imread(file_path, cv2.IMREAD_UNCHANGED)
126 | return img
127 |
128 |
129 | def write_color_image(file_path, image):
130 | """
131 | Args:
132 | file_path (str):
133 | image (np.array): in RGB.
134 |
135 | Returns:
136 | str:
137 | """
138 | cv2.imwrite(file_path, image[..., ::-1])
139 | return file_path
140 |
141 |
142 | def write_gray_image(file_path, image):
143 | """
144 | Args:
145 | file_path (str):
146 | image (np.array):
147 |
148 | Returns:
149 | str:
150 | """
151 | cv2.imwrite(file_path, image)
152 | return file_path
153 |
154 |
155 | def write_image(file_path, image):
156 | """
157 | Args:
158 | file_path (str):
159 | image (np.array):
160 |
161 | Returns:
162 | str:
163 | """
164 | if image.ndim == 2:
165 | return write_gray_image(file_path, image)
166 | elif image.ndim == 3:
167 | return write_color_image(file_path, image)
168 | else:
169 | raise RuntimeError('Image dimensions are not correct!')
170 |
--------------------------------------------------------------------------------