├── pointnet2_ops_lib ├── MANIFEST.in ├── pointnet2_ops │ ├── _version.py │ ├── __init__.py │ ├── _ext-src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── group_points.h │ │ │ ├── sampling.h │ │ │ ├── interpolate.h │ │ │ ├── utils.h │ │ │ └── cuda_utils.h │ │ └── src │ │ │ ├── bindings.cpp │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── sampling.cpp │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ └── sampling_gpu.cu │ ├── pointnet2_modules.py │ └── pointnet2_utils.py └── setup.py ├── .DS_Store ├── img └── pct.png ├── README.md ├── dataset.py ├── util.py ├── module.py ├── cls.py └── model.py /pointnet2_ops_lib/MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft pointnet2_ops/_ext-src 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinglew/PointCloudTransformer/HEAD/.DS_Store -------------------------------------------------------------------------------- /img/pct.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qinglew/PointCloudTransformer/HEAD/img/pct.png -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 18 | 19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 20 | setup( 21 | name="pointnet2_ops", 22 | version=__version__, 23 | author="Erik Wijmans", 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name="pointnet2_ops._ext", 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point Cloud Transformer 2 | 3 | ![](img/pct.png) 4 | 5 | ## Description 6 | 7 | Implementation of PCT(Point Cloud Transformer) in PyTorch. 8 | 9 | > **Abstract**: The irregular domain and lack of ordering make it challenging to design deep neural networks for point cloud processing. This paper presents a novel framework named 10 | > Point Cloud Transformer(PCT) for point cloud learning. PCT is based on Transformer, which achieves huge success in natural language processing and displays great potential in image processing. It is inherently permutation invariant for processing a sequence of points, making it well-suited for point cloud learning. To better capture local context within the point cloud, we enhance input embedding with the support of farthest point sampling and nearest neighbor search. Extensive experiments demonstrate that the PCT achieves the state-of-the-art performance on shape classification, part segmentation and normal estimation tasks. 11 | 12 | ## Environment 13 | 14 | * Ubuntu 18.04 LTS 15 | * CUDA 11.0 16 | * PyTorch 1.7.0 17 | 18 | ## Training 19 | 20 | Before you excute the training code, you need to install the module in `pointnet2_ops_lib`: 21 | 22 | ```shell 23 | pip install pointnet2_ops_lib/. 24 | ``` 25 | 26 | In order to training the model, you can use the following command: 27 | 28 | ```shell 29 | python cls.py --model=pct --exp_name=pct_cls --num_points=1024 --use_sgd=True --batch_size=32 --epochs 250 --lr 0.0001 30 | ``` 31 | 32 | Just modify the parameters if you want to change to another model, etc. 33 | 34 | ## Testing 35 | 36 | In order to testing the model, you can use the following command: 37 | 38 | ```shell 39 | python main.py --exp_name=test --num_points=1024 --use_sgd=True --eval=True --model_path=checkpoints/pct_cls/models/model.t7 --test_batch_size=32 40 | ``` 41 | 42 | Just modify the parameters if you want to change to another model, etc. 43 | 44 | ## Citation 45 | 46 | 1. https://arxiv.org/pdf/2012.09688.pdf 47 | 2. https://github.com/MenghaoGuo/PCT 48 | 3. https://github.com/uyzhang/PCT_Pytorch 49 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def download(): 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | DATA_DIR = os.path.join(BASE_DIR, 'data') 11 | if not os.path.exists(DATA_DIR): 12 | os.mkdir(DATA_DIR) 13 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 14 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 15 | zipfile = os.path.basename(www) 16 | os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile)) 17 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 18 | os.system('rm %s' % (zipfile)) 19 | 20 | 21 | def load_data(partition): 22 | download() 23 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | DATA_DIR = os.path.join(BASE_DIR, 'data') # you can modify here to assign the path where dataset's root located at 25 | all_data = [] 26 | all_label = [] 27 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)): 28 | f = h5py.File(h5_name) 29 | data = f['data'][:].astype('float32') 30 | label = f['label'][:].astype('int64') 31 | f.close() 32 | all_data.append(data) 33 | all_label.append(label) 34 | all_data = np.concatenate(all_data, axis=0) 35 | all_label = np.concatenate(all_label, axis=0) 36 | return all_data, all_label 37 | 38 | 39 | def random_point_dropout(pc, max_dropout_ratio=0.875): 40 | ''' batch_pc: BxNx3 ''' 41 | # for b in range(batch_pc.shape[0]): 42 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 43 | drop_idx = np.where(np.random.random((pc.shape[0]))<=dropout_ratio)[0] 44 | # print ('use random drop', len(drop_idx)) 45 | 46 | if len(drop_idx)>0: 47 | pc[drop_idx,:] = pc[0,:] # set to the first point 48 | return pc 49 | 50 | 51 | def translate_pointcloud(pointcloud): 52 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 53 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 54 | 55 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 56 | return translated_pointcloud 57 | 58 | 59 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 60 | N, C = pointcloud.shape 61 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 62 | return pointcloud 63 | 64 | 65 | class ModelNet40(Dataset): 66 | def __init__(self, num_points, partition='train'): 67 | self.data, self.label = load_data(partition) 68 | self.num_points = num_points 69 | self.partition = partition 70 | 71 | def __getitem__(self, item): 72 | pointcloud = self.data[item][:self.num_points] 73 | label = self.label[item] 74 | if self.partition == 'train': 75 | pointcloud = random_point_dropout(pointcloud) # open for dgcnn not for our idea for all 76 | pointcloud = translate_pointcloud(pointcloud) 77 | np.random.shuffle(pointcloud) 78 | return pointcloud, label 79 | 80 | def __len__(self): 81 | return self.data.shape[0] 82 | 83 | 84 | if __name__ == '__main__': 85 | train = ModelNet40(1024) 86 | test = ModelNet40(1024, 'test') 87 | for data, label in train: 88 | print(data.shape) 89 | print(label.shape) 90 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.ReLU(True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from pointnet2_ops import pointnet2_utils 4 | 5 | 6 | def cal_loss(pred, ground_truth, smoothing=True): 7 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 8 | 9 | ground_truth = ground_truth.contiguous().view(-1) 10 | 11 | if smoothing: 12 | eps = 0.2 13 | n_class = pred.size(1) 14 | 15 | one_hot = torch.zeros_like(pred).scatter(1, ground_truth.view(-1, 1), 1) 16 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 17 | log_prb = F.log_softmax(pred, dim=1) 18 | 19 | loss = -(one_hot * log_prb).sum(dim=1).mean() 20 | else: 21 | loss = F.cross_entropy(pred, ground_truth, reduction='mean') 22 | 23 | return loss 24 | 25 | 26 | def square_distance(src, dst): 27 | """ 28 | Calculate Euclid distance between each two points. 29 | src^T * dst = xn * xm + yn * ym + zn * zm; 30 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 31 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 32 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 33 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 34 | 35 | Input: 36 | src: source points, [B, N, C] 37 | dst: target points, [B, M, C] 38 | 39 | Output: 40 | dist: per-point square distance, [B, N, M] 41 | """ 42 | B, N, _ = src.shape 43 | _, M, _ = dst.shape 44 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 45 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 46 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 47 | return dist 48 | 49 | 50 | def query_ball_point(radius, nsample, xyz, new_xyz): 51 | """ 52 | Ball query. 53 | 54 | Input: 55 | radius: local region radius 56 | nsample: max sample number in local region 57 | xyz: all points, [B, N, 3] 58 | new_xyz: query points, [B, S, 3] 59 | 60 | Output: 61 | group_idx: grouped points index, [B, S, nsample] 62 | """ 63 | device = xyz.device 64 | B, N, C = xyz.shape 65 | _, S, _ = new_xyz.shape 66 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 67 | sqrdists = square_distance(new_xyz, xyz) 68 | group_idx[sqrdists > radius ** 2] = N 69 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 70 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 71 | mask = group_idx == N 72 | group_idx[mask] = group_first[mask] 73 | return group_idx 74 | 75 | 76 | def knn_point(k, xyz, new_xyz): 77 | """ 78 | K nearest neighborhood. 79 | 80 | Input: 81 | k: max sample number in local region 82 | xyz: all points, [B, N, C] 83 | new_xyz: query points, [B, S, C] 84 | 85 | Output: 86 | group_idx: grouped points index, [B, S, k] 87 | """ 88 | sqrdists = square_distance(new_xyz, xyz) 89 | _, group_idx = torch.topk(sqrdists, k, dim=-1, largest=False, sorted=False) 90 | return group_idx 91 | 92 | 93 | def index_points(points, idx): 94 | """ 95 | Input: 96 | points: input points data, [B, N, C] 97 | idx: sample index data, [B, S] 98 | 99 | Output: 100 | new_points:, indexed points data, [B, S, C] 101 | """ 102 | device = points.device 103 | B = points.shape[0] 104 | view_shape = list(idx.shape) 105 | view_shape[1:] = [1] * (len(view_shape) - 1) 106 | repeat_shape = list(idx.shape) 107 | repeat_shape[0] = 1 108 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 109 | new_points = points[batch_indices, idx, :] 110 | return new_points 111 | 112 | 113 | def sample_and_ball_group(s, radius, n, coords, features): 114 | """ 115 | Sampling by FPS and grouping by ball query. 116 | 117 | Input: 118 | s[int]: number of points to be sampled by FPS 119 | k[int]: number of points to be grouped into a neighbor by ball query 120 | n[int]: fix number of points in ball neighbor 121 | coords[tensor]: input points coordinates data with size of [B, N, 3] 122 | features[tensor]: input points features data with size of [B, N, D] 123 | 124 | Returns: 125 | new_coords[tensor]: sampled and grouped points coordinates by FPS with size of [B, s, k, 3] 126 | new_features[tensor]: sampled and grouped points features by FPS with size of [B, s, k, 2D] 127 | """ 128 | batch_size = coords.shape[0] 129 | coords = coords.contiguous() 130 | 131 | # FPS sampling 132 | fps_idx = pointnet2_utils.furthest_point_sample(coords, s).long() # [B, s] 133 | new_coords = index_points(coords, fps_idx) # [B, s, 3] 134 | new_features = index_points(features, fps_idx) # [B, s, D] 135 | 136 | # ball_query grouping 137 | idx = query_ball_point(radius, n, coords, new_coords) # [B, s, n] 138 | grouped_features = index_points(features, idx) # [B, s, n, D] 139 | 140 | # Matrix sub 141 | grouped_features_norm = grouped_features - new_features.view(batch_size, s, 1, -1) # [B, s, n, D] 142 | 143 | # Concat, my be different in many networks 144 | aggregated_features = torch.cat([grouped_features_norm, new_features.view(batch_size, s, 1, -1).repeat(1, 1, n, 1)], dim=-1) # [B, s, n, 2D] 145 | 146 | return new_coords, aggregated_features # [B, s, 3], [B, s, n, 2D] 147 | 148 | 149 | def sample_and_knn_group(s, k, coords, features): 150 | """ 151 | Sampling by FPS and grouping by KNN. 152 | 153 | Input: 154 | s[int]: number of points to be sampled by FPS 155 | k[int]: number of points to be grouped into a neighbor by KNN 156 | coords[tensor]: input points coordinates data with size of [B, N, 3] 157 | features[tensor]: input points features data with size of [B, N, D] 158 | 159 | Returns: 160 | new_coords[tensor]: sampled and grouped points coordinates by FPS with size of [B, s, k, 3] 161 | new_features[tensor]: sampled and grouped points features by FPS with size of [B, s, k, 2D] 162 | """ 163 | batch_size = coords.shape[0] 164 | coords = coords.contiguous() 165 | 166 | # FPS sampling 167 | fps_idx = pointnet2_utils.furthest_point_sample(coords, s).long() # [B, s] 168 | new_coords = index_points(coords, fps_idx) # [B, s, 3] 169 | new_features = index_points(features, fps_idx) # [B, s, D] 170 | 171 | # K-nn grouping 172 | idx = knn_point(k, coords, new_coords) # [B, s, k] 173 | grouped_features = index_points(features, idx) # [B, s, k, D] 174 | 175 | # Matrix sub 176 | grouped_features_norm = grouped_features - new_features.view(batch_size, s, 1, -1) # [B, s, k, D] 177 | 178 | # Concat 179 | aggregated_features = torch.cat([grouped_features_norm, new_features.view(batch_size, s, 1, -1).repeat(1, 1, k, 1)], dim=-1) # [B, s, k, 2D] 180 | 181 | return new_coords, aggregated_features # [B, s, 3], [B, s, k, 2D] 182 | 183 | 184 | class Logger(): 185 | def __init__(self, path): 186 | self.f = open(path, 'a') 187 | 188 | def cprint(self, text): 189 | print(text) 190 | self.f.write(text+'\n') 191 | self.f.flush() 192 | 193 | def close(self): 194 | self.f.close() 195 | 196 | 197 | if __name__ == '__main__': 198 | points = torch.rand(32, 1024, 3).to('cuda') 199 | features = torch.rand(32, 1024, 128).to('cuda') 200 | new_points, new_features = sample_and_knn_group(512, 32, points, features) 201 | print(new_points.size()) 202 | print(new_features.size()) 203 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from util import sample_and_knn_group 7 | 8 | 9 | class Embedding(nn.Module): 10 | """ 11 | Input Embedding layer which consist of 2 stacked LBR layer. 12 | """ 13 | 14 | def __init__(self, in_channels=3, out_channels=128): 15 | super(Embedding, self).__init__() 16 | 17 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 18 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) 19 | 20 | self.bn1 = nn.BatchNorm1d(out_channels) 21 | self.bn2 = nn.BatchNorm1d(out_channels) 22 | 23 | def forward(self, x): 24 | """ 25 | Input 26 | x: [B, in_channels, N] 27 | 28 | Output 29 | x: [B, out_channels, N] 30 | """ 31 | x = F.relu(self.bn1(self.conv1(x))) 32 | x = F.relu(self.bn2(self.conv2(x))) 33 | return x 34 | 35 | 36 | class SA(nn.Module): 37 | """ 38 | Self Attention module. 39 | """ 40 | 41 | def __init__(self, channels): 42 | super(SA, self).__init__() 43 | 44 | self.da = channels // 4 45 | 46 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 47 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 48 | self.q_conv.weight = self.k_conv.weight 49 | self.v_conv = nn.Conv1d(channels, channels, 1) 50 | 51 | self.trans_conv = nn.Conv1d(channels, channels, 1) 52 | self.after_norm = nn.BatchNorm1d(channels) 53 | 54 | self.act = nn.ReLU() 55 | self.softmax = nn.Softmax(dim=-1) 56 | 57 | def forward(self, x): 58 | """ 59 | Input 60 | x: [B, de, N] 61 | 62 | Output 63 | x: [B, de, N] 64 | """ 65 | # compute query, key and value matrix 66 | x_q = self.q_conv(x).permute(0, 2, 1) # [B, N, da] 67 | x_k = self.k_conv(x) # [B, da, N] 68 | x_v = self.v_conv(x) # [B, de, N] 69 | 70 | # compute attention map and scale, the sorfmax 71 | energy = torch.bmm(x_q, x_k) / (math.sqrt(self.da)) # [B, N, N] 72 | attention = self.softmax(energy) # [B, N, N] 73 | 74 | # weighted sum 75 | x_s = torch.bmm(x_v, attention) # [B, de, N] 76 | x_s = self.act(self.after_norm(self.trans_conv(x_s))) 77 | 78 | # residual 79 | x = x + x_s 80 | 81 | return x 82 | 83 | 84 | class SG(nn.Module): 85 | """ 86 | SG(sampling and grouping) module. 87 | """ 88 | 89 | def __init__(self, s, in_channels, out_channels): 90 | super(SG, self).__init__() 91 | 92 | self.s = s 93 | 94 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 95 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) 96 | self.bn1 = nn.BatchNorm1d(out_channels) 97 | self.bn2 = nn.BatchNorm1d(out_channels) 98 | 99 | def forward(self, x, coords): 100 | """ 101 | Input: 102 | x: features with size of [B, in_channels//2, N] 103 | coords: coordinates data with size of [B, N, 3] 104 | """ 105 | x = x.permute(0, 2, 1) # (B, N, in_channels//2) 106 | new_xyz, new_feature = sample_and_knn_group(s=self.s, k=32, coords=coords, features=x) # [B, s, 3], [B, s, 32, in_channels] 107 | b, s, k, d = new_feature.size() 108 | new_feature = new_feature.permute(0, 1, 3, 2) 109 | new_feature = new_feature.reshape(-1, d, k) # [Bxs, in_channels, 32] 110 | batch_size = new_feature.size(0) 111 | new_feature = F.relu(self.bn1(self.conv1(new_feature))) # [Bxs, in_channels, 32] 112 | new_feature = F.relu(self.bn2(self.conv2(new_feature))) # [Bxs, in_channels, 32] 113 | new_feature = F.adaptive_max_pool1d(new_feature, 1).view(batch_size, -1) # [Bxs, in_channels] 114 | new_feature = new_feature.reshape(b, s, -1).permute(0, 2, 1) # [B, in_channels, s] 115 | return new_xyz, new_feature 116 | 117 | 118 | class NeighborEmbedding(nn.Module): 119 | def __init__(self, samples=[512, 256]): 120 | super(NeighborEmbedding, self).__init__() 121 | 122 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) 123 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 124 | self.bn1 = nn.BatchNorm1d(64) 125 | self.bn2 = nn.BatchNorm1d(64) 126 | 127 | self.sg1 = SG(s=samples[0], in_channels=128, out_channels=128) 128 | self.sg2 = SG(s=samples[1], in_channels=256, out_channels=256) 129 | 130 | def forward(self, x): 131 | """ 132 | Input: 133 | x: [B, 3, N] 134 | """ 135 | xyz = x.permute(0, 2, 1) # [B, N ,3] 136 | 137 | features = F.relu(self.bn1(self.conv1(x))) # [B, 64, N] 138 | features = F.relu(self.bn2(self.conv2(features))) # [B, 64, N] 139 | 140 | xyz1, features1 = self.sg1(features, xyz) # [B, 128, 512] 141 | _, features2 = self.sg2(features1, xyz1) # [B, 256, 256] 142 | 143 | return features2 144 | 145 | 146 | class OA(nn.Module): 147 | """ 148 | Offset-Attention Module. 149 | """ 150 | 151 | def __init__(self, channels): 152 | super(OA, self).__init__() 153 | 154 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 155 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 156 | self.q_conv.weight = self.k_conv.weight 157 | self.v_conv = nn.Conv1d(channels, channels, 1) 158 | 159 | self.trans_conv = nn.Conv1d(channels, channels, 1) 160 | self.after_norm = nn.BatchNorm1d(channels) 161 | 162 | self.act = nn.ReLU() 163 | self.softmax = nn.Softmax(dim=-1) # change dim to -2 and change the sum(dim=1, keepdims=True) to dim=2 164 | 165 | def forward(self, x): 166 | """ 167 | Input: 168 | x: [B, de, N] 169 | 170 | Output: 171 | x: [B, de, N] 172 | """ 173 | x_q = self.q_conv(x).permute(0, 2, 1) 174 | x_k = self.k_conv(x) 175 | x_v = self.v_conv(x) 176 | 177 | energy = torch.bmm(x_q, x_k) 178 | attention = self.softmax(energy) 179 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) # here 180 | 181 | x_r = torch.bmm(x_v, attention) 182 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) 183 | x = x + x_r 184 | 185 | return x 186 | 187 | 188 | if __name__ == '__main__': 189 | """ 190 | Please be careful to excute the testing code, because 191 | it may cause the GPU out of memory. 192 | """ 193 | 194 | pc = torch.rand(32, 3, 1024).to('cuda') 195 | 196 | # testing for Embedding 197 | embedding = Embedding().to('cuda') 198 | out = embedding(pc) 199 | print("Embedding output size:", out.size()) 200 | 201 | # testing for SA 202 | sa = SA(channels=out.size(1)).to('cuda') 203 | out = sa(out) 204 | print("SA output size:", out.size()) 205 | 206 | # testing for SG 207 | coords = torch.rand(32, 1024, 3).to('cuda') 208 | features = torch.rand(32, 64, 1024).to('cuda') 209 | sg = SG(512, 128, 128).to('cuda') 210 | new_coords, out = sg(features, coords) 211 | print("SG output size:", new_coords.size(), out.size()) 212 | 213 | # testing for NeighborEmbedding 214 | ne = NeighborEmbedding().to('cuda') 215 | out = ne(pc) 216 | print("NeighborEmbedding output size:", out.size()) 217 | 218 | # testing for OA 219 | oa = OA(256).to('cuda') 220 | out = oa(out) 221 | print("OA output size:", out.size()) 222 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /cls.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | import sklearn.metrics as metrics 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | 13 | from dataset import ModelNet40 14 | from model import NaivePCTCls, SPCTCls, PCTCls 15 | from util import cal_loss, Logger 16 | 17 | 18 | models = {'navie_pct': NaivePCTCls, 19 | 'spct': SPCTCls, 20 | 'pct': PCTCls} 21 | 22 | 23 | def _init_(args): 24 | if not os.path.exists('checkpoints'): 25 | os.makedirs('checkpoints') 26 | if not os.path.exists('checkpoints/' + args.exp_name): 27 | os.makedirs('checkpoints/' + args.exp_name) 28 | if not os.path.exists('checkpoints/' + args.exp_name + '/' + 'models'): 29 | os.makedirs('checkpoints/' + args.exp_name + '/' + 'models') 30 | 31 | 32 | def train(args, io): 33 | train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8, 34 | batch_size=args.batch_size, shuffle=True, drop_last=True) 35 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8, 36 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 37 | 38 | device = torch.device("cuda" if args.cuda else "cpu") 39 | 40 | model = models[args.model]().to(device) 41 | model = nn.DataParallel(model) 42 | 43 | if args.use_sgd: 44 | print("Use SGD") 45 | opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) 46 | else: 47 | print("Use Adam") 48 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 49 | 50 | scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr) 51 | 52 | criterion = cal_loss 53 | best_test_acc = 0 54 | 55 | for epoch in range(args.epochs): 56 | train_loss = 0.0 57 | count = 0.0 # numbers of data 58 | model.train() 59 | train_pred = [] 60 | train_true = [] 61 | idx = 0 # iterations 62 | total_time = 0.0 63 | for data, label in (train_loader): 64 | data, label = data.to(device), label.to(device).squeeze() 65 | data = data.permute(0, 2, 1) 66 | batch_size = data.size()[0] 67 | opt.zero_grad() 68 | 69 | start_time = time.time() 70 | logits = model(data) 71 | loss = criterion(logits, label) 72 | loss.backward() 73 | opt.step() 74 | end_time = time.time() 75 | total_time += (end_time - start_time) 76 | 77 | preds = logits.max(dim=1)[1] 78 | count += batch_size 79 | train_loss += loss.item() * batch_size 80 | train_true.append(label.cpu().numpy()) 81 | train_pred.append(preds.detach().cpu().numpy()) 82 | idx += 1 83 | 84 | print ('train total time is',total_time) 85 | train_true = np.concatenate(train_true) 86 | train_pred = np.concatenate(train_pred) 87 | outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (epoch, 88 | train_loss * 1.0 / count, 89 | metrics.accuracy_score(train_true, train_pred), 90 | metrics.balanced_accuracy_score(train_true, train_pred)) 91 | io.cprint(outstr) 92 | 93 | #################### 94 | # Test 95 | #################### 96 | test_loss = 0.0 97 | count = 0.0 98 | model.eval() 99 | test_pred = [] 100 | test_true = [] 101 | total_time = 0.0 102 | for data, label in test_loader: 103 | data, label = data.to(device), label.to(device).squeeze() 104 | data = data.permute(0, 2, 1) 105 | batch_size = data.size()[0] 106 | start_time = time.time() 107 | logits = model(data) 108 | end_time = time.time() 109 | total_time += (end_time - start_time) 110 | loss = criterion(logits, label) 111 | preds = logits.max(dim=1)[1] 112 | count += batch_size 113 | test_loss += loss.item() * batch_size 114 | test_true.append(label.cpu().numpy()) 115 | test_pred.append(preds.detach().cpu().numpy()) 116 | print ('test total time is', total_time) 117 | test_true = np.concatenate(test_true) 118 | test_pred = np.concatenate(test_pred) 119 | test_acc = metrics.accuracy_score(test_true, test_pred) 120 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) 121 | outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch, 122 | test_loss*1.0/count, 123 | test_acc, 124 | avg_per_class_acc) 125 | io.cprint(outstr) 126 | if test_acc >= best_test_acc: 127 | best_test_acc = test_acc 128 | torch.save(model.state_dict(), 'checkpoints/%s/models/model.t7' % args.exp_name) 129 | 130 | scheduler.step() 131 | 132 | 133 | def test(args, io): 134 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), 135 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 136 | 137 | device = torch.device("cuda" if args.cuda else "cpu") 138 | 139 | model = models[args.model]().to(device) 140 | model = nn.DataParallel(model) 141 | 142 | model.load_state_dict(torch.load(args.model_path)) 143 | model = model.eval() 144 | test_true = [] 145 | test_pred = [] 146 | 147 | for data, label in test_loader: 148 | data, label = data.to(device), label.to(device).squeeze() 149 | data = data.permute(0, 2, 1) 150 | logits = model(data) 151 | preds = logits.max(dim=1)[1] 152 | test_true.append(label.cpu().numpy()) 153 | test_pred.append(preds.detach().cpu().numpy()) 154 | 155 | test_true = np.concatenate(test_true) 156 | test_pred = np.concatenate(test_pred) 157 | test_acc = metrics.accuracy_score(test_true, test_pred) 158 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) 159 | outstr = 'Test :: test acc: %.6f, test avg acc: %.6f' % (test_acc, avg_per_class_acc) 160 | io.cprint(outstr) 161 | 162 | 163 | if __name__ == "__main__": 164 | # Training settings 165 | parser = argparse.ArgumentParser(description='Point Cloud Recognition') 166 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 167 | help='Name of the experiment') 168 | parser.add_argument('--model', type=str, default='pct', choices=['navie_pct', 'spct', 'pct'], 169 | help='which model you want to use') 170 | parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N', 171 | choices=['modelnet40']) 172 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 173 | help='Size of batch)') 174 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 175 | help='Size of batch)') 176 | parser.add_argument('--epochs', type=int, default=250, metavar='N', 177 | help='number of episode to train ') 178 | parser.add_argument('--use_sgd', type=bool, default=True, 179 | help='Use SGD') 180 | parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', 181 | help='learning rate (default: 0.001, 0.1 if using sgd)') 182 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 183 | help='SGD momentum (default: 0.9)') 184 | parser.add_argument('--no_cuda', type=bool, default=False, 185 | help='enables CUDA training') 186 | parser.add_argument('--seed', type=int, default=1, metavar='S', 187 | help='random seed (default: 1)') 188 | parser.add_argument('--eval', type=bool, default=False, 189 | help='evaluate the model') 190 | parser.add_argument('--num_points', type=int, default=1024, 191 | help='num of points to use') 192 | parser.add_argument('--dropout', type=float, default=0.5, 193 | help='dropout rate') 194 | parser.add_argument('--model_path', type=str, default='', metavar='N', 195 | help='Pretrained model path') 196 | args = parser.parse_args() 197 | 198 | _init_(args) 199 | 200 | io = Logger('checkpoints/' + args.exp_name + '/run.log') 201 | io.cprint(str(args)) 202 | 203 | args.cuda = not args.no_cuda and torch.cuda.is_available() 204 | torch.manual_seed(args.seed) 205 | if args.cuda: 206 | io.cprint( 207 | 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 208 | torch.cuda.manual_seed(args.seed) 209 | else: 210 | io.cprint('Using CPU') 211 | 212 | if not args.eval: 213 | train(args, io) 214 | else: 215 | test(args, io) 216 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from module import Embedding, NeighborEmbedding, OA, SA 6 | 7 | 8 | class NaivePCT(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.embedding = Embedding(3, 128) 13 | 14 | self.sa1 = SA(128) 15 | self.sa2 = SA(128) 16 | self.sa3 = SA(128) 17 | self.sa4 = SA(128) 18 | 19 | self.linear = nn.Sequential( 20 | nn.Conv1d(512, 1024, kernel_size=1, bias=False), 21 | nn.BatchNorm1d(1024), 22 | nn.LeakyReLU(negative_slope=0.2) 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.embedding(x) 27 | 28 | x1 = self.sa1(x) 29 | x2 = self.sa2(x1) 30 | x3 = self.sa3(x2) 31 | x4 = self.sa4(x3) 32 | x = torch.cat([x1, x2, x3, x4], dim=1) 33 | 34 | x = self.linear(x) 35 | 36 | # x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 37 | x_max = torch.max(x, dim=-1)[0] 38 | x_mean = torch.mean(x, dim=-1) 39 | 40 | return x, x_max, x_mean 41 | 42 | 43 | class SPCT(nn.Module): 44 | def __init__(self): 45 | super().__init__() 46 | 47 | self.embedding = Embedding(3, 128) 48 | 49 | self.sa1 = OA(128) 50 | self.sa2 = OA(128) 51 | self.sa3 = OA(128) 52 | self.sa4 = OA(128) 53 | 54 | self.linear = nn.Sequential( 55 | nn.Conv1d(512, 1024, kernel_size=1, bias=False), 56 | nn.BatchNorm1d(1024), 57 | nn.LeakyReLU(negative_slope=0.2) 58 | ) 59 | 60 | def forward(self, x): 61 | x = self.embedding(x) 62 | 63 | x1 = self.sa1(x) 64 | x2 = self.sa2(x1) 65 | x3 = self.sa3(x2) 66 | x4 = self.sa4(x3) 67 | x = torch.cat([x1, x2, x3, x4], dim=1) 68 | 69 | x = self.linear(x) 70 | 71 | # x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 72 | x_max = torch.max(x, dim=-1)[0] 73 | x_mean = torch.mean(x, dim=-1) 74 | 75 | return x, x_max, x_mean 76 | 77 | 78 | class PCT(nn.Module): 79 | def __init__(self, samples=[512, 256]): 80 | super().__init__() 81 | 82 | self.neighbor_embedding = NeighborEmbedding(samples) 83 | 84 | self.oa1 = OA(256) 85 | self.oa2 = OA(256) 86 | self.oa3 = OA(256) 87 | self.oa4 = OA(256) 88 | 89 | self.linear = nn.Sequential( 90 | nn.Conv1d(1280, 1024, kernel_size=1, bias=False), 91 | nn.BatchNorm1d(1024), 92 | nn.LeakyReLU(negative_slope=0.2) 93 | ) 94 | 95 | def forward(self, x): 96 | x = self.neighbor_embedding(x) 97 | 98 | x1 = self.oa1(x) 99 | x2 = self.oa2(x1) 100 | x3 = self.oa3(x2) 101 | x4 = self.oa4(x3) 102 | 103 | x = torch.cat([x, x1, x2, x3, x4], dim=1) 104 | 105 | x = self.linear(x) 106 | 107 | # x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 108 | x_max = torch.max(x, dim=-1)[0] 109 | x_mean = torch.mean(x, dim=-1) 110 | 111 | return x, x_max, x_mean 112 | 113 | 114 | class Classification(nn.Module): 115 | def __init__(self, num_categories=40): 116 | super().__init__() 117 | 118 | self.linear1 = nn.Linear(1024, 512, bias=False) 119 | self.linear2 = nn.Linear(512, 256) 120 | self.linear3 = nn.Linear(256, num_categories) 121 | 122 | self.bn1 = nn.BatchNorm1d(512) 123 | self.bn2 = nn.BatchNorm1d(256) 124 | 125 | self.dp1 = nn.Dropout(p=0.5) 126 | self.dp2 = nn.Dropout(p=0.5) 127 | 128 | def forward(self, x): 129 | x = F.relu(self.bn1(self.linear1(x))) 130 | x = self.dp1(x) 131 | x = F.relu(self.bn2(self.linear2(x))) 132 | x = self.dp2(x) 133 | x = self.linear3(x) 134 | return x 135 | 136 | 137 | class Segmentation(nn.Module): 138 | def __init__(self, part_num): 139 | super().__init__() 140 | 141 | self.part_num = part_num 142 | 143 | self.label_conv = nn.Sequential( 144 | nn.Conv1d(16, 64, kernel_size=1, bias=False), 145 | nn.BatchNorm1d(64), 146 | nn.LeakyReLU(negative_slope=0.2) 147 | ) 148 | 149 | self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1) 150 | self.convs2 = nn.Conv1d(512, 256, 1) 151 | self.convs3 = nn.Conv1d(256, self.part_num, 1) 152 | 153 | self.bns1 = nn.BatchNorm1d(512) 154 | self.bns2 = nn.BatchNorm1d(256) 155 | 156 | self.dp1 = nn.Dropout(0.5) 157 | 158 | def forward(self, x, x_max, x_mean, cls_label): 159 | batch_size, _, N = x.size() 160 | 161 | x_max_feature = x_max.unsqueeze(-1).repeat(1, 1, N) 162 | x_mean_feature = x_mean.unsqueeze(-1).repeat(1, 1, N) 163 | 164 | cls_label_one_hot = cls_label.view(batch_size, 16, 1) 165 | cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N) 166 | 167 | x = torch.cat([x, x_max_feature, x_mean_feature, cls_label_feature], dim=1) # 1024 * 3 + 64 168 | 169 | x = F.relu(self.bns1(self.convs1(x))) 170 | x = self.dp1(x) 171 | x = F.relu(self.bns2(self.convs2(x))) 172 | x = self.convs3(x) 173 | 174 | return x 175 | 176 | 177 | class NormalEstimation(nn.Module): 178 | def __init__(self): 179 | super().__init__() 180 | 181 | self.convs1 = nn.Conv1d(1024 * 3, 512, 1) 182 | self.convs2 = nn.Conv1d(512, 256, 1) 183 | self.convs3 = nn.Conv1d(256, 3, 1) 184 | 185 | self.bns1 = nn.BatchNorm1d(512) 186 | self.bns2 = nn.BatchNorm1d(256) 187 | 188 | self.dp1 = nn.Dropout(0.5) 189 | 190 | def forward(self, x, x_max, x_mean): 191 | N = x.size(2) 192 | 193 | x_max_feature = x_max.unsqueeze(-1).repeat(1, 1, N) 194 | x_mean_feature = x_mean.unsqueeze(-1).repeat(1, 1, N) 195 | 196 | x = torch.cat([x_max_feature, x_mean_feature, x], dim=1) 197 | 198 | x = F.relu(self.bns1(self.convs1(x))) 199 | x = self.dp1(x) 200 | x = F.relu(self.bns2(self.convs2(x))) 201 | x = self.convs3(x) 202 | 203 | return x 204 | 205 | 206 | """ 207 | Classification networks. 208 | """ 209 | 210 | class NaivePCTCls(nn.Module): 211 | def __init__(self, num_categories=40): 212 | super().__init__() 213 | 214 | self.encoder = NaivePCT() 215 | self.cls = Classification(num_categories) 216 | 217 | def forward(self, x): 218 | _, x, _ = self.encoder(x) 219 | x = self.cls(x) 220 | return x 221 | 222 | 223 | class SPCTCls(nn.Module): 224 | def __init__(self, num_categories=40): 225 | super().__init__() 226 | 227 | self.encoder = SPCT() 228 | self.cls = Classification(num_categories) 229 | 230 | def forward(self, x): 231 | _, x, _ = self.encoder(x) 232 | x = self.cls(x) 233 | return x 234 | 235 | 236 | class PCTCls(nn.Module): 237 | def __init__(self, num_categories=40): 238 | super().__init__() 239 | 240 | self.encoder = PCT() 241 | self.cls = Classification(num_categories) 242 | 243 | def forward(self, x): 244 | _, x, _ = self.encoder(x) 245 | x = self.cls(x) 246 | return x 247 | 248 | 249 | """ 250 | Part Segmentation Networks. 251 | """ 252 | 253 | class NaivePCTSeg(nn.Module): 254 | def __init__(self, part_num=50): 255 | super().__init__() 256 | 257 | self.encoder = NaivePCT() 258 | self.seg = Segmentation(part_num) 259 | 260 | def forward(self, x, cls_label): 261 | x, x_max, x_mean = self.encoder(x) 262 | x = self.seg(x, x_max, x_mean, cls_label) 263 | return x 264 | 265 | 266 | class SPCTSeg(nn.Module): 267 | def __init__(self, part_num=50): 268 | super().__init__() 269 | 270 | self.encoder = SPCT() 271 | self.seg = Segmentation(part_num) 272 | 273 | def forward(self, x, cls_label): 274 | x, x_max, x_mean = self.encoder(x) 275 | x = self.seg(x, x_max, x_mean, cls_label) 276 | return x 277 | 278 | 279 | class PCTSeg(nn.Module): 280 | def __init__(self, part_num=50): 281 | super().__init__() 282 | 283 | self.encoder = PCT(samples=[1024, 1024]) 284 | self.seg = Segmentation(part_num) 285 | 286 | def forward(self, x, cls_label): 287 | x, x_max, x_mean = self.encoder(x) 288 | x = self.seg(x, x_max, x_mean, cls_label) 289 | return x 290 | 291 | 292 | """ 293 | Normal Estimation networks. 294 | """ 295 | 296 | class NaivePCTNormalEstimation(nn.Module): 297 | def __init__(self): 298 | super().__init__() 299 | 300 | self.encoder = NaivePCT() 301 | self.ne = NormalEstimation() 302 | 303 | def forward(self, x): 304 | x, x_max, x_mean = self.encoder(x) 305 | x = self.ne(x, x_max, x_mean) 306 | return x 307 | 308 | 309 | class SPCTNormalEstimation(nn.Module): 310 | def __init__(self): 311 | super().__init__() 312 | 313 | self.encoder = SPCT() 314 | self.ne = NormalEstimation() 315 | 316 | def forward(self, x): 317 | x, x_max, x_mean = self.encoder(x) 318 | x = self.ne(x, x_max, x_mean) 319 | return x 320 | 321 | 322 | class PCTNormalEstimation(nn.Module): 323 | def __init__(self): 324 | super().__init__() 325 | 326 | self.encoder = PCT(samples=[1024, 1024]) 327 | self.ne = NormalEstimation() 328 | 329 | def forward(self, x): 330 | x, x_max, x_mean = self.encoder(x) 331 | x = self.ne(x, x_max, x_mean) 332 | return x 333 | 334 | 335 | if __name__ == '__main__': 336 | pc = torch.rand(4, 3, 1024).to('cuda') 337 | cls_label = torch.rand(4, 16).to('cuda') 338 | 339 | # testing for cls networks 340 | naive_pct_cls = NaivePCTCls().to('cuda') 341 | spct_cls = SPCTCls().to('cuda') 342 | pct_cls = PCTCls().to('cuda') 343 | 344 | print(naive_pct_cls(pc).size()) 345 | print(spct_cls(pc).size()) 346 | print(pct_cls(pc).size()) 347 | 348 | # testing for segmentation networks 349 | naive_pct_seg = NaivePCTSeg().to('cuda') 350 | spct_seg = SPCTSeg().to('cuda') 351 | pct_seg = PCTSeg().to('cuda') 352 | 353 | print(naive_pct_seg(pc, cls_label).size()) 354 | print(spct_seg(pc, cls_label).size()) 355 | print(pct_seg(pc, cls_label).size()) 356 | 357 | # testing for normal estimation networks 358 | naive_pct_ne = NaivePCTNormalEstimation().to('cuda') 359 | spct_ne = SPCTNormalEstimation().to('cuda') 360 | pct_ne = PCTNormalEstimation().to('cuda') 361 | 362 | print(naive_pct_ne(pc).size()) 363 | print(spct_ne(pc).size()) 364 | print(pct_ne(pc).size()) 365 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import warnings 4 | from torch.autograd import Function 5 | from typing import * 6 | 7 | try: 8 | import pointnet2_ops._ext as _ext 9 | except ImportError: 10 | from torch.utils.cpp_extension import load 11 | import glob 12 | import os.path as osp 13 | import os 14 | 15 | warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.") 16 | 17 | _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src") 18 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 19 | osp.join(_ext_src_root, "src", "*.cu") 20 | ) 21 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 22 | 23 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 24 | _ext = load( 25 | "_ext", 26 | sources=_ext_sources, 27 | extra_include_paths=[osp.join(_ext_src_root, "include")], 28 | extra_cflags=["-O3"], 29 | extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"], 30 | with_cuda=True, 31 | ) 32 | 33 | 34 | class FurthestPointSampling(Function): 35 | @staticmethod 36 | def forward(ctx, xyz, npoint): 37 | # type: (Any, torch.Tensor, int) -> torch.Tensor 38 | r""" 39 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 40 | minimum distance 41 | 42 | Parameters 43 | ---------- 44 | xyz : torch.Tensor 45 | (B, N, 3) tensor where N > npoint 46 | npoint : int32 47 | number of features in the sampled set 48 | 49 | Returns 50 | ------- 51 | torch.Tensor 52 | (B, npoint) tensor containing the set 53 | """ 54 | out = _ext.furthest_point_sampling(xyz, npoint) 55 | 56 | ctx.mark_non_differentiable(out) 57 | 58 | return out 59 | 60 | @staticmethod 61 | def backward(ctx, grad_out): 62 | return () 63 | 64 | 65 | furthest_point_sample = FurthestPointSampling.apply 66 | 67 | 68 | class GatherOperation(Function): 69 | @staticmethod 70 | def forward(ctx, features, idx): 71 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 72 | r""" 73 | 74 | Parameters 75 | ---------- 76 | features : torch.Tensor 77 | (B, C, N) tensor 78 | 79 | idx : torch.Tensor 80 | (B, npoint) tensor of the features to gather 81 | 82 | Returns 83 | ------- 84 | torch.Tensor 85 | (B, C, npoint) tensor 86 | """ 87 | 88 | ctx.save_for_backward(idx, features) 89 | 90 | return _ext.gather_points(features, idx) 91 | 92 | @staticmethod 93 | def backward(ctx, grad_out): 94 | idx, features = ctx.saved_tensors 95 | N = features.size(2) 96 | 97 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 98 | return grad_features, None 99 | 100 | 101 | gather_operation = GatherOperation.apply 102 | 103 | 104 | class ThreeNN(Function): 105 | @staticmethod 106 | def forward(ctx, unknown, known): 107 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 108 | r""" 109 | Find the three nearest neighbors of unknown in known 110 | Parameters 111 | ---------- 112 | unknown : torch.Tensor 113 | (B, n, 3) tensor of known features 114 | known : torch.Tensor 115 | (B, m, 3) tensor of unknown features 116 | 117 | Returns 118 | ------- 119 | dist : torch.Tensor 120 | (B, n, 3) l2 distance to the three nearest neighbors 121 | idx : torch.Tensor 122 | (B, n, 3) index of 3 nearest neighbors 123 | """ 124 | dist2, idx = _ext.three_nn(unknown, known) 125 | dist = torch.sqrt(dist2) 126 | 127 | ctx.mark_non_differentiable(dist, idx) 128 | 129 | return dist, idx 130 | 131 | @staticmethod 132 | def backward(ctx, grad_dist, grad_idx): 133 | return () 134 | 135 | 136 | three_nn = ThreeNN.apply 137 | 138 | 139 | class ThreeInterpolate(Function): 140 | @staticmethod 141 | def forward(ctx, features, idx, weight): 142 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 143 | r""" 144 | Performs weight linear interpolation on 3 features 145 | Parameters 146 | ---------- 147 | features : torch.Tensor 148 | (B, c, m) Features descriptors to be interpolated from 149 | idx : torch.Tensor 150 | (B, n, 3) three nearest neighbors of the target features in features 151 | weight : torch.Tensor 152 | (B, n, 3) weights 153 | 154 | Returns 155 | ------- 156 | torch.Tensor 157 | (B, c, n) tensor of the interpolated features 158 | """ 159 | ctx.save_for_backward(idx, weight, features) 160 | 161 | return _ext.three_interpolate(features, idx, weight) 162 | 163 | @staticmethod 164 | def backward(ctx, grad_out): 165 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 166 | r""" 167 | Parameters 168 | ---------- 169 | grad_out : torch.Tensor 170 | (B, c, n) tensor with gradients of ouputs 171 | 172 | Returns 173 | ------- 174 | grad_features : torch.Tensor 175 | (B, c, m) tensor with gradients of features 176 | 177 | None 178 | 179 | None 180 | """ 181 | idx, weight, features = ctx.saved_tensors 182 | m = features.size(2) 183 | 184 | grad_features = _ext.three_interpolate_grad( 185 | grad_out.contiguous(), idx, weight, m 186 | ) 187 | 188 | return grad_features, torch.zeros_like(idx), torch.zeros_like(weight) 189 | 190 | 191 | three_interpolate = ThreeInterpolate.apply 192 | 193 | 194 | class GroupingOperation(Function): 195 | @staticmethod 196 | def forward(ctx, features, idx): 197 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 198 | r""" 199 | 200 | Parameters 201 | ---------- 202 | features : torch.Tensor 203 | (B, C, N) tensor of features to group 204 | idx : torch.Tensor 205 | (B, npoint, nsample) tensor containing the indicies of features to group with 206 | 207 | Returns 208 | ------- 209 | torch.Tensor 210 | (B, C, npoint, nsample) tensor 211 | """ 212 | ctx.save_for_backward(idx, features) 213 | 214 | return _ext.group_points(features, idx) 215 | 216 | @staticmethod 217 | def backward(ctx, grad_out): 218 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 219 | r""" 220 | 221 | Parameters 222 | ---------- 223 | grad_out : torch.Tensor 224 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 225 | 226 | Returns 227 | ------- 228 | torch.Tensor 229 | (B, C, N) gradient of the features 230 | None 231 | """ 232 | idx, features = ctx.saved_tensors 233 | N = features.size(2) 234 | 235 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 236 | 237 | return grad_features, torch.zeros_like(idx) 238 | 239 | 240 | grouping_operation = GroupingOperation.apply 241 | 242 | 243 | class BallQuery(Function): 244 | @staticmethod 245 | def forward(ctx, radius, nsample, xyz, new_xyz): 246 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 247 | r""" 248 | 249 | Parameters 250 | ---------- 251 | radius : float 252 | radius of the balls 253 | nsample : int 254 | maximum number of features in the balls 255 | xyz : torch.Tensor 256 | (B, N, 3) xyz coordinates of the features 257 | new_xyz : torch.Tensor 258 | (B, npoint, 3) centers of the ball query 259 | 260 | Returns 261 | ------- 262 | torch.Tensor 263 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 264 | """ 265 | output = _ext.ball_query(new_xyz, xyz, radius, nsample) 266 | 267 | ctx.mark_non_differentiable(output) 268 | 269 | return output 270 | 271 | @staticmethod 272 | def backward(ctx, grad_out): 273 | return () 274 | 275 | 276 | ball_query = BallQuery.apply 277 | 278 | 279 | class QueryAndGroup(nn.Module): 280 | r""" 281 | Groups with a ball query of radius 282 | 283 | Parameters 284 | --------- 285 | radius : float32 286 | Radius of ball 287 | nsample : int32 288 | Maximum number of features to gather in the ball 289 | """ 290 | 291 | def __init__(self, radius, nsample, use_xyz=True): 292 | # type: (QueryAndGroup, float, int, bool) -> None 293 | super(QueryAndGroup, self).__init__() 294 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 295 | 296 | def forward(self, xyz, new_xyz, features=None): 297 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 298 | r""" 299 | Parameters 300 | ---------- 301 | xyz : torch.Tensor 302 | xyz coordinates of the features (B, N, 3) 303 | new_xyz : torch.Tensor 304 | centriods (B, npoint, 3) 305 | features : torch.Tensor 306 | Descriptors of the features (B, C, N) 307 | 308 | Returns 309 | ------- 310 | new_features : torch.Tensor 311 | (B, 3 + C, npoint, nsample) tensor 312 | """ 313 | 314 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 315 | xyz_trans = xyz.transpose(1, 2).contiguous() 316 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 317 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 318 | 319 | if features is not None: 320 | grouped_features = grouping_operation(features, idx) 321 | if self.use_xyz: 322 | new_features = torch.cat( 323 | [grouped_xyz, grouped_features], dim=1 324 | ) # (B, C + 3, npoint, nsample) 325 | else: 326 | new_features = grouped_features 327 | else: 328 | assert ( 329 | self.use_xyz 330 | ), "Cannot have not features and not use xyz as a feature!" 331 | new_features = grouped_xyz 332 | 333 | return new_features 334 | 335 | 336 | class GroupAll(nn.Module): 337 | r""" 338 | Groups all features 339 | 340 | Parameters 341 | --------- 342 | """ 343 | 344 | def __init__(self, use_xyz=True): 345 | # type: (GroupAll, bool) -> None 346 | super(GroupAll, self).__init__() 347 | self.use_xyz = use_xyz 348 | 349 | def forward(self, xyz, new_xyz, features=None): 350 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 351 | r""" 352 | Parameters 353 | ---------- 354 | xyz : torch.Tensor 355 | xyz coordinates of the features (B, N, 3) 356 | new_xyz : torch.Tensor 357 | Ignored 358 | features : torch.Tensor 359 | Descriptors of the features (B, C, N) 360 | 361 | Returns 362 | ------- 363 | new_features : torch.Tensor 364 | (B, C + 3, 1, N) tensor 365 | """ 366 | 367 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 368 | if features is not None: 369 | grouped_features = features.unsqueeze(2) 370 | if self.use_xyz: 371 | new_features = torch.cat( 372 | [grouped_xyz, grouped_features], dim=1 373 | ) # (B, 3 + C, 1, N) 374 | else: 375 | new_features = grouped_features 376 | else: 377 | new_features = grouped_xyz 378 | 379 | return new_features 380 | --------------------------------------------------------------------------------