├── LICENSE ├── MANIFEST.in ├── pointnet2_ops ├── __init__.py ├── _ext-src │ ├── include │ │ ├── ball_query.h │ │ ├── cuda_utils.h │ │ ├── group_points.h │ │ ├── interpolate.h │ │ ├── sampling.h │ │ └── utils.h │ └── src │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── bindings.cpp │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ ├── sampling.cpp │ │ └── sampling_gpu.cu ├── _version.py ├── pointnet2_modules.py └── pointnet2_utils.py ├── pyproject.toml ├── setup.cfg └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Adam Fishman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | ---------------------------------------------------------------------------- 25 | Much of the CUDA kernel code comes from https://github.com/sshaoshuai/Pointnet2.PyTorch 26 | Here is the original license for that code 27 | 28 | MIT License 29 | 30 | Copyright (c) 2019 Shaoshuai Shi 31 | 32 | Permission is hereby granted, free of charge, to any person obtaining a copy 33 | of this software and associated documentation files (the "Software"), to deal 34 | in the Software without restriction, including without limitation the rights 35 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 36 | copies of the Software, and to permit persons to whom the Software is 37 | furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in al 40 | copies or substantial portions of the Software. 41 | 42 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 43 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 44 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 45 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 46 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 47 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 48 | SOFTWARE. 49 | 50 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft pointnet2_ops/_ext-src 2 | -------------------------------------------------------------------------------- /pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query( 5 | at::Tensor new_xyz, 6 | at::Tensor xyz, 7 | const float radius, 8 | const int nsample); 9 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #define TOTAL_THREADS 512 16 | 17 | inline int opt_n_threads(int work_size) { 18 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 19 | 20 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 21 | } 22 | 23 | inline dim3 opt_block_config(int x, int y) { 24 | const int x_threads = opt_n_threads(x); 25 | const int y_threads = 26 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 27 | dim3 block_config(x_threads, y_threads, 1); 28 | 29 | return block_config; 30 | } 31 | 32 | #define CUDA_CHECK_ERRORS() \ 33 | do { \ 34 | cudaError_t err = cudaGetLastError(); \ 35 | if (cudaSuccess != err) { \ 36 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 37 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 38 | __FILE__); \ 39 | exit(-1); \ 40 | } \ 41 | } while (0) 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knowns); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 8 | 9 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 10 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | at::Tensor query_ball_point_kernel_wrapper( 5 | int b, 6 | int n, 7 | int m, 8 | float radius, 9 | int nsample, 10 | const at::Tensor new_xyz, 11 | const at::Tensor xyz); 12 | 13 | at::Tensor ball_query( 14 | at::Tensor new_xyz, 15 | at::Tensor xyz, 16 | const float radius, 17 | const int nsample) { 18 | CHECK_INPUT(new_xyz); 19 | CHECK_INPUT(xyz); 20 | 21 | return query_ball_point_kernel_wrapper( 22 | xyz.size(0), 23 | xyz.size(1), 24 | new_xyz.size(1), 25 | radius, 26 | nsample, 27 | new_xyz, 28 | xyz); 29 | } 30 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "cuda_utils.h" 7 | 8 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 9 | // output: idx(b, m, nsample) 10 | template 11 | __global__ void query_ball_point_kernel( 12 | int b, 13 | int n, 14 | int m, 15 | float radius, 16 | int nsample, 17 | const scalar_t *__restrict__ new_xyz, 18 | const scalar_t *__restrict__ xyz, 19 | int *__restrict__ idx) { 20 | int bs_idx = blockIdx.y; 21 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 22 | if (bs_idx >= b || pt_idx >= m) return; 23 | 24 | new_xyz += bs_idx * m * 3 + pt_idx * 3; 25 | xyz += bs_idx * n * 3; 26 | idx += bs_idx * m * nsample + pt_idx * nsample; 27 | 28 | float radius2 = radius * radius; 29 | scalar_t new_x = new_xyz[0]; 30 | scalar_t new_y = new_xyz[1]; 31 | scalar_t new_z = new_xyz[2]; 32 | 33 | int cnt = 0; 34 | for (int k = 0; k < n; ++k) { 35 | scalar_t x = xyz[k * 3 + 0]; 36 | scalar_t y = xyz[k * 3 + 1]; 37 | scalar_t z = xyz[k * 3 + 2]; 38 | scalar_t d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z); 39 | if (d2 < radius2){ 40 | if (cnt == 0){ 41 | for (int l = 0; l < nsample; ++l) { 42 | idx[l] = k; 43 | } 44 | } 45 | idx[cnt] = k; 46 | ++cnt; 47 | if (cnt >= nsample) break; 48 | } 49 | } 50 | } 51 | 52 | at::Tensor query_ball_point_kernel_wrapper( 53 | int b, 54 | int n, 55 | int m, 56 | float radius, 57 | int nsample, 58 | const at::Tensor new_xyz, 59 | const at::Tensor xyz) { 60 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 61 | at::Tensor idx = torch::zeros( 62 | {new_xyz.size(0), new_xyz.size(1), nsample}, 63 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 64 | 65 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 66 | xyz.scalar_type(), "query_ball_cuda", ([&] { 67 | query_ball_point_kernel<<>>( 68 | b, 69 | n, 70 | m, 71 | radius, 72 | nsample, 73 | new_xyz.data_ptr(), 74 | xyz.data_ptr(), 75 | idx.data_ptr()); 76 | })); 77 | CUDA_CHECK_ERRORS(); 78 | return idx; 79 | } 80 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | at::Tensor group_points_kernel_wrapper( 5 | int b, 6 | int c, 7 | int n, 8 | int npoints, 9 | int nsample, 10 | const at::Tensor points, 11 | const at::Tensor idx); 12 | 13 | at::Tensor group_points_grad_kernel_wrapper( 14 | int b, 15 | int c, 16 | int n, 17 | int npoints, 18 | int nsample, 19 | const at::Tensor grad_out, 20 | const at::Tensor idx); 21 | 22 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 23 | CHECK_INPUT(points); 24 | CHECK_INPUT(idx); 25 | CHECK_IS_INT(idx); 26 | 27 | return group_points_kernel_wrapper( 28 | points.size(0), 29 | points.size(1), 30 | points.size(2), 31 | idx.size(1), 32 | idx.size(2), 33 | points, 34 | idx); 35 | } 36 | 37 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 38 | CHECK_INPUT(grad_out); 39 | CHECK_INPUT(idx); 40 | CHECK_IS_INT(idx); 41 | // Maybe should also check if grad_out is double/float/half? 42 | 43 | return group_points_grad_kernel_wrapper( 44 | grad_out.size(0), 45 | grad_out.size(1), 46 | n, 47 | idx.size(1), 48 | idx.size(2), 49 | grad_out, 50 | idx); 51 | } 52 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: points(b, c, n) idx(b, npoints, nsample) 8 | // output: out(b, c, npoints, nsample) 9 | template 10 | __global__ void group_points_kernel( 11 | int b, 12 | int c, 13 | int n, 14 | int npoints, 15 | int nsample, 16 | const scalar_t *__restrict__ points, 17 | const int *__restrict__ idx, 18 | scalar_t *__restrict__ out) { 19 | int bs_idx = blockIdx.z; 20 | int c_idx = blockIdx.y; 21 | int index = blockIdx.x * blockDim.x + threadIdx.x; 22 | int pt_idx = index / nsample; 23 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 24 | 25 | int sample_idx = index % nsample; 26 | 27 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 28 | int in_idx = bs_idx * c * n + c_idx * n + idx[0]; 29 | int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 30 | 31 | out[out_idx] = points[in_idx]; 32 | } 33 | 34 | at::Tensor group_points_kernel_wrapper( 35 | int b, 36 | int c, 37 | int n, 38 | int npoints, 39 | int nsample, 40 | const at::Tensor points, 41 | const at::Tensor idx) { 42 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 43 | at::Tensor out = torch::zeros( 44 | {points.size(0), points.size(1), idx.size(1), idx.size(2)}, 45 | at::device(points.device()).dtype(points.scalar_type())); 46 | 47 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 48 | points.scalar_type(), "group_points_cuda", ([&] { 49 | group_points_kernel<<>>( 50 | b, 51 | c, 52 | n, 53 | npoints, 54 | nsample, 55 | points.data_ptr(), 56 | idx.data_ptr(), 57 | out.data_ptr()); 58 | })); 59 | CUDA_CHECK_ERRORS(); 60 | return out; 61 | } 62 | 63 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 64 | // output: grad_points(b, c, n) 65 | template 66 | __global__ void group_points_grad_kernel( 67 | int b, 68 | int c, 69 | int n, 70 | int npoints, 71 | int nsample, 72 | const scalar_t *__restrict__ grad_out, 73 | const int *__restrict__ idx, 74 | scalar_t *__restrict__ grad_points) { 75 | int bs_idx = blockIdx.z; 76 | int c_idx = blockIdx.y; 77 | int index = blockIdx.x * blockDim.x + threadIdx.x; 78 | int pt_idx = index / nsample; 79 | if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return; 80 | 81 | int sample_idx = index % nsample; 82 | grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx; 83 | idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 84 | 85 | atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]); 86 | } 87 | 88 | at::Tensor group_points_grad_kernel_wrapper( 89 | int b, 90 | int c, 91 | int n, 92 | int npoints, 93 | int nsample, 94 | const at::Tensor grad_out, 95 | const at::Tensor idx) { 96 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 97 | at::Tensor grad_points = torch::zeros( 98 | {grad_out.size(0), grad_out.size(1), n}, 99 | at::device(grad_out.device()).dtype(grad_out.scalar_type())); 100 | 101 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 102 | grad_out.scalar_type(), "group_points_grad_cuda", ([&] { 103 | group_points_grad_kernel<<>>( 104 | b, 105 | c, 106 | n, 107 | npoints, 108 | nsample, 109 | grad_out.data_ptr(), 110 | idx.data_ptr(), 111 | grad_points.data_ptr()); 112 | })); 113 | CUDA_CHECK_ERRORS(); 114 | return grad_points; 115 | } 116 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | std::vector three_nn_kernel_wrapper( 5 | int b, 6 | int n, 7 | int m, 8 | const at::Tensor unknown, 9 | const at::Tensor known); 10 | 11 | at::Tensor three_interpolate_kernel_wrapper( 12 | int b, 13 | int c, 14 | int m, 15 | int n, 16 | const at::Tensor points, 17 | const at::Tensor idx, 18 | const at::Tensor weight); 19 | 20 | at::Tensor three_interpolate_grad_kernel_wrapper( 21 | int b, 22 | int c, 23 | int n, 24 | int m, 25 | const at::Tensor grad_out, 26 | const at::Tensor idx, 27 | const at::Tensor weight); 28 | 29 | std::vector three_nn(at::Tensor unknowns, at::Tensor knowns) { 30 | CHECK_INPUT(unknowns); 31 | CHECK_INPUT(knowns); 32 | // TODO maybe include a check that its floating/scalar type 33 | 34 | return three_nn_kernel_wrapper( 35 | unknowns.size(0), 36 | unknowns.size(1), 37 | knowns.size(1), 38 | unknowns, 39 | knowns); 40 | } 41 | 42 | at::Tensor three_interpolate( 43 | at::Tensor points, 44 | at::Tensor idx, 45 | at::Tensor weight) { 46 | CHECK_INPUT(points); 47 | CHECK_INPUT(idx); 48 | CHECK_INPUT(weight); 49 | CHECK_IS_INT(idx); 50 | // TODO maybe check types for point and weight? 51 | 52 | return three_interpolate_kernel_wrapper( 53 | points.size(0), 54 | points.size(1), 55 | points.size(2), 56 | idx.size(1), 57 | points, 58 | idx, 59 | weight); 60 | } 61 | at::Tensor three_interpolate_grad( 62 | at::Tensor grad_out, 63 | at::Tensor idx, 64 | at::Tensor weight, 65 | const int m) { 66 | CHECK_INPUT(grad_out); 67 | CHECK_INPUT(idx); 68 | CHECK_INPUT(weight); 69 | CHECK_IS_INT(idx); 70 | // TODO maybe check type for weight and grad_out? 71 | 72 | return three_interpolate_grad_kernel_wrapper( 73 | grad_out.size(0), 74 | grad_out.size(1), 75 | grad_out.size(2), 76 | m, 77 | grad_out, 78 | idx, 79 | weight); 80 | } 81 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | // input: unknown(b, n, 3) known(b, m, 3) 10 | // output: dist2(b, n, 3), idx(b, n, 3) 11 | template 12 | __global__ void three_nn_kernel( 13 | int b, 14 | int n, 15 | int m, 16 | const scalar_t *__restrict__ unknown, 17 | const scalar_t *__restrict__ known, 18 | scalar_t *__restrict__ dist2, 19 | int *__restrict__ idx) { 20 | int bs_idx = blockIdx.y; 21 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 22 | if (bs_idx >= b || pt_idx >= n) return; 23 | 24 | unknown += bs_idx * n * 3 + pt_idx * 3; 25 | known += bs_idx * m * 3; 26 | dist2 += bs_idx * n * 3 + pt_idx * 3; 27 | idx += bs_idx * n * 3 + pt_idx * 3; 28 | 29 | scalar_t ux = unknown[0]; 30 | scalar_t uy = unknown[1]; 31 | scalar_t uz = unknown[2]; 32 | 33 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 34 | int besti1 = 0, besti2 = 0, besti3 = 0; 35 | for (int k = 0; k < m; ++k) { 36 | scalar_t x = known[k * 3 + 0]; 37 | scalar_t y = known[k * 3 + 1]; 38 | scalar_t z = known[k * 3 + 2]; 39 | scalar_t d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 40 | if (d < best1) { 41 | best3 = best2; besti3 = besti2; 42 | best2 = best1; besti2 = besti1; 43 | best1 = d; besti1 = k; 44 | } 45 | else if (d < best2) { 46 | best3 = best2; besti3 = besti2; 47 | best2 = d; besti2 = k; 48 | } 49 | else if (d < best3) { 50 | best3 = d; besti3 = k; 51 | } 52 | } 53 | dist2[0] = best1; dist2[1] = best2; dist2[2] = best3; 54 | idx[0] = besti1; idx[1] = besti2; idx[2] = besti3; 55 | } 56 | 57 | std::vector three_nn_kernel_wrapper( 58 | int b, 59 | int n, 60 | int m, 61 | const at::Tensor unknown, 62 | const at::Tensor known) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | at::Tensor idx = torch::zeros( 65 | {unknown.size(0), unknown.size(1), 3}, 66 | at::device(unknown.device()).dtype(at::ScalarType::Int)); 67 | at::Tensor dist2 = 68 | torch::zeros({unknown.size(0), unknown.size(1), 3}, 69 | at::device(unknown.device()).dtype(unknown.scalar_type())); 70 | 71 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 72 | unknown.scalar_type(), "three_nn_kernel_cuda", ([&] { 73 | three_nn_kernel<<>>( 74 | b, 75 | n, 76 | m, 77 | unknown.data_ptr(), 78 | known.data_ptr(), 79 | dist2.data_ptr(), 80 | idx.data_ptr()); 81 | })); 82 | 83 | CUDA_CHECK_ERRORS(); 84 | return {dist2, idx}; 85 | } 86 | 87 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 88 | // output: out(b, c, n) 89 | template 90 | __global__ void three_interpolate_kernel( 91 | int b, 92 | int c, 93 | int m, 94 | int n, 95 | const scalar_t *__restrict__ points, 96 | const int *__restrict__ idx, 97 | const scalar_t *__restrict__ weight, 98 | scalar_t *__restrict__ out) { 99 | int bs_idx = blockIdx.z; 100 | int c_idx = blockIdx.y; 101 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 102 | 103 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 104 | 105 | weight += bs_idx * n * 3 + pt_idx * 3; 106 | points += bs_idx * c * m + c_idx * m; 107 | idx += bs_idx * n * 3 + pt_idx * 3; 108 | out += bs_idx * c * n + c_idx * n; 109 | 110 | out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]]; 111 | } 112 | 113 | at::Tensor three_interpolate_kernel_wrapper( 114 | int b, 115 | int c, 116 | int m, 117 | int n, 118 | const at::Tensor points, 119 | const at::Tensor idx, 120 | const at::Tensor weight) { 121 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 122 | at::Tensor out = 123 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 124 | at::device(points.device()).dtype(points.scalar_type())); 125 | 126 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 127 | points.scalar_type(), "three_interpolate_cuda", ([&] { 128 | three_interpolate_kernel<<>>( 129 | b, 130 | c, 131 | m, 132 | n, 133 | points.data_ptr(), 134 | idx.data_ptr(), 135 | weight.data_ptr(), 136 | out.data_ptr()); 137 | })); 138 | 139 | CUDA_CHECK_ERRORS(); 140 | return out; 141 | } 142 | 143 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 144 | // output: grad_points(b, c, m) 145 | 146 | template 147 | __global__ void three_interpolate_grad_kernel( 148 | int b, 149 | int c, 150 | int n, 151 | int m, 152 | const scalar_t *__restrict__ grad_out, 153 | const int *__restrict__ idx, 154 | const scalar_t *__restrict__ weight, 155 | scalar_t *__restrict__ grad_points) { 156 | int bs_idx = blockIdx.z; 157 | int c_idx = blockIdx.y; 158 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 159 | 160 | if (bs_idx >= b || c_idx >= c || pt_idx >= n) return; 161 | 162 | grad_out += bs_idx * c * n + c_idx * n + pt_idx; 163 | weight += bs_idx * n * 3 + pt_idx * 3; 164 | grad_points += bs_idx * c * m + c_idx * m; 165 | idx += bs_idx * n * 3 + pt_idx * 3; 166 | 167 | 168 | atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]); 169 | atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]); 170 | atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]); 171 | } 172 | 173 | at::Tensor three_interpolate_grad_kernel_wrapper( 174 | int b, 175 | int c, 176 | int n, 177 | int m, 178 | const at::Tensor grad_out, 179 | const at::Tensor idx, 180 | const at::Tensor weight) { 181 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 182 | at::Tensor grad_points = torch::zeros( 183 | {grad_out.size(0), grad_out.size(1), m}, 184 | at::device(grad_out.device()).dtype(grad_out.scalar_type())); 185 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 186 | grad_out.scalar_type(), "three_interpolate_grad_cuda", ([&] { 187 | three_interpolate_grad_kernel<<>>( 188 | b, 189 | c, 190 | n, 191 | m, 192 | grad_out.data_ptr(), 193 | idx.data_ptr(), 194 | weight.data_ptr(), 195 | grad_points.data_ptr()); 196 | })); 197 | CUDA_CHECK_ERRORS(); 198 | return grad_points; 199 | } 200 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | at::Tensor gather_points_kernel_wrapper( 5 | int b, 6 | int c, 7 | int n, 8 | int npoints, 9 | const at::Tensor points, 10 | const at::Tensor idx); 11 | 12 | at::Tensor gather_points_grad_kernel_wrapper( 13 | int b, 14 | int c, 15 | int n, 16 | int npoints, 17 | const at::Tensor grad_out, 18 | const at::Tensor idx); 19 | 20 | at::Tensor furthest_point_sampling_kernel_wrapper( 21 | int b, 22 | int n, 23 | int m, 24 | const at::Tensor points); 25 | 26 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 27 | CHECK_INPUT(points); 28 | CHECK_INPUT(idx); 29 | CHECK_IS_INT(idx); 30 | // TODO check types for points? 31 | 32 | 33 | return gather_points_kernel_wrapper( 34 | points.size(0), 35 | points.size(1), 36 | points.size(2), 37 | idx.size(1), 38 | points, 39 | idx); 40 | } 41 | 42 | at::Tensor gather_points_grad( 43 | at::Tensor grad_out, 44 | at::Tensor idx, 45 | const int n) { 46 | CHECK_INPUT(grad_out); 47 | CHECK_INPUT(idx); 48 | CHECK_IS_INT(idx); 49 | // TODO Check scalar type for grad_out? 50 | 51 | return gather_points_grad_kernel_wrapper( 52 | grad_out.size(0), 53 | grad_out.size(1), 54 | n, 55 | idx.size(1), 56 | grad_out, 57 | idx); 58 | } 59 | 60 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 61 | CHECK_INPUT(points); 62 | return furthest_point_sampling_kernel_wrapper( 63 | points.size(0), 64 | points.size(1), 65 | nsamples, 66 | points); 67 | } 68 | -------------------------------------------------------------------------------- /pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: points(b, c, n) idx(b, m) 8 | // output: out(b, c, m) 9 | template 10 | __global__ void gather_points_kernel(int b, int c, int n, int m, 11 | const scalar_t *__restrict__ points, 12 | const int *__restrict__ idx, 13 | scalar_t *__restrict__ out) { 14 | int bs_idx = blockIdx.z; 15 | int c_idx = blockIdx.y; 16 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 17 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 18 | 19 | out += bs_idx * c * m + c_idx * m + pt_idx; 20 | idx += bs_idx * m + pt_idx; 21 | points += bs_idx * c * n + c_idx * n; 22 | out[0] = points[idx[0]]; 23 | } 24 | 25 | at::Tensor gather_points_kernel_wrapper( 26 | int b, 27 | int c, 28 | int n, 29 | int npoints, 30 | const at::Tensor points, 31 | const at::Tensor idx) { 32 | at::Tensor out = torch::zeros( 33 | {points.size(0), points.size(1), idx.size(1)}, 34 | at::device(points.device()).dtype(points.scalar_type())); 35 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 36 | points.scalar_type(), "gather_points_kernel_cuda", ([&] { 37 | gather_points_kernel<<>>( 38 | b, 39 | c, 40 | n, 41 | npoints, 42 | points.data_ptr(), 43 | idx.data_ptr(), 44 | out.data_ptr()); 45 | })); 46 | CUDA_CHECK_ERRORS(); 47 | return out; 48 | } 49 | 50 | // input: grad_out(b, c, m) idx(b, m) 51 | // output: grad_points(b, c, n) 52 | template 53 | __global__ void gather_points_grad_kernel( 54 | int b, 55 | int c, 56 | int n, 57 | int m, 58 | const scalar_t *__restrict__ grad_out, 59 | const int *__restrict__ idx, 60 | scalar_t *__restrict__ grad_points) { 61 | int bs_idx = blockIdx.z; 62 | int c_idx = blockIdx.y; 63 | int pt_idx = blockIdx.x * blockDim.x + threadIdx.x; 64 | if (bs_idx >= b || c_idx >= c || pt_idx >= m) return; 65 | 66 | grad_out += bs_idx * c * m + c_idx * m + pt_idx; 67 | idx += bs_idx * m + pt_idx; 68 | grad_points += bs_idx * c * n + c_idx * n; 69 | 70 | atomicAdd(grad_points + idx[0], grad_out[0]); 71 | } 72 | 73 | at::Tensor gather_points_grad_kernel_wrapper( 74 | int b, 75 | int c, 76 | int n, 77 | int npoints, 78 | const at::Tensor grad_out, 79 | const at::Tensor idx) { 80 | at::Tensor grad_points = 81 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 82 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 83 | 84 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 85 | grad_out.scalar_type(), "gather_points_grad_cuda", ([&] { 86 | gather_points_grad_kernel<<>>( 87 | b, 88 | c, 89 | n, 90 | npoints, 91 | grad_out.data_ptr(), 92 | idx.data_ptr(), 93 | grad_points.data_ptr()); 94 | })); 95 | CUDA_CHECK_ERRORS(); 96 | return grad_points; 97 | } 98 | 99 | template 100 | __device__ void __update(scalar_t *__restrict__ dists, int *__restrict__ dists_i, 101 | int idx1, int idx2) { 102 | const scalar_t v1 = dists[idx1], v2 = dists[idx2]; 103 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 104 | dists[idx1] = max(v1, v2); 105 | dists_i[idx1] = v2 > v1 ? i2 : i1; 106 | } 107 | 108 | // Input dataset: (b, n, 3), tmp: (b, n) 109 | // Ouput idx (b, m) 110 | template 111 | __global__ void furthest_point_sampling_kernel( 112 | int b, 113 | int n, 114 | int m, 115 | const scalar_t *__restrict__ dataset, 116 | scalar_t *__restrict__ temp, 117 | int *__restrict__ idx) { 118 | if (m <= 0) return; 119 | __shared__ scalar_t dists[block_size]; 120 | __shared__ int dists_i[block_size]; 121 | 122 | int batch_index = blockIdx.x; 123 | dataset += batch_index * n * 3; 124 | temp += batch_index * n; 125 | idx += batch_index * m; 126 | 127 | int tid = threadIdx.x; 128 | const int stride = block_size; 129 | 130 | int old = 0; 131 | if (threadIdx.x == 0) 132 | idx[0] = old; 133 | 134 | __syncthreads(); 135 | for (int j = 1; j < m; j++) { 136 | int besti = 0; 137 | scalar_t best = -1; 138 | scalar_t x1 = dataset[old * 3 + 0]; 139 | scalar_t y1 = dataset[old * 3 + 1]; 140 | scalar_t z1 = dataset[old * 3 + 2]; 141 | for (int k = tid; k < n; k += stride) { 142 | scalar_t x2, y2, z2; 143 | x2 = dataset[k * 3 + 0]; 144 | y2 = dataset[k * 3 + 1]; 145 | z2 = dataset[k * 3 + 2]; 146 | // scalar_t mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 147 | // if (mag <= 1e-3) 148 | // continue; 149 | 150 | scalar_t d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 151 | scalar_t d2 = min(d, temp[k]); 152 | temp[k] = d2; 153 | besti = d2 > best ? k : besti; 154 | best = d2 > best ? d2 : best; 155 | } 156 | dists[tid] = best; 157 | dists_i[tid] = besti; 158 | __syncthreads(); 159 | 160 | if (block_size >= 1024) { 161 | if (tid < 512) { 162 | __update(dists, dists_i, tid, tid + 512); 163 | } 164 | __syncthreads(); 165 | } 166 | 167 | if (block_size >= 512) { 168 | if (tid < 256) { 169 | __update(dists, dists_i, tid, tid + 256); 170 | } 171 | __syncthreads(); 172 | } 173 | if (block_size >= 256) { 174 | if (tid < 128) { 175 | __update(dists, dists_i, tid, tid + 128); 176 | } 177 | __syncthreads(); 178 | } 179 | if (block_size >= 128) { 180 | if (tid < 64) { 181 | __update(dists, dists_i, tid, tid + 64); 182 | } 183 | __syncthreads(); 184 | } 185 | if (block_size >= 64) { 186 | if (tid < 32) { 187 | __update(dists, dists_i, tid, tid + 32); 188 | } 189 | __syncthreads(); 190 | } 191 | if (block_size >= 32) { 192 | if (tid < 16) { 193 | __update(dists, dists_i, tid, tid + 16); 194 | } 195 | __syncthreads(); 196 | } 197 | if (block_size >= 16) { 198 | if (tid < 8) { 199 | __update(dists, dists_i, tid, tid + 8); 200 | } 201 | __syncthreads(); 202 | } 203 | if (block_size >= 8) { 204 | if (tid < 4) { 205 | __update(dists, dists_i, tid, tid + 4); 206 | } 207 | __syncthreads(); 208 | } 209 | if (block_size >= 4) { 210 | if (tid < 2) { 211 | __update(dists, dists_i, tid, tid + 2); 212 | } 213 | __syncthreads(); 214 | } 215 | if (block_size >= 2) { 216 | if (tid < 1) { 217 | __update(dists, dists_i, tid, tid + 1); 218 | } 219 | __syncthreads(); 220 | } 221 | 222 | old = dists_i[0]; 223 | if (tid == 0) 224 | idx[j] = old; 225 | } 226 | } 227 | 228 | at::Tensor furthest_point_sampling_kernel_wrapper( 229 | int b, 230 | int n, 231 | int m, 232 | at::Tensor points) { 233 | unsigned int n_threads = opt_n_threads(n); 234 | 235 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 236 | at::Tensor idx = 237 | torch::zeros({points.size(0), m}, 238 | at::device(points.device()).dtype(at::ScalarType::Int)); 239 | 240 | // Setting a number close to the maximum a half can be 241 | at::Tensor tmp = 242 | torch::full({points.size(0), points.size(1)}, 65e3, 243 | at::device(points.device()).dtype(points.scalar_type())); 244 | 245 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 246 | points.scalar_type(), "furthest_point_sampling_cuda", ([&] { 247 | switch (n_threads) { 248 | case 512: 249 | furthest_point_sampling_kernel<<>>( 250 | b, 251 | n, 252 | m, 253 | points.data_ptr(), 254 | tmp.data_ptr(), 255 | idx.data_ptr()); 256 | break; 257 | case 256: 258 | furthest_point_sampling_kernel<<>>( 259 | b, 260 | n, 261 | m, 262 | points.data_ptr(), 263 | tmp.data_ptr(), 264 | idx.data_ptr()); 265 | break; 266 | case 128: 267 | furthest_point_sampling_kernel<<>>( 268 | b, 269 | n, 270 | m, 271 | points.data_ptr(), 272 | tmp.data_ptr(), 273 | idx.data_ptr()); 274 | break; 275 | case 64: 276 | furthest_point_sampling_kernel<<>>( 277 | b, 278 | n, 279 | m, 280 | points.data_ptr(), 281 | tmp.data_ptr(), 282 | idx.data_ptr()); 283 | break; 284 | case 32: 285 | furthest_point_sampling_kernel<<>>( 286 | b, 287 | n, 288 | m, 289 | points.data_ptr(), 290 | tmp.data_ptr(), 291 | idx.data_ptr()); 292 | break; 293 | case 16: 294 | furthest_point_sampling_kernel<<>>( 295 | b, 296 | n, 297 | m, 298 | points.data_ptr(), 299 | tmp.data_ptr(), 300 | idx.data_ptr()); 301 | break; 302 | case 8: 303 | furthest_point_sampling_kernel<<>>( 304 | b, 305 | n, 306 | m, 307 | points.data_ptr(), 308 | tmp.data_ptr(), 309 | idx.data_ptr()); 310 | break; 311 | case 4: 312 | furthest_point_sampling_kernel<<>>( 313 | b, 314 | n, 315 | m, 316 | points.data_ptr(), 317 | tmp.data_ptr(), 318 | idx.data_ptr()); 319 | break; 320 | case 2: 321 | furthest_point_sampling_kernel<<>>( 322 | b, 323 | n, 324 | m, 325 | points.data_ptr(), 326 | tmp.data_ptr(), 327 | idx.data_ptr()); 328 | break; 329 | case 1: 330 | furthest_point_sampling_kernel<<>>( 331 | b, 332 | n, 333 | m, 334 | points.data_ptr(), 335 | tmp.data_ptr(), 336 | idx.data_ptr()); 337 | break; 338 | default: 339 | furthest_point_sampling_kernel<<>>(b, n, m, points.data_ptr(), tmp.data_ptr(), idx.data_ptr()); 340 | } 341 | })); 342 | CUDA_CHECK_ERRORS(); 343 | return idx; 344 | } 345 | -------------------------------------------------------------------------------- /pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.3.0" 2 | -------------------------------------------------------------------------------- /pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.LeakyReLU(inplace=True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /pointnet2_ops/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | from torch.autograd import Function 5 | from typing import * 6 | 7 | import pointnet2_ops._ext as _ext 8 | 9 | 10 | class FurthestPointSampling(Function): 11 | @staticmethod 12 | def forward(ctx, xyz, npoint): 13 | # type: (Any, torch.Tensor, int) -> torch.Tensor 14 | r""" 15 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 16 | minimum distance 17 | 18 | Parameters 19 | ---------- 20 | xyz : torch.Tensor 21 | (B, N, 3) tensor where N > npoint 22 | npoint : int32 23 | number of features in the sampled set 24 | 25 | Returns 26 | ------- 27 | torch.Tensor 28 | (B, npoint) tensor containing the set 29 | """ 30 | out = _ext.furthest_point_sampling(xyz, npoint) 31 | 32 | ctx.mark_non_differentiable(out) 33 | 34 | return out 35 | 36 | @staticmethod 37 | def backward(ctx, grad_out): 38 | return () 39 | 40 | 41 | furthest_point_sample = FurthestPointSampling.apply 42 | 43 | 44 | class GatherOperation(Function): 45 | @staticmethod 46 | def forward(ctx, features, idx): 47 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 48 | r""" 49 | 50 | Parameters 51 | ---------- 52 | features : torch.Tensor 53 | (B, C, N) tensor 54 | 55 | idx : torch.Tensor 56 | (B, npoint) tensor of the features to gather 57 | 58 | Returns 59 | ------- 60 | torch.Tensor 61 | (B, C, npoint) tensor 62 | """ 63 | 64 | ctx.save_for_backward(idx, features) 65 | 66 | return _ext.gather_points(features, idx) 67 | 68 | @staticmethod 69 | def backward(ctx, grad_out): 70 | idx, features = ctx.saved_tensors 71 | N = features.size(2) 72 | 73 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 74 | return grad_features, None 75 | 76 | 77 | gather_operation = GatherOperation.apply 78 | 79 | 80 | class ThreeNN(Function): 81 | @staticmethod 82 | def forward(ctx, unknown, known): 83 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 84 | r""" 85 | Find the three nearest neighbors of unknown in known 86 | Parameters 87 | ---------- 88 | unknown : torch.Tensor 89 | (B, n, 3) tensor of known features 90 | known : torch.Tensor 91 | (B, m, 3) tensor of unknown features 92 | 93 | Returns 94 | ------- 95 | dist : torch.Tensor 96 | (B, n, 3) l2 distance to the three nearest neighbors 97 | idx : torch.Tensor 98 | (B, n, 3) index of 3 nearest neighbors 99 | """ 100 | dist2, idx = _ext.three_nn(unknown, known) 101 | dist = torch.sqrt(dist2) 102 | 103 | ctx.mark_non_differentiable(dist, idx) 104 | 105 | return dist, idx 106 | 107 | @staticmethod 108 | def backward(ctx, grad_dist, grad_idx): 109 | return () 110 | 111 | 112 | three_nn = ThreeNN.apply 113 | 114 | 115 | class ThreeInterpolate(Function): 116 | @staticmethod 117 | def forward(ctx, features, idx, weight): 118 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 119 | r""" 120 | Performs weight linear interpolation on 3 features 121 | Parameters 122 | ---------- 123 | features : torch.Tensor 124 | (B, c, m) Features descriptors to be interpolated from 125 | idx : torch.Tensor 126 | (B, n, 3) three nearest neighbors of the target features in features 127 | weight : torch.Tensor 128 | (B, n, 3) weights 129 | 130 | Returns 131 | ------- 132 | torch.Tensor 133 | (B, c, n) tensor of the interpolated features 134 | """ 135 | ctx.save_for_backward(idx, weight, features) 136 | 137 | return _ext.three_interpolate(features, idx, weight) 138 | 139 | @staticmethod 140 | def backward(ctx, grad_out): 141 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 142 | r""" 143 | Parameters 144 | ---------- 145 | grad_out : torch.Tensor 146 | (B, c, n) tensor with gradients of ouputs 147 | 148 | Returns 149 | ------- 150 | grad_features : torch.Tensor 151 | (B, c, m) tensor with gradients of features 152 | 153 | None 154 | 155 | None 156 | """ 157 | idx, weight, features = ctx.saved_tensors 158 | m = features.size(2) 159 | 160 | grad_features = _ext.three_interpolate_grad( 161 | grad_out.contiguous(), idx, weight, m 162 | ) 163 | 164 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 165 | 166 | 167 | three_interpolate = ThreeInterpolate.apply 168 | 169 | 170 | class GroupingOperation(Function): 171 | @staticmethod 172 | def forward(ctx, features, idx): 173 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 174 | r""" 175 | 176 | Parameters 177 | ---------- 178 | features : torch.Tensor 179 | (B, C, N) tensor of features to group 180 | idx : torch.Tensor 181 | (B, npoint, nsample) tensor containing the indicies of features to group with 182 | 183 | Returns 184 | ------- 185 | torch.Tensor 186 | (B, C, npoint, nsample) tensor 187 | """ 188 | ctx.save_for_backward(idx, features) 189 | 190 | return _ext.group_points(features, idx) 191 | 192 | @staticmethod 193 | def backward(ctx, grad_out): 194 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 195 | r""" 196 | 197 | Parameters 198 | ---------- 199 | grad_out : torch.Tensor 200 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 201 | 202 | Returns 203 | ------- 204 | torch.Tensor 205 | (B, C, N) gradient of the features 206 | None 207 | """ 208 | idx, features = ctx.saved_tensors 209 | N = features.size(2) 210 | 211 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 212 | 213 | return grad_features, torch.zeros_like(idx) 214 | 215 | 216 | grouping_operation = GroupingOperation.apply 217 | 218 | 219 | class BallQuery(Function): 220 | @staticmethod 221 | def forward(ctx, radius, nsample, xyz, new_xyz): 222 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 223 | r""" 224 | 225 | Parameters 226 | ---------- 227 | radius : float 228 | radius of the balls 229 | nsample : int 230 | maximum number of features in the balls 231 | xyz : torch.Tensor 232 | (B, N, 3) xyz coordinates of the features 233 | new_xyz : torch.Tensor 234 | (B, npoint, 3) centers of the ball query 235 | 236 | Returns 237 | ------- 238 | torch.Tensor 239 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 240 | """ 241 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 242 | 243 | ctx.mark_non_differentiable(output) 244 | 245 | return output 246 | 247 | @staticmethod 248 | def backward(ctx, grad_out): 249 | return () 250 | 251 | 252 | ball_query = BallQuery.apply 253 | 254 | 255 | class QueryAndGroup(nn.Module): 256 | r""" 257 | Groups with a ball query of radius 258 | 259 | Parameters 260 | --------- 261 | radius : float32 262 | Radius of ball 263 | nsample : int32 264 | Maximum number of features to gather in the ball 265 | """ 266 | 267 | def __init__(self, radius, nsample, use_xyz=True, normalize_xyz=True): 268 | # type: (QueryAndGroup, float, int, bool) -> None 269 | super(QueryAndGroup, self).__init__() 270 | self.radius, self.nsample, self.use_xyz, self.normalize_xyz = ( 271 | radius, 272 | nsample, 273 | use_xyz, 274 | normalize_xyz, 275 | ) 276 | 277 | def forward(self, xyz, new_xyz, features=None): 278 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 279 | r""" 280 | Parameters 281 | ---------- 282 | xyz : torch.Tensor 283 | xyz coordinates of the features (B, N, 3) 284 | new_xyz : torch.Tensor 285 | centriods (B, npoint, 3) 286 | features : torch.Tensor 287 | Descriptors of the features (B, C, N) 288 | 289 | Returns 290 | ------- 291 | new_features : torch.Tensor 292 | (B, 3 + C, npoint, nsample) tensor 293 | """ 294 | 295 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 296 | xyz_trans = xyz.transpose(1, 2).contiguous() 297 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 298 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 299 | if self.normalize_xyz: 300 | grouped_xyz /= self.radius 301 | 302 | if features is not None: 303 | grouped_features = grouping_operation(features, idx) 304 | if self.use_xyz: 305 | new_features = torch.cat( 306 | [grouped_xyz, grouped_features], dim=1 307 | ) # (B, C + 3, npoint, nsample) 308 | else: 309 | new_features = grouped_features 310 | else: 311 | assert ( 312 | self.use_xyz 313 | ), "Cannot have not features and not use xyz as a feature!" 314 | new_features = grouped_xyz 315 | 316 | return new_features 317 | 318 | 319 | class GroupAll(nn.Module): 320 | r""" 321 | Groups all features 322 | 323 | Parameters 324 | --------- 325 | """ 326 | 327 | def __init__(self, use_xyz=True): 328 | # type: (GroupAll, bool) -> None 329 | super(GroupAll, self).__init__() 330 | self.use_xyz = use_xyz 331 | 332 | def forward(self, xyz, new_xyz, features=None): 333 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 334 | r""" 335 | Parameters 336 | ---------- 337 | xyz : torch.Tensor 338 | xyz coordinates of the features (B, N, 3) 339 | new_xyz : torch.Tensor 340 | Ignored 341 | features : torch.Tensor 342 | Descriptors of the features (B, C, N) 343 | 344 | Returns 345 | ------- 346 | new_features : torch.Tensor 347 | (B, C + 3, 1, N) tensor 348 | """ 349 | 350 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 351 | if features is not None: 352 | grouped_features = features.unsqueeze(2) 353 | if self.use_xyz: 354 | new_features = torch.cat( 355 | [grouped_xyz, grouped_features], dim=1 356 | ) # (B, 3 + C, 1, N) 357 | else: 358 | new_features = grouped_features 359 | else: 360 | new_features = grouped_xyz 361 | 362 | return new_features 363 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "torch>=1.8"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = pointnet2_ops 3 | version = 3.3.0 4 | 5 | [options] 6 | packages = find 7 | install_requires = 8 | torch 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 16 | 17 | os.environ["TORCH_CUDA_ARCH_LIST"] = "7.0 7.5 8.0 8.6 8.9" 18 | setup( 19 | name="pointnet2_ops", 20 | version=__version__, 21 | author="Erik Wijmans (Modified by Adam Fishman)", 22 | packages=find_packages(), 23 | ext_modules=[ 24 | CUDAExtension( 25 | name="pointnet2_ops._ext", 26 | sources=_ext_sources, 27 | extra_compile_args={ 28 | "cxx": ["-O3"], 29 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 30 | }, 31 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 32 | ) 33 | ], 34 | cmdclass={"build_ext": BuildExtension}, 35 | include_package_data=True, 36 | ) 37 | --------------------------------------------------------------------------------