├── .gitignore ├── LICENSE ├── README.md ├── chamfer2D ├── chamfer2D.cu ├── chamfer_cuda.cpp ├── dist_chamfer_2D.py └── setup.py ├── chamfer3D ├── chamfer3D.cu ├── chamfer_cuda.cpp ├── dist_chamfer_3D.py └── setup.py ├── chamfer5D ├── chamfer5D.cu ├── chamfer_cuda.cpp ├── dist_chamfer_5D.py └── setup.py ├── chamfer6D ├── chamfer6D.cu ├── chamfer_cuda.cpp ├── dist_chamfer_6D.py └── setup.py ├── chamfer_python.py ├── fscore.py └── unit_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | /tmp 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ThibaultGROUEIX 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | `pip install torch ninja` 2 | 3 | # Pytorch Chamfer Distance. 4 | 5 | Include a **CUDA** version, and a **PYTHON** version with pytorch standard operations. 6 | NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt thresholds accordingly. 7 | 8 | - [x] F - Score 9 | 10 | 11 | 12 | ### CUDA VERSION 13 | 14 | - [x] JIT compilation 15 | - [x] Supports multi-gpu 16 | - [x] 2D point clouds. 17 | - [x] 3D point clouds. 18 | - [x] 5D point clouds. 19 | - [x] Contiguous() safe. 20 | 21 | 22 | 23 | ### Python Version 24 | 25 | - [x] Supports any dimension 26 | 27 | 28 | 29 | ### Usage 30 | 31 | ```python 32 | import torch, chamfer3D.dist_chamfer_3D, fscore 33 | chamLoss = chamfer3D.dist_chamfer_3D.chamfer_3DDist() 34 | points1 = torch.rand(32, 1000, 3).cuda() 35 | points2 = torch.rand(32, 2000, 3, requires_grad=True).cuda() 36 | dist1, dist2, idx1, idx2 = chamLoss(points1, points2) 37 | f_score, precision, recall = fscore.fscore(dist1, dist2) 38 | ``` 39 | 40 | 41 | 42 | ### Add it to your project as a submodule 43 | 44 | ```shell 45 | git submodule add https://github.com/ThibaultGROUEIX/ChamferDistancePytorch 46 | ``` 47 | 48 | 49 | 50 | ### Benchmark: [forward + backward] pass 51 | - [x] CUDA 10.1, NVIDIA 435, Pytorch 1.4 52 | - [x] p1 : 32 x 2000 x dim 53 | - [x] p2 : 32 x 1000 x dim 54 | 55 | | *Timing (sec * 1000)* | 2D | 3D | 5D | 56 | | ---------- | -------- | ------- | ------- | 57 | | **Cuda Compiled** | **1.2** | 1.4 |1.8 | 58 | | **Cuda JIT** | 1.3 | **1.4** |**1.5** | 59 | | **Python** | 37 | 37 | 37 | 60 | 61 | 62 | | *Memory (MB)* | 2D | 3D | 5D | 63 | | ---------- | -------- | ------- | ------- | 64 | | **Cuda Compiled** | 529 | 529 | 549 | 65 | | **Cuda JIT** | **520** | **529** |**549** | 66 | | **Python** | 2495 | 2495 | 2495 | 67 | 68 | 69 | 70 | ### What is the chamfer distance ? 71 | 72 | [Stanford course](http://graphics.stanford.edu/courses/cs468-17-spring/LectureSlides/L14%20-%203d%20deep%20learning%20on%20point%20cloud%20representation%20(analysis).pdf) on 3D deep Learning 73 | 74 | 75 | 76 | ### Aknowledgment 77 | 78 | Original backbone from [Fei Xia](https://github.com/fxia22/pointGAN/blob/master/nndistance/src/nnd_cuda.cu). 79 | 80 | JIT cool trick from [Christian Diller](https://github.com/chrdiller) 81 | 82 | ### Troubleshoot 83 | 84 | - `Undefined symbol: Zxxxxxxxxxxxxxxxxx `: 85 | 86 | --> Fix: Make sure to `import torch` before you `import chamfer`. 87 | --> Use pytorch.version >= 1.1.0 88 | 89 | - [RuntimeError: Ninja is required to load C++ extension](https://github.com/zhanghang1989/PyTorch-Encoding/issues/167) 90 | 91 | ```shell 92 | wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip 93 | sudo unzip ninja-linux.zip -d /usr/local/bin/ 94 | sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force 95 | ``` 96 | 97 | 98 | 99 | 100 | 101 | #### TODO: 102 | 103 | * Discuss behaviour of torch.min() and tensor.min() which causes issues in some pytorch versions 104 | -------------------------------------------------------------------------------- /chamfer2D/chamfer2D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*2]; 15 | for (int i=blockIdx.x;ibest){ 117 | result[(i*n+j)]=best; 118 | result_i[(i*n+j)]=best_i; 119 | } 120 | } 121 | __syncthreads(); 122 | } 123 | } 124 | } 125 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 126 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 127 | 128 | const auto batch_size = xyz1.size(0); 129 | const auto n = xyz1.size(1); //num_points point cloud A 130 | const auto m = xyz2.size(1); //num_points point cloud B 131 | 132 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 133 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 134 | 135 | cudaError_t err = cudaGetLastError(); 136 | if (err != cudaSuccess) { 137 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 138 | //THError("aborting"); 139 | return 0; 140 | } 141 | return 1; 142 | 143 | 144 | } 145 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 146 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 171 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 172 | 173 | cudaError_t err = cudaGetLastError(); 174 | if (err != cudaSuccess) { 175 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 176 | //THError("aborting"); 177 | return 0; 178 | } 179 | return 1; 180 | 181 | } 182 | 183 | -------------------------------------------------------------------------------- /chamfer2D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /chamfer2D/dist_chamfer_2D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_2D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 2D") 10 | cur_path = os.path.dirname(os.path.abspath(__file__)) 11 | build_path = cur_path.replace('chamfer2D', 'tmp') 12 | os.makedirs(build_path, exist_ok=True) 13 | 14 | from torch.utils.cpp_extension import load 15 | chamfer_2D = load(name="chamfer_2D", 16 | sources=[ 17 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 18 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer2D.cu"]), 19 | ], build_directory=build_path) 20 | print("Loaded JIT 2D CUDA chamfer distance") 21 | 22 | else: 23 | import chamfer_2D 24 | print("Loaded compiled 2D CUDA chamfer distance") 25 | 26 | # Chamfer's distance module @thibaultgroueix 27 | # GPU tensors only 28 | class chamfer_2DFunction(Function): 29 | @staticmethod 30 | def forward(ctx, xyz1, xyz2): 31 | batchsize, n, dim = xyz1.size() 32 | assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 33 | _, m, dim = xyz2.size() 34 | assert dim==2, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 35 | device = xyz1.device 36 | 37 | device = xyz1.device 38 | 39 | dist1 = torch.zeros(batchsize, n) 40 | dist2 = torch.zeros(batchsize, m) 41 | 42 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 43 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 44 | 45 | dist1 = dist1.to(device) 46 | dist2 = dist2.to(device) 47 | idx1 = idx1.to(device) 48 | idx2 = idx2.to(device) 49 | torch.cuda.set_device(device) 50 | 51 | chamfer_2D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 52 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 53 | return dist1, dist2, idx1, idx2 54 | 55 | @staticmethod 56 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 57 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 58 | graddist1 = graddist1.contiguous() 59 | graddist2 = graddist2.contiguous() 60 | device = graddist1.device 61 | 62 | gradxyz1 = torch.zeros(xyz1.size()) 63 | gradxyz2 = torch.zeros(xyz2.size()) 64 | 65 | gradxyz1 = gradxyz1.to(device) 66 | gradxyz2 = gradxyz2.to(device) 67 | chamfer_2D.backward( 68 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 69 | ) 70 | return gradxyz1, gradxyz2 71 | 72 | 73 | class chamfer_2DDist(nn.Module): 74 | def __init__(self): 75 | super(chamfer_2DDist, self).__init__() 76 | 77 | def forward(self, input1, input2): 78 | input1 = input1.contiguous() 79 | input2 = input2.contiguous() 80 | return chamfer_2DFunction.apply(input1, input2) 81 | -------------------------------------------------------------------------------- /chamfer2D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_2D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_2D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer2D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /chamfer3D/chamfer3D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /chamfer3D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /chamfer3D/dist_chamfer_3D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_3D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 3D") 10 | cur_path = os.path.dirname(os.path.abspath(__file__)) 11 | build_path = cur_path.replace('chamfer3D', 'tmp') 12 | os.makedirs(build_path, exist_ok=True) 13 | 14 | from torch.utils.cpp_extension import load 15 | chamfer_3D = load(name="chamfer_3D", 16 | sources=[ 17 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 18 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), 19 | ], build_directory=build_path) 20 | print("Loaded JIT 3D CUDA chamfer distance") 21 | 22 | else: 23 | import chamfer_3D 24 | print("Loaded compiled 3D CUDA chamfer distance") 25 | 26 | 27 | # Chamfer's distance module @thibaultgroueix 28 | # GPU tensors only 29 | class chamfer_3DFunction(Function): 30 | @staticmethod 31 | def forward(ctx, xyz1, xyz2): 32 | batchsize, n, dim = xyz1.size() 33 | assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 34 | _, m, dim = xyz2.size() 35 | assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 36 | device = xyz1.device 37 | 38 | device = xyz1.device 39 | 40 | dist1 = torch.zeros(batchsize, n) 41 | dist2 = torch.zeros(batchsize, m) 42 | 43 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 44 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 45 | 46 | dist1 = dist1.to(device) 47 | dist2 = dist2.to(device) 48 | idx1 = idx1.to(device) 49 | idx2 = idx2.to(device) 50 | torch.cuda.set_device(device) 51 | 52 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 53 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 54 | return dist1, dist2, idx1, idx2 55 | 56 | @staticmethod 57 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 58 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 59 | graddist1 = graddist1.contiguous() 60 | graddist2 = graddist2.contiguous() 61 | device = graddist1.device 62 | 63 | gradxyz1 = torch.zeros(xyz1.size()) 64 | gradxyz2 = torch.zeros(xyz2.size()) 65 | 66 | gradxyz1 = gradxyz1.to(device) 67 | gradxyz2 = gradxyz2.to(device) 68 | chamfer_3D.backward( 69 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 70 | ) 71 | return gradxyz1, gradxyz2 72 | 73 | 74 | class chamfer_3DDist(nn.Module): 75 | def __init__(self): 76 | super(chamfer_3DDist, self).__init__() 77 | 78 | def forward(self, input1, input2): 79 | input1 = input1.contiguous() 80 | input2 = input2.contiguous() 81 | return chamfer_3DFunction.apply(input1, input2) 82 | -------------------------------------------------------------------------------- /chamfer3D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_3D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_3D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /chamfer5D/chamfer5D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=2048; 14 | __shared__ float buf[batch*5]; 15 | for (int i=blockIdx.x;ibest){ 147 | result[(i*n+j)]=best; 148 | result_i[(i*n+j)]=best_i; 149 | } 150 | } 151 | __syncthreads(); 152 | } 153 | } 154 | } 155 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 156 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 157 | 158 | const auto batch_size = xyz1.size(0); 159 | const auto n = xyz1.size(1); //num_points point cloud A 160 | const auto m = xyz2.size(1); //num_points point cloud B 161 | 162 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 163 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 164 | 165 | cudaError_t err = cudaGetLastError(); 166 | if (err != cudaSuccess) { 167 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 168 | //THError("aborting"); 169 | return 0; 170 | } 171 | return 1; 172 | 173 | 174 | } 175 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 176 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 213 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 214 | 215 | cudaError_t err = cudaGetLastError(); 216 | if (err != cudaSuccess) { 217 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 218 | //THError("aborting"); 219 | return 0; 220 | } 221 | return 1; 222 | 223 | } 224 | -------------------------------------------------------------------------------- /chamfer5D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /chamfer5D/dist_chamfer_5D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | 7 | chamfer_found = importlib.find_loader("chamfer_5D") is not None 8 | if not chamfer_found: 9 | ## Cool trick from https://github.com/chrdiller 10 | print("Jitting Chamfer 5D") 11 | cur_path = os.path.dirname(os.path.abspath(__file__)) 12 | build_path = cur_path.replace('chamfer5D', 'tmp') 13 | os.makedirs(build_path, exist_ok=True) 14 | 15 | from torch.utils.cpp_extension import load 16 | chamfer_5D = load(name="chamfer_5D", 17 | sources=[ 18 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 19 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer5D.cu"]), 20 | ], build_directory=build_path) 21 | print("Loaded JIT 5D CUDA chamfer distance") 22 | 23 | else: 24 | import chamfer_5D 25 | print("Loaded compiled 5D CUDA chamfer distance") 26 | 27 | 28 | # Chamfer's distance module @thibaultgroueix 29 | # GPU tensors only 30 | class chamfer_5DFunction(Function): 31 | @staticmethod 32 | def forward(ctx, xyz1, xyz2): 33 | batchsize, n, dim = xyz1.size() 34 | assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 35 | _, m, dim = xyz2.size() 36 | assert dim==5, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 37 | device = xyz1.device 38 | 39 | device = xyz1.device 40 | 41 | dist1 = torch.zeros(batchsize, n) 42 | dist2 = torch.zeros(batchsize, m) 43 | 44 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 45 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 46 | 47 | dist1 = dist1.to(device) 48 | dist2 = dist2.to(device) 49 | idx1 = idx1.to(device) 50 | idx2 = idx2.to(device) 51 | torch.cuda.set_device(device) 52 | 53 | chamfer_5D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 54 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 55 | return dist1, dist2, idx1, idx2 56 | 57 | @staticmethod 58 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 59 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 60 | graddist1 = graddist1.contiguous() 61 | graddist2 = graddist2.contiguous() 62 | device = graddist1.device 63 | 64 | gradxyz1 = torch.zeros(xyz1.size()) 65 | gradxyz2 = torch.zeros(xyz2.size()) 66 | 67 | gradxyz1 = gradxyz1.to(device) 68 | gradxyz2 = gradxyz2.to(device) 69 | chamfer_5D.backward( 70 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 71 | ) 72 | return gradxyz1, gradxyz2 73 | 74 | 75 | class chamfer_5DDist(nn.Module): 76 | def __init__(self): 77 | super(chamfer_5DDist, self).__init__() 78 | 79 | def forward(self, input1, input2): 80 | input1 = input1.contiguous() 81 | input2 = input2.contiguous() 82 | return chamfer_5DFunction.apply(input1, input2) 83 | -------------------------------------------------------------------------------- /chamfer5D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_5D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_5D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer5D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /chamfer6D/chamfer6D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=2048; 14 | __shared__ float buf[batch*6]; 15 | for (int i=blockIdx.x;ibest){ 157 | result[(i*n+j)]=best; 158 | result_i[(i*n+j)]=best_i; 159 | } 160 | } 161 | __syncthreads(); 162 | } 163 | } 164 | } 165 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 166 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 167 | 168 | const auto batch_size = xyz1.size(0); 169 | const auto n = xyz1.size(1); //num_points point cloud A 170 | const auto m = xyz2.size(1); //num_points point cloud B 171 | 172 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 173 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 174 | 175 | cudaError_t err = cudaGetLastError(); 176 | if (err != cudaSuccess) { 177 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 178 | //THError("aborting"); 179 | return 0; 180 | } 181 | return 1; 182 | 183 | 184 | } 185 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 186 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 227 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 228 | 229 | cudaError_t err = cudaGetLastError(); 230 | if (err != cudaSuccess) { 231 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 232 | //THError("aborting"); 233 | return 0; 234 | } 235 | return 1; 236 | 237 | } 238 | -------------------------------------------------------------------------------- /chamfer6D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /chamfer6D/dist_chamfer_6D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | 7 | chamfer_found = importlib.find_loader("chamfer_6D") is not None 8 | if not chamfer_found: 9 | ## Cool trick from https://github.com/chrdiller 10 | print("Jitting Chamfer 6D") 11 | cur_path = os.path.dirname(os.path.abspath(__file__)) 12 | build_path = cur_path.replace('chamfer6D', 'tmp') 13 | os.makedirs(build_path, exist_ok=True) 14 | 15 | from torch.utils.cpp_extension import load 16 | chamfer_6D = load(name="chamfer_6D", 17 | sources=[ 18 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 19 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer6D.cu"]), 20 | ], build_directory=build_path) 21 | print("Loaded JIT 6D CUDA chamfer distance") 22 | 23 | else: 24 | import chamfer_6D 25 | print("Loaded compiled 6D CUDA chamfer distance") 26 | 27 | 28 | # Chamfer's distance module @thibaultgroueix 29 | # GPU tensors only 30 | class chamfer_6DFunction(Function): 31 | @staticmethod 32 | def forward(ctx, xyz1, xyz2): 33 | batchsize, n, dim = xyz1.size() 34 | assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 35 | _, m, dim = xyz2.size() 36 | assert dim==6, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 37 | device = xyz1.device 38 | 39 | device = xyz1.device 40 | 41 | dist1 = torch.zeros(batchsize, n) 42 | dist2 = torch.zeros(batchsize, m) 43 | 44 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 45 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 46 | 47 | dist1 = dist1.to(device) 48 | dist2 = dist2.to(device) 49 | idx1 = idx1.to(device) 50 | idx2 = idx2.to(device) 51 | torch.cuda.set_device(device) 52 | 53 | chamfer_6D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 54 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 55 | return dist1, dist2, idx1, idx2 56 | 57 | @staticmethod 58 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 59 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 60 | graddist1 = graddist1.contiguous() 61 | graddist2 = graddist2.contiguous() 62 | device = graddist1.device 63 | 64 | gradxyz1 = torch.zeros(xyz1.size()) 65 | gradxyz2 = torch.zeros(xyz2.size()) 66 | 67 | gradxyz1 = gradxyz1.to(device) 68 | gradxyz2 = gradxyz2.to(device) 69 | chamfer_6D.backward( 70 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 71 | ) 72 | return gradxyz1, gradxyz2 73 | 74 | 75 | class chamfer_6DDist(nn.Module): 76 | def __init__(self): 77 | super(chamfer_6DDist, self).__init__() 78 | 79 | def forward(self, input1, input2): 80 | input1 = input1.contiguous() 81 | input2 = input2.contiguous() 82 | return chamfer_6DFunction.apply(input1, input2) 83 | -------------------------------------------------------------------------------- /chamfer6D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_6D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_6D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer6D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /chamfer_python.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pairwise_dist(x, y): 5 | xx, yy, zz = torch.mm(x, x.t()), torch.mm(y, y.t()), torch.mm(x, y.t()) 6 | rx = xx.diag().unsqueeze(0).expand_as(xx) 7 | ry = yy.diag().unsqueeze(0).expand_as(yy) 8 | P = rx.t() + ry - 2 * zz 9 | return P 10 | 11 | 12 | def NN_loss(x, y, dim=0): 13 | dist = pairwise_dist(x, y) 14 | values, indices = dist.min(dim=dim) 15 | return values.mean() 16 | 17 | 18 | def batched_pairwise_dist(a, b): 19 | x, y = a.double(), b.double() 20 | bs, num_points_x, points_dim = x.size() 21 | bs, num_points_y, points_dim = y.size() 22 | 23 | xx = torch.pow(x, 2).sum(2) 24 | yy = torch.pow(y, 2).sum(2) 25 | zz = torch.bmm(x, y.transpose(2, 1)) 26 | rx = xx.unsqueeze(1).expand(bs, num_points_y, num_points_x) # Diagonal elements xx 27 | ry = yy.unsqueeze(1).expand(bs, num_points_x, num_points_y) # Diagonal elements yy 28 | P = rx.transpose(2, 1) + ry - 2 * zz 29 | return P 30 | 31 | def distChamfer(a, b): 32 | """ 33 | :param a: Pointclouds Batch x nul_points x dim 34 | :param b: Pointclouds Batch x nul_points x dim 35 | :return: 36 | -closest point on b of points from a 37 | -closest point on a of points from b 38 | -idx of closest point on b of points from a 39 | -idx of closest point on a of points from b 40 | Works for pointcloud of any dimension 41 | """ 42 | P = batched_pairwise_dist(a, b) 43 | return torch.min(P, 2)[0].float(), torch.min(P, 1)[0].float(), torch.min(P, 2)[1].int(), torch.min(P, 1)[1].int() 44 | 45 | -------------------------------------------------------------------------------- /fscore.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def fscore(dist1, dist2, threshold=0.001): 4 | """ 5 | Calculates the F-score between two point clouds with the corresponding threshold value. 6 | :param dist1: Batch, N-Points 7 | :param dist2: Batch, N-Points 8 | :param th: float 9 | :return: fscore, precision, recall 10 | """ 11 | # NB : In this depo, dist1 and dist2 are squared pointcloud euclidean distances, so you should adapt the threshold accordingly. 12 | precision_1 = torch.mean((dist1 < threshold).float(), dim=1) 13 | precision_2 = torch.mean((dist2 < threshold).float(), dim=1) 14 | fscore = 2 * precision_1 * precision_2 / (precision_1 + precision_2) 15 | fscore[torch.isnan(fscore)] = 0 16 | return fscore, precision_1, precision_2 17 | 18 | -------------------------------------------------------------------------------- /unit_test.py: -------------------------------------------------------------------------------- 1 | import torch, time 2 | import chamfer2D.dist_chamfer_2D 3 | import chamfer3D.dist_chamfer_3D 4 | import chamfer5D.dist_chamfer_5D 5 | import chamfer_python 6 | 7 | cham2D = chamfer2D.dist_chamfer_2D.chamfer_2DDist() 8 | cham3D = chamfer3D.dist_chamfer_3D.chamfer_3DDist() 9 | cham5D = chamfer5D.dist_chamfer_5D.chamfer_5DDist() 10 | 11 | from torch.autograd import Variable 12 | from fscore import fscore 13 | 14 | def test_chamfer(distChamfer, dim): 15 | points1 = torch.rand(4, 100, dim).cuda() 16 | points2 = torch.rand(4, 200, dim, requires_grad=True).cuda() 17 | dist1, dist2, idx1, idx2= distChamfer(points1, points2) 18 | 19 | loss = torch.sum(dist1) 20 | loss.backward() 21 | 22 | mydist1, mydist2, myidx1, myidx2 = chamfer_python.distChamfer(points1, points2) 23 | d1 = (dist1 - mydist1) ** 2 24 | d2 = (dist2 - mydist2) ** 2 25 | assert ( 26 | torch.mean(d1) + torch.mean(d2) < 0.00000001 27 | ), "chamfer cuda and chamfer normal are not giving the same results" 28 | 29 | xd1 = idx1 - myidx1 30 | xd2 = idx2 - myidx2 31 | assert ( 32 | torch.norm(xd1.float()) + torch.norm(xd2.float()) == 0 33 | ), "chamfer cuda and chamfer normal are not giving the same results" 34 | print(f"fscore :", fscore(dist1, dist2)) 35 | print("Unit test passed") 36 | 37 | 38 | def timings(distChamfer, dim): 39 | p1 = torch.rand(32, 2000, dim).cuda() 40 | p2 = torch.rand(32, 1000, dim).cuda() 41 | print("Timings : Start CUDA version") 42 | start = time.time() 43 | num_it = 100 44 | for i in range(num_it): 45 | points1 = Variable(p1, requires_grad=True) 46 | points2 = Variable(p2) 47 | mydist1, mydist2, idx1, idx2 = distChamfer(points1, points2) 48 | loss = torch.sum(mydist1) 49 | loss.backward() 50 | print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") 51 | 52 | 53 | print("Timings : Start Pythonic version") 54 | start = time.time() 55 | for i in range(num_it): 56 | points1 = Variable(p1, requires_grad=True) 57 | points2 = Variable(p2) 58 | mydist1, mydist2, idx1, idx2 = chamfer_python.distChamfer(points1, points2) 59 | loss = torch.sum(mydist1) 60 | loss.backward() 61 | print(f"Ellapsed time forward backward is {(time.time() - start)/num_it} seconds.") 62 | 63 | 64 | 65 | dims = [2,3,5] 66 | for i,cham in enumerate([cham2D, cham3D, cham5D]): 67 | print(f"testing Chamfer {dims[i]}D") 68 | test_chamfer(cham, dims[i]) 69 | timings(cham, dims[i]) 70 | --------------------------------------------------------------------------------