├── .gitignore ├── LICENSE ├── README.md ├── pcrnet ├── data_utils │ ├── __init__.py │ └── dataloaders.py ├── losses │ ├── __init__.py │ ├── chamfer_distance.py │ ├── cuda │ │ ├── chamfer_distance │ │ │ ├── __init__.py │ │ │ ├── chamfer_distance.cpp │ │ │ ├── chamfer_distance.cu │ │ │ └── chamfer_distance.py │ │ └── emd_torch │ │ │ ├── pkg │ │ │ ├── emd_loss_layer.py │ │ │ ├── include │ │ │ │ ├── cuda │ │ │ │ │ └── emd.cuh │ │ │ │ ├── cuda_helper.h │ │ │ │ └── emd.h │ │ │ ├── layer │ │ │ │ ├── __init__.py │ │ │ │ └── emd_loss_layer.py │ │ │ └── src │ │ │ │ ├── cuda │ │ │ │ └── emd.cu │ │ │ │ └── emd.cpp │ │ │ └── setup.py │ └── emd.py ├── models │ ├── __init__.py │ ├── pcrnet.py │ ├── pointnet.py │ └── pooling.py ├── ops │ ├── __init__.py │ ├── data_utils.py │ ├── quaternion.py │ └── transform_functions.py └── pretrained │ ├── exp_ipcrnet │ └── models │ │ └── best_model.t7 │ └── exp_ipcrnet_v1 │ └── models │ ├── best_model.t7 │ └── best_ptnet_model.t7 ├── requirements.txt ├── test_pcrnet.py └── train_pcrnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *__pycache__ 3 | checkpoints/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2010-2019 Google, Inc. http://angularjs.org 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point Cloud Registration Network in PyTorch. 2 | 3 | Source Code Author: Vinit Sarode 4 | 5 | **[[Paper]](https://arxiv.org/abs/1908.07906)** 6 | **[[Github Link]](https://github.com/vinits5/pcrnet)** 7 | 8 | #### This is a pytorch implementation of PCRNet paper. 9 | 10 | ### Requirements: 11 | 1. Cuda 10 12 | 2. pytorch==1.4.0 13 | 3. transforms3d==0.3.1 14 | 4. h5py==2.9.0 15 | 16 | ### How to use code: 17 | 18 | #### Train Iterative-PCRNet: 19 | python train_pcrnet.py 20 | 21 | ### Citation 22 | 23 | ``` 24 | @InProceedings{vsarode2019pcrnet, 25 | author = {Sarode, Vinit and Li, Xueqian and Goforth, Hunter and Aoki, Yasuhiro and Arun Srivatsan, Rangaprasad and Lucey, Simon and Choset, Howie}, 26 | title = {PCRNet: Point Cloud Registration Network using PointNet Encoding}, 27 | month = {Aug}, 28 | year = {2019} 29 | } 30 | ``` 31 | 32 | This code builds upon the code provided in Deep Closest Point [DCP](https://github.com/WangYueFt/dcp.git). We thanks the authors of the paper for sharing their code. -------------------------------------------------------------------------------- /pcrnet/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloaders import ModelNet40Data 2 | from .dataloaders import RegistrationData 3 | from .dataloaders import download_modelnet40, deg_to_rad, create_random_transform -------------------------------------------------------------------------------- /pcrnet/data_utils/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import os 8 | import h5py 9 | import subprocess 10 | import shlex 11 | import json 12 | import glob 13 | 14 | def download_modelnet40(): 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data') 17 | if not os.path.exists(DATA_DIR): 18 | os.mkdir(DATA_DIR) 19 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 20 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 21 | zipfile = os.path.basename(www) 22 | www += ' --no-check-certificate' 23 | os.system('wget %s; unzip %s' % (www, zipfile)) 24 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 25 | os.system('rm %s' % (zipfile)) 26 | 27 | def load_data(train): 28 | if train: partition = 'train' 29 | else: partition = 'test' 30 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 31 | DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data') 32 | all_data = [] 33 | all_label = [] 34 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)): 35 | f = h5py.File(h5_name) 36 | data = f['data'][:].astype('float32') 37 | label = f['label'][:].astype('int64') 38 | f.close() 39 | all_data.append(data) 40 | all_label.append(label) 41 | all_data = np.concatenate(all_data, axis=0) 42 | all_label = np.concatenate(all_label, axis=0) 43 | return all_data, all_label 44 | 45 | def deg_to_rad(deg): 46 | return np.pi / 180 * deg 47 | 48 | def create_random_transform(dtype, max_rotation_deg, max_translation): 49 | max_rotation = deg_to_rad(max_rotation_deg) 50 | rot = np.random.uniform(-max_rotation, max_rotation, [1, 3]) 51 | trans = np.random.uniform(-max_translation, max_translation, [1, 3]) 52 | quat = transform_functions.euler_to_quaternion(rot, "xyz") 53 | 54 | vec = np.concatenate([quat, trans], axis=1) 55 | vec = torch.tensor(vec, dtype=dtype) 56 | return vec 57 | 58 | class ModelNet40Data(Dataset): 59 | def __init__( 60 | self, 61 | train=True, 62 | num_points=1024, 63 | download=True, 64 | randomize_data=False 65 | ): 66 | super(ModelNet40Data, self).__init__() 67 | if download: download_modelnet40() 68 | self.data, self.labels = load_data(train) 69 | if not train: self.shapes = self.read_classes_ModelNet40() 70 | self.num_points = num_points 71 | self.randomize_data = randomize_data 72 | 73 | def __getitem__(self, idx): 74 | if self.randomize_data: current_points = self.randomize(idx) 75 | else: current_points = self.data[idx].copy() 76 | 77 | current_points = torch.from_numpy(current_points).float() 78 | label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor) 79 | 80 | return current_points, label 81 | 82 | def __len__(self): 83 | return self.data.shape[0] 84 | 85 | def randomize(self, idx): 86 | pt_idxs = np.arange(0, self.num_points) 87 | np.random.shuffle(pt_idxs) 88 | return self.data[idx, pt_idxs].copy() 89 | 90 | def get_shape(self, label): 91 | return self.shapes[label] 92 | 93 | def read_classes_ModelNet40(self): 94 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 95 | DATA_DIR = os.path.join(BASE_DIR, os.pardir, 'data') 96 | file = open(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'shape_names.txt'), 'r') 97 | shape_names = file.read() 98 | shape_names = np.array(shape_names.split('\n')[:-1]) 99 | return shape_names 100 | 101 | class RegistrationData(Dataset): 102 | def __init__(self, algorithm, data_class=ModelNet40Data(), is_testing=False): 103 | super(RegistrationData, self).__init__() 104 | self.algorithm = 'iPCRNet' 105 | self.is_testing = is_testing 106 | 107 | self.set_class(data_class) 108 | if self.algorithm == 'PCRNet' or self.algorithm == 'iPCRNet': 109 | from .. ops.transform_functions import PCRNetTransform 110 | self.transforms = PCRNetTransform(len(data_class), angle_range=45, translation_range=1) 111 | 112 | def __len__(self): 113 | return len(self.data_class) 114 | 115 | def set_class(self, data_class): 116 | self.data_class = data_class 117 | 118 | def __getitem__(self, index): 119 | template, label = self.data_class[index] 120 | self.transforms.index = index # for fixed transformations in PCRNet. 121 | source = self.transforms(template) 122 | igt = self.transforms.igt 123 | if self.is_testing: 124 | return template, source, igt, self.transforms.igt_rotation, self.transforms.igt_translation 125 | else: 126 | return template, source, igt -------------------------------------------------------------------------------- /pcrnet/losses/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .emd import EMDLoss 3 | except: 4 | print("Sorry EMD loss is not compatible with your system!") 5 | try: 6 | from .chamfer_distance import ChamferDistanceLoss 7 | except: 8 | print("Sorry ChamferDistance loss is not compatible with your system!") -------------------------------------------------------------------------------- /pcrnet/losses/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def pairwise_distances(a: torch.Tensor, b: torch.Tensor, p=2): 6 | """ 7 | Compute the pairwise distance_tensor matrix between a and b which both have size [m, n, d]. The result is a tensor of 8 | size [m, n, n] whose entry [m, i, j] contains the distance_tensor between a[m, i, :] and b[m, j, :]. 9 | :param a: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d] 10 | :param b: A tensor containing m batches of n points of dimension d. i.e. of size [m, n, d] 11 | :param p: Norm to use for the distance_tensor 12 | :return: A tensor containing the pairwise distance_tensor between each pair of inputs in a batch. 13 | """ 14 | 15 | if len(a.shape) != 3: 16 | raise ValueError("Invalid shape for a. Must be [m, n, d] but got", a.shape) 17 | if len(b.shape) != 3: 18 | raise ValueError("Invalid shape for a. Must be [m, n, d] but got", b.shape) 19 | return (a.unsqueeze(2) - b.unsqueeze(1)).abs().pow(p).sum(3) 20 | 21 | def chamfer(a, b): 22 | """ 23 | Compute the chamfer distance between two sets of vectors, a, and b 24 | :param a: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d] 25 | :param b: A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d] 26 | :return: A [m] shaped tensor storing the Chamfer distance between each minibatch entry 27 | """ 28 | M = pairwise_distances(a, b) 29 | dist1 = torch.mean(torch.sqrt(M.min(1)[0])) 30 | dist2 = torch.mean(torch.sqrt(M.min(2)[0])) 31 | return (dist1 + dist2) / 2.0 32 | 33 | 34 | def chamfer_distance(template: torch.Tensor, source: torch.Tensor): 35 | try: 36 | from .cuda.chamfer_distance import ChamferDistance 37 | cost_p0_p1, cost_p1_p0 = ChamferDistance()(template, source) 38 | cost_p0_p1 = torch.mean(torch.sqrt(cost_p0_p1)) 39 | cost_p1_p0 = torch.mean(torch.sqrt(cost_p1_p0)) 40 | chamfer_loss = (cost_p0_p1 + cost_p1_p0)/2.0 41 | except: 42 | chamfer_loss = chamfer(template, source) 43 | return chamfer_loss 44 | 45 | 46 | class ChamferDistanceLoss(nn.Module): 47 | def __init__(self): 48 | super(ChamferDistanceLoss, self).__init__() 49 | 50 | def forward(self, template, source): 51 | return chamfer_distance(template, source) -------------------------------------------------------------------------------- /pcrnet/losses/cuda/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /pcrnet/losses/cuda/chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /pcrnet/losses/cuda/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /pcrnet/losses/cuda/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.cpp_extension import load 3 | import os 4 | 5 | script_dir = os.path.dirname(__file__) 6 | sources = [ 7 | os.path.join(script_dir, "chamfer_distance.cpp"), 8 | os.path.join(script_dir, "chamfer_distance.cu"), 9 | ] 10 | 11 | cd = load(name="cd", sources=sources) 12 | 13 | 14 | class ChamferDistanceFunction(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, xyz1, xyz2): 17 | batchsize, n, _ = xyz1.size() 18 | _, m, _ = xyz2.size() 19 | xyz1 = xyz1.contiguous() 20 | xyz2 = xyz2.contiguous() 21 | dist1 = torch.zeros(batchsize, n) 22 | dist2 = torch.zeros(batchsize, m) 23 | 24 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 25 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 26 | 27 | if not xyz1.is_cuda: 28 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 29 | else: 30 | dist1 = dist1.cuda() 31 | dist2 = dist2.cuda() 32 | idx1 = idx1.cuda() 33 | idx2 = idx2.cuda() 34 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 35 | 36 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 37 | 38 | return dist1, dist2 39 | 40 | @staticmethod 41 | def backward(ctx, graddist1, graddist2): 42 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 43 | 44 | graddist1 = graddist1.contiguous() 45 | graddist2 = graddist2.contiguous() 46 | 47 | gradxyz1 = torch.zeros(xyz1.size()) 48 | gradxyz2 = torch.zeros(xyz2.size()) 49 | 50 | if not graddist1.is_cuda: 51 | cd.backward( 52 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 53 | ) 54 | else: 55 | gradxyz1 = gradxyz1.cuda() 56 | gradxyz2 = gradxyz2.cuda() 57 | cd.backward_cuda( 58 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 59 | ) 60 | 61 | return gradxyz1, gradxyz2 62 | 63 | 64 | class ChamferDistance(torch.nn.Module): 65 | def forward(self, xyz1, xyz2): 66 | return ChamferDistanceFunction.apply(xyz1, xyz2) 67 | -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/emd_loss_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import _emd_ext._emd as emd 5 | 6 | 7 | class EMDFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, xyz1, xyz2): 10 | cost, match = emd.emd_forward(xyz1, xyz2) 11 | self.save_for_backward(xyz1, xyz2, match) 12 | return cost 13 | 14 | 15 | @staticmethod 16 | def backward(self, grad_output): 17 | xyz1, xyz2, match = self.saved_tensors 18 | grad_xyz1, grad_xyz2 = emd.emd_backward(xyz1, xyz2, match) 19 | return grad_xyz1, grad_xyz2 20 | 21 | 22 | 23 | 24 | class EMDLoss(nn.Module): 25 | ''' 26 | Computes the (approximate) Earth Mover's Distance between two point sets. 27 | 28 | IMPLEMENTATION LIMITATIONS: 29 | - Double tensors must have <=11 dimensions 30 | - Float tensors must have <=23 dimensions 31 | This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 32 | ''' 33 | 34 | def __init__(self): 35 | super(EMDLoss, self).__init__() 36 | 37 | def forward(self, xyz1, xyz2): 38 | 39 | assert xyz1.shape[-1] == xyz2.shape[-1], 'Both point sets must have the same dimensions!' 40 | assert xyz1.shape[1] == xyz2.shape[1], 'Both Point Clouds must have same number of points in it.' 41 | return EMDFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/include/cuda/emd.cuh: -------------------------------------------------------------------------------- 1 | #ifndef EMD_CUH_ 2 | #define EMD_CUH_ 3 | 4 | #include "cuda_helper.h" 5 | 6 | template 7 | __global__ void approxmatch(const int b, const int n, const int m, const T * __restrict__ xyz1, const T * __restrict__ xyz2, T * __restrict__ match, T * temp){ 8 | T * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 9 | T multiL,multiR; 10 | if (n>=m){ 11 | multiL=1; 12 | multiR=n/m; 13 | }else{ 14 | multiL=m/n; 15 | multiR=1; 16 | } 17 | const int Block=1024; 18 | __shared__ T buf[Block*4]; 19 | for (int i=blockIdx.x;i=-2;j--){ 28 | T level=-powf(4.0f,j); 29 | if (j==-2){ 30 | level=0; 31 | } 32 | for (int k0=0;k0>>( 191 | b, n, m, 192 | xyz1.data(), 193 | xyz2.data(), 194 | match.data(), 195 | temp.data()); 196 | })); 197 | cudaDeviceSynchronize(); 198 | CUDA_CHECK(cudaGetLastError()) 199 | } 200 | 201 | template 202 | __global__ void matchcost(const int b, const int n, const int m, const T * __restrict__ xyz1, const T * __restrict__ xyz2, const T * __restrict__ match, T * __restrict__ out){ 203 | __shared__ T allsum[512]; 204 | const int Block=1024; 205 | __shared__ T buf[Block*3]; 206 | for (int i=blockIdx.x;i>>( 249 | b, n, m, 250 | xyz1.data(), 251 | xyz2.data(), 252 | match.data(), 253 | out.data()); 254 | })); 255 | CUDA_CHECK(cudaGetLastError()) 256 | } 257 | 258 | template 259 | __global__ void matchcostgrad2(const int b, const int n, const int m,const T * __restrict__ xyz1, const T * __restrict__ xyz2, const T * __restrict__ match, T * __restrict__ grad2){ 260 | __shared__ T sum_grad[256*3]; 261 | for (int i=blockIdx.x;i 302 | __global__ void matchcostgrad1(const int b, const int n, const int m, const T * __restrict__ xyz1, const T * __restrict__ xyz2, const T * __restrict__ match, T * __restrict__ grad1){ 303 | for (int i=blockIdx.x;i>>( 328 | b, n, m, 329 | xyz1.data(), 330 | xyz2.data(), 331 | match.data(), 332 | grad1.data()); 333 | })); 334 | CUDA_CHECK(cudaGetLastError()) 335 | 336 | AT_DISPATCH_FLOATING_TYPES(xyz1.type(), "matchcostgrad2", ([&] { 337 | matchcostgrad2<<>>( 338 | b, n, m, 339 | xyz1.data(), 340 | xyz2.data(), 341 | match.data(), 342 | grad2.data()); 343 | })); 344 | CUDA_CHECK(cudaGetLastError()) 345 | } 346 | 347 | #endif -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/include/cuda_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_HELPER_H_ 2 | #define CUDA_HELPER_H_ 3 | 4 | #define CUDA_CHECK(err) \ 5 | if (cudaSuccess != err) \ 6 | { \ 7 | fprintf(stderr, "CUDA kernel failed: %s (%s:%d)\n", \ 8 | cudaGetErrorString(err), __FILE__, __LINE__); \ 9 | std::exit(-1); \ 10 | } 11 | 12 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), \ 13 | #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), \ 15 | #x " must be contiguous") 16 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 17 | 18 | #endif -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/include/emd.h: -------------------------------------------------------------------------------- 1 | #ifndef EMD_H_ 2 | #define EMD_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_helper.h" 8 | 9 | 10 | std::vector emd_forward_cuda( 11 | at::Tensor xyz1, 12 | at::Tensor xyz2); 13 | 14 | std::vector emd_backward_cuda( 15 | at::Tensor xyz1, 16 | at::Tensor xyz2, 17 | at::Tensor match); 18 | 19 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 20 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 21 | // CALL FUNCTION IMPLEMENTATIONS 22 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 23 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 24 | 25 | std::vector emd_forward( 26 | at::Tensor xyz1, 27 | at::Tensor xyz2) 28 | { 29 | CHECK_INPUT(xyz1); 30 | CHECK_INPUT(xyz2); 31 | 32 | return emd_forward_cuda(xyz1, xyz2); 33 | } 34 | 35 | std::vector emd_backward( 36 | at::Tensor xyz1, 37 | at::Tensor xyz2, 38 | at::Tensor match) 39 | { 40 | CHECK_INPUT(xyz1); 41 | CHECK_INPUT(xyz2); 42 | CHECK_INPUT(match); 43 | 44 | return emd_backward_cuda(xyz1, xyz2, match); 45 | } 46 | 47 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 48 | m.def("emd_forward", &emd_forward, "Compute Earth Mover's Distance"); 49 | m.def("emd_backward", &emd_backward, "Compute Gradients for Earth Mover's Distance"); 50 | } 51 | 52 | 53 | 54 | #endif -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from .emd_loss_layer import EMDLoss -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/layer/emd_loss_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import _emd_ext._emd as emd 5 | 6 | 7 | class EMDFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, xyz1, xyz2): 10 | cost, match = emd.emd_forward(xyz1, xyz2) 11 | self.save_for_backward(xyz1, xyz2, match) 12 | return cost 13 | 14 | 15 | @staticmethod 16 | def backward(self, grad_output): 17 | xyz1, xyz2, match = self.saved_tensors 18 | grad_xyz1, grad_xyz2 = emd.emd_backward(xyz1, xyz2, match) 19 | return grad_xyz1, grad_xyz2 20 | 21 | 22 | 23 | 24 | class EMDLoss(nn.Module): 25 | ''' 26 | Computes the (approximate) Earth Mover's Distance between two point sets. 27 | 28 | IMPLEMENTATION LIMITATIONS: 29 | - Double tensors must have <=11 dimensions 30 | - Float tensors must have <=23 dimensions 31 | This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 32 | ''' 33 | 34 | def __init__(self): 35 | super(EMDLoss, self).__init__() 36 | 37 | def forward(self, xyz1, xyz2): 38 | 39 | assert xyz1.shape[-1] == xyz2.shape[-1], 'Both point sets must have the same dimensionality' 40 | return EMDFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/src/cuda/emd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "cuda/emd.cuh" 6 | 7 | 8 | std::vector emd_forward_cuda( 9 | at::Tensor xyz1, // B x N1 x D 10 | at::Tensor xyz2) // B x N2 x D 11 | { 12 | // Some useful values 13 | const int batch_size = xyz1.size(0); 14 | const int num_pts_1 = xyz1.size(1); 15 | const int num_pts_2 = xyz2.size(1); 16 | 17 | // Allocate necessary data structures 18 | at::Tensor match = at::zeros({batch_size, num_pts_1, num_pts_2}, 19 | xyz1.options()); 20 | at::Tensor cost = at::zeros({batch_size}, xyz1.options()); 21 | at::Tensor temp = at::zeros({batch_size, 2 * (num_pts_1 + num_pts_2)}, 22 | xyz1.options()); 23 | 24 | // Find the approximate matching 25 | approxmatchLauncher( 26 | batch_size, num_pts_1, num_pts_2, 27 | xyz1, 28 | xyz2, 29 | match, 30 | temp 31 | ); 32 | 33 | // Compute the matching cost 34 | matchcostLauncher( 35 | batch_size, num_pts_1, num_pts_2, 36 | xyz1, 37 | xyz2, 38 | match, 39 | cost 40 | ); 41 | 42 | return {cost, match}; 43 | } 44 | 45 | std::vector emd_backward_cuda( 46 | at::Tensor xyz1, 47 | at::Tensor xyz2, 48 | at::Tensor match) 49 | { 50 | // Some useful values 51 | const int batch_size = xyz1.size(0); 52 | const int num_pts_1 = xyz1.size(1); 53 | const int num_pts_2 = xyz2.size(1); 54 | 55 | // Allocate necessary data structures 56 | at::Tensor grad_xyz1 = at::zeros_like(xyz1); 57 | at::Tensor grad_xyz2 = at::zeros_like(xyz2); 58 | 59 | // Compute the gradient with respect to the two inputs (xyz1 and xyz2) 60 | matchcostgradLauncher( 61 | batch_size, num_pts_1, num_pts_2, 62 | xyz1, 63 | xyz2, 64 | match, 65 | grad_xyz1, 66 | grad_xyz2 67 | ); 68 | 69 | return {grad_xyz1, grad_xyz2}; 70 | } -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/pkg/src/emd.cpp: -------------------------------------------------------------------------------- 1 | #include "emd.h" 2 | -------------------------------------------------------------------------------- /pcrnet/losses/cuda/emd_torch/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='PyTorch EMD', 7 | version='0.0', 8 | author='Vinit Sarode', 9 | author_email='vinitsarode5@gmail.com', 10 | description='A PyTorch module for the earth mover\'s distance loss', 11 | ext_package='_emd_ext', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='_emd', 15 | sources=[ 16 | 'pkg/src/emd.cpp', 17 | 'pkg/src/cuda/emd.cu', 18 | ], 19 | include_dirs=['pkg/include'], 20 | ), 21 | ], 22 | packages=[ 23 | 'emd', 24 | ], 25 | package_dir={ 26 | 'emd' : 'pkg/layer' 27 | }, 28 | cmdclass={'build_ext': BuildExtension}, 29 | ) -------------------------------------------------------------------------------- /pcrnet/losses/emd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def emd(template: torch.Tensor, source: torch.Tensor): 6 | from emd import EMDLoss 7 | emd_loss = torch.mean(self.emd(template, source))/(template.size()[1]) 8 | return emd_loss 9 | 10 | 11 | class EMDLoss(nn.Module): 12 | def __init__(self): 13 | super(EMDLoss, self).__init__() 14 | 15 | def forward(self, template, source): 16 | return emd(template, source) -------------------------------------------------------------------------------- /pcrnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointnet import PointNet 2 | from .pooling import Pooling 3 | from .pcrnet import iPCRNet -------------------------------------------------------------------------------- /pcrnet/models/pcrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .pointnet import PointNet 5 | from .pooling import Pooling 6 | from .. ops.transform_functions import PCRNetTransform as transform 7 | 8 | 9 | class iPCRNet(nn.Module): 10 | def __init__(self, feature_model=PointNet(), droput=0.0, pooling='max'): 11 | super().__init__() 12 | self.feature_model = feature_model 13 | self.pooling = Pooling(pooling) 14 | 15 | self.linear = [nn.Linear(self.feature_model.emb_dims * 2, 1024), nn.ReLU(), 16 | nn.Linear(1024, 1024), nn.ReLU(), 17 | nn.Linear(1024, 512), nn.ReLU(), 18 | nn.Linear(512, 512), nn.ReLU(), 19 | nn.Linear(512, 256), nn.ReLU()] 20 | 21 | if droput>0.0: 22 | self.linear.append(nn.Dropout(droput)) 23 | self.linear.append(nn.Linear(256,7)) 24 | 25 | self.linear = nn.Sequential(*self.linear) 26 | 27 | # Single Pass Alignment Module (SPAM) 28 | def spam(self, template_features, source, est_R, est_t): 29 | batch_size = source.size(0) 30 | 31 | self.source_features = self.pooling(self.feature_model(source)) 32 | y = torch.cat([template_features, self.source_features], dim=1) 33 | pose_7d = self.linear(y) 34 | pose_7d = transform.create_pose_7d(pose_7d) 35 | 36 | # Find current rotation and translation. 37 | identity = torch.eye(3).to(source).view(1,3,3).expand(batch_size, 3, 3).contiguous() 38 | est_R_temp = transform.quaternion_rotate(identity, pose_7d).permute(0, 2, 1) 39 | est_t_temp = transform.get_translation(pose_7d).view(-1, 1, 3) 40 | 41 | # update translation matrix. 42 | est_t = torch.bmm(est_R_temp, est_t.permute(0, 2, 1)).permute(0, 2, 1) + est_t_temp 43 | # update rotation matrix. 44 | est_R = torch.bmm(est_R_temp, est_R) 45 | 46 | source = transform.quaternion_transform(source, pose_7d) # Ps' = est_R*Ps + est_t 47 | return est_R, est_t, source 48 | 49 | def forward(self, template, source, max_iteration=8): 50 | est_R = torch.eye(3).to(template).view(1, 3, 3).expand(template.size(0), 3, 3).contiguous() # (Bx3x3) 51 | est_t = torch.zeros(1,3).to(template).view(1, 1, 3).expand(template.size(0), 1, 3).contiguous() # (Bx1x3) 52 | template_features = self.pooling(self.feature_model(template)) 53 | 54 | if max_iteration == 1: 55 | est_R, est_t, source = self.spam(template_features, source, est_R, est_t) 56 | else: 57 | for i in range(max_iteration): 58 | est_R, est_t, source = self.spam(template_features, source, est_R, est_t) 59 | 60 | result = {'est_R': est_R, # source -> template 61 | 'est_t': est_t, # source -> template 62 | 'est_T': transform.convert2transformation(est_R, est_t), # source -> template 63 | 'r': template_features - self.source_features, 64 | 'transformed_source': source} 65 | return result 66 | 67 | 68 | if __name__ == '__main__': 69 | template, source = torch.rand(10,1024,3), torch.rand(10,1024,3) 70 | pn = PointNet() 71 | 72 | net = iPCRNet(pn) 73 | result = net(template, source) 74 | import ipdb; ipdb.set_trace() -------------------------------------------------------------------------------- /pcrnet/models/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .pooling import Pooling 5 | 6 | 7 | class PointNet(torch.nn.Module): 8 | def __init__(self, emb_dims=1024, input_shape="bnc"): 9 | # emb_dims: Embedding Dimensions for PointNet. 10 | # input_shape: Shape of Input Point Cloud (b: batch, n: no of points, c: channels) 11 | super(PointNet, self).__init__() 12 | if input_shape not in ["bcn", "bnc"]: 13 | raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ") 14 | self.input_shape = input_shape 15 | self.emb_dims = emb_dims 16 | self.layers = self.create_structure() 17 | 18 | def create_structure(self): 19 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 20 | self.conv2 = torch.nn.Conv1d(64, 64, 1) 21 | self.conv3 = torch.nn.Conv1d(64, 64, 1) 22 | self.conv4 = torch.nn.Conv1d(64, 128, 1) 23 | self.conv5 = torch.nn.Conv1d(128, self.emb_dims, 1) 24 | self.relu = torch.nn.ReLU() 25 | 26 | layers = [self.conv1, self.relu, 27 | self.conv2, self.relu, 28 | self.conv3, self.relu, 29 | self.conv4, self.relu, 30 | self.conv5, self.relu] 31 | return layers 32 | 33 | 34 | def forward(self, input_data): 35 | # input_data: Point Cloud having shape input_shape. 36 | # output: PointNet features (Batch x emb_dims) 37 | if self.input_shape == "bnc": 38 | num_points = input_data.shape[1] 39 | input_data = input_data.permute(0, 2, 1) 40 | else: 41 | num_points = input_data.shape[2] 42 | if input_data.shape[1] != 3: 43 | raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]") 44 | 45 | output = input_data 46 | for idx, layer in enumerate(self.layers): 47 | output = layer(output) 48 | 49 | return output 50 | 51 | 52 | if __name__ == '__main__': 53 | # Test the code. 54 | x = torch.rand((10,1024,3)) 55 | 56 | pn = PointNet(use_bn=True) 57 | y = pn(x) 58 | print("Network Architecture: ") 59 | print(pn) 60 | print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape) -------------------------------------------------------------------------------- /pcrnet/models/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Pooling(torch.nn.Module): 7 | def __init__(self, pool_type='max'): 8 | self.pool_type = pool_type 9 | super(Pooling, self).__init__() 10 | 11 | def forward(self, input): 12 | if self.pool_type == 'max': 13 | return torch.max(input, 2)[0].contiguous() 14 | elif self.pool_type == 'avg' or self.pool_type == 'average': 15 | return torch.mean(input, 2).contiguous() -------------------------------------------------------------------------------- /pcrnet/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinits5/pcrnet_pytorch/ca20d4ea5071f5e5aa9a4c7143eb34be80a406a7/pcrnet/ops/__init__.py -------------------------------------------------------------------------------- /pcrnet/ops/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mean_shift(template, source, p0_zero_mean, p1_zero_mean): 4 | template_mean = torch.eye(3).view(1, 3, 3).expand(template.size(0), 3, 3).to(template) # [B, 3, 3] 5 | source_mean = torch.eye(3).view(1, 3, 3).expand(source.size(0), 3, 3).to(source) # [B, 3, 3] 6 | 7 | if p0_zero_mean: 8 | p0_m = template.mean(dim=1) # [B, N, 3] -> [B, 3] 9 | template_mean = torch.cat([template_mean, p0_m.unsqueeze(-1)], dim=2) 10 | one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(template_mean.shape[0], 1, 1).to(template_mean) # (Bx1x4) 11 | template_mean = torch.cat([template_mean, one_], dim=1) 12 | template = template - p0_m.unsqueeze(1) 13 | # else: 14 | # q0 = template 15 | 16 | if p1_zero_mean: 17 | #print(numpy.any(numpy.isnan(p1.numpy()))) 18 | p1_m = source.mean(dim=1) # [B, N, 3] -> [B, 3] 19 | source_mean = torch.cat([source_mean, -p0_m.unsqueeze(-1)], dim=2) 20 | one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(source_mean.shape[0], 1, 1).to(source_mean) # (Bx1x4) 21 | source_mean = torch.cat([source_mean, one_], dim=1) 22 | source = source - p1_m.unsqueeze(1) 23 | # else: 24 | # q1 = source 25 | return template, source, template_mean, source_mean 26 | 27 | def postprocess_data(result, p0, p1, a0, a1, p0_zero_mean, p1_zero_mean): 28 | #output' = trans(p0_m) * output * trans(-p1_m) 29 | # = [I, p0_m;] * [R, t;] * [I, -p1_m;] 30 | # [0, 1 ] [0, 1 ] [0, 1 ] 31 | est_g = result['est_T'] 32 | if p0_zero_mean: 33 | est_g = a0.to(est_g).bmm(est_g) 34 | if p1_zero_mean: 35 | est_g = est_g.bmm(a1.to(est_g)) 36 | result['est_T'] = est_g 37 | 38 | est_gs = result['est_T_series'] # [M, B, 4, 4] 39 | if p0_zero_mean: 40 | est_gs = a0.unsqueeze(0).contiguous().to(est_gs).matmul(est_gs) 41 | if p1_zero_mean: 42 | est_gs = est_gs.matmul(a1.unsqueeze(0).contiguous().to(est_gs)) 43 | result['est_T_series'] = est_gs 44 | 45 | return result 46 | -------------------------------------------------------------------------------- /pcrnet/ops/quaternion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch 9 | import numpy as np 10 | 11 | # PyTorch-backed implementations 12 | 13 | 14 | def qmul(q, r): 15 | """ 16 | Multiply quaternion(s) q with quaternion(s) r. 17 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 18 | Returns q*r as a tensor of shape (*, 4). 19 | """ 20 | assert q.shape[-1] == 4 21 | assert r.shape[-1] == 4 22 | 23 | original_shape = q.shape 24 | 25 | # Compute outer product 26 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 27 | 28 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 29 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 30 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 31 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 32 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 33 | 34 | 35 | def qrot(q, v): 36 | """ 37 | Rotate vector(s) v about the rotation described by quaternion(s) q. 38 | Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v, 39 | where * denotes any number of dimensions. 40 | Returns a tensor of shape (*, 3). 41 | """ 42 | assert q.shape[-1] == 4 43 | assert v.shape[-1] == 3 44 | assert q.shape[:-1] == v.shape[:-1] 45 | 46 | original_shape = list(v.shape) 47 | q = q.view(-1, 4) 48 | v = v.view(-1, 3) 49 | 50 | qvec = q[:, 1:] 51 | uv = torch.cross(qvec, v, dim=1) 52 | uuv = torch.cross(qvec, uv, dim=1) 53 | return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape) 54 | 55 | 56 | def qeuler(q, order, epsilon=0): 57 | """ 58 | Convert quaternion(s) q to Euler angles. 59 | Expects a tensor of shape (*, 4), where * denotes any number of dimensions. 60 | Returns a tensor of shape (*, 3). 61 | """ 62 | assert q.shape[-1] == 4 63 | 64 | original_shape = list(q.shape) 65 | original_shape[-1] = 3 66 | q = q.view(-1, 4) 67 | 68 | q0 = q[:, 0] 69 | q1 = q[:, 1] 70 | q2 = q[:, 2] 71 | q3 = q[:, 3] 72 | 73 | if order == "xyz": 74 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 75 | y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon)) 76 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 77 | elif order == "yzx": 78 | x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 79 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 80 | z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon)) 81 | elif order == "zxy": 82 | x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon)) 83 | y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 84 | z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3)) 85 | elif order == "xzy": 86 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 87 | y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3)) 88 | z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon)) 89 | elif order == "yxz": 90 | x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon)) 91 | y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2)) 92 | z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3)) 93 | elif order == "zyx": 94 | x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2)) 95 | y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon)) 96 | z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3)) 97 | else: 98 | raise 99 | 100 | return torch.stack((x, y, z), dim=1).view(original_shape) 101 | 102 | 103 | # Numpy-backed implementations 104 | 105 | 106 | def qmul_np(q, r): 107 | q = torch.from_numpy(q).contiguous() 108 | r = torch.from_numpy(r).contiguous() 109 | return qmul(q, r).numpy() 110 | 111 | 112 | def qrot_np(q, v): 113 | q = torch.from_numpy(q).contiguous() 114 | v = torch.from_numpy(v).contiguous() 115 | return qrot(q, v).numpy() 116 | 117 | 118 | def qeuler_np(q, order, epsilon=0, use_gpu=False): 119 | if use_gpu: 120 | q = torch.from_numpy(q).cuda() 121 | return qeuler(q, order, epsilon).cpu().numpy() 122 | else: 123 | q = torch.from_numpy(q).contiguous() 124 | return qeuler(q, order, epsilon).numpy() 125 | 126 | 127 | def qfix(q): 128 | """ 129 | Enforce quaternion continuity across the time dimension by selecting 130 | the representation (q or -q) with minimal distance (or, equivalently, maximal dot product) 131 | between two consecutive frames. 132 | 133 | Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints. 134 | Returns a tensor of the same shape. 135 | """ 136 | assert len(q.shape) == 3 137 | assert q.shape[-1] == 4 138 | 139 | result = q.copy() 140 | dot_products = np.sum(q[1:] * q[:-1], axis=2) 141 | mask = dot_products < 0 142 | mask = (np.cumsum(mask, axis=0) % 2).astype(bool) 143 | result[1:][mask] *= -1 144 | return result 145 | 146 | 147 | def expmap_to_quaternion(e): 148 | """ 149 | Convert axis-angle rotations (aka exponential maps) to quaternions. 150 | Stable formula from "Practical Parameterization of Rotations Using the Exponential Map". 151 | Expects a tensor of shape (*, 3), where * denotes any number of dimensions. 152 | Returns a tensor of shape (*, 4). 153 | """ 154 | assert e.shape[-1] == 3 155 | 156 | original_shape = list(e.shape) 157 | original_shape[-1] = 4 158 | e = e.reshape(-1, 3) 159 | 160 | theta = np.linalg.norm(e, axis=1).reshape(-1, 1) 161 | w = np.cos(0.5 * theta).reshape(-1, 1) 162 | xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e 163 | return np.concatenate((w, xyz), axis=1).reshape(original_shape) 164 | 165 | 166 | def euler_to_quaternion(e, order): 167 | """ 168 | Convert Euler angles to quaternions. 169 | """ 170 | assert e.shape[-1] == 3 171 | 172 | original_shape = list(e.shape) 173 | original_shape[-1] = 4 174 | 175 | e = e.reshape(-1, 3) 176 | 177 | x = e[:, 0] 178 | y = e[:, 1] 179 | z = e[:, 2] 180 | 181 | rx = np.stack( 182 | (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1 183 | ) 184 | ry = np.stack( 185 | (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1 186 | ) 187 | rz = np.stack( 188 | (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1 189 | ) 190 | 191 | result = None 192 | for coord in order: 193 | if coord == "x": 194 | r = rx 195 | elif coord == "y": 196 | r = ry 197 | elif coord == "z": 198 | r = rz 199 | else: 200 | raise 201 | if result is None: 202 | result = r 203 | else: 204 | result = qmul_np(result, r) 205 | 206 | # Reverse antipodal representation to have a non-negative "w" 207 | if order in ["xyz", "yzx", "zxy"]: 208 | result *= -1 209 | 210 | return result.reshape(original_shape) 211 | 212 | 213 | def qinv(q): 214 | # expectes q in (w,x,y,z) format 215 | w = q[:, 0:1] 216 | v = q[:, 1:] 217 | inv = torch.cat([w, -v], dim=1) 218 | return inv 219 | -------------------------------------------------------------------------------- /pcrnet/ops/transform_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from . import quaternion # works with (w, x, y, z) quaternions 6 | 7 | 8 | def quat2mat(quat): 9 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] 10 | 11 | B = quat.size(0) 12 | 13 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 14 | wx, wy, wz = w*x, w*y, w*z 15 | xy, xz, yz = x*y, x*z, y*z 16 | 17 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 18 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 19 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) 20 | return rotMat 21 | 22 | def transform_point_cloud(point_cloud: torch.Tensor, rotation: torch.Tensor, translation: torch.Tensor): 23 | if len(rotation.size()) == 2: 24 | rot_mat = quat2mat(rotation) 25 | else: 26 | rot_mat = rotation 27 | return (torch.matmul(rot_mat, point_cloud.permute(0, 2, 1)) + translation.unsqueeze(2)).permute(0, 2, 1) 28 | 29 | def convert2transformation(rotation_matrix: torch.Tensor, translation_vector: torch.Tensor): 30 | one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(rotation_matrix.shape[0], 1, 1).to(rotation_matrix) # (Bx1x4) 31 | transformation_matrix = torch.cat([rotation_matrix, translation_vector.unsqueeze(-1)], dim=2) # (Bx3x4) 32 | transformation_matrix = torch.cat([transformation_matrix, one_], dim=1) # (Bx4x4) 33 | return transformation_matrix 34 | 35 | def qmul(q, r): 36 | """ 37 | Multiply quaternion(s) q with quaternion(s) r. 38 | Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions. 39 | Returns q*r as a tensor of shape (*, 4). 40 | """ 41 | assert q.shape[-1] == 4 42 | assert r.shape[-1] == 4 43 | 44 | original_shape = q.shape 45 | 46 | # Compute outer product 47 | terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4)) 48 | 49 | w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3] 50 | x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2] 51 | y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1] 52 | z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0] 53 | return torch.stack((w, x, y, z), dim=1).view(original_shape) 54 | 55 | def qmul_np(q, r): 56 | q = torch.from_numpy(q).contiguous() 57 | r = torch.from_numpy(r).contiguous() 58 | return qmul(q, r).numpy() 59 | 60 | def euler_to_quaternion(e, order): 61 | """ 62 | Convert Euler angles to quaternions. 63 | """ 64 | assert e.shape[-1] == 3 65 | 66 | original_shape = list(e.shape) 67 | original_shape[-1] = 4 68 | 69 | e = e.reshape(-1, 3) 70 | 71 | x = e[:, 0] 72 | y = e[:, 1] 73 | z = e[:, 2] 74 | 75 | rx = np.stack( 76 | (np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1 77 | ) 78 | ry = np.stack( 79 | (np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1 80 | ) 81 | rz = np.stack( 82 | (np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1 83 | ) 84 | 85 | result = None 86 | for coord in order: 87 | if coord == "x": 88 | r = rx 89 | elif coord == "y": 90 | r = ry 91 | elif coord == "z": 92 | r = rz 93 | else: 94 | raise 95 | if result is None: 96 | result = r 97 | else: 98 | result = qmul_np(result, r) 99 | 100 | # Reverse antipodal representation to have a non-negative "w" 101 | if order in ["xyz", "yzx", "zxy"]: 102 | result *= -1 103 | 104 | return result.reshape(original_shape) 105 | 106 | 107 | class PCRNetTransform: 108 | def __init__(self, data_size, angle_range=45, translation_range=1): 109 | self.angle_range = angle_range 110 | self.translation_range = translation_range 111 | self.dtype = torch.float32 112 | self.transformations = [self.create_random_transform(torch.float32, self.angle_range, self.translation_range) for _ in range(data_size)] 113 | self.index = 0 114 | 115 | @staticmethod 116 | def deg_to_rad(deg): 117 | return np.pi / 180 * deg 118 | 119 | def create_random_transform(self, dtype, max_rotation_deg, max_translation): 120 | max_rotation = self.deg_to_rad(max_rotation_deg) 121 | rot = np.random.uniform(-max_rotation, max_rotation, [1, 3]) 122 | trans = np.random.uniform(-max_translation, max_translation, [1, 3]) 123 | quat = euler_to_quaternion(rot, "xyz") 124 | 125 | vec = np.concatenate([quat, trans], axis=1) 126 | vec = torch.tensor(vec, dtype=dtype) 127 | return vec 128 | 129 | @staticmethod 130 | def create_pose_7d(vector: torch.Tensor): 131 | # Normalize the quaternion. 132 | pre_normalized_quaternion = vector[:, 0:4] 133 | normalized_quaternion = F.normalize(pre_normalized_quaternion, dim=1) 134 | 135 | # B x 7 vector of 4 quaternions and 3 translation parameters 136 | translation = vector[:, 4:] 137 | vector = torch.cat([normalized_quaternion, translation], dim=1) 138 | return vector.view([-1, 7]) 139 | 140 | @staticmethod 141 | def get_quaternion(pose_7d: torch.Tensor): 142 | return pose_7d[:, 0:4] 143 | 144 | @staticmethod 145 | def get_translation(pose_7d: torch.Tensor): 146 | return pose_7d[:, 4:] 147 | 148 | @staticmethod 149 | def quaternion_rotate(point_cloud: torch.Tensor, pose_7d: torch.Tensor): 150 | ndim = point_cloud.dim() 151 | if ndim == 2: 152 | N, _ = point_cloud.shape 153 | assert pose_7d.shape[0] == 1 154 | # repeat transformation vector for each point in shape 155 | quat = PCRNetTransform.get_quaternion(pose_7d).expand([N, -1]) 156 | rotated_point_cloud = quaternion.qrot(quat, point_cloud) 157 | 158 | elif ndim == 3: 159 | B, N, _ = point_cloud.shape 160 | quat = PCRNetTransform.get_quaternion(pose_7d).unsqueeze(1).expand([-1, N, -1]).contiguous() 161 | rotated_point_cloud = quaternion.qrot(quat, point_cloud) 162 | 163 | return rotated_point_cloud 164 | 165 | @staticmethod 166 | def quaternion_transform(point_cloud: torch.Tensor, pose_7d: torch.Tensor): 167 | transformed_point_cloud = PCRNetTransform.quaternion_rotate(point_cloud, pose_7d) + PCRNetTransform.get_translation(pose_7d).view(-1, 1, 3).repeat(1, point_cloud.shape[1], 1) # Ps' = R*Ps + t 168 | return transformed_point_cloud 169 | 170 | @staticmethod 171 | def convert2transformation(rotation_matrix: torch.Tensor, translation_vector: torch.Tensor): 172 | one_ = torch.tensor([[[0.0, 0.0, 0.0, 1.0]]]).repeat(rotation_matrix.shape[0], 1, 1).to(rotation_matrix) # (Bx1x4) 173 | transformation_matrix = torch.cat([rotation_matrix, translation_vector[:,0,:].unsqueeze(-1)], dim=2) # (Bx3x4) 174 | transformation_matrix = torch.cat([transformation_matrix, one_], dim=1) # (Bx4x4) 175 | return transformation_matrix 176 | 177 | def __call__(self, template): 178 | self.igt = self.transformations[self.index] 179 | igt = self.create_pose_7d(self.igt) 180 | self.igt_rotation = self.quaternion_rotate(torch.eye(3), igt).permute(1, 0) # [3x3] 181 | self.igt_translation = self.get_translation(igt) # [1x3] 182 | source = self.quaternion_rotate(template, igt) + self.get_translation(igt) 183 | return source -------------------------------------------------------------------------------- /pcrnet/pretrained/exp_ipcrnet/models/best_model.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinits5/pcrnet_pytorch/ca20d4ea5071f5e5aa9a4c7143eb34be80a406a7/pcrnet/pretrained/exp_ipcrnet/models/best_model.t7 -------------------------------------------------------------------------------- /pcrnet/pretrained/exp_ipcrnet_v1/models/best_model.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinits5/pcrnet_pytorch/ca20d4ea5071f5e5aa9a4c7143eb34be80a406a7/pcrnet/pretrained/exp_ipcrnet_v1/models/best_model.t7 -------------------------------------------------------------------------------- /pcrnet/pretrained/exp_ipcrnet_v1/models/best_ptnet_model.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinits5/pcrnet_pytorch/ca20d4ea5071f5e5aa9a4c7143eb34be80a406a7/pcrnet/pretrained/exp_ipcrnet_v1/models/best_ptnet_model.t7 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | transforms3d==0.4.1 3 | h5py==3.8.0 4 | numpy==1.24.3 5 | tqdm==4.65.0 6 | tensorboardX==2.6 7 | open3d==0.17.0 8 | scikit-learn==1.2.2 9 | scipy==1.10.1 10 | torchvision==0.15.2 11 | -------------------------------------------------------------------------------- /test_pcrnet.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import argparse 3 | import os 4 | import sys 5 | import logging 6 | import numpy 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | from torch.utils.data import DataLoader 12 | from tensorboardX import SummaryWriter 13 | from tqdm import tqdm 14 | import transforms3d 15 | 16 | # Only if the files are in example folder. 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | if BASE_DIR[-8:] == 'examples': 19 | sys.path.append(os.path.join(BASE_DIR, os.pardir)) 20 | os.chdir(os.path.join(BASE_DIR, os.pardir)) 21 | 22 | from pcrnet.models import PointNet, iPCRNet 23 | from pcrnet.losses import ChamferDistanceLoss 24 | from pcrnet.data_utils import RegistrationData, ModelNet40Data 25 | 26 | 27 | def display_open3d(template, source, transformed_source): 28 | template_ = o3d.geometry.PointCloud() 29 | source_ = o3d.geometry.PointCloud() 30 | transformed_source_ = o3d.geometry.PointCloud() 31 | template_.points = o3d.utility.Vector3dVector(template) 32 | source_.points = o3d.utility.Vector3dVector(source + np.array([0,0,0])) 33 | transformed_source_.points = o3d.utility.Vector3dVector(transformed_source) 34 | template_.paint_uniform_color([1, 0, 0]) 35 | source_.paint_uniform_color([0, 1, 0]) 36 | transformed_source_.paint_uniform_color([0, 0, 1]) 37 | o3d.visualization.draw_geometries([template_, source_, transformed_source_]) 38 | 39 | # Find error metrics. 40 | def find_errors(igt_R, pred_R, igt_t, pred_t): 41 | # igt_R: Rotation matrix [3, 3] (source = igt_R * template) 42 | # pred_R: Registration algorithm's rotation matrix [3, 3] (template = pred_R * source) 43 | # igt_t: translation vector [1, 3] (source = template + igt_t) 44 | # pred_t: Registration algorithm's translation matrix [1, 3] (template = source + pred_t) 45 | 46 | # Euler distance between ground truth translation and predicted translation. 47 | igt_t = -np.matmul(igt_R.T, igt_t.T).T # gt translation vector (source -> template) 48 | translation_error = np.sqrt(np.sum(np.square(igt_t - pred_t))) 49 | 50 | # Convert matrix remains to axis angle representation and report the angle as rotation error. 51 | error_mat = np.dot(igt_R, pred_R) # matrix remains [3, 3] 52 | _, angle = transforms3d.axangles.mat2axangle(error_mat) 53 | return translation_error, abs(angle*(180/np.pi)) 54 | 55 | def compute_accuracy(igt_R, pred_R, igt_t, pred_t): 56 | errors_temp = [] 57 | for igt_R_i, pred_R_i, igt_t_i, pred_t_i in zip(igt_R, pred_R, igt_t, pred_t): 58 | errors_temp.append(find_errors(igt_R_i, pred_R_i, igt_t_i, pred_t_i)) 59 | return np.mean(errors_temp, axis=0) 60 | 61 | def test_one_epoch(device, model, test_loader): 62 | model.eval() 63 | test_loss = 0.0 64 | pred = 0.0 65 | count = 0 66 | errors = [] 67 | 68 | for i, data in enumerate(tqdm(test_loader)): 69 | template, source, igt, igt_R, igt_t = data 70 | 71 | template = template.to(device) 72 | source = source.to(device) 73 | igt = igt.to(device) 74 | 75 | source_original = source.clone() 76 | template_original = template.clone() 77 | igt_t = igt_t - torch.mean(source, dim=1).unsqueeze(1) 78 | source = source - torch.mean(source, dim=1, keepdim=True) 79 | template = template - torch.mean(template, dim=1, keepdim=True) 80 | 81 | output = model(template, source) 82 | est_R = output['est_R'] 83 | est_t = output['est_t'] 84 | 85 | errors.append(compute_accuracy(igt_R.detach().cpu().numpy(), est_R.detach().cpu().numpy(), 86 | igt_t.detach().cpu().numpy(), est_t.detach().cpu().numpy())) 87 | 88 | transformed_source = torch.bmm(est_R, source.permute(0, 2, 1)).permute(0,2,1) + est_t 89 | display_open3d(template.detach().cpu().numpy()[0], source_original.detach().cpu().numpy()[0], transformed_source.detach().cpu().numpy()[0]) 90 | 91 | loss_val = ChamferDistanceLoss()(template, output['transformed_source']) 92 | 93 | test_loss += loss_val.item() 94 | count += 1 95 | 96 | test_loss = float(test_loss)/count 97 | errors = np.mean(np.array(errors), axis=0) 98 | return test_loss, errors[0], errors[1] 99 | 100 | def test(args, model, test_loader): 101 | test_loss, translation_error, rotation_error = test_one_epoch(args.device, model, test_loader) 102 | print("Test Loss: {}, Rotation Error: {} & Translation Error: {}".format(test_loss, rotation_error, translation_error)) 103 | 104 | def options(): 105 | parser = argparse.ArgumentParser(description='Point Cloud Registration') 106 | parser.add_argument('--exp_name', type=str, default='exp_ipcrnet', metavar='N', 107 | help='Name of the experiment') 108 | parser.add_argument('--dataset_path', type=str, default='ModelNet40', 109 | metavar='PATH', help='path to the input dataset') # like '/path/to/ModelNet40' 110 | parser.add_argument('--eval', type=bool, default=False, help='Train or Evaluate the network.') 111 | 112 | # settings for input data 113 | parser.add_argument('--dataset_type', default='modelnet', choices=['modelnet', 'shapenet2'], 114 | metavar='DATASET', help='dataset type (default: modelnet)') 115 | parser.add_argument('--num_points', default=1024, type=int, 116 | metavar='N', help='points in point-cloud (default: 1024)') 117 | 118 | # settings for PointNet 119 | parser.add_argument('--emb_dims', default=1024, type=int, 120 | metavar='K', help='dim. of the feature vector (default: 1024)') 121 | parser.add_argument('--symfn', default='max', choices=['max', 'avg'], 122 | help='symmetric function (default: max)') 123 | 124 | # settings for on training 125 | parser.add_argument('-j', '--workers', default=4, type=int, 126 | metavar='N', help='number of data loading workers (default: 4)') 127 | parser.add_argument('-b', '--batch_size', default=20, type=int, 128 | metavar='N', help='mini-batch size (default: 32)') 129 | parser.add_argument('--pretrained', default='pcrnet/pretrained/exp_ipcrnet/models/best_model.t7', type=str, 130 | metavar='PATH', help='path to pretrained model file (default: null (no-use))') 131 | parser.add_argument('--device', default='cuda:0', type=str, 132 | metavar='DEVICE', help='use CUDA if available') 133 | 134 | args = parser.parse_args() 135 | return args 136 | 137 | def main(): 138 | args = options() 139 | 140 | testset = RegistrationData('PCRNet', ModelNet40Data(train=False), is_testing=True) 141 | test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=args.workers) 142 | 143 | if not torch.cuda.is_available(): 144 | args.device = 'cpu' 145 | args.device = torch.device(args.device) 146 | 147 | # Create PointNet Model. 148 | ptnet = PointNet(emb_dims=args.emb_dims) 149 | model = iPCRNet(feature_model=ptnet) 150 | model = model.to(args.device) 151 | 152 | if args.pretrained: 153 | assert os.path.isfile(args.pretrained) 154 | model.load_state_dict(torch.load(args.pretrained, map_location='cpu')) 155 | model.to(args.device) 156 | 157 | test(args, model, test_loader) 158 | 159 | if __name__ == '__main__': 160 | main() -------------------------------------------------------------------------------- /train_pcrnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import logging 5 | import numpy 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | import torchvision 10 | from torch.utils.data import DataLoader 11 | from tensorboardX import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | # Only if the files are in example folder. 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | if BASE_DIR[-8:] == 'examples': 17 | sys.path.append(os.path.join(BASE_DIR, os.pardir)) 18 | os.chdir(os.path.join(BASE_DIR, os.pardir)) 19 | 20 | from pcrnet.models import PointNet 21 | from pcrnet.models import iPCRNet 22 | from pcrnet.losses import ChamferDistanceLoss 23 | from pcrnet.data_utils import RegistrationData, ModelNet40Data 24 | 25 | def _init_(args): 26 | if not os.path.exists('checkpoints'): 27 | os.makedirs('checkpoints') 28 | if not os.path.exists('checkpoints/' + args.exp_name): 29 | os.makedirs('checkpoints/' + args.exp_name) 30 | if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'): 31 | os.makedirs('checkpoints/' + args.exp_name + '/' + 'models') 32 | os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup') 33 | os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup') 34 | 35 | 36 | class IOStream: 37 | def __init__(self, path): 38 | self.f = open(path, 'a') 39 | 40 | def cprint(self, text): 41 | print(text) 42 | self.f.write(text + '\n') 43 | self.f.flush() 44 | 45 | def close(self): 46 | self.f.close() 47 | 48 | def test_one_epoch(device, model, test_loader): 49 | model.eval() 50 | test_loss = 0.0 51 | pred = 0.0 52 | count = 0 53 | for i, data in enumerate(tqdm(test_loader)): 54 | template, source, igt = data 55 | 56 | template = template.to(device) 57 | source = source.to(device) 58 | igt = igt.to(device) 59 | 60 | # mean substraction 61 | source = source - torch.mean(source, dim=1, keepdim=True) 62 | template = template - torch.mean(template, dim=1, keepdim=True) 63 | 64 | output = model(template, source) 65 | loss_val = ChamferDistanceLoss()(template, output['transformed_source']) 66 | 67 | test_loss += loss_val.item() 68 | count += 1 69 | 70 | test_loss = float(test_loss)/count 71 | return test_loss 72 | 73 | def test(args, model, test_loader, textio): 74 | test_loss = test_one_epoch(args.device, model, test_loader) 75 | textio.cprint('Validation Loss: %f'%(test_loss)) 76 | 77 | def train_one_epoch(device, model, train_loader, optimizer): 78 | model.train() 79 | train_loss = 0.0 80 | pred = 0.0 81 | count = 0 82 | for i, data in enumerate(tqdm(train_loader)): 83 | template, source, igt = data 84 | 85 | template = template.to(device) 86 | source = source.to(device) 87 | igt = igt.to(device) 88 | 89 | # mean substraction 90 | source = source - torch.mean(source, dim=1, keepdim=True) 91 | template = template - torch.mean(template, dim=1, keepdim=True) 92 | 93 | output = model(template, source) 94 | loss_val = ChamferDistanceLoss()(template, output['transformed_source']) 95 | # print(loss_val.item()) 96 | 97 | # forward + backward + optimize 98 | optimizer.zero_grad() 99 | loss_val.backward() 100 | optimizer.step() 101 | 102 | train_loss += loss_val.item() 103 | count += 1 104 | 105 | train_loss = float(train_loss)/count 106 | return train_loss 107 | 108 | def train(args, model, train_loader, test_loader, boardio, textio, checkpoint): 109 | learnable_params = filter(lambda p: p.requires_grad, model.parameters()) 110 | if args.optimizer == 'Adam': 111 | optimizer = torch.optim.Adam(learnable_params) 112 | else: 113 | optimizer = torch.optim.SGD(learnable_params, lr=0.1) 114 | 115 | if checkpoint is not None: 116 | min_loss = checkpoint['min_loss'] 117 | optimizer.load_state_dict(checkpoint['optimizer']) 118 | 119 | best_test_loss = np.inf 120 | 121 | for epoch in range(args.start_epoch, args.epochs): 122 | train_loss = train_one_epoch(args.device, model, train_loader, optimizer) 123 | test_loss = test_one_epoch(args.device, model, test_loader) 124 | 125 | if test_loss