├── pkg ├── src │ ├── emd.cpp │ └── cuda │ │ └── emd.cu ├── layer │ ├── __init__.py │ └── emd_loss_layer.py ├── include │ ├── cuda_helper.h │ ├── emd.h │ └── cuda │ │ └── emd.cuh └── emd_loss_layer.py ├── test_emd_loss.py ├── setup.py └── README.md /pkg/src/emd.cpp: -------------------------------------------------------------------------------- 1 | #include "emd.h" 2 | -------------------------------------------------------------------------------- /pkg/layer/__init__.py: -------------------------------------------------------------------------------- 1 | from .emd_loss_layer import EMDLoss -------------------------------------------------------------------------------- /test_emd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import time 4 | 5 | from emd import EMDLoss 6 | 7 | dist = EMDLoss() 8 | 9 | p1 = torch.rand(1,5,3).cuda().double() 10 | p2 = torch.rand(1,10,3).cuda().double() 11 | p1.requires_grad = True 12 | p2.requires_grad = True 13 | 14 | s = time.time() 15 | cost = dist(p1, p2) 16 | emd_time = time.time() - s 17 | 18 | print('Time: ', emd_time) 19 | print(cost) 20 | loss = torch.sum(cost) 21 | print(loss) 22 | loss.backward() 23 | print(p1.grad) 24 | print(p2.grad) 25 | -------------------------------------------------------------------------------- /pkg/include/cuda_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_HELPER_H_ 2 | #define CUDA_HELPER_H_ 3 | 4 | #define CUDA_CHECK(err) \ 5 | if (cudaSuccess != err) \ 6 | { \ 7 | fprintf(stderr, "CUDA kernel failed: %s (%s:%d)\n", \ 8 | cudaGetErrorString(err), __FILE__, __LINE__); \ 9 | std::exit(-1); \ 10 | } 11 | 12 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), \ 13 | #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), \ 15 | #x " must be contiguous") 16 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 17 | 18 | #endif -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='PyTorch EMD', 7 | version='0.0', 8 | author='Marc Eder', 9 | author_email='meder@cs.unc.edu', 10 | description='A PyTorch module for the earth mover\'s distance loss', 11 | ext_package='_emd_ext', 12 | ext_modules=[ 13 | CUDAExtension( 14 | name='_emd', 15 | sources=[ 16 | 'pkg/src/emd.cpp', 17 | 'pkg/src/cuda/emd.cu', 18 | ], 19 | include_dirs=['pkg/include'], 20 | ), 21 | ], 22 | packages=[ 23 | 'emd', 24 | ], 25 | package_dir={ 26 | 'emd' : 'pkg/layer' 27 | }, 28 | cmdclass={'build_ext': BuildExtension}, 29 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch EMDLoss 2 | PyTorch 1.0 implementation of the approximate Earth Mover's Distance 3 | 4 | This is a PyTorch wrapper of CUDA code for computing an approximation to the Earth Mover's Distance loss. 5 | 6 | Original source code can be found [here](https://github.com/fxia22/pointGAN/tree/74b6c432c5eaa1e0a833e755f450df2ee2c5488e/emd). This repository updates the code to be compatible with PyTorch 1.0 and extends the implementation to handle arbitrary dimensions of data. 7 | 8 | Installation should be as simple as running `python setup.py install`. 9 | 10 | **Limitations and Known Bugs:** 11 | - Double tensors must have <=11 dimensions while float tensors must have <=23 dimensions. This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 12 | - When handling larger point sets (M, N > ~2000), the CUDA kernel will fail. I think this is due to an overflow error in computing the approximate matching kernel. Any suggestions to fix this would be greatly appreciated. I have pinpointed the source of the bug [here](https://github.com/meder411/PyTorch-EMDLoss/blob/master/pkg/include/cuda/emd.cuh#L160). 13 | -------------------------------------------------------------------------------- /pkg/layer/emd_loss_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import _emd_ext._emd as emd 5 | 6 | 7 | class EMDFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, xyz1, xyz2): 10 | cost, match = emd.emd_forward(xyz1, xyz2) 11 | self.save_for_backward(xyz1, xyz2, match) 12 | return cost 13 | 14 | 15 | @staticmethod 16 | def backward(self, grad_output): 17 | xyz1, xyz2, match = self.saved_tensors 18 | grad_xyz1, grad_xyz2 = emd.emd_backward(xyz1, xyz2, match) 19 | return grad_xyz1, grad_xyz2 20 | 21 | 22 | 23 | 24 | class EMDLoss(nn.Module): 25 | ''' 26 | Computes the (approximate) Earth Mover's Distance between two point sets. 27 | 28 | IMPLEMENTATION LIMITATIONS: 29 | - Double tensors must have <=11 dimensions 30 | - Float tensors must have <=23 dimensions 31 | This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 32 | ''' 33 | 34 | def __init__(self): 35 | super(EMDLoss, self).__init__() 36 | 37 | def forward(self, xyz1, xyz2): 38 | 39 | assert xyz1.shape[-1] == xyz2.shape[-1], 'Both point sets must have the same dimensionality' 40 | if xyz1.dtype == torch.float64 and xyz1.shape[-1] > 11: 41 | error('Tensors of type double can have a maximum of 11 dimensions') 42 | if xyz1.dtype == torch.float32 and xyz1.shape[-1] > 23: 43 | error('Tensors of type float can have a maximum of 23 dimensions') 44 | 45 | return EMDFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /pkg/include/emd.h: -------------------------------------------------------------------------------- 1 | #ifndef EMD_H_ 2 | #define EMD_H_ 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_helper.h" 8 | 9 | 10 | std::vector emd_forward_cuda( 11 | at::Tensor xyz1, 12 | at::Tensor xyz2); 13 | 14 | std::vector emd_backward_cuda( 15 | at::Tensor xyz1, 16 | at::Tensor xyz2, 17 | at::Tensor match); 18 | 19 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 20 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 21 | // CALL FUNCTION IMPLEMENTATIONS 22 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 23 | // * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * 24 | 25 | std::vector emd_forward( 26 | at::Tensor xyz1, 27 | at::Tensor xyz2) 28 | { 29 | CHECK_INPUT(xyz1); 30 | CHECK_INPUT(xyz2); 31 | 32 | return emd_forward_cuda(xyz1, xyz2); 33 | } 34 | 35 | std::vector emd_backward( 36 | at::Tensor xyz1, 37 | at::Tensor xyz2, 38 | at::Tensor match) 39 | { 40 | CHECK_INPUT(xyz1); 41 | CHECK_INPUT(xyz2); 42 | CHECK_INPUT(match); 43 | 44 | return emd_backward_cuda(xyz1, xyz2, match); 45 | } 46 | 47 | 48 | 49 | 50 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 51 | m.def("emd_forward", &emd_forward, "Compute Earth Mover's Distance approximation"); 52 | m.def("emd_backward", &emd_backward, "Compute Earth Mover's Distance approximation"); 53 | } 54 | 55 | 56 | 57 | #endif -------------------------------------------------------------------------------- /pkg/emd_loss_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import _emd_ext._emd as emd 5 | 6 | 7 | class EMDFunction(torch.autograd.Function): 8 | @staticmethod 9 | def forward(self, xyz1, xyz2): 10 | cost, match = emd.emd_forward(xyz1, xyz2) 11 | self.save_for_backward(xyz1, xyz2, match) 12 | return cost 13 | 14 | 15 | @staticmethod 16 | def backward(self, grad_output): 17 | xyz1, xyz2, match = self.saved_tensors 18 | grad_xyz1, grad_xyz2 = emd.emd_backward(xyz1, xyz2, match) 19 | return grad_xyz1, grad_xyz2 20 | 21 | 22 | 23 | 24 | class EMDLoss(nn.Module): 25 | ''' 26 | Computes the (approximate) Earth Mover's Distance between two point sets. 27 | 28 | IMPLEMENTATION LIMITATIONS: 29 | - Double tensors must have <=11 dimensions 30 | - Float tensors must have <=23 dimensions 31 | This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB. 32 | ''' 33 | 34 | def __init__(self): 35 | super(EMDLoss, self).__init__() 36 | 37 | def forward(self, xyz1, xyz2): 38 | ''' 39 | xyz1: B x N x D point set 40 | xyz2: B x M x D point set 41 | ''' 42 | 43 | assert xyz1.shape[-1] == xyz2.shape[-1], 'Both point sets must have the same dimensionality' 44 | if xyz1.dtype == torch.float64 and xyz1.shape[-1] > 11: 45 | error('Tensors of type double can have a maximum of 11 dimensions') 46 | if xyz1.dtype == torch.float32 and xyz1.shape[-1] > 23: 47 | error('Tensors of type float can have a maximum of 23 dimensions') 48 | 49 | return EMDFunction.apply(xyz1, xyz2) -------------------------------------------------------------------------------- /pkg/src/cuda/emd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "cuda/emd.cuh" 6 | 7 | 8 | std::vector emd_forward_cuda( 9 | at::Tensor xyz1, // B x N1 x D 10 | at::Tensor xyz2) // B x N2 x D 11 | { 12 | // Some useful values 13 | const int64_t batch_size = xyz1.size(0); 14 | const int64_t num_pts_1 = xyz1.size(1); 15 | const int64_t num_pts_2 = xyz2.size(1); 16 | const int64_t dim = xyz1.size(2); 17 | 18 | // Allocate necessary data structures 19 | at::Tensor match = at::zeros({batch_size, num_pts_1, num_pts_2}, 20 | xyz1.options()); 21 | at::Tensor cost = at::zeros({batch_size}, xyz1.options()); 22 | at::Tensor temp = at::zeros({batch_size, 2 * (num_pts_1 + num_pts_2)}, 23 | xyz1.options()); 24 | 25 | // Find the approximate matching 26 | approx_match( 27 | batch_size, num_pts_1, num_pts_2, dim, 28 | xyz1, 29 | xyz2, 30 | match, 31 | temp 32 | ); 33 | 34 | // Compute the matching cost 35 | match_cost( 36 | batch_size, num_pts_1, num_pts_2, dim, 37 | xyz1, 38 | xyz2, 39 | match, 40 | cost 41 | ); 42 | 43 | return {cost, match}; 44 | } 45 | 46 | 47 | std::vector emd_backward_cuda( 48 | at::Tensor xyz1, 49 | at::Tensor xyz2, 50 | at::Tensor match) 51 | { 52 | // Some useful values 53 | const int64_t batch_size = xyz1.size(0); 54 | const int64_t num_pts_1 = xyz1.size(1); 55 | const int64_t num_pts_2 = xyz2.size(1); 56 | const int64_t dim = xyz1.size(2); 57 | 58 | // Allocate necessary data structures 59 | at::Tensor grad_xyz1 = at::zeros_like(xyz1); 60 | at::Tensor grad_xyz2 = at::zeros_like(xyz2); 61 | 62 | // Compute the gradient with respect to the two inputs (xyz1 and xyz2) 63 | match_cost_grad( 64 | batch_size, num_pts_1, num_pts_2, dim, 65 | xyz1, 66 | xyz2, 67 | match, 68 | grad_xyz1, 69 | grad_xyz2 70 | ); 71 | 72 | return {grad_xyz1, grad_xyz2}; 73 | } 74 | -------------------------------------------------------------------------------- /pkg/include/cuda/emd.cuh: -------------------------------------------------------------------------------- 1 | #ifndef EMD_CUH_ 2 | #define EMD_CUH_ 3 | 4 | #include "cuda_helper.h" 5 | 6 | #define BLOCK_SIZE 512 7 | 8 | template 9 | __global__ void approx_match_kernel( 10 | const int64_t b, const int64_t n, const int64_t m, const int64_t d, 11 | const T * __restrict__ xyz1, 12 | const T * __restrict__ xyz2, 13 | T * __restrict__ match, 14 | T * temp) 15 | { 16 | // Pointers to temporary storage for this current block 17 | // Starting point for batch = blockIdx.x * (n + m) 18 | T * remainL = temp + blockIdx.x * (n + m) * 2; // Start of set 1 19 | T * remainR = temp + blockIdx.x * (n + m) * 2 + n; // Start of set 2 20 | T * ratioL = temp + blockIdx.x * (n + m) * 2 + n + m; // Start of set 1 21 | T * ratioR = temp + blockIdx.x * (n + m) * 2 + n + m + n; // Start of set 2 22 | 23 | // Ratio of two point sets 24 | T multiL; 25 | T multiR; 26 | if (n >= m) 27 | { 28 | multiL = 1; 29 | multiR = n / m; 30 | } 31 | else 32 | { 33 | multiL = m / n; 34 | multiR = 1; 35 | } 36 | 37 | // Dynamic shared memory buffer for templated function 38 | // https://stackoverflow.com/a/27570775/3427580 39 | extern __shared__ __align__(sizeof(T)) unsigned char my_buf[]; 40 | T *buf = reinterpret_cast(my_buf); 41 | 42 | // For each batch 43 | for (int64_t i = blockIdx.x; i < b; i += gridDim.x) 44 | { 45 | // Initialize match values 46 | for (int64_t j = threadIdx.x; j < n*m; j += blockDim.x) 47 | { 48 | match[i*n*m+j] = 0; 49 | } 50 | for (int64_t j = threadIdx.x; j < n; j += blockDim.x) 51 | { 52 | remainL[j] = multiL; 53 | } 54 | for (int64_t j = threadIdx.x; j < m; j += blockDim.x) 55 | { 56 | remainR[j] = multiR; 57 | } 58 | __syncthreads(); 59 | 60 | 61 | for (int64_t j = 7; j >= -2; j--) 62 | { 63 | T level = -pow(4.0, j); 64 | if (j == -2) { level = 0; } 65 | 66 | // Iterate over blocks 67 | for (int64_t k0 = 0; k0 < n; k0 += blockDim.x) 68 | { 69 | // Current thread linear index 70 | int64_t k = k0 + threadIdx.x; 71 | 72 | // Initialize a really small, non-zero sum 73 | T suml = T(1e-9); 74 | 75 | // Iterate over grid 76 | for (int64_t l0 = 0; l0 < m; l0 += BLOCK_SIZE) 77 | { 78 | // End of the block or m, whichever comes first 79 | int64_t lend = min(m, l0 + BLOCK_SIZE) - l0; 80 | 81 | // Put points from the second set into the shared buffer 82 | for (int64_t l = threadIdx.x; l < lend; l += blockDim.x) 83 | { 84 | for(int64_t z = 0; z < d; z++) 85 | { 86 | buf[l*(d+1)+z] = xyz2[i*m*d+l0*d+l*d+z];; 87 | 88 | } 89 | buf[l*(d+1)+d] = remainR[l0+l]; 90 | } 91 | __syncthreads(); 92 | 93 | for (int64_t l = 0; l < lend; l++) 94 | { 95 | T v = 0; 96 | for (int64_t z = 0; z < d; z++) 97 | { 98 | if (k < n) 99 | { 100 | v += (buf[l*(d+1)+z] - xyz1[i*n*d+k*d+z]) * 101 | (buf[l*(d+1)+z] - xyz1[i*n*d+k*d+z]); 102 | } 103 | else 104 | { 105 | v += buf[l*(d+1)+z] * buf[l*(d+1)+z]; 106 | } 107 | } 108 | v *= level; 109 | suml += exp(v)*buf[l*(d+1)+d]; 110 | } 111 | __syncthreads(); 112 | } 113 | if (k < n) { ratioL[k] = remainL[k] / suml; } 114 | } 115 | __syncthreads(); 116 | 117 | // Iterate over blocks again (now for second point set) 118 | for (int64_t l0 = 0; l0 < m; l0 += blockDim.x) 119 | { 120 | int64_t l = l0 + threadIdx.x; 121 | T sumr = 0; 122 | for (int64_t k0 = 0; k0 < n; k0 += BLOCK_SIZE) 123 | { 124 | int64_t kend = min(n, k0 + BLOCK_SIZE) - k0; 125 | for (int64_t k = threadIdx.x; k < kend; k += blockDim.x) 126 | { 127 | for (int64_t z = 0; z < d; z++) 128 | { 129 | buf[k*(d+1)+z] = xyz1[i*n*d+k0*d+k*d+z]; 130 | } 131 | buf[k*(d+1)+d] = ratioL[k0+k]; 132 | } 133 | __syncthreads(); 134 | 135 | for (int64_t k = 0; k < kend; k++) 136 | { 137 | T v = 0; 138 | for (int64_t z = 0; z < d; z++) 139 | { 140 | if (l < m) 141 | { 142 | v += (xyz2[i*m*d+l*d+z] - buf[l*(d+1)+z]) * 143 | (xyz2[i*m*d+l*d+z] - buf[l*(d+1)+z]); 144 | } 145 | else 146 | { 147 | v += buf[l*(d+1)+z] * buf[l*(d+1)+z]; 148 | } 149 | } 150 | v *= level; 151 | sumr += exp(v)*buf[k*(d+1)+d]; 152 | } 153 | __syncthreads(); 154 | } 155 | 156 | if (l < m) 157 | { 158 | sumr *= remainR[l]; 159 | T consumption = fmin(remainR[l] / (sumr + 1e-9), 1.0); 160 | // ****************************** 161 | // SOURCE OF THE ISSUE: sumr 162 | // Any variable that is a function of sumr causes an error 163 | // Specifically the assignments below 164 | // It's an issue only with large m and n. Maybe it's a 165 | // overflow issue? 166 | // ****************************** 167 | ratioR[l] = consumption * remainR[l]; 168 | remainR[l] = fmax(0.0, remainR[l] - sumr); 169 | } 170 | } 171 | __syncthreads(); 172 | 173 | for (int64_t k0 = 0; k0 < n; k0 += blockDim.x) 174 | { 175 | int64_t k = k0 + threadIdx.x; 176 | T suml=0; 177 | 178 | for (int64_t l0 = 0; l0 < m; l0 += BLOCK_SIZE) 179 | { 180 | int64_t lend = min(m, l0 + BLOCK_SIZE) - l0; 181 | for (int64_t l = threadIdx.x; l < lend; l += blockDim.x) 182 | { 183 | 184 | for(int64_t z = 0; z < d; z++) 185 | { 186 | buf[l*(d+1)+z] = xyz2[i*m*d+l0*d+l*d+z];; 187 | 188 | } 189 | buf[l*(d+1)+d] = ratioR[l0+l]; 190 | } 191 | __syncthreads(); 192 | 193 | T rl = ratioL[k]; 194 | if (k < n) 195 | { 196 | for (int64_t l = 0; l < lend; l++) 197 | { 198 | 199 | T v = 0; 200 | for (int64_t z = 0; z < d; z++) 201 | { 202 | if (k < n) 203 | { 204 | v += (buf[l*(d+1)+z] - xyz1[i*n*d+k*d+z]) * 205 | (buf[l*(d+1)+z] - xyz1[i*n*d+k*d+z]); 206 | } 207 | else 208 | { 209 | v += buf[l*(d+1)+z] * buf[l*(d+1)+z]; 210 | } 211 | } 212 | v *= level; 213 | 214 | T w = __expf(v)*buf[l*(d+1)+d]*rl; 215 | match[i*n*m+(l0+l)*n+k] += w; 216 | suml += w; 217 | } 218 | } 219 | __syncthreads(); 220 | } 221 | 222 | if (k < n) { remainL[k] = fmaxf(0.0f, remainL[k] - suml); } 223 | } 224 | __syncthreads(); 225 | } 226 | } 227 | } 228 | 229 | void approx_match( 230 | const int64_t b, const int64_t n, 231 | const int64_t m, const int64_t d, 232 | const at::Tensor xyz1, 233 | const at::Tensor xyz2, 234 | at::Tensor match, 235 | at::Tensor temp) 236 | { 237 | AT_DISPATCH_FLOATING_TYPES(match.type(), "approx_match_kernel", ([&] { 238 | approx_match_kernel 239 | <<<32, 512, BLOCK_SIZE*(d+1)*sizeof(scalar_t)>>>( 240 | b, n, m, d, 241 | xyz1.data(), 242 | xyz2.data(), 243 | match.data(), 244 | temp.data()); 245 | })); 246 | cudaDeviceSynchronize(); 247 | CUDA_CHECK(cudaGetLastError()) 248 | } 249 | 250 | 251 | 252 | template 253 | __global__ void match_cost_kernel( 254 | const int64_t b, const int64_t n, const int64_t m, const int64_t d, 255 | const T * __restrict__ xyz1, 256 | const T * __restrict__ xyz2, 257 | const T * __restrict__ match, 258 | T * __restrict__ out) 259 | { 260 | // First 512 elements is used for sum computation 261 | // Remaining buffer is a general buffer 262 | extern __shared__ __align__(sizeof(T)) unsigned char my_buf[]; 263 | T *buf = reinterpret_cast(my_buf); 264 | 265 | for (int64_t i=blockIdx.x;i>>( 321 | b, n, m, d, 322 | xyz1.data(), 323 | xyz2.data(), 324 | match.data(), 325 | out.data()); 326 | })); 327 | CUDA_CHECK(cudaGetLastError()) 328 | } 329 | 330 | 331 | 332 | template 333 | __global__ void match_cost_grad2_kernel( 334 | const int64_t b, const int64_t n, 335 | const int64_t m, const int64_t d, 336 | const T * __restrict__ xyz1, 337 | const T * __restrict__ xyz2, 338 | const T * __restrict__ match, 339 | T * __restrict__ grad2) 340 | { 341 | 342 | extern __shared__ __align__(sizeof(T)) unsigned char my_buf[]; 343 | T *sum_grad = reinterpret_cast(my_buf); 344 | 345 | for (int64_t i = blockIdx.x; i < b; i += gridDim.x) 346 | { 347 | int64_t kbeg = m*blockIdx.y / gridDim.y; 348 | int64_t kend = m*(blockIdx.y+1) / gridDim.y; 349 | for (int64_t k = kbeg; k < kend; k++) 350 | { 351 | for (int64_t j = threadIdx.x; j < n; j += blockDim.x) 352 | { 353 | T v = 0; 354 | for (int64_t z = 0; z < d; z++) 355 | { 356 | v += (xyz2[(i*m+k)*d+z] - xyz1[(i*n+j)*d+z]) * 357 | (xyz2[(i*m+k)*d+z] - xyz1[(i*n+j)*d+z]); 358 | } 359 | T w = match[i*n*m+k*n+j] * rsqrtf(fmaxf(v, 1e-20f)); 360 | 361 | for (int64_t z = 0; z < d; z++) 362 | { 363 | sum_grad[threadIdx.x*d+z] += xyz1[(i*n+j)*d+z] * w; 364 | } 365 | } 366 | for (int64_t j = 1; j < blockDim.x; j <<= 1) 367 | { 368 | __syncthreads(); 369 | int64_t j1 = threadIdx.x; 370 | int64_t j2 = threadIdx.x + j; 371 | if ((j1 & j) == 0 && j2 < blockDim.x) 372 | { 373 | for (int64_t z = 0; z < d; z++) 374 | { 375 | sum_grad[j1*d+z] += sum_grad[j2*d+z]; 376 | } 377 | } 378 | } 379 | if (threadIdx.x == 0) 380 | { 381 | for (int64_t z = 0; z < d; z++) 382 | { 383 | grad2[(i*m+k)*d+z] = sum_grad[z]; 384 | } 385 | } 386 | __syncthreads(); 387 | } 388 | } 389 | } 390 | 391 | 392 | template 393 | __global__ void match_cost_grad1_kernel( 394 | const int64_t b, const int64_t n, 395 | const int64_t m, const int64_t d, 396 | const T * __restrict__ xyz1, 397 | const T * __restrict__ xyz2, 398 | const T * __restrict__ match, 399 | T * __restrict__ grad1) 400 | { 401 | for (int64_t i = blockIdx.x; i < b; i += gridDim.x) 402 | { 403 | for (int64_t l = threadIdx.x; l < n; l += blockDim.x) 404 | { 405 | for (int64_t k = 0; k < m; k++) 406 | { 407 | T v = 0; 408 | for (int64_t z = 0; z < d; z++) 409 | { 410 | v += (xyz1[i*n*d+l*d+z] - xyz2[i*m*d+k*d+z]) * 411 | (xyz1[i*n*d+l*d+z] - xyz2[i*m*d+k*d+z]); 412 | } 413 | T w = match[i*n*m+k*n+l] * rsqrtf(fmaxf(v, 1e-20f)); 414 | 415 | for (int64_t z = 0; z < d; z++) 416 | { 417 | grad1[i*n*d+l*d+z] += 418 | (xyz1[i*n*d+l*d+z] - xyz2[i*m*d+k*d+z]) * w; 419 | } 420 | } 421 | } 422 | } 423 | } 424 | 425 | void match_cost_grad( 426 | const int64_t b, const int64_t n, 427 | const int64_t m, const int64_t d, 428 | const at::Tensor xyz1, 429 | const at::Tensor xyz2, 430 | const at::Tensor match, 431 | at::Tensor grad1, 432 | at::Tensor grad2) 433 | { 434 | AT_DISPATCH_FLOATING_TYPES(xyz1.type(), "match_cost_grad1_kernel", ([&] { 435 | match_cost_grad1_kernel<<<32,512>>>( 436 | b, n, m, d, 437 | xyz1.data(), 438 | xyz2.data(), 439 | match.data(), 440 | grad1.data()); 441 | })); 442 | CUDA_CHECK(cudaGetLastError()) 443 | 444 | AT_DISPATCH_FLOATING_TYPES(xyz1.type(), "match_cost_grad2_kernel", ([&] { 445 | match_cost_grad2_kernel<<>>( 446 | b, n, m, d, 447 | xyz1.data(), 448 | xyz2.data(), 449 | match.data(), 450 | grad2.data()); 451 | })); 452 | CUDA_CHECK(cudaGetLastError()) 453 | } 454 | 455 | 456 | #endif --------------------------------------------------------------------------------