├── lib ├── __init__.py ├── transforms.py ├── metrics.py ├── eval.py ├── timer.py ├── data_loaders.py └── trainer.py ├── util ├── __init__.py ├── trajectory.py ├── visualization.py ├── file.py ├── misc.py ├── transform_estimation.py └── pointcloud.py ├── config ├── val_kitti.txt ├── test_kitti.txt ├── train_kitti.txt ├── val_3dmatch.txt ├── test_3dmatch.txt └── train_3dmatch.txt ├── assets ├── 1.png ├── 2.png ├── 3.png ├── 3_1.png ├── 3_2.png ├── 4.png ├── demo.png ├── table.png ├── fps_acc.png ├── kitchen_0.png ├── kitchen_1.png ├── text_scene000.gif ├── fig4_dist_thresh.txt └── fig4_inlier_thresh.txt ├── .style.yapf ├── .gitignore ├── model ├── common.py ├── __init__.py ├── residual_block.py ├── resunet.py └── simpleunet.py ├── requirements.txt ├── LICENSE ├── scripts ├── download_datasets.sh ├── download_3dmatch_test.sh ├── train_fcgf_kitti.sh ├── benchmark_util.py ├── test_kitti.py └── benchmark_3dmatch.py ├── benchmark.py ├── demo.py ├── train.py ├── config.py └── README.md /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/val_kitti.txt: -------------------------------------------------------------------------------- 1 | 6 2 | 7 3 | -------------------------------------------------------------------------------- /config/test_kitti.txt: -------------------------------------------------------------------------------- 1 | 8 2 | 9 3 | 10 4 | -------------------------------------------------------------------------------- /config/train_kitti.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 6 | 5 7 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/1.png -------------------------------------------------------------------------------- /assets/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/2.png -------------------------------------------------------------------------------- /assets/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/3.png -------------------------------------------------------------------------------- /assets/3_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/3_1.png -------------------------------------------------------------------------------- /assets/3_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/3_2.png -------------------------------------------------------------------------------- /assets/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/4.png -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/demo.png -------------------------------------------------------------------------------- /assets/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/table.png -------------------------------------------------------------------------------- /assets/fps_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/fps_acc.png -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = yapf 3 | column_limit = 88 4 | indent_width = 2 5 | -------------------------------------------------------------------------------- /assets/kitchen_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/kitchen_0.png -------------------------------------------------------------------------------- /assets/kitchen_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/kitchen_1.png -------------------------------------------------------------------------------- /assets/text_scene000.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chrischoy/FCGF/HEAD/assets/text_scene000.gif -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | __pycache__ 3 | *.swp 4 | *.swo 5 | *.orig 6 | .idea 7 | outputs/ 8 | *.pyc 9 | *.ply 10 | *.pth -------------------------------------------------------------------------------- /config/val_3dmatch.txt: -------------------------------------------------------------------------------- 1 | sun3d-brown_bm_4-brown_bm_4 2 | sun3d-harvard_c11-hv_c11_2 3 | 7-scenes-heads 4 | rgbd-scenes-v2-scene_10 5 | bundlefusion-office0 6 | analysis-by-synthesis-apt2-kitchen 7 | -------------------------------------------------------------------------------- /config/test_3dmatch.txt: -------------------------------------------------------------------------------- 1 | 7-scenes-redkitchen 2 | sun3d-home_at-home_at_scan1_2013_jan_1 3 | sun3d-home_md-home_md_scan9_2012_sep_30 4 | sun3d-hotel_uc-scan3 5 | sun3d-hotel_umd-maryland_hotel1 6 | sun3d-hotel_umd-maryland_hotel3 7 | sun3d-mit_76_studyroom-76-1studyroom2 8 | sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika -------------------------------------------------------------------------------- /model/common.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | 3 | 4 | def get_norm(norm_type, num_feats, bn_momentum=0.05, D=-1): 5 | if norm_type == 'BN': 6 | return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) 7 | elif norm_type == 'IN': 8 | return ME.MinkowskiInstanceNorm(num_feats, dimension=D) 9 | else: 10 | raise ValueError(f'Type {norm_type}, not defined') 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | # pytorch # for anaconda, please refer to pytorch.org for installation 3 | scipy 4 | matplotlib 5 | open3d 6 | # to visualize it, you need tensorflow, but it doesn't have to be in the same virual environment :) 7 | tensorboardX 8 | MinkowskiEngine 9 | # Or follow the installation instruction on github.com/StanfordVL/MinkowskiEngine 10 | future-fstrings 11 | easydict 12 | joblib 13 | scikit-learn 14 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import model.simpleunet as simpleunets 3 | import model.resunet as resunets 4 | 5 | MODELS = [] 6 | 7 | 8 | def add_models(module): 9 | MODELS.extend([getattr(module, a) for a in dir(module) if 'Net' in a or 'MLP' in a]) 10 | 11 | 12 | add_models(simpleunets) 13 | add_models(resunets) 14 | 15 | 16 | def load_model(name): 17 | '''Creates and returns an instance of the model given its class name. 18 | ''' 19 | # Find the model class from its name 20 | all_models = MODELS 21 | mdict = {model.__name__: model for model in all_models} 22 | if name not in mdict: 23 | logging.info(f'Invalid model index. You put {name}. Options are:') 24 | # Display a list of valid model names 25 | for model in all_models: 26 | logging.info('\t* {}'.format(model.__name__)) 27 | return None 28 | NetClass = mdict[name] 29 | 30 | return NetClass 31 | -------------------------------------------------------------------------------- /lib/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import torch 5 | 6 | 7 | class Compose: 8 | 9 | def __init__(self, transforms): 10 | self.transforms = transforms 11 | 12 | def __call__(self, coords, feats): 13 | for transform in self.transforms: 14 | coords, feats = transform(coords, feats) 15 | return coords, feats 16 | 17 | 18 | class Jitter: 19 | 20 | def __init__(self, mu=0, sigma=0.01): 21 | self.mu = mu 22 | self.sigma = sigma 23 | 24 | def __call__(self, coords, feats): 25 | if random.random() < 0.95: 26 | if isinstance(feats, np.ndarray): 27 | feats += np.random.normal(self.mu, self.sigma, (feats.shape[0], feats.shape[1])) 28 | else: 29 | feats += (torch.randn_like(feats) * self.sigma) + self.mu 30 | return coords, feats 31 | 32 | 33 | class ChromaticShift: 34 | 35 | def __init__(self, mu=0, sigma=0.1): 36 | self.mu = mu 37 | self.sigma = sigma 38 | 39 | def __call__(self, coords, feats): 40 | if random.random() < 0.95: 41 | feats[:, :3] += np.random.normal(self.mu, self.sigma, (1, 3)) 42 | return coords, feats 43 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.functional as F 5 | 6 | 7 | def eval_metrics(output, target): 8 | output = (F.sigmoid(output) > 0.5).cpu().data.numpy() 9 | target = target.cpu().data.numpy() 10 | return np.linalg.norm(output - target) 11 | 12 | 13 | def corr_dist(est, gth, xyz0, xyz1, weight=None, max_dist=1): 14 | xyz0_est = xyz0 @ est[:3, :3].t() + est[:3, 3] 15 | xyz0_gth = xyz0 @ gth[:3, :3].t() + gth[:3, 3] 16 | dists = torch.clamp(torch.sqrt(((xyz0_est - xyz0_gth).pow(2)).sum(1)), max=max_dist) 17 | if weight is not None: 18 | dists = weight * dists 19 | return dists.mean() 20 | 21 | 22 | def pdist(A, B, dist_type='L2'): 23 | if dist_type == 'L2': 24 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 25 | return torch.sqrt(D2 + 1e-7) 26 | elif dist_type == 'SquareL2': 27 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 28 | else: 29 | raise NotImplementedError('Not implemented') 30 | 31 | 32 | def get_loss_fn(loss): 33 | if loss == 'corr_dist': 34 | return corr_dist 35 | else: 36 | raise ValueError(f'Loss {loss}, not defined') 37 | -------------------------------------------------------------------------------- /util/trajectory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | 5 | class CameraPose: 6 | 7 | def __init__(self, meta, mat): 8 | self.metadata = meta 9 | self.pose = mat 10 | 11 | def __str__(self): 12 | return 'metadata : ' + ' '.join(map(str, self.metadata)) + '\n' + \ 13 | "pose : " + "\n" + np.array_str(self.pose) 14 | 15 | 16 | def read_trajectory(filename, dim=4): 17 | traj = [] 18 | assert os.path.exists(filename) 19 | with open(filename, 'r') as f: 20 | metastr = f.readline() 21 | while metastr: 22 | metadata = list(map(int, metastr.split())) 23 | mat = np.zeros(shape=(dim, dim)) 24 | for i in range(dim): 25 | matstr = f.readline() 26 | mat[i, :] = np.fromstring(matstr, dtype=float, sep=' \t') 27 | traj.append(CameraPose(metadata, mat)) 28 | metastr = f.readline() 29 | return traj 30 | 31 | 32 | def write_trajectory(traj, filename, dim=4): 33 | with open(filename, 'w') as f: 34 | for x in traj: 35 | p = x.pose.tolist() 36 | f.write(' '.join(map(str, x.metadata)) + '\n') 37 | f.write('\n'.join(' '.join(map('{0:.12f}'.format, p[i])) for i in range(dim))) 38 | f.write('\n') 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Chris Choy (chrischoy@ai.stanford.edu), Jaesik Park (jaesik.park@postech.ac.kr) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/download_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA_DIR=$1 4 | 5 | function download() { 6 | TMP_PATH="$DATA_DIR/tmp" 7 | echo "#################################" 8 | echo "Data Root Dir: ${DATA_DIR}" 9 | echo "Download Path: ${TMP_PATH}" 10 | echo "#################################" 11 | urls=( 12 | 'http://node2.chrischoy.org/data/datasets/registration/threedmatch.tgz' 13 | ) 14 | 15 | if [ ! -d "$TMP_PATH" ]; then 16 | echo ">> Create temporary directory: ${TMP_PATH}" 17 | mkdir -p "$TMP_PATH" 18 | fi 19 | cd "$TMP_PATH" 20 | 21 | echo ">> Start downloading" 22 | echo ${urls[@]} | xargs -n 1 -P 3 wget --no-check-certificate -q -c --show-progress $0 23 | 24 | echo ">> Unpack .zip file" 25 | for filename in *.tgz 26 | do 27 | tar -xvzf $filename -C ../ 28 | done 29 | 30 | echo ">> Clear tmp directory" 31 | cd .. 32 | rm -rf ./tmp 33 | 34 | echo "#################################" 35 | echo "Done!" 36 | echo "#################################" 37 | } 38 | 39 | function main() { 40 | echo $DATA_DIR 41 | if [ -z "$DATA_DIR" ]; then 42 | echo "DATA_DIR is required config!" 43 | else 44 | download 45 | fi 46 | } 47 | 48 | main; 49 | -------------------------------------------------------------------------------- /util/visualization.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import open3d as o3d 3 | import numpy as np 4 | 5 | from sklearn.manifold import TSNE 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | def get_color_map(x): 10 | colours = plt.cm.Spectral(x) 11 | return colours[:, :3] 12 | 13 | 14 | def mesh_sphere(pcd, voxel_size, sphere_size=0.6): 15 | # Create a mesh sphere 16 | spheres = o3d.geometry.TriangleMesh() 17 | s = o3d.geometry.TriangleMesh.create_sphere(radius=voxel_size * sphere_size) 18 | s.compute_vertex_normals() 19 | 20 | for i, p in enumerate(pcd.points): 21 | si = copy.deepcopy(s) 22 | trans = np.identity(4) 23 | trans[:3, 3] = p 24 | si.transform(trans) 25 | si.paint_uniform_color(pcd.colors[i]) 26 | spheres += si 27 | return spheres 28 | 29 | 30 | def get_colored_point_cloud_feature(pcd, feature, voxel_size): 31 | tsne_results = embed_tsne(feature) 32 | 33 | color = get_color_map(tsne_results) 34 | pcd.colors = o3d.utility.Vector3dVector(color) 35 | spheres = mesh_sphere(pcd, voxel_size) 36 | 37 | return spheres 38 | 39 | 40 | def embed_tsne(data): 41 | """ 42 | N x D np.array data 43 | """ 44 | tsne = TSNE(n_components=1, verbose=1, perplexity=40, n_iter=300, random_state=0) 45 | tsne_results = tsne.fit_transform(data) 46 | tsne_results = np.squeeze(tsne_results) 47 | tsne_min = np.min(tsne_results) 48 | tsne_max = np.max(tsne_results) 49 | return (tsne_results - tsne_min) / (tsne_max - tsne_min) 50 | -------------------------------------------------------------------------------- /config/train_3dmatch.txt: -------------------------------------------------------------------------------- 1 | sun3d-brown_bm_1-brown_bm_1 2 | sun3d-brown_cogsci_1-brown_cogsci_1 3 | sun3d-brown_cs_2-brown_cs2 4 | sun3d-brown_cs_3-brown_cs3 5 | sun3d-harvard_c3-hv_c3_1 6 | sun3d-harvard_c5-hv_c5_1 7 | sun3d-harvard_c6-hv_c6_1 8 | sun3d-harvard_c8-hv_c8_3 9 | sun3d-home_bksh-home_bksh_oct_30_2012_scan2_erika 10 | sun3d-hotel_nips2012-nips_4 11 | sun3d-hotel_sf-scan1 12 | sun3d-mit_32_d507-d507_2 13 | sun3d-mit_46_ted_lab1-ted_lab_2 14 | sun3d-mit_76_417-76-417b 15 | sun3d-mit_dorm_next_sj-dorm_next_sj_oct_30_2012_scan1_erika 16 | sun3d-mit_w20_athena-sc_athena_oct_29_2012_scan1_erika 17 | 7-scenes-chess 18 | 7-scenes-fire 19 | 7-scenes-office 20 | 7-scenes-pumpkin 21 | 7-scenes-stairs 22 | rgbd-scenes-v2-scene_01 23 | rgbd-scenes-v2-scene_02 24 | rgbd-scenes-v2-scene_03 25 | rgbd-scenes-v2-scene_04 26 | rgbd-scenes-v2-scene_05 27 | rgbd-scenes-v2-scene_06 28 | rgbd-scenes-v2-scene_07 29 | rgbd-scenes-v2-scene_08 30 | rgbd-scenes-v2-scene_09 31 | rgbd-scenes-v2-scene_11 32 | rgbd-scenes-v2-scene_12 33 | rgbd-scenes-v2-scene_13 34 | rgbd-scenes-v2-scene_14 35 | bundlefusion-apt0 36 | bundlefusion-apt1 37 | bundlefusion-apt2 38 | bundlefusion-copyroom 39 | bundlefusion-office1 40 | bundlefusion-office2 41 | bundlefusion-office3 42 | analysis-by-synthesis-apt1-kitchen 43 | analysis-by-synthesis-apt1-living 44 | analysis-by-synthesis-apt2-bed 45 | analysis-by-synthesis-apt2-living 46 | analysis-by-synthesis-apt2-luke 47 | analysis-by-synthesis-office2-5a 48 | analysis-by-synthesis-office2-5b 49 | -------------------------------------------------------------------------------- /lib/eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | from lib.metrics import pdist 6 | from scipy.spatial import cKDTree 7 | 8 | 9 | def find_nn_cpu(feat0, feat1, return_distance=False): 10 | feat1tree = cKDTree(feat1) 11 | dists, nn_inds = feat1tree.query(feat0, k=1, n_jobs=-1) 12 | if return_distance: 13 | return nn_inds, dists 14 | else: 15 | return nn_inds 16 | 17 | 18 | def find_nn_gpu(F0, F1, nn_max_n=-1, return_distance=False, dist_type='SquareL2'): 19 | # Too much memory if F0 or F1 large. Divide the F0 20 | if nn_max_n > 1: 21 | N = len(F0) 22 | C = int(np.ceil(N / nn_max_n)) 23 | stride = nn_max_n 24 | dists, inds = [], [] 25 | for i in range(C): 26 | dist = pdist(F0[i * stride:(i + 1) * stride], F1, dist_type=dist_type) 27 | min_dist, ind = dist.min(dim=1) 28 | dists.append(min_dist.detach().unsqueeze(1).cpu()) 29 | inds.append(ind.cpu()) 30 | 31 | if C * stride < N: 32 | dist = pdist(F0[C * stride:], F1, dist_type=dist_type) 33 | min_dist, ind = dist.min(dim=1) 34 | dists.append(min_dist.detach().unsqueeze(1).cpu()) 35 | inds.append(ind.cpu()) 36 | 37 | dists = torch.cat(dists) 38 | inds = torch.cat(inds) 39 | assert len(inds) == N 40 | else: 41 | dist = pdist(F0, F1, dist_type=dist_type) 42 | min_dist, inds = dist.min(dim=1) 43 | dists = min_dist.detach().unsqueeze(1).cpu() 44 | inds = inds.cpu() 45 | if return_distance: 46 | return inds, dists 47 | else: 48 | return inds 49 | -------------------------------------------------------------------------------- /util/file.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from os import listdir 4 | from os.path import isfile, isdir, join, splitext 5 | 6 | 7 | def read_txt(path): 8 | """Read txt file into lines. 9 | """ 10 | with open(path) as f: 11 | lines = f.readlines() 12 | lines = [x.strip() for x in lines] 13 | return lines 14 | 15 | 16 | def ensure_dir(path): 17 | if not os.path.exists(path): 18 | os.makedirs(path, mode=0o755) 19 | 20 | 21 | def sorted_alphanum(file_list_ordered): 22 | 23 | def convert(text): 24 | return int(text) if text.isdigit() else text 25 | 26 | def alphanum_key(key): 27 | return [convert(c) for c in re.split('([0-9]+)', key)] 28 | 29 | return sorted(file_list_ordered, key=alphanum_key) 30 | 31 | 32 | def get_file_list(path, extension=None): 33 | if extension is None: 34 | file_list = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 35 | else: 36 | file_list = [ 37 | join(path, f) 38 | for f in listdir(path) 39 | if isfile(join(path, f)) and splitext(f)[1] == extension 40 | ] 41 | file_list = sorted_alphanum(file_list) 42 | return file_list 43 | 44 | 45 | def get_file_list_specific(path, color_depth, extension=None): 46 | if extension is None: 47 | file_list = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 48 | else: 49 | file_list = [ 50 | join(path, f) 51 | for f in listdir(path) 52 | if isfile(join(path, f)) and color_depth in f and splitext(f)[1] == extension 53 | ] 54 | file_list = sorted_alphanum(file_list) 55 | return file_list 56 | 57 | 58 | def get_folder_list(path): 59 | folder_list = [join(path, f) for f in listdir(path) if isdir(join(path, f))] 60 | folder_list = sorted_alphanum(folder_list) 61 | return folder_list 62 | -------------------------------------------------------------------------------- /lib/timer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0.0 15 | self.sq_sum = 0.0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count 23 | self.sq_sum += val**2 * n 24 | self.var = self.sq_sum / self.count - self.avg**2 25 | 26 | 27 | class Timer(object): 28 | """A simple timer.""" 29 | 30 | def __init__(self, binary_fn=None, init_val=0): 31 | self.total_time = 0. 32 | self.calls = 0 33 | self.start_time = 0. 34 | self.diff = 0. 35 | self.binary_fn = binary_fn 36 | self.tmp = init_val 37 | 38 | def reset(self): 39 | self.total_time = 0 40 | self.calls = 0 41 | self.start_time = 0 42 | self.diff = 0 43 | 44 | @property 45 | def avg(self): 46 | return self.total_time / self.calls 47 | 48 | def tic(self): 49 | # using time.time instead of time.clock because time time.clock 50 | # does not normalize for multithreading 51 | self.start_time = time.time() 52 | 53 | def toc(self, average=True): 54 | self.diff = time.time() - self.start_time 55 | self.total_time += self.diff 56 | self.calls += 1 57 | if self.binary_fn: 58 | self.tmp = self.binary_fn(self.tmp, self.diff) 59 | if average: 60 | return self.avg 61 | else: 62 | return self.diff 63 | 64 | 65 | class MinTimer(Timer): 66 | 67 | def __init__(self): 68 | Timer.__init__(self, binary_fn=lambda x, y: min(x, y), init_val=math.inf) 69 | 70 | @property 71 | def min(self): 72 | return self.tmp 73 | -------------------------------------------------------------------------------- /assets/fig4_dist_thresh.txt: -------------------------------------------------------------------------------- 1 | method 0.0 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 0.09 0.1 0.11 0.12 0.13 0.14 0.15 0.16 0.17 0.18 0.19 0.2 2 | FPFH 0.0 0.32518435 3.0640497 7.408897 12.879427 17.654156 21.894594 24.814766 28.676664 32.31617 35.307827 38.895893 41.363243 44.120644 47.001804 49.11531 52.33542 55.02593 56.853123 59.70479 61.781036 3 | SpinImages 0.0 0.2650882 1.4736623 3.4342945 5.703008 8.022694 9.878324 12.34205 15.342176 18.735947 21.89705 24.729364 28.054333 31.125975 33.943203 36.965153 40.642887 43.33119 46.57485 49.260128 52.152588 4 | SHOT 0.0 0.47008017 2.6745644 5.1028447 7.4799767 9.238545 11.927085 16.089037 18.879076 22.094217 24.974531 27.351418 29.76674 32.573856 34.916573 37.086315 40.306644 43.138355 45.73636 48.786892 50.972153 5 | 3DMatch 0.0 0.34988788 1.735477 6.3130245 13.2243185 20.68582 28.49041 34.01513 40.062992 45.687683 52.00012 55.463696 59.081406 61.986275 65.45967 68.43626 70.95725 73.55156 76.07139 77.85846 80.27285 6 | CGF 0.0 0.024703559 2.1318865 6.50821 13.123867 20.129152 28.38341 32.935925 38.11737 44.04901 47.796814 52.474613 56.87633 59.58374 62.497223 65.93624 68.82815 71.535965 74.28786 75.93417 77.80694 7 | PPFNet 0.0 0.0 0.060096156 0.86937517 3.1743605 8.557394 19.648903 32.87094 44.737736 54.218216 62.422962 68.653786 74.78312 79.51771 82.49317 84.438225 86.81797 89.20367 90.20915 91.9664 92.85675 8 | PPF-FoldNet 0.0 0.6783977 7.9283724 22.183943 34.720436 44.628433 51.92579 57.17631 61.85746 66.01428 69.106926 71.37115 73.30209 74.196884 75.28364 76.539635 77.94661 79.70165 80.93724 81.7935 82.764626 9 | Ours 0.0 0.4708969187453404 21.502159936042897 50.05412810643319 70.68030594316472 81.07975133647629 87.4676806078735 91.11393845764234 93.41293607558425 94.42439530203379 95.5016728687582 96.0080013333998 96.71576097308037 97.12192918131171 97.2141445151143 97.87661267544205 98.02974089028855 98.32543524018643 98.36824345936452 98.85109394221501 99.07352775839882 10 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import open3d as o3d 5 | from urllib.request import urlretrieve 6 | from model.resunet import ResUNetBN2C 7 | from lib.timer import MinTimer 8 | 9 | import torch 10 | import MinkowskiEngine as ME 11 | 12 | if not os.path.isfile('redkitchen-20.ply'): 13 | print('Downloading a mesh...') 14 | urlretrieve("https://node1.chrischoy.org/data/publications/fcgf/redkitchen-20.ply", 15 | 'redkitchen-20.ply') 16 | 17 | 18 | def benchmark(config): 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | model = ResUNetBN2C(1, 16, normalize_feature=True, conv1_kernel_size=3, D=3) 22 | model.eval() 23 | model = model.to(device) 24 | 25 | pcd = o3d.io.read_point_cloud(config.input) 26 | coords = ME.utils.batched_coordinates( 27 | [torch.from_numpy(np.array(pcd.points)) / config.voxel_size]) 28 | feats = torch.from_numpy(np.ones((len(coords), 1))).float() 29 | 30 | with torch.no_grad(): 31 | t = MinTimer() 32 | for i in range(100): 33 | # initialization time includes copy to GPU 34 | t.tic() 35 | sinput = ME.SparseTensor( 36 | feats, 37 | coords, 38 | minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, 39 | # minkowski_algorithm=ME.MinkowskiAlgorithm.MEMORY_EFFICIENT, 40 | device=device) 41 | model(sinput) 42 | t.toc() 43 | print(t.min) 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument( 49 | '-i', 50 | '--input', 51 | default='redkitchen-20.ply', 52 | type=str, 53 | help='path to a pointcloud file') 54 | parser.add_argument( 55 | '--voxel_size', 56 | default=0.05, 57 | type=float, 58 | help='voxel size to preprocess point cloud') 59 | 60 | config = parser.parse_args() 61 | benchmark(config) 62 | -------------------------------------------------------------------------------- /scripts/download_3dmatch_test.sh: -------------------------------------------------------------------------------- 1 | export DATA_DIR=$1 2 | export BASE_PATH=http://vision.princeton.edu/projects/2016/3DMatch/downloads/scene-fragments 3 | 4 | function download() { 5 | TMP_PATH="$DATA_DIR/tmp" 6 | echo "#################################" 7 | echo "Data Root Dir: ${DATA_DIR}" 8 | echo "Download Path: ${TMP_PATH}" 9 | echo "#################################" 10 | urls=( 11 | '7-scenes-redkitchen' 12 | '7-scenes-redkitchen-evaluation' 13 | 'sun3d-home_at-home_at_scan1_2013_jan_1' 14 | 'sun3d-home_at-home_at_scan1_2013_jan_1-evaluation' 15 | 'sun3d-home_md-home_md_scan9_2012_sep_30' 16 | 'sun3d-home_md-home_md_scan9_2012_sep_30-evaluation' 17 | 'sun3d-hotel_uc-scan3' 18 | 'sun3d-hotel_uc-scan3-evaluation' 19 | 'sun3d-hotel_umd-maryland_hotel1' 20 | 'sun3d-hotel_umd-maryland_hotel1-evaluation' 21 | 'sun3d-hotel_umd-maryland_hotel3' 22 | 'sun3d-hotel_umd-maryland_hotel3-evaluation' 23 | 'sun3d-mit_76_studyroom-76-1studyroom2' 24 | 'sun3d-mit_76_studyroom-76-1studyroom2-evaluation' 25 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika' 26 | 'sun3d-mit_lab_hj-lab_hj_tea_nov_2_2012_scan1_erika-evaluation' 27 | ) 28 | 29 | if [ ! -d "$TMP_PATH" ]; then 30 | echo ">> Create temporary directory: ${TMP_PATH}" 31 | mkdir -p "$TMP_PATH" 32 | fi 33 | cd "$TMP_PATH" 34 | 35 | echo ">> Start downloading" 36 | for url in ${urls[@]} 37 | do 38 | echo $BASE_PATH/$url 39 | wget --no-check-certificate --show-progress $BASE_PATH/$url.zip ./ 40 | done 41 | 42 | echo ">> Unpack .zip file" 43 | for filename in *.zip 44 | do 45 | unzip $filename 46 | done 47 | 48 | rm *.zip 49 | mv * ../ 50 | 51 | rm -r tmp 52 | } 53 | 54 | function main() { 55 | echo $DATA_DIR 56 | if [ -z "$DATA_DIR" ]; then 57 | echo "DATA_DIR is required config!" 58 | else 59 | download 60 | fi 61 | } 62 | 63 | main; 64 | -------------------------------------------------------------------------------- /assets/fig4_inlier_thresh.txt: -------------------------------------------------------------------------------- 1 | method 0.01 0.02 0.03 0.04 0.05 0.060000000000000005 0.06999999999999999 0.08 0.09 0.09999999999999999 0.11 0.12 0.13 0.14 0.15000000000000002 0.16 0.17 0.18000000000000002 0.19 0.2 0.21000000000000002 2 | FPFH 84.33316 68.71096 54.144688 43.4266 35.885315 29.24774 24.669365 21.693983 18.398415 15.19658 13.052581 10.593815 9.197111 7.121167 6.0142927 5.459413 4.7025414 3.7748396 3.2582376 2.9224794 2.7575514 3 | SpinImages 80.370415 57.09011 41.317715 30.377916 22.666761 16.668486 12.008004 9.49888 6.9450083 5.2102103 4.6000204 3.9652426 3.6801224 2.8298047 2.4588833 1.6771933 1.4722013 1.3379945 1.1484778 0.9681894 0.8833897 4 | SHOT 74.7338 55.754692 41.421146 31.38755 23.82493 18.377901 14.077372 11.229745 9.498036 7.2400875 6.259093 5.4715004 4.264607 3.6964087 3.070859 2.418819 2.0778346 1.5370839 1.5123804 1.4522841 1.2673242 5 | 3DMatch 93.57423 84.989296 72.26616 60.18662 50.84514 40.988052 35.264496 29.492676 24.875406 20.927317 18.053923 14.7144 12.210866 10.735968 9.550746 8.107913 7.4702764 6.172435 5.2574573 4.3098807 3.7686903 6 | CGF 93.651184 81.15855 68.73972 56.782322 47.758163 39.41682 34.020992 28.672413 23.078693 19.149256 16.00261 13.47433 11.930753 9.2401 6.4093018 5.2058983 4.630285 3.859022 3.5339522 2.9732845 2.6834931 7 | PPFNet 95.92501 89.856384 81.85167 70.101074 62.30703 53.38279 46.71212 38.00567 31.49468 25.324678 20.358257 17.48238 13.581946 10.503738 7.8975635 6.4955277 5.223775 4.1398544 3.4175255 3.1279635 2.5851185 8 | PPF-FoldNet 90.44259 83.66921 77.68714 72.016045 68.04171 62.860317 58.093655 54.45246 50.859642 47.77691 45.581078 42.854107 39.8563 36.73241 33.889057 32.212887 30.558687 28.446907 27.44521 25.17946 23.433472 9 | Ours 99.52991980845242 98.84371726893317 97.46276148034085 96.77077123423552 95.5016728687582 93.87989671321939 92.9612544809232 91.80873139043958 89.55258094103811 88.15087025481105 86.45012108855058 83.95821858043234 82.71260616076292 80.49154883453722 78.89347421469562 76.53057772636661 74.26297769943262 71.7272846810864 69.05880287956903 67.35529094542803 65.39114150434874 10 | -------------------------------------------------------------------------------- /model/residual_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from model.common import get_norm 4 | 5 | import MinkowskiEngine as ME 6 | import MinkowskiEngine.MinkowskiFunctional as MEF 7 | 8 | 9 | class BasicBlockBase(nn.Module): 10 | expansion = 1 11 | NORM_TYPE = 'BN' 12 | 13 | def __init__(self, 14 | inplanes, 15 | planes, 16 | stride=1, 17 | dilation=1, 18 | downsample=None, 19 | bn_momentum=0.1, 20 | D=3): 21 | super(BasicBlockBase, self).__init__() 22 | 23 | self.conv1 = ME.MinkowskiConvolution( 24 | inplanes, planes, kernel_size=3, stride=stride, dimension=D) 25 | self.norm1 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) 26 | self.conv2 = ME.MinkowskiConvolution( 27 | planes, 28 | planes, 29 | kernel_size=3, 30 | stride=1, 31 | dilation=dilation, 32 | bias=False, 33 | dimension=D) 34 | self.norm2 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) 35 | self.downsample = downsample 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.norm1(out) 42 | out = MEF.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.norm2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = MEF.relu(out) 52 | 53 | return out 54 | 55 | 56 | class BasicBlockBN(BasicBlockBase): 57 | NORM_TYPE = 'BN' 58 | 59 | 60 | class BasicBlockIN(BasicBlockBase): 61 | NORM_TYPE = 'IN' 62 | 63 | 64 | def get_block(norm_type, 65 | inplanes, 66 | planes, 67 | stride=1, 68 | dilation=1, 69 | downsample=None, 70 | bn_momentum=0.1, 71 | D=3): 72 | if norm_type == 'BN': 73 | return BasicBlockBN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 74 | elif norm_type == 'IN': 75 | return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 76 | else: 77 | raise ValueError(f'Type {norm_type}, not defined') 78 | -------------------------------------------------------------------------------- /scripts/train_fcgf_kitti.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export PATH_POSTFIX=$1 3 | export MISC_ARGS=$2 4 | 5 | export DATA_ROOT="./outputs/Experiments" 6 | export DATASET=${DATASET:-KITTINMPairDataset} 7 | export TRAINER=${TRAINER:-HardestContrastiveLossTrainer} 8 | export MODEL=${MODEL:-ResUNetBN2C} 9 | export MODEL_N_OUT=${MODEL_N_OUT:-16} 10 | export OPTIMIZER=${OPTIMIZER:-SGD} 11 | export LR=${LR:-1e-1} 12 | export MAX_EPOCH=${MAX_EPOCH:-200} 13 | export BATCH_SIZE=${BATCH_SIZE:-8} 14 | export ITER_SIZE=${ITER_SIZE:-1} 15 | export VOXEL_SIZE=${VOXEL_SIZE:-0.3} 16 | export POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER=${POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER:-1.5} 17 | export CONV1_KERNEL_SIZE=${CONV1_KERNEL_SIZE:-5} 18 | export EXP_GAMMA=${EXP_GAMMA:-0.99} 19 | export RANDOM_SCALE=${RANDOM_SCALE:-True} 20 | export TIME=$(date +"%Y-%m-%d_%H-%M-%S") 21 | export KITTI_PATH=${KITTI_PATH:-/home/chrischoy/datasets/KITTI_FCGF} 22 | export VERSION=$(git rev-parse HEAD) 23 | 24 | export OUT_DIR=${DATA_ROOT}/${DATASET}-v${VOXEL_SIZE}/${TRAINER}/${MODEL}/${OPTIMIZER}-lr${LR}-e${MAX_EPOCH}-b${BATCH_SIZE}i${ITER_SIZE}-modelnout${MODEL_N_OUT}${PATH_POSTFIX}/${TIME} 25 | 26 | export PYTHONUNBUFFERED="True" 27 | 28 | echo $OUT_DIR 29 | 30 | mkdir -m 755 -p $OUT_DIR 31 | 32 | LOG=${OUT_DIR}/log_${TIME}.txt 33 | 34 | echo "Host: " $(hostname) | tee -a $LOG 35 | echo "Conda " $(which conda) | tee -a $LOG 36 | echo $(pwd) | tee -a $LOG 37 | echo "Version: " $VERSION | tee -a $LOG 38 | echo "Git diff" | tee -a $LOG 39 | echo "" | tee -a $LOG 40 | git diff | tee -a $LOG 41 | echo "" | tee -a $LOG 42 | nvidia-smi | tee -a $LOG 43 | 44 | # Training 45 | python train.py \ 46 | --dataset ${DATASET} \ 47 | --trainer ${TRAINER} \ 48 | --model ${MODEL} \ 49 | --model_n_out ${MODEL_N_OUT} \ 50 | --conv1_kernel_size ${CONV1_KERNEL_SIZE} \ 51 | --optimizer ${OPTIMIZER} \ 52 | --lr ${LR} \ 53 | --batch_size ${BATCH_SIZE} \ 54 | --iter_size ${ITER_SIZE} \ 55 | --max_epoch ${MAX_EPOCH} \ 56 | --voxel_size ${VOXEL_SIZE} \ 57 | --out_dir ${OUT_DIR} \ 58 | --use_random_scale ${RANDOM_SCALE} \ 59 | --positive_pair_search_voxel_size_multiplier ${POSITIVE_PAIR_SEARCH_VOXEL_SIZE_MULTIPLIER} \ 60 | --kitti_root ${KITTI_PATH} \ 61 | --hit_ratio_thresh 0.3 \ 62 | $MISC_ARGS 2>&1 | tee -a $LOG 63 | 64 | # Test 65 | python -m scripts.test_kitti \ 66 | --kitti_root ${KITTI_PATH} \ 67 | --save_dir ${OUT_DIR} | tee -a $LOG 68 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import open3d as o3d 5 | from urllib.request import urlretrieve 6 | from util.visualization import get_colored_point_cloud_feature 7 | from util.misc import extract_features 8 | 9 | from model.resunet import ResUNetBN2C 10 | 11 | import torch 12 | 13 | if not os.path.isfile('ResUNetBN2C-16feat-3conv.pth'): 14 | print('Downloading weights...') 15 | urlretrieve( 16 | "https://node1.chrischoy.org/data/publications/fcgf/2019-09-18_14-15-59.pth", 17 | 'ResUNetBN2C-16feat-3conv.pth') 18 | 19 | if not os.path.isfile('redkitchen-20.ply'): 20 | print('Downloading a mesh...') 21 | urlretrieve("https://node1.chrischoy.org/data/publications/fcgf/redkitchen-20.ply", 22 | 'redkitchen-20.ply') 23 | 24 | 25 | def demo(config): 26 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | 28 | checkpoint = torch.load(config.model) 29 | model = ResUNetBN2C(1, 16, normalize_feature=True, conv1_kernel_size=3, D=3) 30 | model.load_state_dict(checkpoint['state_dict']) 31 | model.eval() 32 | 33 | model = model.to(device) 34 | 35 | pcd = o3d.io.read_point_cloud(config.input) 36 | xyz_down, feature = extract_features( 37 | model, 38 | xyz=np.array(pcd.points), 39 | voxel_size=config.voxel_size, 40 | device=device, 41 | skip_check=True) 42 | 43 | vis_pcd = o3d.geometry.PointCloud() 44 | vis_pcd.points = o3d.utility.Vector3dVector(xyz_down) 45 | 46 | vis_pcd = get_colored_point_cloud_feature(vis_pcd, 47 | feature.detach().cpu().numpy(), 48 | config.voxel_size) 49 | o3d.visualization.draw_geometries([vis_pcd]) 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument( 55 | '-i', 56 | '--input', 57 | default='redkitchen-20.ply', 58 | type=str, 59 | help='path to a pointcloud file') 60 | parser.add_argument( 61 | '-m', 62 | '--model', 63 | default='ResUNetBN2C-16feat-3conv.pth', 64 | type=str, 65 | help='path to latest checkpoint (default: None)') 66 | parser.add_argument( 67 | '--voxel_size', 68 | default=0.025, 69 | type=float, 70 | help='voxel size to preprocess point cloud') 71 | 72 | config = parser.parse_args() 73 | demo(config) 74 | -------------------------------------------------------------------------------- /scripts/benchmark_util.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import os 3 | import logging 4 | import numpy as np 5 | 6 | from util.trajectory import CameraPose 7 | from util.pointcloud import compute_overlap_ratio, \ 8 | make_open3d_point_cloud, make_open3d_feature_from_numpy 9 | 10 | 11 | def run_ransac(xyz0, xyz1, feat0, feat1, voxel_size): 12 | distance_threshold = voxel_size * 1.5 13 | result_ransac = o3d.registration.registration_ransac_based_on_feature_matching( 14 | xyz0, xyz1, feat0, feat1, distance_threshold, 15 | o3d.registration.TransformationEstimationPointToPoint(False), 4, [ 16 | o3d.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 17 | o3d.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold) 18 | ], o3d.registration.RANSACConvergenceCriteria(4000000, 500)) 19 | return result_ransac.transformation 20 | 21 | 22 | def gather_results(results): 23 | traj = [] 24 | for r in results: 25 | success = r[0] 26 | if success: 27 | traj.append(CameraPose([r[1], r[2], r[3]], r[4])) 28 | return traj 29 | 30 | 31 | def gen_matching_pair(pts_num): 32 | matching_pairs = [] 33 | for i in range(pts_num): 34 | for j in range(i + 1, pts_num): 35 | matching_pairs.append([i, j, pts_num]) 36 | return matching_pairs 37 | 38 | 39 | def read_data(feature_path, name): 40 | data = np.load(os.path.join(feature_path, name + ".npz")) 41 | xyz = make_open3d_point_cloud(data['xyz']) 42 | feat = make_open3d_feature_from_numpy(data['feature']) 43 | return data['points'], xyz, feat 44 | 45 | 46 | def do_single_pair_matching(feature_path, set_name, m, voxel_size): 47 | i, j, s = m 48 | name_i = "%s_%03d" % (set_name, i) 49 | name_j = "%s_%03d" % (set_name, j) 50 | logging.info("matching %s %s" % (name_i, name_j)) 51 | points_i, xyz_i, feat_i = read_data(feature_path, name_i) 52 | points_j, xyz_j, feat_j = read_data(feature_path, name_j) 53 | if len(xyz_i.points) < len(xyz_j.points): 54 | trans = run_ransac(xyz_i, xyz_j, feat_i, feat_j, voxel_size) 55 | else: 56 | trans = run_ransac(xyz_j, xyz_i, feat_j, feat_i, voxel_size) 57 | trans = np.linalg.inv(trans) 58 | ratio = compute_overlap_ratio(xyz_i, xyz_j, trans, voxel_size) 59 | logging.info(f"{ratio}") 60 | if ratio > 0.3: 61 | return [True, i, j, s, np.linalg.inv(trans)] 62 | else: 63 | return [False, i, j, s, np.identity(4)] 64 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | import open3d as o3d # prevent loading error 3 | 4 | import sys 5 | import json 6 | import logging 7 | import torch 8 | from easydict import EasyDict as edict 9 | 10 | from lib.data_loaders import make_data_loader 11 | from config import get_config 12 | 13 | from lib.trainer import ContrastiveLossTrainer, HardestContrastiveLossTrainer, \ 14 | TripletLossTrainer, HardestTripletLossTrainer 15 | 16 | ch = logging.StreamHandler(sys.stdout) 17 | logging.getLogger().setLevel(logging.INFO) 18 | logging.basicConfig( 19 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 20 | 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed(0) 23 | 24 | logging.basicConfig(level=logging.INFO, format="") 25 | 26 | 27 | def get_trainer(trainer): 28 | if trainer == 'ContrastiveLossTrainer': 29 | return ContrastiveLossTrainer 30 | elif trainer == 'HardestContrastiveLossTrainer': 31 | return HardestContrastiveLossTrainer 32 | elif trainer == 'TripletLossTrainer': 33 | return TripletLossTrainer 34 | elif trainer == 'HardestTripletLossTrainer': 35 | return HardestTripletLossTrainer 36 | else: 37 | raise ValueError(f'Trainer {trainer} not found') 38 | 39 | 40 | def main(config, resume=False): 41 | train_loader = make_data_loader( 42 | config, 43 | config.train_phase, 44 | config.batch_size, 45 | num_threads=config.train_num_thread) 46 | 47 | if config.test_valid: 48 | val_loader = make_data_loader( 49 | config, 50 | config.val_phase, 51 | config.val_batch_size, 52 | num_threads=config.val_num_thread) 53 | else: 54 | val_loader = None 55 | 56 | Trainer = get_trainer(config.trainer) 57 | trainer = Trainer( 58 | config=config, 59 | data_loader=train_loader, 60 | val_data_loader=val_loader, 61 | ) 62 | 63 | trainer.train() 64 | 65 | 66 | if __name__ == "__main__": 67 | logger = logging.getLogger() 68 | config = get_config() 69 | 70 | dconfig = vars(config) 71 | if config.resume_dir: 72 | resume_config = json.load(open(config.resume_dir + '/config.json', 'r')) 73 | for k in dconfig: 74 | if k not in ['resume_dir'] and k in resume_config: 75 | dconfig[k] = resume_config[k] 76 | dconfig['resume'] = resume_config['out_dir'] + '/checkpoint.pth' 77 | 78 | logging.info('===> Configurations') 79 | for k in dconfig: 80 | logging.info(' {}: {}'.format(k, dconfig[k])) 81 | 82 | # Convert to dict 83 | config = edict(dconfig) 84 | main(config) 85 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import MinkowskiEngine as ME 4 | 5 | 6 | def _hash(arr, M): 7 | if isinstance(arr, np.ndarray): 8 | N, D = arr.shape 9 | else: 10 | N, D = len(arr[0]), len(arr) 11 | 12 | hash_vec = np.zeros(N, dtype=np.int64) 13 | for d in range(D): 14 | if isinstance(arr, np.ndarray): 15 | hash_vec += arr[:, d] * M**d 16 | else: 17 | hash_vec += arr[d] * M**d 18 | return hash_vec 19 | 20 | 21 | def extract_features(model, 22 | xyz, 23 | rgb=None, 24 | normal=None, 25 | voxel_size=0.05, 26 | device=None, 27 | skip_check=False, 28 | is_eval=True): 29 | ''' 30 | xyz is a N x 3 matrix 31 | rgb is a N x 3 matrix and all color must range from [0, 1] or None 32 | normal is a N x 3 matrix and all normal range from [-1, 1] or None 33 | 34 | if both rgb and normal are None, we use Nx1 one vector as an input 35 | 36 | if device is None, it tries to use gpu by default 37 | 38 | if skip_check is True, skip rigorous checks to speed up 39 | 40 | model = model.to(device) 41 | xyz, feats = extract_features(model, xyz) 42 | ''' 43 | if is_eval: 44 | model.eval() 45 | 46 | if not skip_check: 47 | assert xyz.shape[1] == 3 48 | 49 | N = xyz.shape[0] 50 | if rgb is not None: 51 | assert N == len(rgb) 52 | assert rgb.shape[1] == 3 53 | if np.any(rgb > 1): 54 | raise ValueError('Invalid color. Color must range from [0, 1]') 55 | 56 | if normal is not None: 57 | assert N == len(normal) 58 | assert normal.shape[1] == 3 59 | if np.any(normal > 1): 60 | raise ValueError('Invalid normal. Normal must range from [-1, 1]') 61 | 62 | if device is None: 63 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 64 | 65 | feats = [] 66 | if rgb is not None: 67 | # [0, 1] 68 | feats.append(rgb - 0.5) 69 | 70 | if normal is not None: 71 | # [-1, 1] 72 | feats.append(normal / 2) 73 | 74 | if rgb is None and normal is None: 75 | feats.append(np.ones((len(xyz), 1))) 76 | 77 | feats = np.hstack(feats) 78 | 79 | # Voxelize xyz and feats 80 | coords = np.floor(xyz / voxel_size) 81 | coords, inds = ME.utils.sparse_quantize(coords, return_index=True) 82 | # Convert to batched coords compatible with ME 83 | coords = ME.utils.batched_coordinates([coords]) 84 | return_coords = xyz[inds] 85 | 86 | feats = feats[inds] 87 | 88 | feats = torch.tensor(feats, dtype=torch.float32) 89 | coords = torch.tensor(coords, dtype=torch.int32) 90 | 91 | stensor = ME.SparseTensor(feats, coordinates=coords, device=device) 92 | 93 | return return_coords, model(stensor).F 94 | -------------------------------------------------------------------------------- /util/transform_estimation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | 4 | 5 | def rot_x(x): 6 | out = torch.zeros((3, 3)) 7 | c = torch.cos(x) 8 | s = torch.sin(x) 9 | out[0, 0] = 1 10 | out[1, 1] = c 11 | out[1, 2] = -s 12 | out[2, 1] = s 13 | out[2, 2] = c 14 | return out 15 | 16 | 17 | def rot_y(x): 18 | out = torch.zeros((3, 3)) 19 | c = torch.cos(x) 20 | s = torch.sin(x) 21 | out[0, 0] = c 22 | out[0, 2] = s 23 | out[1, 1] = 1 24 | out[2, 0] = -s 25 | out[2, 2] = c 26 | return out 27 | 28 | 29 | def rot_z(x): 30 | out = torch.zeros((3, 3)) 31 | c = torch.cos(x) 32 | s = torch.sin(x) 33 | out[0, 0] = c 34 | out[0, 1] = -s 35 | out[1, 0] = s 36 | out[1, 1] = c 37 | out[2, 2] = 1 38 | return out 39 | 40 | 41 | def get_trans(x): 42 | trans = torch.eye(4) 43 | trans[:3, :3] = rot_z(x[2]).mm(rot_y(x[1])).mm(rot_x(x[0])) 44 | trans[:3, 3] = x[3:, 0] 45 | return trans 46 | 47 | 48 | def update_pcd(pts, trans): 49 | R = trans[:3, :3] 50 | T = trans[:3, 3] 51 | # pts = R.mm(pts.t()).t() + T.unsqueeze(1).t().expand_as(pts) 52 | pts = torch.t(R @ torch.t(pts)) + T 53 | return pts 54 | 55 | 56 | def build_linear_system(pts0, pts1, weight): 57 | npts0 = pts0.shape[0] 58 | A0 = torch.zeros((npts0, 6)) 59 | A1 = torch.zeros((npts0, 6)) 60 | A2 = torch.zeros((npts0, 6)) 61 | A0[:, 1] = pts0[:, 2] 62 | A0[:, 2] = -pts0[:, 1] 63 | A0[:, 3] = 1 64 | A1[:, 0] = -pts0[:, 2] 65 | A1[:, 2] = pts0[:, 0] 66 | A1[:, 4] = 1 67 | A2[:, 0] = pts0[:, 1] 68 | A2[:, 1] = -pts0[:, 0] 69 | A2[:, 5] = 1 70 | ww1 = weight.repeat(3, 6) 71 | ww2 = weight.repeat(3, 1) 72 | A = ww1 * torch.cat((A0, A1, A2), 0) 73 | b = ww2 * torch.cat( 74 | (pts1[:, 0] - pts0[:, 0], pts1[:, 1] - pts0[:, 1], pts1[:, 2] - pts0[:, 2]), 75 | 0, 76 | ).unsqueeze(1) 77 | return A, b 78 | 79 | 80 | def solve_linear_system(A, b): 81 | temp = torch.inverse(A.t().mm(A)) 82 | return temp.mm(A.t()).mm(b) 83 | 84 | 85 | def compute_weights(pts0, pts1, par): 86 | return par / (torch.norm(pts0 - pts1, dim=1).unsqueeze(1) + par) 87 | 88 | 89 | def est_quad_linear_robust(pts0, pts1, weight=None): 90 | # TODO: 2. residual scheduling 91 | pts0_curr = pts0 92 | trans = torch.eye(4) 93 | 94 | par = 1.0 # todo: need to decide 95 | if weight is None: 96 | weight = torch.ones(pts0.size()[0], 1) 97 | 98 | for i in range(20): 99 | if i > 0 and i % 5 == 0: 100 | par /= 2.0 101 | 102 | # compute weights 103 | A, b = build_linear_system(pts0_curr, pts1, weight) 104 | x = solve_linear_system(A, b) 105 | 106 | # TODO: early termination 107 | # residual = np.linalg.norm(A@x - b) 108 | # print(residual) 109 | 110 | # x = torch.empty(6, 1).uniform_(0, 1) 111 | trans_curr = get_trans(x) 112 | pts0_curr = update_pcd(pts0_curr, trans_curr) 113 | weight = compute_weights(pts0_curr, pts1, par) 114 | trans = trans_curr.mm(trans) 115 | 116 | return trans 117 | 118 | 119 | def pose_estimation(model, 120 | device, 121 | xyz0, 122 | xyz1, 123 | coord0, 124 | coord1, 125 | feats0, 126 | feats1, 127 | return_corr=False): 128 | sinput0 = ME.SparseTensor(feats0.to(device), coordinates=coord0.to(device)) 129 | F0 = model(sinput0).F 130 | 131 | sinput1 = ME.SparseTensor(feats1.to(device), coordinates=coord1.to(device)) 132 | F1 = model(sinput1).F 133 | 134 | corr = F0.mm(F1.t()) 135 | weight, inds = corr.max(dim=1) 136 | weight = weight.unsqueeze(1).cpu() 137 | xyz1_corr = xyz1[inds, :] 138 | 139 | trans = est_quad_linear_robust(xyz0, xyz1_corr, weight) # let's do this later 140 | 141 | if return_corr: 142 | return trans, weight, corr 143 | else: 144 | return trans, weight 145 | -------------------------------------------------------------------------------- /util/pointcloud.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import math 4 | 5 | import open3d as o3d 6 | from lib.eval import find_nn_cpu 7 | 8 | 9 | def make_open3d_point_cloud(xyz, color=None): 10 | pcd = o3d.geometry.PointCloud() 11 | pcd.points = o3d.utility.Vector3dVector(xyz) 12 | if color is not None: 13 | pcd.colors = o3d.utility.Vector3dVector(color) 14 | return pcd 15 | 16 | 17 | def make_open3d_feature(data, dim, npts): 18 | feature = o3d.registration.Feature() 19 | feature.resize(dim, npts) 20 | feature.data = data.cpu().numpy().astype('d').transpose() 21 | return feature 22 | 23 | 24 | def make_open3d_feature_from_numpy(data): 25 | assert isinstance(data, np.ndarray) 26 | assert data.ndim == 2 27 | 28 | feature = o3d.registration.Feature() 29 | feature.resize(data.shape[1], data.shape[0]) 30 | feature.data = data.astype('d').transpose() 31 | return feature 32 | 33 | 34 | def prepare_pointcloud(filename, voxel_size): 35 | pcd = o3d.io.read_point_cloud(filename) 36 | T = get_random_transformation(pcd) 37 | pcd.transform(T) 38 | pcd_down = pcd.voxel_down_sample(voxel_size) 39 | return pcd_down, T 40 | 41 | 42 | def compute_overlap_ratio(pcd0, pcd1, trans, voxel_size): 43 | pcd0_down = pcd0.voxel_down_sample(voxel_size) 44 | pcd1_down = pcd1.voxel_down_sample(voxel_size) 45 | matching01 = get_matching_indices(pcd0_down, pcd1_down, trans, voxel_size, 1) 46 | matching10 = get_matching_indices(pcd1_down, pcd0_down, np.linalg.inv(trans), 47 | voxel_size, 1) 48 | overlap0 = len(matching01) / len(pcd0_down.points) 49 | overlap1 = len(matching10) / len(pcd1_down.points) 50 | return max(overlap0, overlap1) 51 | 52 | 53 | def get_matching_indices(source, target, trans, search_voxel_size, K=None): 54 | source_copy = copy.deepcopy(source) 55 | target_copy = copy.deepcopy(target) 56 | source_copy.transform(trans) 57 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 58 | 59 | match_inds = [] 60 | for i, point in enumerate(source_copy.points): 61 | [_, idx, _] = pcd_tree.search_radius_vector_3d(point, search_voxel_size) 62 | if K is not None: 63 | idx = idx[:K] 64 | for j in idx: 65 | match_inds.append((i, j)) 66 | return match_inds 67 | 68 | 69 | def evaluate_feature(pcd0, pcd1, feat0, feat1, trans_gth, search_voxel_size): 70 | match_inds = get_matching_indices(pcd0, pcd1, trans_gth, search_voxel_size) 71 | pcd_tree = o3d.geometry.KDTreeFlann(feat1) 72 | dist = [] 73 | for ind in match_inds: 74 | k, idx, _ = pcd_tree.search_knn_vector_xd(feat0.data[:, ind[0]], 1) 75 | dist.append( 76 | np.clip( 77 | np.power(pcd1.points[ind[1]] - pcd1.points[idx[0]], 2), 78 | a_min=0.0, 79 | a_max=1.0)) 80 | return np.mean(dist) 81 | 82 | 83 | def valid_feat_ratio(pcd0, pcd1, feat0, feat1, trans_gth, thresh=0.1): 84 | pcd0_copy = copy.deepcopy(pcd0) 85 | pcd0_copy.transform(trans_gth) 86 | inds = find_nn_cpu(feat0, feat1, return_distance=False) 87 | dist = np.sqrt(((np.array(pcd0_copy.points) - np.array(pcd1.points)[inds])**2).sum(1)) 88 | return np.mean(dist < thresh) 89 | 90 | 91 | def evaluate_feature_3dmatch(pcd0, pcd1, feat0, feat1, trans_gth, inlier_thresh=0.1): 92 | r"""Return the hit ratio (ratio of inlier correspondences and all correspondences). 93 | 94 | inliear_thresh is the inlier_threshold in meter. 95 | """ 96 | if len(pcd0.points) < len(pcd1.points): 97 | hit = valid_feat_ratio(pcd0, pcd1, feat0, feat1, trans_gth, inlier_thresh) 98 | else: 99 | hit = valid_feat_ratio(pcd1, pcd0, feat1, feat0, np.linalg.inv(trans_gth), inlier_thresh) 100 | return hit 101 | 102 | 103 | def get_matching_matrix(source, target, trans, voxel_size, debug_mode): 104 | source_copy = copy.deepcopy(source) 105 | target_copy = copy.deepcopy(target) 106 | source_copy.transform(trans) 107 | pcd_tree = o3d.geometry.KDTreeFlann(target_copy) 108 | matching_matrix = np.zeros((len(source_copy.points), len(target_copy.points))) 109 | 110 | for i, point in enumerate(source_copy.points): 111 | [k, idx, _] = pcd_tree.search_radius_vector_3d(point, voxel_size * 1.5) 112 | if k >= 1: 113 | matching_matrix[i, idx[0]] = 1 # TODO: only the cloest? 114 | 115 | return matching_matrix 116 | 117 | 118 | def get_random_transformation(pcd_input): 119 | 120 | def rot_x(x): 121 | out = np.zeros((3, 3)) 122 | c = math.cos(x) 123 | s = math.sin(x) 124 | out[0, 0] = 1 125 | out[1, 1] = c 126 | out[1, 2] = -s 127 | out[2, 1] = s 128 | out[2, 2] = c 129 | return out 130 | 131 | def rot_y(x): 132 | out = np.zeros((3, 3)) 133 | c = math.cos(x) 134 | s = math.sin(x) 135 | out[0, 0] = c 136 | out[0, 2] = s 137 | out[1, 1] = 1 138 | out[2, 0] = -s 139 | out[2, 2] = c 140 | return out 141 | 142 | def rot_z(x): 143 | out = np.zeros((3, 3)) 144 | c = math.cos(x) 145 | s = math.sin(x) 146 | out[0, 0] = c 147 | out[0, 1] = -s 148 | out[1, 0] = s 149 | out[1, 1] = c 150 | out[2, 2] = 1 151 | return out 152 | 153 | pcd_output = copy.deepcopy(pcd_input) 154 | mean = np.mean(np.asarray(pcd_output.points), axis=0).transpose() 155 | xyz = np.random.uniform(0, 2 * math.pi, 3) 156 | R = np.dot(np.dot(rot_x(xyz[0]), rot_y(xyz[1])), rot_z(xyz[2])) 157 | T = np.zeros((4, 4)) 158 | T[:3, :3] = R 159 | T[:3, 3] = np.dot(-R, mean) 160 | T[3, 3] = 1 161 | return T 162 | -------------------------------------------------------------------------------- /scripts/test_kitti.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d # prevent loading error 2 | 3 | import sys 4 | import logging 5 | import json 6 | import argparse 7 | import numpy as np 8 | from easydict import EasyDict as edict 9 | 10 | import torch 11 | from model import load_model 12 | 13 | from lib.data_loaders import make_data_loader 14 | from util.pointcloud import make_open3d_point_cloud, make_open3d_feature 15 | from lib.timer import AverageMeter, Timer 16 | 17 | import MinkowskiEngine as ME 18 | 19 | ch = logging.StreamHandler(sys.stdout) 20 | logging.getLogger().setLevel(logging.INFO) 21 | logging.basicConfig( 22 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 23 | 24 | 25 | def main(config): 26 | test_loader = make_data_loader( 27 | config, config.test_phase, 1, num_threads=config.test_num_workers, shuffle=True) 28 | 29 | num_feats = 1 30 | 31 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 32 | 33 | Model = load_model(config.model) 34 | model = Model( 35 | num_feats, 36 | config.model_n_out, 37 | bn_momentum=config.bn_momentum, 38 | conv1_kernel_size=config.conv1_kernel_size, 39 | normalize_feature=config.normalize_feature) 40 | checkpoint = torch.load(config.save_dir + '/checkpoint.pth') 41 | model.load_state_dict(checkpoint['state_dict']) 42 | model = model.to(device) 43 | model.eval() 44 | 45 | success_meter, rte_meter, rre_meter = AverageMeter(), AverageMeter(), AverageMeter() 46 | data_timer, feat_timer, reg_timer = Timer(), Timer(), Timer() 47 | 48 | test_iter = test_loader.__iter__() 49 | N = len(test_iter) 50 | n_gpu_failures = 0 51 | 52 | # downsample_voxel_size = 2 * config.voxel_size 53 | 54 | for i in range(len(test_iter)): 55 | data_timer.tic() 56 | try: 57 | data_dict = test_iter.next() 58 | except ValueError: 59 | n_gpu_failures += 1 60 | logging.info(f"# Erroneous GPU Pair {n_gpu_failures}") 61 | continue 62 | data_timer.toc() 63 | xyz0, xyz1 = data_dict['pcd0'], data_dict['pcd1'] 64 | T_gth = data_dict['T_gt'] 65 | xyz0np, xyz1np = xyz0.numpy(), xyz1.numpy() 66 | 67 | pcd0 = make_open3d_point_cloud(xyz0np) 68 | pcd1 = make_open3d_point_cloud(xyz1np) 69 | 70 | with torch.no_grad(): 71 | feat_timer.tic() 72 | sinput0 = ME.SparseTensor( 73 | data_dict['sinput0_F'].to(device), coordinates=data_dict['sinput0_C'].to(device)) 74 | F0 = model(sinput0).F.detach() 75 | sinput1 = ME.SparseTensor( 76 | data_dict['sinput1_F'].to(device), coordinates=data_dict['sinput1_C'].to(device)) 77 | F1 = model(sinput1).F.detach() 78 | feat_timer.toc() 79 | 80 | feat0 = make_open3d_feature(F0, 32, F0.shape[0]) 81 | feat1 = make_open3d_feature(F1, 32, F1.shape[0]) 82 | 83 | reg_timer.tic() 84 | distance_threshold = config.voxel_size * 1.0 85 | ransac_result = o3d.registration.registration_ransac_based_on_feature_matching( 86 | pcd0, pcd1, feat0, feat1, distance_threshold, 87 | o3d.registration.TransformationEstimationPointToPoint(False), 4, [ 88 | o3d.registration.CorrespondenceCheckerBasedOnEdgeLength(0.9), 89 | o3d.registration.CorrespondenceCheckerBasedOnDistance(distance_threshold) 90 | ], o3d.registration.RANSACConvergenceCriteria(4000000, 10000)) 91 | T_ransac = torch.from_numpy(ransac_result.transformation.astype(np.float32)) 92 | reg_timer.toc() 93 | 94 | # Translation error 95 | rte = np.linalg.norm(T_ransac[:3, 3] - T_gth[:3, 3]) 96 | rre = np.arccos((np.trace(T_ransac[:3, :3].t() @ T_gth[:3, :3]) - 1) / 2) 97 | 98 | # Check if the ransac was successful. successful if rte < 2m and rre < 5◦ 99 | # http://openaccess.thecvf.com/content_ECCV_2018/papers/Zi_Jian_Yew_3DFeat-Net_Weakly_Supervised_ECCV_2018_paper.pdf 100 | if rte < 2: 101 | rte_meter.update(rte) 102 | 103 | if not np.isnan(rre) and rre < np.pi / 180 * 5: 104 | rre_meter.update(rre) 105 | 106 | if rte < 2 and not np.isnan(rre) and rre < np.pi / 180 * 5: 107 | success_meter.update(1) 108 | else: 109 | success_meter.update(0) 110 | logging.info(f"Failed with RTE: {rte}, RRE: {rre}") 111 | 112 | if i % 10 == 0: 113 | logging.info( 114 | f"{i} / {N}: Data time: {data_timer.avg}, Feat time: {feat_timer.avg}," + 115 | f" Reg time: {reg_timer.avg}, RTE: {rte_meter.avg}," + 116 | f" RRE: {rre_meter.avg}, Success: {success_meter.sum} / {success_meter.count}" 117 | + f" ({success_meter.avg * 100} %)") 118 | data_timer.reset() 119 | feat_timer.reset() 120 | reg_timer.reset() 121 | 122 | logging.info( 123 | f"RTE: {rte_meter.avg}, var: {rte_meter.var}," + 124 | f" RRE: {rre_meter.avg}, var: {rre_meter.var}, Success: {success_meter.sum} " + 125 | f"/ {success_meter.count} ({success_meter.avg * 100} %)") 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--save_dir', default=None, type=str) 131 | parser.add_argument('--test_phase', default='test', type=str) 132 | parser.add_argument('--test_num_thread', default=5, type=int) 133 | parser.add_argument('--kitti_root', type=str, default="/data/kitti/") 134 | args = parser.parse_args() 135 | 136 | config = json.load(open(args.save_dir + '/config.json', 'r')) 137 | config = edict(config) 138 | config.save_dir = args.save_dir 139 | config.test_phase = args.test_phase 140 | config.kitti_root = args.kitti_root 141 | config.kitti_odometry_root = args.kitti_root + '/dataset' 142 | config.test_num_thread = args.test_num_thread 143 | 144 | main(config) 145 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | arg_lists = [] 4 | parser = argparse.ArgumentParser() 5 | 6 | 7 | def add_argument_group(name): 8 | arg = parser.add_argument_group(name) 9 | arg_lists.append(arg) 10 | return arg 11 | 12 | 13 | def str2bool(v): 14 | return v.lower() in ('true', '1') 15 | 16 | 17 | logging_arg = add_argument_group('Logging') 18 | logging_arg.add_argument('--out_dir', type=str, default='outputs') 19 | 20 | trainer_arg = add_argument_group('Trainer') 21 | trainer_arg.add_argument('--trainer', type=str, default='HardestContrastiveLossTrainer') 22 | trainer_arg.add_argument('--save_freq_epoch', type=int, default=1) 23 | trainer_arg.add_argument('--batch_size', type=int, default=4) 24 | trainer_arg.add_argument('--val_batch_size', type=int, default=1) 25 | 26 | # Hard negative mining 27 | trainer_arg.add_argument('--use_hard_negative', type=str2bool, default=True) 28 | trainer_arg.add_argument('--hard_negative_sample_ratio', type=int, default=0.05) 29 | trainer_arg.add_argument('--hard_negative_max_num', type=int, default=3000) 30 | trainer_arg.add_argument('--num_pos_per_batch', type=int, default=1024) 31 | trainer_arg.add_argument('--num_hn_samples_per_batch', type=int, default=256) 32 | 33 | # Metric learning loss 34 | trainer_arg.add_argument('--neg_thresh', type=float, default=1.4) 35 | trainer_arg.add_argument('--pos_thresh', type=float, default=0.1) 36 | trainer_arg.add_argument('--neg_weight', type=float, default=1) 37 | 38 | # Data augmentation 39 | trainer_arg.add_argument('--use_random_scale', type=str2bool, default=False) 40 | trainer_arg.add_argument('--min_scale', type=float, default=0.8) 41 | trainer_arg.add_argument('--max_scale', type=float, default=1.2) 42 | trainer_arg.add_argument('--use_random_rotation', type=str2bool, default=True) 43 | trainer_arg.add_argument('--rotation_range', type=float, default=360) 44 | 45 | # Data loader configs 46 | trainer_arg.add_argument('--train_phase', type=str, default="train") 47 | trainer_arg.add_argument('--val_phase', type=str, default="val") 48 | trainer_arg.add_argument('--test_phase', type=str, default="test") 49 | 50 | trainer_arg.add_argument('--stat_freq', type=int, default=40) 51 | trainer_arg.add_argument('--test_valid', type=str2bool, default=True) 52 | trainer_arg.add_argument('--val_max_iter', type=int, default=400) 53 | trainer_arg.add_argument('--val_epoch_freq', type=int, default=1) 54 | trainer_arg.add_argument( 55 | '--positive_pair_search_voxel_size_multiplier', type=float, default=1.5) 56 | 57 | trainer_arg.add_argument('--hit_ratio_thresh', type=float, default=0.1) 58 | 59 | # Triplets 60 | trainer_arg.add_argument('--triplet_num_pos', type=int, default=256) 61 | trainer_arg.add_argument('--triplet_num_hn', type=int, default=512) 62 | trainer_arg.add_argument('--triplet_num_rand', type=int, default=1024) 63 | 64 | # dNetwork specific configurations 65 | net_arg = add_argument_group('Network') 66 | net_arg.add_argument('--model', type=str, default='ResUNetBN2C') 67 | net_arg.add_argument('--model_n_out', type=int, default=32, help='Feature dimension') 68 | net_arg.add_argument('--conv1_kernel_size', type=int, default=5) 69 | net_arg.add_argument('--normalize_feature', type=str2bool, default=True) 70 | net_arg.add_argument('--dist_type', type=str, default='L2') 71 | net_arg.add_argument('--best_val_metric', type=str, default='feat_match_ratio') 72 | 73 | # Optimizer arguments 74 | opt_arg = add_argument_group('Optimizer') 75 | opt_arg.add_argument('--optimizer', type=str, default='SGD') 76 | opt_arg.add_argument('--max_epoch', type=int, default=100) 77 | opt_arg.add_argument('--lr', type=float, default=1e-1) 78 | opt_arg.add_argument('--momentum', type=float, default=0.8) 79 | opt_arg.add_argument('--sgd_momentum', type=float, default=0.9) 80 | opt_arg.add_argument('--sgd_dampening', type=float, default=0.1) 81 | opt_arg.add_argument('--adam_beta1', type=float, default=0.9) 82 | opt_arg.add_argument('--adam_beta2', type=float, default=0.999) 83 | opt_arg.add_argument('--weight_decay', type=float, default=1e-4) 84 | opt_arg.add_argument('--iter_size', type=int, default=1, help='accumulate gradient') 85 | opt_arg.add_argument('--bn_momentum', type=float, default=0.05) 86 | opt_arg.add_argument('--exp_gamma', type=float, default=0.99) 87 | opt_arg.add_argument('--scheduler', type=str, default='ExpLR') 88 | opt_arg.add_argument( 89 | '--icp_cache_path', type=str, default="/home/chrischoy/datasets/FCGF/kitti/icp/") 90 | 91 | misc_arg = add_argument_group('Misc') 92 | misc_arg.add_argument('--use_gpu', type=str2bool, default=True) 93 | misc_arg.add_argument('--weights', type=str, default=None) 94 | misc_arg.add_argument('--weights_dir', type=str, default=None) 95 | misc_arg.add_argument('--resume', type=str, default=None) 96 | misc_arg.add_argument('--resume_dir', type=str, default=None) 97 | misc_arg.add_argument('--train_num_thread', type=int, default=2) 98 | misc_arg.add_argument('--val_num_thread', type=int, default=1) 99 | misc_arg.add_argument('--test_num_thread', type=int, default=2) 100 | misc_arg.add_argument('--fast_validation', type=str2bool, default=False) 101 | misc_arg.add_argument( 102 | '--nn_max_n', 103 | type=int, 104 | default=500, 105 | help='The maximum number of features to find nearest neighbors in batch') 106 | 107 | # Dataset specific configurations 108 | data_arg = add_argument_group('Data') 109 | data_arg.add_argument('--dataset', type=str, default='ThreeDMatchPairDataset') 110 | data_arg.add_argument('--voxel_size', type=float, default=0.025) 111 | data_arg.add_argument( 112 | '--threed_match_dir', type=str, default="/home/chrischoy/datasets/FCGF/threedmatch") 113 | data_arg.add_argument( 114 | '--kitti_root', type=str, default="/home/chrischoy/datasets/FCGF/kitti/") 115 | data_arg.add_argument( 116 | '--kitti_max_time_diff', 117 | type=int, 118 | default=3, 119 | help='max time difference between pairs (non inclusive)') 120 | data_arg.add_argument('--kitti_date', type=str, default='2011_09_26') 121 | 122 | 123 | def get_config(): 124 | args = parser.parse_args() 125 | return args 126 | -------------------------------------------------------------------------------- /model/resunet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | import torch 3 | import MinkowskiEngine as ME 4 | import MinkowskiEngine.MinkowskiFunctional as MEF 5 | from model.common import get_norm 6 | 7 | from model.residual_block import get_block 8 | 9 | 10 | class ResUNet2(ME.MinkowskiNetwork): 11 | NORM_TYPE = None 12 | BLOCK_NORM_TYPE = 'BN' 13 | CHANNELS = [None, 32, 64, 128, 256] 14 | TR_CHANNELS = [None, 32, 64, 64, 128] 15 | 16 | # To use the model, must call initialize_coords before forward pass. 17 | # Once data is processed, call clear to reset the model before calling initialize_coords 18 | def __init__(self, 19 | in_channels=3, 20 | out_channels=32, 21 | bn_momentum=0.1, 22 | normalize_feature=None, 23 | conv1_kernel_size=None, 24 | D=3): 25 | ME.MinkowskiNetwork.__init__(self, D) 26 | NORM_TYPE = self.NORM_TYPE 27 | BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE 28 | CHANNELS = self.CHANNELS 29 | TR_CHANNELS = self.TR_CHANNELS 30 | self.normalize_feature = normalize_feature 31 | self.conv1 = ME.MinkowskiConvolution( 32 | in_channels=in_channels, 33 | out_channels=CHANNELS[1], 34 | kernel_size=conv1_kernel_size, 35 | stride=1, 36 | dilation=1, 37 | bias=False, 38 | dimension=D) 39 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 40 | 41 | self.block1 = get_block( 42 | BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D) 43 | 44 | self.conv2 = ME.MinkowskiConvolution( 45 | in_channels=CHANNELS[1], 46 | out_channels=CHANNELS[2], 47 | kernel_size=3, 48 | stride=2, 49 | dilation=1, 50 | bias=False, 51 | dimension=D) 52 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 53 | 54 | self.block2 = get_block( 55 | BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D) 56 | 57 | self.conv3 = ME.MinkowskiConvolution( 58 | in_channels=CHANNELS[2], 59 | out_channels=CHANNELS[3], 60 | kernel_size=3, 61 | stride=2, 62 | dilation=1, 63 | bias=False, 64 | dimension=D) 65 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 66 | 67 | self.block3 = get_block( 68 | BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D) 69 | 70 | self.conv4 = ME.MinkowskiConvolution( 71 | in_channels=CHANNELS[3], 72 | out_channels=CHANNELS[4], 73 | kernel_size=3, 74 | stride=2, 75 | dilation=1, 76 | bias=False, 77 | dimension=D) 78 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) 79 | 80 | self.block4 = get_block( 81 | BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D) 82 | 83 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 84 | in_channels=CHANNELS[4], 85 | out_channels=TR_CHANNELS[4], 86 | kernel_size=3, 87 | stride=2, 88 | dilation=1, 89 | bias=False, 90 | dimension=D) 91 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 92 | 93 | self.block4_tr = get_block( 94 | BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 95 | 96 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 97 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 98 | out_channels=TR_CHANNELS[3], 99 | kernel_size=3, 100 | stride=2, 101 | dilation=1, 102 | bias=False, 103 | dimension=D) 104 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 105 | 106 | self.block3_tr = get_block( 107 | BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 108 | 109 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 110 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 111 | out_channels=TR_CHANNELS[2], 112 | kernel_size=3, 113 | stride=2, 114 | dilation=1, 115 | bias=False, 116 | dimension=D) 117 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 118 | 119 | self.block2_tr = get_block( 120 | BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 121 | 122 | self.conv1_tr = ME.MinkowskiConvolution( 123 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 124 | out_channels=TR_CHANNELS[1], 125 | kernel_size=1, 126 | stride=1, 127 | dilation=1, 128 | bias=False, 129 | dimension=D) 130 | 131 | # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 132 | 133 | self.final = ME.MinkowskiConvolution( 134 | in_channels=TR_CHANNELS[1], 135 | out_channels=out_channels, 136 | kernel_size=1, 137 | stride=1, 138 | dilation=1, 139 | bias=True, 140 | dimension=D) 141 | 142 | def forward(self, x): 143 | out_s1 = self.conv1(x) 144 | out_s1 = self.norm1(out_s1) 145 | out_s1 = self.block1(out_s1) 146 | out = MEF.relu(out_s1) 147 | 148 | out_s2 = self.conv2(out) 149 | out_s2 = self.norm2(out_s2) 150 | out_s2 = self.block2(out_s2) 151 | out = MEF.relu(out_s2) 152 | 153 | out_s4 = self.conv3(out) 154 | out_s4 = self.norm3(out_s4) 155 | out_s4 = self.block3(out_s4) 156 | out = MEF.relu(out_s4) 157 | 158 | out_s8 = self.conv4(out) 159 | out_s8 = self.norm4(out_s8) 160 | out_s8 = self.block4(out_s8) 161 | out = MEF.relu(out_s8) 162 | 163 | out = self.conv4_tr(out) 164 | out = self.norm4_tr(out) 165 | out = self.block4_tr(out) 166 | out_s4_tr = MEF.relu(out) 167 | 168 | out = ME.cat(out_s4_tr, out_s4) 169 | 170 | out = self.conv3_tr(out) 171 | out = self.norm3_tr(out) 172 | out = self.block3_tr(out) 173 | out_s2_tr = MEF.relu(out) 174 | 175 | out = ME.cat(out_s2_tr, out_s2) 176 | 177 | out = self.conv2_tr(out) 178 | out = self.norm2_tr(out) 179 | out = self.block2_tr(out) 180 | out_s1_tr = MEF.relu(out) 181 | 182 | out = ME.cat(out_s1_tr, out_s1) 183 | out = self.conv1_tr(out) 184 | out = MEF.relu(out) 185 | out = self.final(out) 186 | 187 | if self.normalize_feature: 188 | return ME.SparseTensor( 189 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 190 | coordinate_map_key=out.coordinate_map_key, 191 | coordinate_manager=out.coordinate_manager) 192 | else: 193 | return out 194 | 195 | 196 | class ResUNetBN2(ResUNet2): 197 | NORM_TYPE = 'BN' 198 | 199 | 200 | class ResUNetBN2B(ResUNet2): 201 | NORM_TYPE = 'BN' 202 | CHANNELS = [None, 32, 64, 128, 256] 203 | TR_CHANNELS = [None, 64, 64, 64, 64] 204 | 205 | 206 | class ResUNetBN2C(ResUNet2): 207 | NORM_TYPE = 'BN' 208 | CHANNELS = [None, 32, 64, 128, 256] 209 | TR_CHANNELS = [None, 64, 64, 64, 128] 210 | 211 | 212 | class ResUNetBN2D(ResUNet2): 213 | NORM_TYPE = 'BN' 214 | CHANNELS = [None, 32, 64, 128, 256] 215 | TR_CHANNELS = [None, 64, 64, 128, 128] 216 | 217 | 218 | class ResUNetBN2E(ResUNet2): 219 | NORM_TYPE = 'BN' 220 | CHANNELS = [None, 128, 128, 128, 256] 221 | TR_CHANNELS = [None, 64, 128, 128, 128] 222 | 223 | 224 | class ResUNetIN2(ResUNet2): 225 | NORM_TYPE = 'BN' 226 | BLOCK_NORM_TYPE = 'IN' 227 | 228 | 229 | class ResUNetIN2B(ResUNetBN2B): 230 | NORM_TYPE = 'BN' 231 | BLOCK_NORM_TYPE = 'IN' 232 | 233 | 234 | class ResUNetIN2C(ResUNetBN2C): 235 | NORM_TYPE = 'BN' 236 | BLOCK_NORM_TYPE = 'IN' 237 | 238 | 239 | class ResUNetIN2D(ResUNetBN2D): 240 | NORM_TYPE = 'BN' 241 | BLOCK_NORM_TYPE = 'IN' 242 | 243 | 244 | class ResUNetIN2E(ResUNetBN2E): 245 | NORM_TYPE = 'BN' 246 | BLOCK_NORM_TYPE = 'IN' 247 | -------------------------------------------------------------------------------- /scripts/benchmark_3dmatch.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of unrefactored functions. 3 | """ 4 | import os 5 | import sys 6 | import numpy as np 7 | import argparse 8 | import logging 9 | import open3d as o3d 10 | 11 | from lib.timer import Timer, AverageMeter 12 | 13 | from util.misc import extract_features 14 | 15 | from model import load_model 16 | from util.file import ensure_dir, get_folder_list, get_file_list 17 | from util.trajectory import read_trajectory, write_trajectory 18 | from util.pointcloud import make_open3d_point_cloud, evaluate_feature_3dmatch 19 | from scripts.benchmark_util import do_single_pair_matching, gen_matching_pair, gather_results 20 | 21 | import torch 22 | 23 | import MinkowskiEngine as ME 24 | 25 | ch = logging.StreamHandler(sys.stdout) 26 | logging.getLogger().setLevel(logging.INFO) 27 | logging.basicConfig( 28 | format='%(asctime)s %(message)s', datefmt='%m/%d %H:%M:%S', handlers=[ch]) 29 | 30 | o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) 31 | 32 | 33 | def extract_features_batch(model, config, source_path, target_path, voxel_size, device): 34 | 35 | folders = get_folder_list(source_path) 36 | assert len(folders) > 0, f"Could not find 3DMatch folders under {source_path}" 37 | logging.info(folders) 38 | list_file = os.path.join(target_path, "list.txt") 39 | f = open(list_file, "w") 40 | timer, tmeter = Timer(), AverageMeter() 41 | num_feat = 0 42 | model.eval() 43 | 44 | for fo in folders: 45 | if 'evaluation' in fo: 46 | continue 47 | files = get_file_list(fo, ".ply") 48 | fo_base = os.path.basename(fo) 49 | f.write("%s %d\n" % (fo_base, len(files))) 50 | for i, fi in enumerate(files): 51 | # Extract features from a file 52 | pcd = o3d.io.read_point_cloud(fi) 53 | save_fn = "%s_%03d" % (fo_base, i) 54 | if i % 100 == 0: 55 | logging.info(f"{i} / {len(files)}: {save_fn}") 56 | 57 | timer.tic() 58 | xyz_down, feature = extract_features( 59 | model, 60 | xyz=np.array(pcd.points), 61 | rgb=None, 62 | normal=None, 63 | voxel_size=voxel_size, 64 | device=device, 65 | skip_check=True) 66 | t = timer.toc() 67 | if i > 0: 68 | tmeter.update(t) 69 | num_feat += len(xyz_down) 70 | 71 | np.savez_compressed( 72 | os.path.join(target_path, save_fn), 73 | points=np.array(pcd.points), 74 | xyz=xyz_down, 75 | feature=feature.detach().cpu().numpy()) 76 | if i % 20 == 0 and i > 0: 77 | logging.info( 78 | f'Average time: {tmeter.avg}, FPS: {num_feat / tmeter.sum}, time / feat: {tmeter.sum / num_feat}, ' 79 | ) 80 | 81 | f.close() 82 | 83 | 84 | def registration(feature_path, voxel_size): 85 | """ 86 | Gather .log files produced in --target folder and run this Matlab script 87 | https://github.com/andyzeng/3dmatch-toolbox#geometric-registration-benchmark 88 | (see Geometric Registration Benchmark section in 89 | http://3dmatch.cs.princeton.edu/) 90 | """ 91 | # List file from the extract_features_batch function 92 | with open(os.path.join(feature_path, "list.txt")) as f: 93 | sets = f.readlines() 94 | sets = [x.strip().split() for x in sets] 95 | for s in sets: 96 | set_name = s[0] 97 | pts_num = int(s[1]) 98 | matching_pairs = gen_matching_pair(pts_num) 99 | results = [] 100 | for m in matching_pairs: 101 | results.append(do_single_pair_matching(feature_path, set_name, m, voxel_size)) 102 | traj = gather_results(results) 103 | logging.info(f"Writing the trajectory to {feature_path}/{set_name}.log") 104 | write_trajectory(traj, "%s.log" % (os.path.join(feature_path, set_name))) 105 | 106 | 107 | def do_single_pair_evaluation(feature_path, 108 | set_name, 109 | traj, 110 | voxel_size, 111 | tau_1=0.1, 112 | tau_2=0.05, 113 | num_rand_keypoints=-1): 114 | trans_gth = np.linalg.inv(traj.pose) 115 | i = traj.metadata[0] 116 | j = traj.metadata[1] 117 | name_i = "%s_%03d" % (set_name, i) 118 | name_j = "%s_%03d" % (set_name, j) 119 | 120 | # coord and feat form a sparse tensor. 121 | data_i = np.load(os.path.join(feature_path, name_i + ".npz")) 122 | coord_i, points_i, feat_i = data_i['xyz'], data_i['points'], data_i['feature'] 123 | data_j = np.load(os.path.join(feature_path, name_j + ".npz")) 124 | coord_j, points_j, feat_j = data_j['xyz'], data_j['points'], data_j['feature'] 125 | 126 | # use the keypoints in 3DMatch 127 | if num_rand_keypoints > 0: 128 | # Randomly subsample N points 129 | Ni, Nj = len(points_i), len(points_j) 130 | inds_i = np.random.choice(Ni, min(Ni, num_rand_keypoints), replace=False) 131 | inds_j = np.random.choice(Nj, min(Nj, num_rand_keypoints), replace=False) 132 | 133 | sample_i, sample_j = points_i[inds_i], points_j[inds_j] 134 | 135 | key_points_i = ME.utils.fnv_hash_vec(np.floor(sample_i / voxel_size)) 136 | key_points_j = ME.utils.fnv_hash_vec(np.floor(sample_j / voxel_size)) 137 | 138 | key_coords_i = ME.utils.fnv_hash_vec(np.floor(coord_i / voxel_size)) 139 | key_coords_j = ME.utils.fnv_hash_vec(np.floor(coord_j / voxel_size)) 140 | 141 | inds_i = np.where(np.isin(key_coords_i, key_points_i))[0] 142 | inds_j = np.where(np.isin(key_coords_j, key_points_j))[0] 143 | 144 | coord_i, feat_i = coord_i[inds_i], feat_i[inds_i] 145 | coord_j, feat_j = coord_j[inds_j], feat_j[inds_j] 146 | 147 | coord_i = make_open3d_point_cloud(coord_i) 148 | coord_j = make_open3d_point_cloud(coord_j) 149 | 150 | hit_ratio = evaluate_feature_3dmatch(coord_i, coord_j, feat_i, feat_j, trans_gth, 151 | tau_1) 152 | 153 | # logging.info(f"Hit ratio of {name_i}, {name_j}: {hit_ratio}, {hit_ratio >= tau_2}") 154 | if hit_ratio >= tau_2: 155 | return True 156 | else: 157 | return False 158 | 159 | 160 | def feature_evaluation(source_path, feature_path, voxel_size, num_rand_keypoints=-1): 161 | with open(os.path.join(feature_path, "list.txt")) as f: 162 | sets = f.readlines() 163 | sets = [x.strip().split() for x in sets] 164 | 165 | assert len( 166 | sets 167 | ) > 0, "Empty list file. Makesure to run the feature extraction first with --do_extract_feature." 168 | 169 | tau_1 = 0.1 # 10cm 170 | tau_2 = 0.05 # 5% inlier 171 | logging.info("%f %f" % (tau_1, tau_2)) 172 | recall = [] 173 | for s in sets: 174 | set_name = s[0] 175 | traj = read_trajectory(os.path.join(source_path, set_name + "_gt.log")) 176 | assert len(traj) > 0, "Empty trajectory file" 177 | results = [] 178 | for i in range(len(traj)): 179 | results.append( 180 | do_single_pair_evaluation(feature_path, set_name, traj[i], voxel_size, tau_1, 181 | tau_2, num_rand_keypoints)) 182 | 183 | mean_recall = np.array(results).mean() 184 | std_recall = np.array(results).std() 185 | recall.append([set_name, mean_recall, std_recall]) 186 | logging.info(f'{set_name}: {mean_recall} +- {std_recall}') 187 | for r in recall: 188 | logging.info("%s : %.4f" % (r[0], r[1])) 189 | scene_r = np.array([r[1] for r in recall]) 190 | logging.info("average : %.4f +- %.4f" % (scene_r.mean(), scene_r.std())) 191 | 192 | 193 | if __name__ == '__main__': 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument( 196 | '--source', default=None, type=str, help='path to 3dmatch test dataset') 197 | parser.add_argument( 198 | '--source_high_res', 199 | default=None, 200 | type=str, 201 | help='path to high_resolution point cloud') 202 | parser.add_argument( 203 | '--target', default=None, type=str, help='path to produce generated data') 204 | parser.add_argument( 205 | '-m', 206 | '--model', 207 | default=None, 208 | type=str, 209 | help='path to latest checkpoint (default: None)') 210 | parser.add_argument( 211 | '--voxel_size', 212 | default=0.05, 213 | type=float, 214 | help='voxel size to preprocess point cloud') 215 | parser.add_argument('--extract_features', action='store_true') 216 | parser.add_argument('--evaluate_feature_match_recall', action='store_true') 217 | parser.add_argument( 218 | '--evaluate_registration', 219 | action='store_true', 220 | help='The target directory must contain extracted features') 221 | parser.add_argument('--with_cuda', action='store_true') 222 | parser.add_argument( 223 | '--num_rand_keypoints', 224 | type=int, 225 | default=5000, 226 | help='Number of random keypoints for each scene') 227 | 228 | args = parser.parse_args() 229 | 230 | device = torch.device('cuda' if args.with_cuda else 'cpu') 231 | 232 | if args.extract_features: 233 | assert args.model is not None 234 | assert args.source is not None 235 | assert args.target is not None 236 | 237 | ensure_dir(args.target) 238 | checkpoint = torch.load(args.model) 239 | config = checkpoint['config'] 240 | 241 | num_feats = 1 242 | Model = load_model(config.model) 243 | model = Model( 244 | num_feats, 245 | config.model_n_out, 246 | bn_momentum=0.05, 247 | normalize_feature=config.normalize_feature, 248 | conv1_kernel_size=config.conv1_kernel_size, 249 | D=3) 250 | model.load_state_dict(checkpoint['state_dict']) 251 | model.eval() 252 | 253 | model = model.to(device) 254 | 255 | with torch.no_grad(): 256 | extract_features_batch(model, config, args.source, args.target, config.voxel_size, 257 | device) 258 | 259 | if args.evaluate_feature_match_recall: 260 | assert (args.target is not None) 261 | with torch.no_grad(): 262 | feature_evaluation(args.source, args.target, args.voxel_size, 263 | args.num_rand_keypoints) 264 | 265 | if args.evaluate_registration: 266 | assert (args.target is not None) 267 | with torch.no_grad(): 268 | registration(args.target, args.voxel_size) 269 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fully Convolutional Geometric Features, ICCV, 2019 2 | 3 | Extracting geometric features from 3D scans or point clouds is the first step in applications such as registration, reconstruction, and tracking. State-of-the-art methods require computing low-level features as input or extracting patch-based features with limited receptive field. In this work, we present fully-convolutional geometric features, computed in a single pass by a 3D fully-convolutional network. We also present new metric learning losses that dramatically improve performance. Fully-convolutional geometric features are compact, capture broad spatial context, and scale to large scenes. We experimentally validate our approach on both indoor and outdoor datasets. Fully-convolutional geometric features achieve state-of-the-art accuracy without requiring prepossessing, are compact (32 dimensions), and are 600 times faster than the most accurate prior method. 4 | 5 | [ICCV'19 Paper](https://node1.chrischoy.org/data/publications/fcgf/fcgf.pdf) 6 | 7 | ## News 8 | 9 | - 2020-10-02 Measure the FCGF speedup on v0.5 on [MinkowskiEngineBenchmark](https://github.com/chrischoy/MinkowskiEngineBenchmark). The speedup ranges from 2.7x to 7.7x depending on the batch size. 10 | - 2020-09-04 Updates on ME v0.5 further speed up the inference time from 13.2ms to 11.8ms. As a reference, ME v0.4 takes 37ms. 11 | - 2020-08-18 Merged the v0.5 to the master with v0.5 installation. You can now use the full GPU support for sparse tensor hi-COO representation for faster training and inference. 12 | - 2020-08-07 [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine) v0.5 improves the **FCGF inference speed by x2.8** (280% speed-up, feed forward time for ResUNetBN2C on the 3DMatch kitchen point cloud ID-20: 37ms (ME v0.4.3) down to 13.2ms (ME v0.5.0). Measured on TitanXP, Ryzen-3700X). 13 | - 2020-06-15 [Source code](https://github.com/chrischoy/DeepGlobalRegistration) for **Deep Global Registration, CVPR'20 Oral** has been released. Please refer to the repository and the paper for using FCGF for registration. 14 | 15 | ## 3D Feature Accuracy vs. Speed 16 | 17 | | Comparison Table | Speed vs. Accuracy | 18 | |:----------------------------:|:------------------:| 19 | | ![Table](assets/table.png) | ![Accuracy vs. Speed](assets/fps_acc.png) | 20 | 21 | *Feature-match recall and speed in log scale on the 3DMatch benchmark. Our approach is the most accurate and the fastest. The gray region shows the Pareto frontier of the prior methods.* 22 | 23 | 24 | ### Related Works 25 | 26 | 3DMatch by Zeng et al. uses a Siamese convolutional network to learn 3D patch descriptors. 27 | CGF by Khoury et al. maps 3D oriented histograms to a low-dimensional feature space using multi-layer perceptrons. PPFNet and PPF FoldNet by Deng et al. adapts the PointNet architecture for geometric feature description. 3DFeat by Yew and Lee uses a PointNet to extract features in outdoor scenes. 28 | 29 | Our work addressed a number of limitations in the prior work. First, all prior approaches extract a small 3D patch or a set of points and map it to a low-dimensional space. This not only limits the receptive field of the network but is also computationally inefficient since all intermediate representations are computed separately even for overlapping 3D regions. Second, using expensive low-level geometric signatures as input can slow down feature computation. Lastly, limiting feature extraction to a subset of interest points results in lower spatial resolution for subsequent matching stages and can thus reduce registration accuracy. 30 | 31 | 32 | ### Fully Convolutional Metric Learning, Hardest Contrastive, and Hardest Triplet Loss 33 | 34 | Traditional metric learning assumes that the features are independent and identically distributed (i.i.d.) since a batch is constructed by random sampling. However, in fully-convolutional metric learning first proposed in [Universal Correspondence Network, Choy 2016](https://github.com/chrischoy/open-ucn), adjacent features are locally correlated and hard-negative mining could find features adjacent to anchors, which are false negatives. Thus, filtering out these false negatives is a crucial step similar to how Universal Correspondence Network used a distance threshold to filter out the false negatives. 35 | 36 | Also, the number of features used in the fully-convolutional setting is orders of magnitude larger than that in standard metric learning algorithms. For instance, FCGF generates ~40k features for a pair of scans (this increases proportionally with the batch size) while a minibatch in traditional metric learning has around 1k features. Thus, it is not feasible to use all pairwise distances within a batch in the standard metric learning. 37 | 38 | Instead, we propose the hardest-contrastive loss and the hardest-triplet loss. Visually, these are simple variants that use the hardest negatives for both features within a positive pair. 39 | One of the key advantages of the hardest-contrastive loss is that you do not need to save the temporary variables used to find the hardest negatives. This small change allows us to reconstruct the loss from the hardest negatives indices and throw away the intermediate results among a large number of feature. [Here](https://github.com/chrischoy/open-ucn/blob/master/lib/ucn_trainer.py#L435), we used almost 40k features to mine the hardest negative and destroy all intermediate variables once the indices of the hardest negatives are found for each positive feature. 40 | 41 | | Contrastive Loss | Triplet Loss | Hardest Contrastive | Hardest Triplet | 42 | |:------------------:|:------------------:|:-------------------:|:------------------:| 43 | | ![1](assets/1.png) | ![2](assets/2.png) | ![3](assets/3.png) | ![4](assets/4.png) | 44 | 45 | *Sampling and negative-mining strategy for each method. Blue: positives, Red: Negatives. Traditional contrastive and triplet losses use random sampling. Our hardest-contrastive and hardest-triplet losses use the hardest negatives.* 46 | 47 | Please refer to our [ICCV'19 paper](https://node1.chrischoy.org/data/publications/fcgf/fcgf.pdf) for more details. 48 | 49 | 50 | ### Visualization of FCGF 51 | 52 | We color-coded FCGF features for pairs of 3D scans that are 10m apart for KITTI and a 3DMatch benchmark pair for indoor scans. FCGF features are mapped to a scalar space using t-SNE and colorized with the Spectral color map. 53 | 54 | | KITTI LIDAR Scan 1 | KITTI LIDAR Scan 2 | 55 | |:--------------------:|:--------------------:| 56 | | ![0](assets/3_1.png) | ![1](assets/3_2.png) | 57 | 58 | | Indoor Scan 1 | Indoor Scan 2 | 59 | |:--------------------------:|:--------------------------:| 60 | | ![0](assets/kitchen_0.png) | ![1](assets/kitchen_1.png) | 61 | 62 | #### FCGF Correspondence Visualizations 63 | 64 | Please follow the link [Youtube Video](https://www.youtube.com/watch?v=d0p0eTaB50k) or click the image to view the YouTube video of FCGF visualizations. 65 | [![](assets/text_scene000.gif)](https://www.youtube.com/watch?v=d0p0eTaB50k) 66 | 67 | ## Requirements 68 | 69 | - Ubuntu 14.04 or higher 70 | - CUDA 11.1 or higher 71 | - Python v3.7 or higher 72 | - Pytorch v1.6 or higher 73 | - [MinkowskiEngine](https://github.com/stanfordvl/MinkowskiEngine) v0.5 or higher 74 | 75 | 76 | ## Installation & Dataset Download 77 | 78 | 79 | We recommend conda for installation. First, create a conda environment with pytorch 1.5 or higher with 80 | 81 | ``` 82 | conda create -n py3-fcgf python=3.7 83 | conda activate py3-fcgf 84 | conda install pytorch -c pytorch 85 | pip install git+https://github.com/NVIDIA/MinkowskiEngine.git 86 | ``` 87 | 88 | Next, download FCGF git repository and install the requirement from the FCGF root directory.. 89 | 90 | ``` 91 | git clone https://github.com/chrischoy/FCGF.git 92 | cd FCGF 93 | # Do the following inside the conda environment 94 | pip install -r requirements.txt 95 | ``` 96 | 97 | For training, download the preprocessed 3DMatch benchmark dataset. 98 | 99 | ``` 100 | ./scripts/download_datasets.sh /path/to/dataset/download/dir 101 | ``` 102 | 103 | For KITTI training, follow the instruction on [KITTI Odometry website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) to download the KITTI odometry training set. 104 | 105 | 106 | ## Demo: Extracting and color coding FCGF 107 | 108 | After installation, you can run the demo script by 109 | 110 | ``` 111 | python demo.py 112 | ``` 113 | 114 | The demo script will first extract FCGF features from a mesh file generated from a kitchen scene. Next, it will color code the features independent of their spatial location. 115 | After the color mapping using TSNE, the demo script will visualize the color coded features by coloring the input point cloud. 116 | 117 | ![demo](./assets/demo.png) 118 | 119 | *You may have to rotate the scene to get the above visualization.* 120 | 121 | 122 | ## Training and running 3DMatch benchmark 123 | 124 | ``` 125 | python train.py --threed_match_dir /path/to/threedmatch/ 126 | ``` 127 | 128 | For benchmarking the trained weights on 3DMatch, download the 3DMatch Geometric Registration Benchmark dataset from [here](http://3dmatch.cs.princeton.edu/) or run 129 | 130 | ``` 131 | bash ./scripts/download_3dmatch_test.sh /path/to/threedmatch_test/ 132 | ``` 133 | 134 | and follow: 135 | 136 | ``` 137 | python -m scripts.benchmark_3dmatch.py \ 138 | --source /path/to/threedmatch \ 139 | --target ./features_tmp/ \ 140 | --voxel_size 0.025 \ 141 | --model ~/outputs/checkpoint.pth \ 142 | --extract_features --evaluate_feature_match_recall --with_cuda 143 | ``` 144 | 145 | 146 | ## Training and testing on KITTI Odometry custom split 147 | 148 | For KITTI training, follow the instruction on [KITTI Odometry website](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) to download the KITTI odometry training set. 149 | 150 | ``` 151 | export KITTI_PATH=/path/to/kitti/; ./scripts/train_fcgf_kitti.sh 152 | ``` 153 | 154 | ## Registration Test on 3DMatch 155 | 156 | 157 | 158 | ## Model Zoo 159 | 160 | | Model | Normalized Feature | Dataset | Voxel Size | Feature Dimension | Performance | Link | 161 | |:-----------:|:-------------------:|:-------:|:-------------:|:-----------------:|:--------------------------:|:------:| 162 | | ResUNetBN2C | True | 3DMatch | 2.5cm (0.025) | 32 | FMR: 0.9578 +- 0.0272 | [download](https://node1.chrischoy.org/data/publications/fcgf/2019-08-19_06-17-41.pth) | 163 | | ResUNetBN2C | True | 3DMatch | 2.5cm (0.025) | 16 | FMR: 0.9442 +- 0.0345 | [download](https://node1.chrischoy.org/data/publications/fcgf/2019-09-18_14-15-59.pth) | 164 | | ResUNetBN2C | True | 3DMatch | 5cm (0.05) | 32 | FMR: 0.9372 +- 0.0332 | [download](https://node1.chrischoy.org/data/publications/fcgf/2019-08-16_19-21-47.pth) | 165 | | ResUNetBN2C | False | KITTI | 20cm (0.2) | 32 | RTE: 0.0534m, RRE: 0.1704° | [download](https://node1.chrischoy.org/data/publications/fcgf/2019-07-31_19-30-19.pth) | 166 | | ResUNetBN2C | False | KITTI | 30cm (0.3) | 32 | RTE: 0.0607m, RRE: 0.2280° | [download](https://node1.chrischoy.org/data/publications/fcgf/2019-07-31_19-37-00.pth) | 167 | | ResUNetBN2C | True | KITTI | 30cm (0.3) | 16 | RTE: 0.0670m, RRE: 0.2295° | [download](https://node1.chrischoy.org/data/publications/fcgf/KITTI-v0.3-ResUNetBN2C-conv1-5-nout16.pth) | 168 | | ResUNetBN2C | True | KITTI | 30cm (0.3) | 32 | RTE: 0.0639m, RRE: 0.2253° | [download](https://node1.chrischoy.org/data/publications/fcgf/KITTI-v0.3-ResUNetBN2C-conv1-5-nout32.pth) | 169 | 170 | 171 | ## Raw Data for FCGF Figure 4 172 | 173 | - [Distance threshold data](https://raw.githubusercontent.com/chrischoy/FCGF/master/assets/fig4_dist_thresh.txt) 174 | - [Inlier threshold data](https://raw.githubusercontent.com/chrischoy/FCGF/master/assets/fig4_inlier_thresh.txt) 175 | 176 | 177 | ## Citing FCGF 178 | 179 | FCGF will be presented at ICCV'19: Friday, November 1, 2019, 1030–1300 Poster 4.1 (Hall B) 180 | 181 | ``` 182 | @inproceedings{FCGF2019, 183 | author = {Christopher Choy and Jaesik Park and Vladlen Koltun}, 184 | title = {Fully Convolutional Geometric Features}, 185 | booktitle = {ICCV}, 186 | year = {2019}, 187 | } 188 | ``` 189 | 190 | ## Related Projects 191 | 192 | - A neural network library for high-dimensional sparse tensors: [Minkowski Engine, CVPR'19](https://github.com/StanfordVL/MinkowskiEngine) 193 | - Semantic segmentation on a high-dimensional sparse tensor: [4D Spatio Temporal ConvNets, CVPR'19](https://github.com/chrischoy/SpatioTemporalSegmentation) 194 | - The first fully convolutional metric learning for correspondences: [Universal Correspondence Network, NIPS'16](https://github.com/chrischoy/open-ucn) 195 | - 3D Registration Network with 6-dimensional ConvNets: [Deep Global Registration, CVPR'20](https://github.com/chrischoy/DeepGlobalRegistration) 196 | 197 | 198 | ## Projects using FCGF 199 | 200 | - Gojcic et al., [Learning multiview 3D point cloud registration, CVPR'20](https://arxiv.org/abs/2001.05119) 201 | - Choy et al., [Deep Global Registration, CVPR'20 Oral](https://arxiv.org/abs/2004.11540) 202 | 203 | 204 | ## Acknowledgements 205 | 206 | We want to thank all the ICCV reviewers, especially R2, for suggestions and valuable pointers. 207 | -------------------------------------------------------------------------------- /model/simpleunet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | import torch 3 | import MinkowskiEngine as ME 4 | import MinkowskiEngine.MinkowskiFunctional as MEF 5 | from model.common import get_norm 6 | 7 | 8 | class SimpleNet(ME.MinkowskiNetwork): 9 | NORM_TYPE = None 10 | CHANNELS = [None, 32, 64, 128] 11 | TR_CHANNELS = [None, 32, 32, 64] 12 | 13 | # To use the model, must call initialize_coords before forward pass. 14 | # Once data is processed, call clear to reset the model before calling initialize_coords 15 | def __init__(self, 16 | in_channels=3, 17 | out_channels=32, 18 | bn_momentum=0.1, 19 | normalize_feature=None, 20 | conv1_kernel_size=None, 21 | D=3): 22 | super(SimpleNet, self).__init__(D) 23 | NORM_TYPE = self.NORM_TYPE 24 | CHANNELS = self.CHANNELS 25 | TR_CHANNELS = self.TR_CHANNELS 26 | self.normalize_feature = normalize_feature 27 | self.conv1 = ME.MinkowskiConvolution( 28 | in_channels=in_channels, 29 | out_channels=CHANNELS[1], 30 | kernel_size=conv1_kernel_size, 31 | stride=1, 32 | dilation=1, 33 | bias=False, 34 | dimension=D) 35 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 36 | 37 | self.conv2 = ME.MinkowskiConvolution( 38 | in_channels=CHANNELS[1], 39 | out_channels=CHANNELS[2], 40 | kernel_size=3, 41 | stride=2, 42 | dilation=1, 43 | bias=False, 44 | dimension=D) 45 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 46 | 47 | self.conv3 = ME.MinkowskiConvolution( 48 | in_channels=CHANNELS[2], 49 | out_channels=CHANNELS[3], 50 | kernel_size=3, 51 | stride=2, 52 | dilation=1, 53 | bias=False, 54 | dimension=D) 55 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 56 | 57 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 58 | in_channels=CHANNELS[3], 59 | out_channels=TR_CHANNELS[3], 60 | kernel_size=3, 61 | stride=2, 62 | dilation=1, 63 | bias=False, 64 | dimension=D) 65 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 66 | 67 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 68 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 69 | out_channels=TR_CHANNELS[2], 70 | kernel_size=3, 71 | stride=2, 72 | dilation=1, 73 | bias=False, 74 | dimension=D) 75 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 76 | 77 | self.conv1_tr = ME.MinkowskiConvolution( 78 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 79 | out_channels=TR_CHANNELS[1], 80 | kernel_size=3, 81 | stride=1, 82 | dilation=1, 83 | bias=False, 84 | dimension=D) 85 | self.norm1_tr = get_norm(NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 86 | 87 | self.final = ME.MinkowskiConvolution( 88 | in_channels=TR_CHANNELS[1], 89 | out_channels=out_channels, 90 | kernel_size=1, 91 | stride=1, 92 | dilation=1, 93 | bias=True, 94 | dimension=D) 95 | 96 | def forward(self, x): 97 | out_s1 = self.conv1(x) 98 | out_s1 = self.norm1(out_s1) 99 | out = MEF.relu(out_s1) 100 | 101 | out_s2 = self.conv2(out) 102 | out_s2 = self.norm2(out_s2) 103 | out = MEF.relu(out_s2) 104 | 105 | out_s4 = self.conv3(out) 106 | out_s4 = self.norm3(out_s4) 107 | out = MEF.relu(out_s4) 108 | 109 | out = self.conv3_tr(out) 110 | out = self.norm3_tr(out) 111 | out_s2_tr = MEF.relu(out) 112 | 113 | out = ME.cat(out_s2_tr, out_s2) 114 | 115 | out = self.conv2_tr(out) 116 | out = self.norm2_tr(out) 117 | out_s1_tr = MEF.relu(out) 118 | 119 | out = ME.cat(out_s1_tr, out_s1) 120 | out = self.conv1_tr(out) 121 | out = self.norm1_tr(out) 122 | out = MEF.relu(out) 123 | 124 | out = self.final(out) 125 | 126 | if self.normalize_feature: 127 | return ME.SparseTensor( 128 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 129 | coordinate_map_key=out.coordinate_map_key, 130 | coordinate_manager=out.coordinate_manager) 131 | else: 132 | return out 133 | 134 | 135 | class SimpleNetIN(SimpleNet): 136 | NORM_TYPE = 'IN' 137 | 138 | 139 | class SimpleNetBN(SimpleNet): 140 | NORM_TYPE = 'BN' 141 | 142 | 143 | class SimpleNetBNE(SimpleNetBN): 144 | CHANNELS = [None, 16, 32, 32] 145 | TR_CHANNELS = [None, 16, 16, 32] 146 | 147 | 148 | class SimpleNetINE(SimpleNetBNE): 149 | NORM_TYPE = 'IN' 150 | 151 | 152 | class SimpleNet2(ME.MinkowskiNetwork): 153 | NORM_TYPE = None 154 | CHANNELS = [None, 32, 64, 128, 256] 155 | TR_CHANNELS = [None, 32, 32, 64, 64] 156 | 157 | # To use the model, must call initialize_coords before forward pass. 158 | # Once data is processed, call clear to reset the model before calling initialize_coords 159 | def __init__(self, in_channels=3, out_channels=32, bn_momentum=0.1, D=3, config=None): 160 | ME.MinkowskiNetwork.__init__(self, D) 161 | NORM_TYPE = self.NORM_TYPE 162 | bn_momentum = config.bn_momentum 163 | CHANNELS = self.CHANNELS 164 | TR_CHANNELS = self.TR_CHANNELS 165 | self.normalize_feature = config.normalize_feature 166 | self.conv1 = ME.MinkowskiConvolution( 167 | in_channels=in_channels, 168 | out_channels=CHANNELS[1], 169 | kernel_size=config.conv1_kernel_size, 170 | stride=1, 171 | dilation=1, 172 | bias=False, 173 | dimension=D) 174 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 175 | 176 | self.conv2 = ME.MinkowskiConvolution( 177 | in_channels=CHANNELS[1], 178 | out_channels=CHANNELS[2], 179 | kernel_size=3, 180 | stride=2, 181 | dilation=1, 182 | bias=False, 183 | dimension=D) 184 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 185 | 186 | self.conv3 = ME.MinkowskiConvolution( 187 | in_channels=CHANNELS[2], 188 | out_channels=CHANNELS[3], 189 | kernel_size=3, 190 | stride=2, 191 | dilation=1, 192 | bias=False, 193 | dimension=D) 194 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 195 | 196 | self.conv4 = ME.MinkowskiConvolution( 197 | in_channels=CHANNELS[3], 198 | out_channels=CHANNELS[4], 199 | kernel_size=3, 200 | stride=2, 201 | dilation=1, 202 | bias=False, 203 | dimension=D) 204 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) 205 | 206 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 207 | in_channels=CHANNELS[4], 208 | out_channels=TR_CHANNELS[4], 209 | kernel_size=3, 210 | stride=2, 211 | dilation=1, 212 | bias=False, 213 | dimension=D) 214 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 215 | 216 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 217 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 218 | out_channels=TR_CHANNELS[3], 219 | kernel_size=3, 220 | stride=2, 221 | dilation=1, 222 | bias=False, 223 | dimension=D) 224 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 225 | 226 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 227 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 228 | out_channels=TR_CHANNELS[2], 229 | kernel_size=3, 230 | stride=2, 231 | dilation=1, 232 | bias=False, 233 | dimension=D) 234 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 235 | 236 | self.conv1_tr = ME.MinkowskiConvolution( 237 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 238 | out_channels=TR_CHANNELS[1], 239 | kernel_size=3, 240 | stride=1, 241 | dilation=1, 242 | bias=False, 243 | dimension=D) 244 | self.norm1_tr = get_norm(NORM_TYPE, TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 245 | 246 | self.final = ME.MinkowskiConvolution( 247 | in_channels=TR_CHANNELS[1], 248 | out_channels=out_channels, 249 | kernel_size=1, 250 | stride=1, 251 | dilation=1, 252 | bias=True, 253 | dimension=D) 254 | 255 | def forward(self, x): 256 | out_s1 = self.conv1(x) 257 | out_s1 = self.norm1(out_s1) 258 | out = MEF.relu(out_s1) 259 | 260 | out_s2 = self.conv2(out) 261 | out_s2 = self.norm2(out_s2) 262 | out = MEF.relu(out_s2) 263 | 264 | out_s4 = self.conv3(out) 265 | out_s4 = self.norm3(out_s4) 266 | out = MEF.relu(out_s4) 267 | 268 | out_s8 = self.conv4(out) 269 | out_s8 = self.norm4(out_s8) 270 | out = MEF.relu(out_s8) 271 | 272 | out = self.conv4_tr(out) 273 | out = self.norm4_tr(out) 274 | out_s4_tr = MEF.relu(out) 275 | 276 | out = ME.cat(out_s4_tr, out_s4) 277 | 278 | out = self.conv3_tr(out) 279 | out = self.norm3_tr(out) 280 | out_s2_tr = MEF.relu(out) 281 | 282 | out = ME.cat(out_s2_tr, out_s2) 283 | 284 | out = self.conv2_tr(out) 285 | out = self.norm2_tr(out) 286 | out_s1_tr = MEF.relu(out) 287 | 288 | out = ME.cat(out_s1_tr, out_s1) 289 | out = self.conv1_tr(out) 290 | out = self.norm1_tr(out) 291 | out = MEF.relu(out) 292 | 293 | out = self.final(out) 294 | 295 | if self.normalize_feature: 296 | return ME.SparseTensor( 297 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 298 | coordinate_map_key=out.coordinate_map_key, 299 | coordinate_manager=out.coordinate_manager) 300 | else: 301 | return out 302 | 303 | 304 | class SimpleNetIN2(SimpleNet2): 305 | NORM_TYPE = 'IN' 306 | 307 | 308 | class SimpleNetBN2(SimpleNet2): 309 | NORM_TYPE = 'BN' 310 | 311 | 312 | class SimpleNetBN2B(SimpleNet2): 313 | NORM_TYPE = 'BN' 314 | CHANNELS = [None, 32, 64, 128, 256] 315 | TR_CHANNELS = [None, 64, 64, 64, 64] 316 | 317 | 318 | class SimpleNetBN2C(SimpleNet2): 319 | NORM_TYPE = 'BN' 320 | CHANNELS = [None, 32, 64, 128, 256] 321 | TR_CHANNELS = [None, 32, 64, 64, 128] 322 | 323 | 324 | class SimpleNetBN2D(SimpleNet2): 325 | NORM_TYPE = 'BN' 326 | CHANNELS = [None, 32, 64, 128, 256] 327 | TR_CHANNELS = [None, 32, 64, 64, 128] 328 | 329 | 330 | class SimpleNetBN2E(SimpleNet2): 331 | NORM_TYPE = 'BN' 332 | CHANNELS = [None, 16, 32, 64, 128] 333 | TR_CHANNELS = [None, 16, 32, 32, 64] 334 | 335 | 336 | class SimpleNetIN2E(SimpleNetBN2E): 337 | NORM_TYPE = 'IN' 338 | 339 | 340 | class SimpleNet3(ME.MinkowskiNetwork): 341 | NORM_TYPE = None 342 | CHANNELS = [None, 32, 64, 128, 256, 512] 343 | TR_CHANNELS = [None, 32, 32, 64, 64, 128] 344 | 345 | # To use the model, must call initialize_coords before forward pass. 346 | # Once data is processed, call clear to reset the model before calling initialize_coords 347 | def __init__(self, in_channels=3, out_channels=32, bn_momentum=0.1, D=3, config=None): 348 | ME.MinkowskiNetwork.__init__(self, D) 349 | NORM_TYPE = self.NORM_TYPE 350 | bn_momentum = config.bn_momentum 351 | CHANNELS = self.CHANNELS 352 | TR_CHANNELS = self.TR_CHANNELS 353 | self.normalize_feature = config.normalize_feature 354 | self.conv1 = ME.MinkowskiConvolution( 355 | in_channels=in_channels, 356 | out_channels=CHANNELS[1], 357 | kernel_size=config.conv1_kernel_size, 358 | stride=1, 359 | dilation=1, 360 | bias=False, 361 | dimension=D) 362 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 363 | 364 | self.conv2 = ME.MinkowskiConvolution( 365 | in_channels=CHANNELS[1], 366 | out_channels=CHANNELS[2], 367 | kernel_size=3, 368 | stride=2, 369 | dilation=1, 370 | bias=False, 371 | dimension=D) 372 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 373 | 374 | self.conv3 = ME.MinkowskiConvolution( 375 | in_channels=CHANNELS[2], 376 | out_channels=CHANNELS[3], 377 | kernel_size=3, 378 | stride=2, 379 | dilation=1, 380 | bias=False, 381 | dimension=D) 382 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 383 | 384 | self.conv4 = ME.MinkowskiConvolution( 385 | in_channels=CHANNELS[3], 386 | out_channels=CHANNELS[4], 387 | kernel_size=3, 388 | stride=2, 389 | dilation=1, 390 | bias=False, 391 | dimension=D) 392 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) 393 | 394 | self.conv5 = ME.MinkowskiConvolution( 395 | in_channels=CHANNELS[4], 396 | out_channels=CHANNELS[5], 397 | kernel_size=3, 398 | stride=2, 399 | dilation=1, 400 | bias=False, 401 | dimension=D) 402 | self.norm5 = get_norm(NORM_TYPE, CHANNELS[5], bn_momentum=bn_momentum, D=D) 403 | 404 | self.conv5_tr = ME.MinkowskiConvolutionTranspose( 405 | in_channels=CHANNELS[5], 406 | out_channels=TR_CHANNELS[5], 407 | kernel_size=3, 408 | stride=2, 409 | dilation=1, 410 | bias=False, 411 | dimension=D) 412 | self.norm5_tr = get_norm(NORM_TYPE, TR_CHANNELS[5], bn_momentum=bn_momentum, D=D) 413 | 414 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 415 | in_channels=CHANNELS[4] + TR_CHANNELS[5], 416 | out_channels=TR_CHANNELS[4], 417 | kernel_size=3, 418 | stride=2, 419 | dilation=1, 420 | bias=False, 421 | dimension=D) 422 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 423 | 424 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 425 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 426 | out_channels=TR_CHANNELS[3], 427 | kernel_size=3, 428 | stride=2, 429 | dilation=1, 430 | bias=False, 431 | dimension=D) 432 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 433 | 434 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 435 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 436 | out_channels=TR_CHANNELS[2], 437 | kernel_size=3, 438 | stride=2, 439 | dilation=1, 440 | bias=False, 441 | dimension=D) 442 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 443 | 444 | self.conv1_tr = ME.MinkowskiConvolution( 445 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 446 | out_channels=TR_CHANNELS[1], 447 | kernel_size=1, 448 | stride=1, 449 | dilation=1, 450 | bias=True, 451 | dimension=D) 452 | 453 | def forward(self, x): 454 | out_s1 = self.conv1(x) 455 | out_s1 = self.norm1(out_s1) 456 | out = MEF.relu(out_s1) 457 | 458 | out_s2 = self.conv2(out) 459 | out_s2 = self.norm2(out_s2) 460 | out = MEF.relu(out_s2) 461 | 462 | out_s4 = self.conv3(out) 463 | out_s4 = self.norm3(out_s4) 464 | out = MEF.relu(out_s4) 465 | 466 | out_s8 = self.conv4(out) 467 | out_s8 = self.norm4(out_s8) 468 | out = MEF.relu(out_s8) 469 | 470 | out_s16 = self.conv5(out) 471 | out_s16 = self.norm5(out_s16) 472 | out = MEF.relu(out_s16) 473 | 474 | out = self.conv5_tr(out) 475 | out = self.norm5_tr(out) 476 | out_s8_tr = MEF.relu(out) 477 | 478 | out = ME.cat(out_s8_tr, out_s8) 479 | 480 | out = self.conv4_tr(out) 481 | out = self.norm4_tr(out) 482 | out_s4_tr = MEF.relu(out) 483 | 484 | out = ME.cat(out_s4_tr, out_s4) 485 | 486 | out = self.conv3_tr(out) 487 | out = self.norm3_tr(out) 488 | out_s2_tr = MEF.relu(out) 489 | 490 | out = ME.cat(out_s2_tr, out_s2) 491 | 492 | out = self.conv2_tr(out) 493 | out = self.norm2_tr(out) 494 | out_s1_tr = MEF.relu(out) 495 | 496 | out = ME.cat(out_s1_tr, out_s1) 497 | out = self.conv1_tr(out) 498 | 499 | if self.normalize_feature: 500 | return ME.SparseTensor( 501 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 502 | coordinate_map_key=out.coordinate_map_key, 503 | coordinate_manager=out.coordinate_manager) 504 | else: 505 | return out 506 | 507 | 508 | class SimpleNetIN3(SimpleNet3): 509 | NORM_TYPE = 'IN' 510 | 511 | 512 | class SimpleNetBN3(SimpleNet3): 513 | NORM_TYPE = 'BN' 514 | 515 | 516 | class SimpleNetBN3B(SimpleNet3): 517 | NORM_TYPE = 'BN' 518 | CHANNELS = [None, 32, 64, 128, 256, 512] 519 | TR_CHANNELS = [None, 32, 64, 64, 64, 128] 520 | 521 | 522 | class SimpleNetBN3C(SimpleNet3): 523 | NORM_TYPE = 'BN' 524 | CHANNELS = [None, 32, 64, 128, 256, 512] 525 | TR_CHANNELS = [None, 32, 32, 64, 128, 128] 526 | 527 | 528 | class SimpleNetBN3D(SimpleNet3): 529 | NORM_TYPE = 'BN' 530 | CHANNELS = [None, 32, 64, 128, 256, 512] 531 | TR_CHANNELS = [None, 32, 64, 64, 128, 128] 532 | 533 | 534 | class SimpleNetBN3E(SimpleNet3): 535 | NORM_TYPE = 'BN' 536 | CHANNELS = [None, 16, 32, 64, 128, 256] 537 | TR_CHANNELS = [None, 16, 32, 32, 64, 128] 538 | 539 | 540 | class SimpleNetIN3E(SimpleNetBN3E): 541 | NORM_TYPE = 'IN' 542 | -------------------------------------------------------------------------------- /lib/data_loaders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | # 3 | # Written by Chris Choy 4 | # Distributed under MIT License 5 | import logging 6 | import random 7 | import torch 8 | import torch.utils.data 9 | import numpy as np 10 | import glob 11 | import os 12 | from scipy.linalg import expm, norm 13 | import pathlib 14 | 15 | from util.pointcloud import get_matching_indices, make_open3d_point_cloud 16 | import lib.transforms as t 17 | 18 | import MinkowskiEngine as ME 19 | 20 | import open3d as o3d 21 | 22 | kitti_cache = {} 23 | kitti_icp_cache = {} 24 | 25 | 26 | def collate_pair_fn(list_data): 27 | xyz0, xyz1, coords0, coords1, feats0, feats1, matching_inds, trans = list( 28 | zip(*list_data)) 29 | xyz_batch0, xyz_batch1 = [], [] 30 | matching_inds_batch, trans_batch, len_batch = [], [], [] 31 | 32 | batch_id = 0 33 | curr_start_inds = np.zeros((1, 2)) 34 | 35 | def to_tensor(x): 36 | if isinstance(x, torch.Tensor): 37 | return x 38 | elif isinstance(x, np.ndarray): 39 | return torch.from_numpy(x) 40 | else: 41 | raise ValueError(f'Can not convert to torch tensor, {x}') 42 | 43 | for batch_id, _ in enumerate(coords0): 44 | N0 = coords0[batch_id].shape[0] 45 | N1 = coords1[batch_id].shape[0] 46 | 47 | xyz_batch0.append(to_tensor(xyz0[batch_id])) 48 | xyz_batch1.append(to_tensor(xyz1[batch_id])) 49 | 50 | trans_batch.append(to_tensor(trans[batch_id])) 51 | 52 | matching_inds_batch.append( 53 | torch.from_numpy(np.array(matching_inds[batch_id]) + curr_start_inds)) 54 | len_batch.append([N0, N1]) 55 | 56 | # Move the head 57 | curr_start_inds[0, 0] += N0 58 | curr_start_inds[0, 1] += N1 59 | 60 | coords_batch0, feats_batch0 = ME.utils.sparse_collate(coords0, feats0) 61 | coords_batch1, feats_batch1 = ME.utils.sparse_collate(coords1, feats1) 62 | 63 | # Concatenate all lists 64 | xyz_batch0 = torch.cat(xyz_batch0, 0).float() 65 | xyz_batch1 = torch.cat(xyz_batch1, 0).float() 66 | trans_batch = torch.cat(trans_batch, 0).float() 67 | matching_inds_batch = torch.cat(matching_inds_batch, 0).int() 68 | 69 | return { 70 | 'pcd0': xyz_batch0, 71 | 'pcd1': xyz_batch1, 72 | 'sinput0_C': coords_batch0, 73 | 'sinput0_F': feats_batch0.float(), 74 | 'sinput1_C': coords_batch1, 75 | 'sinput1_F': feats_batch1.float(), 76 | 'correspondences': matching_inds_batch, 77 | 'T_gt': trans_batch, 78 | 'len_batch': len_batch 79 | } 80 | 81 | 82 | # Rotation matrix along axis with angle theta 83 | def M(axis, theta): 84 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 85 | 86 | 87 | def sample_random_trans(pcd, randg, rotation_range=360): 88 | T = np.eye(4) 89 | R = M(randg.rand(3) - 0.5, rotation_range * np.pi / 180.0 * (randg.rand(1) - 0.5)) 90 | T[:3, :3] = R 91 | T[:3, 3] = R.dot(-np.mean(pcd, axis=0)) 92 | return T 93 | 94 | 95 | class PairDataset(torch.utils.data.Dataset): 96 | AUGMENT = None 97 | 98 | def __init__(self, 99 | phase, 100 | transform=None, 101 | random_rotation=True, 102 | random_scale=True, 103 | manual_seed=False, 104 | config=None): 105 | self.phase = phase 106 | self.files = [] 107 | self.data_objects = [] 108 | self.transform = transform 109 | self.voxel_size = config.voxel_size 110 | self.matching_search_voxel_size = \ 111 | config.voxel_size * config.positive_pair_search_voxel_size_multiplier 112 | 113 | self.random_scale = random_scale 114 | self.min_scale = config.min_scale 115 | self.max_scale = config.max_scale 116 | self.random_rotation = random_rotation 117 | self.rotation_range = config.rotation_range 118 | self.randg = np.random.RandomState() 119 | if manual_seed: 120 | self.reset_seed() 121 | 122 | def reset_seed(self, seed=0): 123 | logging.info(f"Resetting the data loader seed to {seed}") 124 | self.randg.seed(seed) 125 | 126 | def apply_transform(self, pts, trans): 127 | R = trans[:3, :3] 128 | T = trans[:3, 3] 129 | pts = pts @ R.T + T 130 | return pts 131 | 132 | def __len__(self): 133 | return len(self.files) 134 | 135 | 136 | class ThreeDMatchTestDataset(PairDataset): 137 | DATA_FILES = { 138 | 'test': './config/test_3dmatch.txt' 139 | } 140 | 141 | def __init__(self, 142 | phase, 143 | transform=None, 144 | random_rotation=True, 145 | random_scale=True, 146 | manual_seed=False, 147 | scene_id=None, 148 | config=None, 149 | return_ply_names=False): 150 | 151 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 152 | manual_seed, config) 153 | assert phase == 'test', "Supports only the test set." 154 | 155 | self.root = config.threed_match_dir 156 | 157 | subset_names = open(self.DATA_FILES[phase]).read().split() 158 | if scene_id is not None: 159 | subset_names = [subset_names[scene_id]] 160 | for sname in subset_names: 161 | traj_file = os.path.join(self.root, sname + '-evaluation/gt.log') 162 | assert os.path.exists(traj_file) 163 | traj = read_trajectory(traj_file) 164 | for ctraj in traj: 165 | i = ctraj.metadata[0] 166 | j = ctraj.metadata[1] 167 | T_gt = ctraj.pose 168 | self.files.append((sname, i, j, T_gt)) 169 | 170 | self.return_ply_names = return_ply_names 171 | 172 | def __getitem__(self, pair_index): 173 | sname, i, j, T_gt = self.files[pair_index] 174 | ply_name0 = os.path.join(self.root, sname, f'cloud_bin_{i}.ply') 175 | ply_name1 = os.path.join(self.root, sname, f'cloud_bin_{j}.ply') 176 | 177 | if self.return_ply_names: 178 | return sname, ply_name0, ply_name1, T_gt 179 | 180 | pcd0 = o3d.io.read_point_cloud(ply_name0) 181 | pcd1 = o3d.io.read_point_cloud(ply_name1) 182 | pcd0 = np.asarray(pcd0.points) 183 | pcd1 = np.asarray(pcd1.points) 184 | return sname, pcd0, pcd1, T_gt 185 | 186 | 187 | class IndoorPairDataset(PairDataset): 188 | OVERLAP_RATIO = None 189 | AUGMENT = None 190 | 191 | def __init__(self, 192 | phase, 193 | transform=None, 194 | random_rotation=True, 195 | random_scale=True, 196 | manual_seed=False, 197 | config=None): 198 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 199 | manual_seed, config) 200 | self.root = root = config.threed_match_dir 201 | logging.info(f"Loading the subset {phase} from {root}") 202 | 203 | subset_names = open(self.DATA_FILES[phase]).read().split() 204 | for name in subset_names: 205 | fname = name + "*%.2f.txt" % self.OVERLAP_RATIO 206 | fnames_txt = glob.glob(root + "/" + fname) 207 | assert len(fnames_txt) > 0, f"Make sure that the path {root} has data {fname}" 208 | for fname_txt in fnames_txt: 209 | with open(fname_txt) as f: 210 | content = f.readlines() 211 | fnames = [x.strip().split() for x in content] 212 | for fname in fnames: 213 | self.files.append([fname[0], fname[1]]) 214 | 215 | def __getitem__(self, idx): 216 | file0 = os.path.join(self.root, self.files[idx][0]) 217 | file1 = os.path.join(self.root, self.files[idx][1]) 218 | data0 = np.load(file0) 219 | data1 = np.load(file1) 220 | xyz0 = data0["pcd"] 221 | xyz1 = data1["pcd"] 222 | color0 = data0["color"] 223 | color1 = data1["color"] 224 | matching_search_voxel_size = self.matching_search_voxel_size 225 | 226 | if self.random_scale and random.random() < 0.95: 227 | scale = self.min_scale + \ 228 | (self.max_scale - self.min_scale) * random.random() 229 | matching_search_voxel_size *= scale 230 | xyz0 = scale * xyz0 231 | xyz1 = scale * xyz1 232 | 233 | if self.random_rotation: 234 | T0 = sample_random_trans(xyz0, self.randg, self.rotation_range) 235 | T1 = sample_random_trans(xyz1, self.randg, self.rotation_range) 236 | trans = T1 @ np.linalg.inv(T0) 237 | 238 | xyz0 = self.apply_transform(xyz0, T0) 239 | xyz1 = self.apply_transform(xyz1, T1) 240 | else: 241 | trans = np.identity(4) 242 | 243 | # Voxelization 244 | _, sel0 = ME.utils.sparse_quantize(xyz0 / self.voxel_size, return_index=True) 245 | _, sel1 = ME.utils.sparse_quantize(xyz1 / self.voxel_size, return_index=True) 246 | 247 | # Make point clouds using voxelized points 248 | pcd0 = make_open3d_point_cloud(xyz0) 249 | pcd1 = make_open3d_point_cloud(xyz1) 250 | 251 | # Select features and points using the returned voxelized indices 252 | pcd0.colors = o3d.utility.Vector3dVector(color0[sel0]) 253 | pcd1.colors = o3d.utility.Vector3dVector(color1[sel1]) 254 | pcd0.points = o3d.utility.Vector3dVector(np.array(pcd0.points)[sel0]) 255 | pcd1.points = o3d.utility.Vector3dVector(np.array(pcd1.points)[sel1]) 256 | # Get matches 257 | matches = get_matching_indices(pcd0, pcd1, trans, matching_search_voxel_size) 258 | 259 | # Get features 260 | npts0 = len(pcd0.colors) 261 | npts1 = len(pcd1.colors) 262 | 263 | feats_train0, feats_train1 = [], [] 264 | 265 | feats_train0.append(np.ones((npts0, 1))) 266 | feats_train1.append(np.ones((npts1, 1))) 267 | 268 | feats0 = np.hstack(feats_train0) 269 | feats1 = np.hstack(feats_train1) 270 | 271 | # Get coords 272 | xyz0 = np.array(pcd0.points) 273 | xyz1 = np.array(pcd1.points) 274 | 275 | coords0 = np.floor(xyz0 / self.voxel_size) 276 | coords1 = np.floor(xyz1 / self.voxel_size) 277 | 278 | if self.transform: 279 | coords0, feats0 = self.transform(coords0, feats0) 280 | coords1, feats1 = self.transform(coords1, feats1) 281 | 282 | return (xyz0, xyz1, coords0, coords1, feats0, feats1, matches, trans) 283 | 284 | 285 | class KITTIPairDataset(PairDataset): 286 | AUGMENT = None 287 | DATA_FILES = { 288 | 'train': './config/train_kitti.txt', 289 | 'val': './config/val_kitti.txt', 290 | 'test': './config/test_kitti.txt' 291 | } 292 | TEST_RANDOM_ROTATION = False 293 | IS_ODOMETRY = True 294 | 295 | def __init__(self, 296 | phase, 297 | transform=None, 298 | random_rotation=True, 299 | random_scale=True, 300 | manual_seed=False, 301 | config=None): 302 | # For evaluation, use the odometry dataset training following the 3DFeat eval method 303 | if self.IS_ODOMETRY: 304 | self.root = root = config.kitti_root + '/dataset' 305 | random_rotation = self.TEST_RANDOM_ROTATION 306 | else: 307 | self.date = config.kitti_date 308 | self.root = root = os.path.join(config.kitti_root, self.date) 309 | 310 | self.icp_path = os.path.join(config.kitti_root, 'icp') 311 | pathlib.Path(self.icp_path).mkdir(parents=True, exist_ok=True) 312 | 313 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 314 | manual_seed, config) 315 | 316 | logging.info(f"Loading the subset {phase} from {root}") 317 | # Use the kitti root 318 | self.max_time_diff = max_time_diff = config.kitti_max_time_diff 319 | 320 | subset_names = open(self.DATA_FILES[phase]).read().split() 321 | for dirname in subset_names: 322 | drive_id = int(dirname) 323 | inames = self.get_all_scan_ids(drive_id) 324 | for start_time in inames: 325 | for time_diff in range(2, max_time_diff): 326 | pair_time = time_diff + start_time 327 | if pair_time in inames: 328 | self.files.append((drive_id, start_time, pair_time)) 329 | 330 | def get_all_scan_ids(self, drive_id): 331 | if self.IS_ODOMETRY: 332 | fnames = glob.glob(self.root + '/sequences/%02d/velodyne/*.bin' % drive_id) 333 | else: 334 | fnames = glob.glob(self.root + '/' + self.date + 335 | '_drive_%04d_sync/velodyne_points/data/*.bin' % drive_id) 336 | assert len( 337 | fnames) > 0, f"Make sure that the path {self.root} has drive id: {drive_id}" 338 | inames = [int(os.path.split(fname)[-1][:-4]) for fname in fnames] 339 | return inames 340 | 341 | @property 342 | def velo2cam(self): 343 | try: 344 | velo2cam = self._velo2cam 345 | except AttributeError: 346 | R = np.array([ 347 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 348 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 349 | ]).reshape(3, 3) 350 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 351 | velo2cam = np.hstack([R, T]) 352 | self._velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 353 | return self._velo2cam 354 | 355 | def get_video_odometry(self, drive, indices=None, ext='.txt', return_all=False): 356 | if self.IS_ODOMETRY: 357 | data_path = self.root + '/poses/%02d.txt' % drive 358 | if data_path not in kitti_cache: 359 | kitti_cache[data_path] = np.genfromtxt(data_path) 360 | if return_all: 361 | return kitti_cache[data_path] 362 | else: 363 | return kitti_cache[data_path][indices] 364 | else: 365 | data_path = self.root + '/' + self.date + '_drive_%04d_sync/oxts/data' % drive 366 | odometry = [] 367 | if indices is None: 368 | fnames = glob.glob(self.root + '/' + self.date + 369 | '_drive_%04d_sync/velodyne_points/data/*.bin' % drive) 370 | indices = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 371 | 372 | for index in indices: 373 | filename = os.path.join(data_path, '%010d%s' % (index, ext)) 374 | if filename not in kitti_cache: 375 | kitti_cache[filename] = np.genfromtxt(filename) 376 | odometry.append(kitti_cache[filename]) 377 | 378 | odometry = np.array(odometry) 379 | return odometry 380 | 381 | def odometry_to_positions(self, odometry): 382 | if self.IS_ODOMETRY: 383 | T_w_cam0 = odometry.reshape(3, 4) 384 | T_w_cam0 = np.vstack((T_w_cam0, [0, 0, 0, 1])) 385 | return T_w_cam0 386 | else: 387 | lat, lon, alt, roll, pitch, yaw = odometry.T[:6] 388 | 389 | R = 6378137 # Earth's radius in metres 390 | 391 | # convert to metres 392 | lat, lon = np.deg2rad(lat), np.deg2rad(lon) 393 | mx = R * lon * np.cos(lat) 394 | my = R * lat 395 | 396 | times = odometry.T[-1] 397 | return np.vstack([mx, my, alt, roll, pitch, yaw, times]).T 398 | 399 | def rot3d(self, axis, angle): 400 | ei = np.ones(3, dtype='bool') 401 | ei[axis] = 0 402 | i = np.nonzero(ei)[0] 403 | m = np.eye(3) 404 | c, s = np.cos(angle), np.sin(angle) 405 | m[i[0], i[0]] = c 406 | m[i[0], i[1]] = -s 407 | m[i[1], i[0]] = s 408 | m[i[1], i[1]] = c 409 | return m 410 | 411 | def pos_transform(self, pos): 412 | x, y, z, rx, ry, rz, _ = pos[0] 413 | RT = np.eye(4) 414 | RT[:3, :3] = np.dot(np.dot(self.rot3d(0, rx), self.rot3d(1, ry)), self.rot3d(2, rz)) 415 | RT[:3, 3] = [x, y, z] 416 | return RT 417 | 418 | def get_position_transform(self, pos0, pos1, invert=False): 419 | T0 = self.pos_transform(pos0) 420 | T1 = self.pos_transform(pos1) 421 | return (np.dot(T1, np.linalg.inv(T0)).T if not invert else np.dot( 422 | np.linalg.inv(T1), T0).T) 423 | 424 | def _get_velodyne_fn(self, drive, t): 425 | if self.IS_ODOMETRY: 426 | fname = self.root + '/sequences/%02d/velodyne/%06d.bin' % (drive, t) 427 | else: 428 | fname = self.root + \ 429 | '/' + self.date + '_drive_%04d_sync/velodyne_points/data/%010d.bin' % ( 430 | drive, t) 431 | return fname 432 | 433 | def __getitem__(self, idx): 434 | drive = self.files[idx][0] 435 | t0, t1 = self.files[idx][1], self.files[idx][2] 436 | all_odometry = self.get_video_odometry(drive, [t0, t1]) 437 | positions = [self.odometry_to_positions(odometry) for odometry in all_odometry] 438 | fname0 = self._get_velodyne_fn(drive, t0) 439 | fname1 = self._get_velodyne_fn(drive, t1) 440 | 441 | # XYZ and reflectance 442 | xyzr0 = np.fromfile(fname0, dtype=np.float32).reshape(-1, 4) 443 | xyzr1 = np.fromfile(fname1, dtype=np.float32).reshape(-1, 4) 444 | 445 | xyz0 = xyzr0[:, :3] 446 | xyz1 = xyzr1[:, :3] 447 | 448 | key = '%d_%d_%d' % (drive, t0, t1) 449 | filename = self.icp_path + '/' + key + '.npy' 450 | if key not in kitti_icp_cache: 451 | if not os.path.exists(filename): 452 | # work on the downsampled xyzs, 0.05m == 5cm 453 | _, sel0 = ME.utils.sparse_quantize(xyz0 / 0.05, return_index=True) 454 | _, sel1 = ME.utils.sparse_quantize(xyz1 / 0.05, return_index=True) 455 | 456 | M = (self.velo2cam @ positions[0].T @ np.linalg.inv(positions[1].T) 457 | @ np.linalg.inv(self.velo2cam)).T 458 | xyz0_t = self.apply_transform(xyz0[sel0], M) 459 | pcd0 = make_open3d_point_cloud(xyz0_t) 460 | pcd1 = make_open3d_point_cloud(xyz1[sel1]) 461 | reg = o3d.registration.registration_icp( 462 | pcd0, pcd1, 0.2, np.eye(4), 463 | o3d.registration.TransformationEstimationPointToPoint(), 464 | o3d.registration.ICPConvergenceCriteria(max_iteration=200)) 465 | pcd0.transform(reg.transformation) 466 | # pcd0.transform(M2) or self.apply_transform(xyz0, M2) 467 | M2 = M @ reg.transformation 468 | # o3d.draw_geometries([pcd0, pcd1]) 469 | # write to a file 470 | np.save(filename, M2) 471 | else: 472 | M2 = np.load(filename) 473 | kitti_icp_cache[key] = M2 474 | else: 475 | M2 = kitti_icp_cache[key] 476 | 477 | if self.random_rotation: 478 | T0 = sample_random_trans(xyz0, self.randg, np.pi / 4) 479 | T1 = sample_random_trans(xyz1, self.randg, np.pi / 4) 480 | trans = T1 @ M2 @ np.linalg.inv(T0) 481 | 482 | xyz0 = self.apply_transform(xyz0, T0) 483 | xyz1 = self.apply_transform(xyz1, T1) 484 | else: 485 | trans = M2 486 | 487 | matching_search_voxel_size = self.matching_search_voxel_size 488 | if self.random_scale and random.random() < 0.95: 489 | scale = self.min_scale + \ 490 | (self.max_scale - self.min_scale) * random.random() 491 | matching_search_voxel_size *= scale 492 | xyz0 = scale * xyz0 493 | xyz1 = scale * xyz1 494 | 495 | # Voxelization 496 | xyz0_th = torch.from_numpy(xyz0) 497 | xyz1_th = torch.from_numpy(xyz1) 498 | 499 | _, sel0 = ME.utils.sparse_quantize(xyz0_th / self.voxel_size, return_index=True) 500 | _, sel1 = ME.utils.sparse_quantize(xyz1_th / self.voxel_size, return_index=True) 501 | 502 | # Make point clouds using voxelized points 503 | pcd0 = make_open3d_point_cloud(xyz0[sel0]) 504 | pcd1 = make_open3d_point_cloud(xyz1[sel1]) 505 | 506 | # Get matches 507 | matches = get_matching_indices(pcd0, pcd1, trans, matching_search_voxel_size) 508 | if len(matches) < 1000: 509 | raise ValueError(f"{drive}, {t0}, {t1}") 510 | 511 | # Get features 512 | npts0 = len(sel0) 513 | npts1 = len(sel1) 514 | 515 | feats_train0, feats_train1 = [], [] 516 | 517 | unique_xyz0_th = xyz0_th[sel0] 518 | unique_xyz1_th = xyz1_th[sel1] 519 | 520 | feats_train0.append(torch.ones((npts0, 1))) 521 | feats_train1.append(torch.ones((npts1, 1))) 522 | 523 | feats0 = torch.cat(feats_train0, 1) 524 | feats1 = torch.cat(feats_train1, 1) 525 | 526 | coords0 = torch.floor(unique_xyz0_th / self.voxel_size) 527 | coords1 = torch.floor(unique_xyz1_th / self.voxel_size) 528 | 529 | if self.transform: 530 | coords0, feats0 = self.transform(coords0, feats0) 531 | coords1, feats1 = self.transform(coords1, feats1) 532 | 533 | return (unique_xyz0_th.float(), unique_xyz1_th.float(), coords0.int(), 534 | coords1.int(), feats0.float(), feats1.float(), matches, trans) 535 | 536 | 537 | class KITTINMPairDataset(KITTIPairDataset): 538 | r""" 539 | Generate KITTI pairs within N meter distance 540 | """ 541 | MIN_DIST = 10 542 | 543 | def __init__(self, 544 | phase, 545 | transform=None, 546 | random_rotation=True, 547 | random_scale=True, 548 | manual_seed=False, 549 | config=None): 550 | if self.IS_ODOMETRY: 551 | self.root = root = os.path.join(config.kitti_root, 'dataset') 552 | random_rotation = self.TEST_RANDOM_ROTATION 553 | else: 554 | self.date = config.kitti_date 555 | self.root = root = os.path.join(config.kitti_root, self.date) 556 | 557 | self.icp_path = os.path.join(config.kitti_root, 'icp') 558 | pathlib.Path(self.icp_path).mkdir(parents=True, exist_ok=True) 559 | 560 | PairDataset.__init__(self, phase, transform, random_rotation, random_scale, 561 | manual_seed, config) 562 | 563 | logging.info(f"Loading the subset {phase} from {root}") 564 | 565 | subset_names = open(self.DATA_FILES[phase]).read().split() 566 | if self.IS_ODOMETRY: 567 | for dirname in subset_names: 568 | drive_id = int(dirname) 569 | fnames = glob.glob(root + '/sequences/%02d/velodyne/*.bin' % drive_id) 570 | assert len(fnames) > 0, f"Make sure that the path {root} has data {dirname}" 571 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 572 | 573 | all_odo = self.get_video_odometry(drive_id, return_all=True) 574 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 575 | Ts = all_pos[:, :3, 3] 576 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3))**2 577 | pdist = np.sqrt(pdist.sum(-1)) 578 | valid_pairs = pdist > self.MIN_DIST 579 | curr_time = inames[0] 580 | while curr_time in inames: 581 | # Find the min index 582 | next_time = np.where(valid_pairs[curr_time][curr_time:curr_time + 100])[0] 583 | if len(next_time) == 0: 584 | curr_time += 1 585 | else: 586 | # Follow https://github.com/yewzijian/3DFeatNet/blob/master/scripts_data_processing/kitti/process_kitti_data.m#L44 587 | next_time = next_time[0] + curr_time - 1 588 | 589 | if next_time in inames: 590 | self.files.append((drive_id, curr_time, next_time)) 591 | curr_time = next_time + 1 592 | else: 593 | for dirname in subset_names: 594 | drive_id = int(dirname) 595 | fnames = glob.glob(root + '/' + self.date + 596 | '_drive_%04d_sync/velodyne_points/data/*.bin' % drive_id) 597 | assert len(fnames) > 0, f"Make sure that the path {root} has data {dirname}" 598 | inames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 599 | 600 | all_odo = self.get_video_odometry(drive_id, return_all=True) 601 | all_pos = np.array([self.odometry_to_positions(odo) for odo in all_odo]) 602 | Ts = all_pos[:, 0, :3] 603 | 604 | pdist = (Ts.reshape(1, -1, 3) - Ts.reshape(-1, 1, 3))**2 605 | pdist = np.sqrt(pdist.sum(-1)) 606 | 607 | for start_time in inames: 608 | pair_time = np.where( 609 | pdist[start_time][start_time:start_time + 100] > self.MIN_DIST)[0] 610 | if len(pair_time) == 0: 611 | continue 612 | else: 613 | pair_time = pair_time[0] + start_time 614 | 615 | if pair_time in inames: 616 | self.files.append((drive_id, start_time, pair_time)) 617 | 618 | if self.IS_ODOMETRY: 619 | # Remove problematic sequence 620 | for item in [ 621 | (8, 15, 58), 622 | ]: 623 | if item in self.files: 624 | self.files.pop(self.files.index(item)) 625 | 626 | 627 | class ThreeDMatchPairDataset(IndoorPairDataset): 628 | OVERLAP_RATIO = 0.3 629 | DATA_FILES = { 630 | 'train': './config/train_3dmatch.txt', 631 | 'val': './config/val_3dmatch.txt', 632 | 'test': './config/test_3dmatch.txt' 633 | } 634 | 635 | 636 | ALL_DATASETS = [ThreeDMatchPairDataset, KITTIPairDataset, KITTINMPairDataset] 637 | dataset_str_mapping = {d.__name__: d for d in ALL_DATASETS} 638 | 639 | 640 | def make_data_loader(config, phase, batch_size, num_threads=0, shuffle=None): 641 | assert phase in ['train', 'trainval', 'val', 'test'] 642 | if shuffle is None: 643 | shuffle = phase != 'test' 644 | 645 | if config.dataset not in dataset_str_mapping.keys(): 646 | logging.error(f'Dataset {config.dataset}, does not exists in ' + 647 | ', '.join(dataset_str_mapping.keys())) 648 | 649 | Dataset = dataset_str_mapping[config.dataset] 650 | 651 | use_random_scale = False 652 | use_random_rotation = False 653 | transforms = [] 654 | if phase in ['train', 'trainval']: 655 | use_random_rotation = config.use_random_rotation 656 | use_random_scale = config.use_random_scale 657 | transforms += [t.Jitter()] 658 | 659 | dset = Dataset( 660 | phase, 661 | transform=t.Compose(transforms), 662 | random_scale=use_random_scale, 663 | random_rotation=use_random_rotation, 664 | config=config) 665 | 666 | loader = torch.utils.data.DataLoader( 667 | dset, 668 | batch_size=batch_size, 669 | shuffle=shuffle, 670 | num_workers=num_threads, 671 | collate_fn=collate_pair_fn, 672 | pin_memory=False, 673 | drop_last=True) 674 | 675 | return loader 676 | -------------------------------------------------------------------------------- /lib/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: future_fstrings -*- 2 | # 3 | # Written by Chris Choy 4 | # Distributed under MIT License 5 | import os 6 | import os.path as osp 7 | import gc 8 | import logging 9 | import numpy as np 10 | import json 11 | 12 | import torch 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | from tensorboardX import SummaryWriter 16 | 17 | from model import load_model 18 | import util.transform_estimation as te 19 | from lib.metrics import pdist, corr_dist 20 | from lib.timer import Timer, AverageMeter 21 | from lib.eval import find_nn_gpu 22 | 23 | from util.file import ensure_dir 24 | from util.misc import _hash 25 | 26 | import MinkowskiEngine as ME 27 | 28 | 29 | class AlignmentTrainer: 30 | 31 | def __init__( 32 | self, 33 | config, 34 | data_loader, 35 | val_data_loader=None, 36 | ): 37 | num_feats = 1 # occupancy only for 3D Match dataset. For ScanNet, use RGB 3 channels. 38 | 39 | # Model initialization 40 | Model = load_model(config.model) 41 | model = Model( 42 | num_feats, 43 | config.model_n_out, 44 | bn_momentum=config.bn_momentum, 45 | normalize_feature=config.normalize_feature, 46 | conv1_kernel_size=config.conv1_kernel_size, 47 | D=3) 48 | 49 | if config.weights: 50 | checkpoint = torch.load(config.weights) 51 | model.load_state_dict(checkpoint['state_dict']) 52 | 53 | logging.info(model) 54 | 55 | self.config = config 56 | self.model = model 57 | self.max_epoch = config.max_epoch 58 | self.save_freq = config.save_freq_epoch 59 | self.val_max_iter = config.val_max_iter 60 | self.val_epoch_freq = config.val_epoch_freq 61 | 62 | self.best_val_metric = config.best_val_metric 63 | self.best_val_epoch = -np.inf 64 | self.best_val = -np.inf 65 | 66 | if config.use_gpu and not torch.cuda.is_available(): 67 | logging.warning('Warning: There\'s no CUDA support on this machine, ' 68 | 'training is performed on CPU.') 69 | raise ValueError('GPU not available, but cuda flag set') 70 | 71 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 72 | 73 | self.optimizer = getattr(optim, config.optimizer)( 74 | model.parameters(), 75 | lr=config.lr, 76 | momentum=config.momentum, 77 | weight_decay=config.weight_decay) 78 | 79 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, config.exp_gamma) 80 | 81 | self.start_epoch = 1 82 | self.checkpoint_dir = config.out_dir 83 | 84 | ensure_dir(self.checkpoint_dir) 85 | json.dump( 86 | config, 87 | open(os.path.join(self.checkpoint_dir, 'config.json'), 'w'), 88 | indent=4, 89 | sort_keys=False) 90 | 91 | self.iter_size = config.iter_size 92 | self.batch_size = data_loader.batch_size 93 | self.data_loader = data_loader 94 | self.val_data_loader = val_data_loader 95 | 96 | self.test_valid = True if self.val_data_loader is not None else False 97 | self.log_step = int(np.sqrt(self.config.batch_size)) 98 | self.model = self.model.to(self.device) 99 | self.writer = SummaryWriter(logdir=config.out_dir) 100 | 101 | if config.resume is not None: 102 | if osp.isfile(config.resume): 103 | logging.info("=> loading checkpoint '{}'".format(config.resume)) 104 | state = torch.load(config.resume) 105 | self.start_epoch = state['epoch'] 106 | model.load_state_dict(state['state_dict']) 107 | self.scheduler.load_state_dict(state['scheduler']) 108 | self.optimizer.load_state_dict(state['optimizer']) 109 | 110 | if 'best_val' in state.keys(): 111 | self.best_val = state['best_val'] 112 | self.best_val_epoch = state['best_val_epoch'] 113 | self.best_val_metric = state['best_val_metric'] 114 | else: 115 | raise ValueError(f"=> no checkpoint found at '{config.resume}'") 116 | 117 | def train(self): 118 | """ 119 | Full training logic 120 | """ 121 | # Baseline random feature performance 122 | if self.test_valid: 123 | with torch.no_grad(): 124 | val_dict = self._valid_epoch() 125 | 126 | for k, v in val_dict.items(): 127 | self.writer.add_scalar(f'val/{k}', v, 0) 128 | 129 | for epoch in range(self.start_epoch, self.max_epoch + 1): 130 | lr = self.scheduler.get_lr() 131 | logging.info(f" Epoch: {epoch}, LR: {lr}") 132 | self._train_epoch(epoch) 133 | self._save_checkpoint(epoch) 134 | self.scheduler.step() 135 | 136 | if self.test_valid and epoch % self.val_epoch_freq == 0: 137 | with torch.no_grad(): 138 | val_dict = self._valid_epoch() 139 | 140 | for k, v in val_dict.items(): 141 | self.writer.add_scalar(f'val/{k}', v, epoch) 142 | if self.best_val < val_dict[self.best_val_metric]: 143 | logging.info( 144 | f'Saving the best val model with {self.best_val_metric}: {val_dict[self.best_val_metric]}' 145 | ) 146 | self.best_val = val_dict[self.best_val_metric] 147 | self.best_val_epoch = epoch 148 | self._save_checkpoint(epoch, 'best_val_checkpoint') 149 | else: 150 | logging.info( 151 | f'Current best val model with {self.best_val_metric}: {self.best_val} at epoch {self.best_val_epoch}' 152 | ) 153 | 154 | def _save_checkpoint(self, epoch, filename='checkpoint'): 155 | state = { 156 | 'epoch': epoch, 157 | 'state_dict': self.model.state_dict(), 158 | 'optimizer': self.optimizer.state_dict(), 159 | 'scheduler': self.scheduler.state_dict(), 160 | 'config': self.config, 161 | 'best_val': self.best_val, 162 | 'best_val_epoch': self.best_val_epoch, 163 | 'best_val_metric': self.best_val_metric 164 | } 165 | filename = os.path.join(self.checkpoint_dir, f'{filename}.pth') 166 | logging.info("Saving checkpoint: {} ...".format(filename)) 167 | torch.save(state, filename) 168 | 169 | 170 | class ContrastiveLossTrainer(AlignmentTrainer): 171 | 172 | def __init__( 173 | self, 174 | config, 175 | data_loader, 176 | val_data_loader=None, 177 | ): 178 | if val_data_loader is not None: 179 | assert val_data_loader.batch_size == 1, "Val set batch size must be 1 for now." 180 | AlignmentTrainer.__init__(self, config, data_loader, val_data_loader) 181 | self.neg_thresh = config.neg_thresh 182 | self.pos_thresh = config.pos_thresh 183 | self.neg_weight = config.neg_weight 184 | 185 | def apply_transform(self, pts, trans): 186 | R = trans[:3, :3] 187 | T = trans[:3, 3] 188 | return pts @ R.t() + T 189 | 190 | def generate_rand_negative_pairs(self, positive_pairs, hash_seed, N0, N1, N_neg=0): 191 | """ 192 | Generate random negative pairs 193 | """ 194 | if not isinstance(positive_pairs, np.ndarray): 195 | positive_pairs = np.array(positive_pairs, dtype=np.int64) 196 | if N_neg < 1: 197 | N_neg = positive_pairs.shape[0] * 2 198 | pos_keys = _hash(positive_pairs, hash_seed) 199 | 200 | neg_pairs = np.floor(np.random.rand(int(N_neg), 2) * np.array([[N0, N1]])).astype( 201 | np.int64) 202 | neg_keys = _hash(neg_pairs, hash_seed) 203 | mask = np.isin(neg_keys, pos_keys, assume_unique=False) 204 | return neg_pairs[np.logical_not(mask)] 205 | 206 | def _train_epoch(self, epoch): 207 | gc.collect() 208 | self.model.train() 209 | # Epoch starts from 1 210 | total_loss = 0 211 | total_num = 0.0 212 | 213 | data_loader = self.data_loader 214 | data_loader_iter = self.data_loader.__iter__() 215 | 216 | iter_size = self.iter_size 217 | start_iter = (epoch - 1) * (len(data_loader) // iter_size) 218 | 219 | data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer() 220 | 221 | # Main training 222 | for curr_iter in range(len(data_loader) // iter_size): 223 | self.optimizer.zero_grad() 224 | batch_pos_loss, batch_neg_loss, batch_loss = 0, 0, 0 225 | 226 | data_time = 0 227 | total_timer.tic() 228 | for iter_idx in range(iter_size): 229 | # Caffe iter size 230 | data_timer.tic() 231 | input_dict = data_loader_iter.next() 232 | data_time += data_timer.toc(average=False) 233 | 234 | # pairs consist of (xyz1 index, xyz0 index) 235 | sinput0 = ME.SparseTensor( 236 | input_dict['sinput0_F'].to(self.device), 237 | coordinates=input_dict['sinput0_C'].to(self.device)) 238 | F0 = self.model(sinput0).F 239 | 240 | sinput1 = ME.SparseTensor( 241 | input_dict['sinput1_F'].to(self.device), 242 | coordinates=input_dict['sinput1_C'].to(self.device)) 243 | F1 = self.model(sinput1).F 244 | 245 | N0, N1 = len(sinput0), len(sinput1) 246 | 247 | pos_pairs = input_dict['correspondences'] 248 | neg_pairs = self.generate_rand_negative_pairs(pos_pairs, max(N0, N1), N0, N1) 249 | pos_pairs = pos_pairs.long().to(self.device) 250 | neg_pairs = torch.from_numpy(neg_pairs).long().to(self.device) 251 | 252 | neg0 = F0.index_select(0, neg_pairs[:, 0]) 253 | neg1 = F1.index_select(0, neg_pairs[:, 1]) 254 | pos0 = F0.index_select(0, pos_pairs[:, 0]) 255 | pos1 = F1.index_select(0, pos_pairs[:, 1]) 256 | 257 | # Positive loss 258 | pos_loss = (pos0 - pos1).pow(2).sum(1) 259 | 260 | # Negative loss 261 | neg_loss = F.relu(self.neg_thresh - 262 | ((neg0 - neg1).pow(2).sum(1) + 1e-4).sqrt()).pow(2) 263 | 264 | pos_loss_mean = pos_loss.mean() / iter_size 265 | neg_loss_mean = neg_loss.mean() / iter_size 266 | 267 | # Weighted loss 268 | loss = pos_loss_mean + self.neg_weight * neg_loss_mean 269 | loss.backward( 270 | ) # To accumulate gradient, zero gradients only at the begining of iter_size 271 | batch_loss += loss.item() 272 | batch_pos_loss += pos_loss_mean.item() 273 | batch_neg_loss += neg_loss_mean.item() 274 | 275 | self.optimizer.step() 276 | 277 | torch.cuda.empty_cache() 278 | 279 | total_loss += batch_loss 280 | total_num += 1.0 281 | total_timer.toc() 282 | data_meter.update(data_time) 283 | 284 | # Print logs 285 | if curr_iter % self.config.stat_freq == 0: 286 | self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter) 287 | self.writer.add_scalar('train/pos_loss', batch_pos_loss, start_iter + curr_iter) 288 | self.writer.add_scalar('train/neg_loss', batch_neg_loss, start_iter + curr_iter) 289 | logging.info( 290 | "Train Epoch: {} [{}/{}], Current Loss: {:.3e} Pos: {:.3f} Neg: {:.3f}" 291 | .format(epoch, curr_iter, 292 | len(self.data_loader) // 293 | iter_size, batch_loss, batch_pos_loss, batch_neg_loss) + 294 | "\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format( 295 | data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg)) 296 | data_meter.reset() 297 | total_timer.reset() 298 | 299 | def _valid_epoch(self): 300 | # Change the network to evaluation mode 301 | self.model.eval() 302 | self.val_data_loader.dataset.reset_seed(0) 303 | num_data = 0 304 | hit_ratio_meter, feat_match_ratio, loss_meter, rte_meter, rre_meter = AverageMeter( 305 | ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() 306 | data_timer, feat_timer, matching_timer = Timer(), Timer(), Timer() 307 | tot_num_data = len(self.val_data_loader.dataset) 308 | if self.val_max_iter > 0: 309 | tot_num_data = min(self.val_max_iter, tot_num_data) 310 | data_loader_iter = self.val_data_loader.__iter__() 311 | 312 | for batch_idx in range(tot_num_data): 313 | data_timer.tic() 314 | input_dict = data_loader_iter.next() 315 | data_timer.toc() 316 | 317 | # pairs consist of (xyz1 index, xyz0 index) 318 | feat_timer.tic() 319 | sinput0 = ME.SparseTensor( 320 | input_dict['sinput0_F'].to(self.device), 321 | coordinates=input_dict['sinput0_C'].to(self.device)) 322 | F0 = self.model(sinput0).F 323 | 324 | sinput1 = ME.SparseTensor( 325 | input_dict['sinput1_F'].to(self.device), 326 | coordinates=input_dict['sinput1_C'].to(self.device)) 327 | F1 = self.model(sinput1).F 328 | feat_timer.toc() 329 | 330 | matching_timer.tic() 331 | xyz0, xyz1, T_gt = input_dict['pcd0'], input_dict['pcd1'], input_dict['T_gt'] 332 | xyz0_corr, xyz1_corr = self.find_corr(xyz0, xyz1, F0, F1, subsample_size=5000) 333 | T_est = te.est_quad_linear_robust(xyz0_corr, xyz1_corr) 334 | 335 | loss = corr_dist(T_est, T_gt, xyz0, xyz1, weight=None) 336 | loss_meter.update(loss) 337 | 338 | rte = np.linalg.norm(T_est[:3, 3] - T_gt[:3, 3]) 339 | rte_meter.update(rte) 340 | rre = np.arccos((np.trace(T_est[:3, :3].t() @ T_gt[:3, :3]) - 1) / 2) 341 | if not np.isnan(rre): 342 | rre_meter.update(rre) 343 | 344 | hit_ratio = self.evaluate_hit_ratio( 345 | xyz0_corr, xyz1_corr, T_gt, thresh=self.config.hit_ratio_thresh) 346 | hit_ratio_meter.update(hit_ratio) 347 | feat_match_ratio.update(hit_ratio > 0.05) 348 | matching_timer.toc() 349 | 350 | num_data += 1 351 | torch.cuda.empty_cache() 352 | 353 | if batch_idx % 100 == 0 and batch_idx > 0: 354 | logging.info(' '.join([ 355 | f"Validation iter {num_data} / {tot_num_data} : Data Loading Time: {data_timer.avg:.3f},", 356 | f"Feature Extraction Time: {feat_timer.avg:.3f}, Matching Time: {matching_timer.avg:.3f},", 357 | f"Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},", 358 | f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}" 359 | ])) 360 | data_timer.reset() 361 | 362 | logging.info(' '.join([ 363 | f"Final Loss: {loss_meter.avg:.3f}, RTE: {rte_meter.avg:.3f}, RRE: {rre_meter.avg:.3f},", 364 | f"Hit Ratio: {hit_ratio_meter.avg:.3f}, Feat Match Ratio: {feat_match_ratio.avg:.3f}" 365 | ])) 366 | return { 367 | "loss": loss_meter.avg, 368 | "rre": rre_meter.avg, 369 | "rte": rte_meter.avg, 370 | 'feat_match_ratio': feat_match_ratio.avg, 371 | 'hit_ratio': hit_ratio_meter.avg 372 | } 373 | 374 | def find_corr(self, xyz0, xyz1, F0, F1, subsample_size=-1): 375 | subsample = len(F0) > subsample_size 376 | if subsample_size > 0 and subsample: 377 | N0 = min(len(F0), subsample_size) 378 | N1 = min(len(F1), subsample_size) 379 | inds0 = np.random.choice(len(F0), N0, replace=False) 380 | inds1 = np.random.choice(len(F1), N1, replace=False) 381 | F0, F1 = F0[inds0], F1[inds1] 382 | 383 | # Compute the nn 384 | nn_inds = find_nn_gpu(F0, F1, nn_max_n=self.config.nn_max_n) 385 | if subsample_size > 0 and subsample: 386 | return xyz0[inds0], xyz1[inds1[nn_inds]] 387 | else: 388 | return xyz0, xyz1[nn_inds] 389 | 390 | def evaluate_hit_ratio(self, xyz0, xyz1, T_gth, thresh=0.1): 391 | xyz0 = self.apply_transform(xyz0, T_gth) 392 | dist = np.sqrt(((xyz0 - xyz1)**2).sum(1) + 1e-6) 393 | return (dist < thresh).float().mean().item() 394 | 395 | 396 | class HardestContrastiveLossTrainer(ContrastiveLossTrainer): 397 | 398 | def contrastive_hardest_negative_loss(self, 399 | F0, 400 | F1, 401 | positive_pairs, 402 | num_pos=5192, 403 | num_hn_samples=2048, 404 | thresh=None): 405 | """ 406 | Generate negative pairs 407 | """ 408 | N0, N1 = len(F0), len(F1) 409 | N_pos_pairs = len(positive_pairs) 410 | hash_seed = max(N0, N1) 411 | sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) 412 | sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) 413 | 414 | if N_pos_pairs > num_pos: 415 | pos_sel = np.random.choice(N_pos_pairs, num_pos, replace=False) 416 | sample_pos_pairs = positive_pairs[pos_sel] 417 | else: 418 | sample_pos_pairs = positive_pairs 419 | 420 | # Find negatives for all F1[positive_pairs[:, 1]] 421 | subF0, subF1 = F0[sel0], F1[sel1] 422 | 423 | pos_ind0 = sample_pos_pairs[:, 0].long() 424 | pos_ind1 = sample_pos_pairs[:, 1].long() 425 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1] 426 | 427 | D01 = pdist(posF0, subF1, dist_type='L2') 428 | D10 = pdist(posF1, subF0, dist_type='L2') 429 | 430 | D01min, D01ind = D01.min(1) 431 | D10min, D10ind = D10.min(1) 432 | 433 | if not isinstance(positive_pairs, np.ndarray): 434 | positive_pairs = np.array(positive_pairs, dtype=np.int64) 435 | 436 | pos_keys = _hash(positive_pairs, hash_seed) 437 | 438 | D01ind = sel1[D01ind.cpu().numpy()] 439 | D10ind = sel0[D10ind.cpu().numpy()] 440 | neg_keys0 = _hash([pos_ind0.numpy(), D01ind], hash_seed) 441 | neg_keys1 = _hash([D10ind, pos_ind1.numpy()], hash_seed) 442 | 443 | mask0 = torch.from_numpy( 444 | np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) 445 | mask1 = torch.from_numpy( 446 | np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) 447 | pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - self.pos_thresh) 448 | neg_loss0 = F.relu(self.neg_thresh - D01min[mask0]).pow(2) 449 | neg_loss1 = F.relu(self.neg_thresh - D10min[mask1]).pow(2) 450 | return pos_loss.mean(), (neg_loss0.mean() + neg_loss1.mean()) / 2 451 | 452 | def _train_epoch(self, epoch): 453 | gc.collect() 454 | self.model.train() 455 | # Epoch starts from 1 456 | total_loss = 0 457 | total_num = 0.0 458 | data_loader = self.data_loader 459 | data_loader_iter = self.data_loader.__iter__() 460 | iter_size = self.iter_size 461 | data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer() 462 | start_iter = (epoch - 1) * (len(data_loader) // iter_size) 463 | for curr_iter in range(len(data_loader) // iter_size): 464 | self.optimizer.zero_grad() 465 | batch_pos_loss, batch_neg_loss, batch_loss = 0, 0, 0 466 | 467 | data_time = 0 468 | total_timer.tic() 469 | for iter_idx in range(iter_size): 470 | data_timer.tic() 471 | input_dict = data_loader_iter.next() 472 | data_time += data_timer.toc(average=False) 473 | 474 | sinput0 = ME.SparseTensor( 475 | input_dict['sinput0_F'].to(self.device), 476 | coordinates=input_dict['sinput0_C'].to(self.device)) 477 | F0 = self.model(sinput0).F 478 | 479 | sinput1 = ME.SparseTensor( 480 | input_dict['sinput1_F'].to(self.device), 481 | coordinates=input_dict['sinput1_C'].to(self.device)) 482 | 483 | F1 = self.model(sinput1).F 484 | 485 | pos_pairs = input_dict['correspondences'] 486 | pos_loss, neg_loss = self.contrastive_hardest_negative_loss( 487 | F0, 488 | F1, 489 | pos_pairs, 490 | num_pos=self.config.num_pos_per_batch * self.config.batch_size, 491 | num_hn_samples=self.config.num_hn_samples_per_batch * 492 | self.config.batch_size) 493 | 494 | pos_loss /= iter_size 495 | neg_loss /= iter_size 496 | loss = pos_loss + self.neg_weight * neg_loss 497 | loss.backward() 498 | 499 | batch_loss += loss.item() 500 | batch_pos_loss += pos_loss.item() 501 | batch_neg_loss += neg_loss.item() 502 | 503 | self.optimizer.step() 504 | gc.collect() 505 | 506 | torch.cuda.empty_cache() 507 | 508 | total_loss += batch_loss 509 | total_num += 1.0 510 | total_timer.toc() 511 | data_meter.update(data_time) 512 | 513 | if curr_iter % self.config.stat_freq == 0: 514 | self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter) 515 | self.writer.add_scalar('train/pos_loss', batch_pos_loss, start_iter + curr_iter) 516 | self.writer.add_scalar('train/neg_loss', batch_neg_loss, start_iter + curr_iter) 517 | logging.info( 518 | "Train Epoch: {} [{}/{}], Current Loss: {:.3e} Pos: {:.3f} Neg: {:.3f}" 519 | .format(epoch, curr_iter, 520 | len(self.data_loader) // 521 | iter_size, batch_loss, batch_pos_loss, batch_neg_loss) + 522 | "\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format( 523 | data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg)) 524 | data_meter.reset() 525 | total_timer.reset() 526 | 527 | 528 | class TripletLossTrainer(ContrastiveLossTrainer): 529 | 530 | def triplet_loss(self, 531 | F0, 532 | F1, 533 | positive_pairs, 534 | num_pos=1024, 535 | num_hn_samples=None, 536 | num_rand_triplet=1024): 537 | """ 538 | Generate negative pairs 539 | """ 540 | N0, N1 = len(F0), len(F1) 541 | num_pos_pairs = len(positive_pairs) 542 | hash_seed = max(N0, N1) 543 | 544 | if num_pos_pairs > num_pos: 545 | pos_sel = np.random.choice(num_pos_pairs, num_pos, replace=False) 546 | sample_pos_pairs = positive_pairs[pos_sel] 547 | else: 548 | sample_pos_pairs = positive_pairs 549 | 550 | pos_ind0 = sample_pos_pairs[:, 0].long() 551 | pos_ind1 = sample_pos_pairs[:, 1].long() 552 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1] 553 | 554 | if not isinstance(positive_pairs, np.ndarray): 555 | positive_pairs = np.array(positive_pairs, dtype=np.int64) 556 | 557 | pos_keys = _hash(positive_pairs, hash_seed) 558 | pos_dist = torch.sqrt((posF0 - posF1).pow(2).sum(1) + 1e-7) 559 | 560 | # Random triplets 561 | rand_inds = np.random.choice( 562 | num_pos_pairs, min(num_pos_pairs, num_rand_triplet), replace=False) 563 | rand_pairs = positive_pairs[rand_inds] 564 | negatives = np.random.choice(N1, min(N1, num_rand_triplet), replace=False) 565 | 566 | # Remove positives from negatives 567 | rand_neg_keys = _hash([rand_pairs[:, 0], negatives], hash_seed) 568 | rand_mask = np.logical_not(np.isin(rand_neg_keys, pos_keys, assume_unique=False)) 569 | anchors, positives = rand_pairs[torch.from_numpy(rand_mask)].T 570 | negatives = negatives[rand_mask] 571 | 572 | rand_pos_dist = torch.sqrt((F0[anchors] - F1[positives]).pow(2).sum(1) + 1e-7) 573 | rand_neg_dist = torch.sqrt((F0[anchors] - F1[negatives]).pow(2).sum(1) + 1e-7) 574 | 575 | loss = F.relu(rand_pos_dist + self.neg_thresh - rand_neg_dist).mean() 576 | 577 | return loss, pos_dist.mean(), rand_neg_dist.mean() 578 | 579 | def _train_epoch(self, epoch): 580 | config = self.config 581 | 582 | gc.collect() 583 | self.model.train() 584 | 585 | # Epoch starts from 1 586 | total_loss = 0 587 | total_num = 0.0 588 | data_loader = self.data_loader 589 | data_loader_iter = self.data_loader.__iter__() 590 | iter_size = self.iter_size 591 | data_meter, data_timer, total_timer = AverageMeter(), Timer(), Timer() 592 | pos_dist_meter, neg_dist_meter = AverageMeter(), AverageMeter() 593 | start_iter = (epoch - 1) * (len(data_loader) // iter_size) 594 | for curr_iter in range(len(data_loader) // iter_size): 595 | self.optimizer.zero_grad() 596 | batch_loss = 0 597 | data_time = 0 598 | total_timer.tic() 599 | for iter_idx in range(iter_size): 600 | data_timer.tic() 601 | input_dict = data_loader_iter.next() 602 | data_time += data_timer.toc(average=False) 603 | 604 | # pairs consist of (xyz1 index, xyz0 index) 605 | sinput0 = ME.SparseTensor( 606 | input_dict['sinput0_F'].to(self.device), 607 | coordinates=input_dict['sinput0_C'].to(self.device)) 608 | F0 = self.model(sinput0).F 609 | 610 | sinput1 = ME.SparseTensor( 611 | input_dict['sinput1_F'].to(self.device), 612 | coordinates=input_dict['sinput1_C'].to(self.device)) 613 | F1 = self.model(sinput1).F 614 | 615 | pos_pairs = input_dict['correspondences'] 616 | loss, pos_dist, neg_dist = self.triplet_loss( 617 | F0, 618 | F1, 619 | pos_pairs, 620 | num_pos=config.triplet_num_pos * config.batch_size, 621 | num_hn_samples=config.triplet_num_hn * config.batch_size, 622 | num_rand_triplet=config.triplet_num_rand * config.batch_size) 623 | loss /= iter_size 624 | loss.backward() 625 | batch_loss += loss.item() 626 | pos_dist_meter.update(pos_dist) 627 | neg_dist_meter.update(neg_dist) 628 | 629 | self.optimizer.step() 630 | gc.collect() 631 | 632 | torch.cuda.empty_cache() 633 | 634 | total_loss += batch_loss 635 | total_num += 1.0 636 | total_timer.toc() 637 | data_meter.update(data_time) 638 | 639 | if curr_iter % self.config.stat_freq == 0: 640 | self.writer.add_scalar('train/loss', batch_loss, start_iter + curr_iter) 641 | logging.info( 642 | "Train Epoch: {} [{}/{}], Current Loss: {:.3e}, Pos dist: {:.3e}, Neg dist: {:.3e}" 643 | .format(epoch, curr_iter, 644 | len(self.data_loader) // 645 | iter_size, batch_loss, pos_dist_meter.avg, neg_dist_meter.avg) + 646 | "\tData time: {:.4f}, Train time: {:.4f}, Iter time: {:.4f}".format( 647 | data_meter.avg, total_timer.avg - data_meter.avg, total_timer.avg)) 648 | pos_dist_meter.reset() 649 | neg_dist_meter.reset() 650 | data_meter.reset() 651 | total_timer.reset() 652 | 653 | 654 | class HardestTripletLossTrainer(TripletLossTrainer): 655 | 656 | def triplet_loss(self, 657 | F0, 658 | F1, 659 | positive_pairs, 660 | num_pos=1024, 661 | num_hn_samples=512, 662 | num_rand_triplet=1024): 663 | """ 664 | Generate negative pairs 665 | """ 666 | N0, N1 = len(F0), len(F1) 667 | num_pos_pairs = len(positive_pairs) 668 | hash_seed = max(N0, N1) 669 | sel0 = np.random.choice(N0, min(N0, num_hn_samples), replace=False) 670 | sel1 = np.random.choice(N1, min(N1, num_hn_samples), replace=False) 671 | 672 | if num_pos_pairs > num_pos: 673 | pos_sel = np.random.choice(num_pos_pairs, num_pos, replace=False) 674 | sample_pos_pairs = positive_pairs[pos_sel] 675 | else: 676 | sample_pos_pairs = positive_pairs 677 | 678 | # Find negatives for all F1[positive_pairs[:, 1]] 679 | subF0, subF1 = F0[sel0], F1[sel1] 680 | 681 | pos_ind0 = sample_pos_pairs[:, 0].long() 682 | pos_ind1 = sample_pos_pairs[:, 1].long() 683 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1] 684 | 685 | D01 = pdist(posF0, subF1, dist_type='L2') 686 | D10 = pdist(posF1, subF0, dist_type='L2') 687 | 688 | D01min, D01ind = D01.min(1) 689 | D10min, D10ind = D10.min(1) 690 | 691 | if not isinstance(positive_pairs, np.ndarray): 692 | positive_pairs = np.array(positive_pairs, dtype=np.int64) 693 | 694 | pos_keys = _hash(positive_pairs, hash_seed) 695 | 696 | D01ind = sel1[D01ind.cpu().numpy()] 697 | D10ind = sel0[D10ind.cpu().numpy()] 698 | neg_keys0 = _hash([pos_ind0.numpy(), D01ind], hash_seed) 699 | neg_keys1 = _hash([D10ind, pos_ind1.numpy()], hash_seed) 700 | 701 | mask0 = torch.from_numpy( 702 | np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) 703 | mask1 = torch.from_numpy( 704 | np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) 705 | pos_dist = torch.sqrt((posF0 - posF1).pow(2).sum(1) + 1e-7) 706 | 707 | # Random triplets 708 | rand_inds = np.random.choice( 709 | num_pos_pairs, min(num_pos_pairs, num_rand_triplet), replace=False) 710 | rand_pairs = positive_pairs[rand_inds] 711 | negatives = np.random.choice(N1, min(N1, num_rand_triplet), replace=False) 712 | 713 | # Remove positives from negatives 714 | rand_neg_keys = _hash([rand_pairs[:, 0], negatives], hash_seed) 715 | rand_mask = np.logical_not(np.isin(rand_neg_keys, pos_keys, assume_unique=False)) 716 | anchors, positives = rand_pairs[torch.from_numpy(rand_mask)].T 717 | negatives = negatives[rand_mask] 718 | 719 | rand_pos_dist = torch.sqrt((F0[anchors] - F1[positives]).pow(2).sum(1) + 1e-7) 720 | rand_neg_dist = torch.sqrt((F0[anchors] - F1[negatives]).pow(2).sum(1) + 1e-7) 721 | 722 | loss = F.relu( 723 | torch.cat([ 724 | rand_pos_dist + self.neg_thresh - rand_neg_dist, 725 | pos_dist[mask0] + self.neg_thresh - D01min[mask0], 726 | pos_dist[mask1] + self.neg_thresh - D10min[mask1] 727 | ])).mean() 728 | 729 | return loss, pos_dist.mean(), (D01min.mean() + D10min.mean()).item() / 2 730 | --------------------------------------------------------------------------------