├── code ├── utils │ ├── __init__.py │ ├── knn_patch.py │ ├── helper_function.py │ ├── plot_training_validation_loss.py │ ├── fp_sampling.py │ ├── create_data.py │ ├── write_hdf5.py │ └── data_provider.py ├── __init__.py ├── dataset.py ├── evaluate.py ├── train.py ├── loss.py └── model.py ├── pointnet2 ├── utils │ ├── .gitignore │ ├── __init__.py │ ├── linalg_utils.py │ ├── pointnet2_modules.py │ ├── pointnet2_utils.py │ └── pytorch_utils.py ├── _ext-src │ ├── include │ │ ├── ball_query.h │ │ ├── group_points.h │ │ ├── sampling.h │ │ ├── interpolate.h │ │ ├── utils.h │ │ └── cuda_utils.h │ └── src │ │ ├── bindings.cpp │ │ ├── ball_query.cpp │ │ ├── ball_query_gpu.cu │ │ ├── group_points.cpp │ │ ├── group_points_gpu.cu │ │ ├── sampling.cpp │ │ ├── interpolate.cpp │ │ ├── interpolate_gpu.cu │ │ └── sampling_gpu.cu └── __init__.py ├── requirements.txt ├── pyTorchChamferDistance ├── chamfer_distance │ ├── __init__.py │ ├── chamfer_distance.py │ ├── chamfer_distance.cu │ └── chamfer_distance.cpp ├── README.md └── LICENSE.md ├── sampling ├── setup.py ├── cuda_utils.h ├── sampling.cpp └── sampling_cuda.cu ├── LICENSE ├── setup.py ├── .gitignore └── README.md /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pointnet2/utils/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | _ext 3 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | numpy 3 | pprint 4 | enum34 5 | future 6 | tqdm 7 | scipy 8 | sklearn 9 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance, ChamferDistance_PPU 2 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | from . import pointnet2_utils 9 | from . import pointnet2_modules 10 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | 9 | __version__ = "2.1.1" 10 | 11 | try: 12 | __POINTNET2_SETUP__ 13 | except NameError: 14 | __POINTNET2_SETUP__ = False 15 | 16 | if not __POINTNET2_SETUP__: 17 | from pointnet2 import utils 18 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /sampling/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='sampling', 6 | ext_modules=[ 7 | CUDAExtension('sampling', [ 8 | 'sampling.cpp', 9 | 'sampling_cuda.cu',], 10 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}, 11 | include_dirs=["."]) 12 | ], 13 | 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) 17 | -------------------------------------------------------------------------------- /code/utils/knn_patch.py: -------------------------------------------------------------------------------- 1 | from sklearn.neighbors import NearestNeighbors 2 | import numpy as np 3 | 4 | def extract_knn_patch(queries, pc, k): 5 | """ 6 | queries [M, C] 7 | pc [P, C] 8 | """ 9 | #print(queries.shape) 10 | #print(pc.shape) 11 | knn_search = NearestNeighbors(n_neighbors=k, algorithm='auto') 12 | knn_search.fit(pc[:, 0:3]) 13 | knn_idx = knn_search.kneighbors(queries, return_distance=False) 14 | k_patches = np.take(pc, knn_idx, axis=0) # M, K, C 15 | return k_patches 16 | -------------------------------------------------------------------------------- /code/utils/helper_function.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def get_best_epoch(f_pointer): 5 | f_pointer.seek(0, 0) # begining of file 6 | read_states = f_pointer.readlines() 7 | read_min_loss = min(read_states, key=lambda k: float(k.split(", ")[1].split(" ")[1][0:-1])) 8 | best_epoch = int(read_min_loss.split(", ")[0].split(" ")[1]) 9 | return best_epoch, read_min_loss 10 | 11 | 12 | 13 | def get_current_state(f_pointer): 14 | f_pointer.seek(0, 0) # begining of file 15 | read_states = f_pointer.readlines() 16 | last_epoch = int(read_states[-1].split(", ")[0].split(" ")[1]) if len(read_states) else -1 17 | return last_epoch 18 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /code/utils/plot_training_validation_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | state = open("../results/step_1/state_1.txt", "a+") 8 | 9 | l = state.readlines() 10 | 11 | state.seek(0, 0) 12 | l = state.readlines() 13 | 14 | t_loss = [float(i.split(", ")[1].split(" ")[-1]) for i in l] 15 | 16 | v_loss = [float(i.split(", ")[2].split(" ")[-1][0:-1]) for i in l] 17 | 18 | plt.plot(t_loss, 'r--', label='training loss') 19 | plt.plot(v_loss, '-b', label='validation loss') 20 | plt.yticks(np.arange(min(t_loss), max(t_loss)+1, 1)) 21 | plt.xlabel("n iteration") 22 | plt.legend(loc='upper left') 23 | plt.title("T Loss vs V Loss") 24 | 25 | plt.show() -------------------------------------------------------------------------------- /pyTorchChamferDistance/README.md: -------------------------------------------------------------------------------- 1 | # Chamfer Distance for pyTorch 2 | 3 | This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension. 4 | 5 | As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run. 6 | 7 | ### Usage 8 | ```python 9 | from chamfer_distance import ChamferDistance 10 | chamfer_dist = ChamferDistance() 11 | 12 | #... 13 | # points and points_reconstructed are n_points x 3 matrices 14 | 15 | dist1, dist2 = chamfer_dist(points, points_reconstructed) 16 | loss = (torch.mean(dist1)) + (torch.mean(dist2)) 17 | 18 | 19 | #... 20 | ``` 21 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_CHECK(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_CHECK(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /code/utils/fp_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sampling 3 | 4 | class FurthestPointSampling(torch.autograd.Function): 5 | 6 | @staticmethod 7 | def forward(ctx, xyz, npoint): 8 | r""" 9 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 10 | minimum distance 11 | Parameters 12 | ---------- 13 | xyz : torch.Tensor 14 | (B, N, 3) tensor where N > npoint 15 | npoint : int32 16 | number of features in the sampled set 17 | Returns 18 | ------- 19 | torch.LongTensor 20 | (B, npoint) tensor containing the indices 21 | """ 22 | B, N, _ = xyz.size() 23 | 24 | idx = torch.empty([B, npoint], dtype=torch.int32, device=xyz.device) 25 | temp = torch.full([B, N], 1e10, dtype=torch.float32, device=xyz.device) 26 | 27 | sampling.furthest_sampling( 28 | B, N, npoint, xyz, temp, idx 29 | ) 30 | ctx.mark_non_differentiable(idx) 31 | return idx 32 | 33 | 34 | furthest_point_sample = FurthestPointSampling.apply 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Rajat Sharma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.type().is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.type().is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data(), 26 | xyz.data(), idx.data()); 27 | } else { 28 | AT_CHECK(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | from code.utils import data_provider 8 | import h5py 9 | 10 | class PointCLoudDataset(Dataset): 11 | def __init__(self, hdf5_file_path, transform=True): 12 | with h5py.File(hdf5_file_path, "r") as h5: 13 | print(h5["poisson_4096"].shape) 14 | gt = h5["poisson_4096"][:, :, :] 15 | 16 | if transform: 17 | centroid = np.mean(gt[:,:,0:3], axis=1, keepdims=True) 18 | gt[:,:,0:3] = gt[:,:,0:3] - centroid 19 | furthest_distance = np.amax(np.sqrt(np.sum(gt[:,:,0:3] ** 2, axis=-1)),axis=1,keepdims=True) 20 | gt[:, :, 0:3] = gt[:,:,0:3] / np.expand_dims(furthest_distance,axis=-1) 21 | 22 | self.gt_set = torch.from_numpy(gt) 23 | print(self.gt_set.shape) 24 | 25 | def __len__(self): 26 | return len(self.gt_set) 27 | 28 | 29 | def __getitem__(self, idx): 30 | gt = self.gt_set[idx] 31 | random_1024 = np.random.choice(4096, 1024) #random sampling 32 | ip = gt[random_1024] 33 | 34 | return (ip, gt) 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, absolute_import, with_statement, print_function 2 | from setuptools import setup, find_packages 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | import glob 5 | 6 | try: 7 | import builtins 8 | except: 9 | import __builtin__ as builtins 10 | 11 | builtins.__POINTNET2_SETUP__ = True 12 | import pointnet2 13 | 14 | _ext_src_root = "pointnet2/_ext-src" 15 | _ext_sources = glob.glob("{}/src/*.cpp".format(_ext_src_root)) + glob.glob( 16 | "{}/src/*.cu".format(_ext_src_root) 17 | ) 18 | _ext_headers = glob.glob("{}/include/*".format(_ext_src_root)) 19 | 20 | requirements = ["h5py", "pprint", "enum34", "future"] 21 | 22 | setup( 23 | name="pointnet2", 24 | version=pointnet2.__version__, 25 | author="Erik Wijmans", 26 | packages=find_packages(), 27 | install_requires=requirements, 28 | ext_modules=[ 29 | CUDAExtension( 30 | name="pointnet2._ext", 31 | sources=_ext_sources, 32 | extra_compile_args={ 33 | "cxx": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 34 | "nvcc": ["-O2", "-I{}".format("{}/include".format(_ext_src_root))], 35 | }, 36 | ) 37 | ], 38 | cmdclass={"build_ext": BuildExtension}, 39 | ) 40 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /sampling/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | 7 | #define TOTAL_THREADS 512 8 | 9 | inline int opt_n_threads(int work_size) 10 | { 11 | // round work_size to power of 2 betwwen 512 and 1 12 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 13 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 14 | } 15 | 16 | inline dim3 opt_block_config(int x, int y) 17 | { 18 | const int x_threads = opt_n_threads(x); 19 | const int y_threads = 20 | std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 21 | dim3 block_config(x_threads, y_threads, 1); 22 | 23 | return block_config; 24 | } 25 | 26 | #define CUDA_CHECK_ERRORS() \ 27 | do \ 28 | { \ 29 | cudaError_t err = cudaGetLastError(); \ 30 | if (cudaSuccess != err) \ 31 | { \ 32 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 33 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 34 | __FILE__); \ 35 | exit(-1); \ 36 | } \ 37 | } while (0) 38 | #endif 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | # dataset file 108 | *.h5 109 | 110 | # checkoints 111 | checkpoint/ 112 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.type().is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.type().is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), points.data(), 29 | idx.data(), output.data()); 30 | } else { 31 | AT_CHECK(false, "CPU not supported"); 32 | } 33 | 34 | return output; 35 | } 36 | 37 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 38 | CHECK_CONTIGUOUS(grad_out); 39 | CHECK_CONTIGUOUS(idx); 40 | CHECK_IS_FLOAT(grad_out); 41 | CHECK_IS_INT(idx); 42 | 43 | if (grad_out.type().is_cuda()) { 44 | CHECK_CUDA(idx); 45 | } 46 | 47 | at::Tensor output = 48 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 49 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 50 | 51 | if (grad_out.type().is_cuda()) { 52 | group_points_grad_kernel_wrapper( 53 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 54 | grad_out.data(), idx.data(), output.data()); 55 | } else { 56 | AT_CHECK(false, "CPU not supported"); 57 | } 58 | 59 | return output; 60 | } 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Point Cloud Upsampling and Normal Estimation using Deep Learning for Robust Surface Reconstruction 2 | 3 | To run the project add root folder of the project to python path.```export PYTHONPATH="ROOTPATH_OF_PROJECT:$PYTHONPATH"``` e.g., ```export PYTHONPATH="/home/user/point-normals-upsampling:$PYTHONPATH"``` 4 | 5 | ## Setup 6 | - Use anaconda for python3.7. Install ```requirements.txt```. Install torch and cuda toolkit ```conda install pytorch torchvision cudatoolkit=10.1 -c pytorch``` 7 | - Build pointnet++ module run ```python setup.py build_ext --inplace``` in root folder of project 8 | - Build sampling module run ```python setup.py install``` in sampling folder of project 9 | - Add absolute path of chamfer_distace.cpp and chamfer_distance.cu in chamfer_distance.py 10 | 11 | ## Training 12 | 13 | - For this repo we used [PU-NET](https://raw.githubusercontent.com/yulequan/PU-Net) dataset for training. Download the hdf5 format patches dataset from [GoogleDrive](https://drive.google.com/file/d/1wMtNGvliK_pUTogfzMyrz57iDb_jSQR8/view?usp=sharing) 14 | - For training and evalutation run all commands inside code folder. 15 | - Training: ```python train.py --num_points 1024 --checkpoint_path .. --batch_size 20 --epochs 400 --h5_data_file dataset_path``` e.g., ```python train.py --num_points 1024 --checkpoint_path .. --batch_size 20 --epochs 400 --h5_data_file ../data.h5``` 16 | - Evaluation: ```python evaluate.py --test_file filename(.xyz) --num_points num (default=1024) --patch_num_ratio num (default=4) --trained_model checkpoint_path``` e.g., ```python evaluate.py --test_file ../test.xyz --num_points 1024 --patch_num_ratio 4 --trained_model ../checkpoint``` 17 | - All the results will be saved in results folder in root directory 18 | 19 | ## Acknowledgement 20 | - **PointNet++ PyTorch Implementation**: [erikwijmans/Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch) 21 | - **Official PyTorch**: [charlesq34/pointnet2](https://github.com/charlesq34/pointnet2) 22 | - **PyTorch Chamfer Distance**: [chrdiller/pyTorchChamferDistance](https://github.com/chrdiller/pyTorchChamferDistance) 23 | - **Patch-base progressive 3D Point Set Upsampling**: [yifita/3PU_pytorch](https://github.com/yifita/3PU_pytorch) 24 | -------------------------------------------------------------------------------- /pointnet2/utils/linalg_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | from enum import Enum 10 | import numpy as np 11 | 12 | PDist2Order = Enum("PDist2Order", "d_first d_second") 13 | 14 | 15 | def pdist2(X, Z=None, order=PDist2Order.d_second): 16 | # type: (torch.Tensor, torch.Tensor, PDist2Order) -> torch.Tensor 17 | r""" Calculates the pairwise distance between X and Z 18 | 19 | D[b, i, j] = l2 distance X[b, i] and Z[b, j] 20 | 21 | Parameters 22 | --------- 23 | X : torch.Tensor 24 | X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d 25 | Z: torch.Tensor 26 | Z is a (B, M, d) tensor. If Z is None, then Z = X 27 | 28 | Returns 29 | ------- 30 | torch.Tensor 31 | Distance matrix is size (B, N, M) 32 | """ 33 | 34 | if order == PDist2Order.d_second: 35 | if X.dim() == 2: 36 | X = X.unsqueeze(0) 37 | if Z is None: 38 | Z = X 39 | G = np.matmul(X, Z.transpose(-2, -1)) 40 | S = (X * X).sum(-1, keepdim=True) 41 | R = S.transpose(-2, -1) 42 | else: 43 | if Z.dim() == 2: 44 | Z = Z.unsqueeze(0) 45 | G = np.matmul(X, Z.transpose(-2, -1)) 46 | S = (X * X).sum(-1, keepdim=True) 47 | R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1) 48 | else: 49 | if X.dim() == 2: 50 | X = X.unsqueeze(0) 51 | if Z is None: 52 | Z = X 53 | G = np.matmul(X.transpose(-2, -1), Z) 54 | R = (X * X).sum(-2, keepdim=True) 55 | S = R.transpose(-2, -1) 56 | else: 57 | if Z.dim() == 2: 58 | Z = Z.unsqueeze(0) 59 | G = np.matmul(X.transpose(-2, -1), Z) 60 | S = (X * X).sum(-2, keepdim=True).transpose(-2, -1) 61 | R = (Z * Z).sum(-2, keepdim=True) 62 | 63 | return torch.abs(R + S - 2 * G).squeeze(0) 64 | 65 | 66 | def pdist2_slow(X, Z=None): 67 | if Z is None: 68 | Z = X 69 | D = torch.zeros(X.size(0), X.size(2), Z.size(2)) 70 | 71 | for b in range(D.size(0)): 72 | for i in range(D.size(1)): 73 | for j in range(D.size(2)): 74 | D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j]) 75 | return D 76 | 77 | 78 | if __name__ == "__main__": 79 | X = torch.randn(2, 3, 5) 80 | Z = torch.randn(2, 3, 3) 81 | 82 | print(pdist2(X, order=PDist2Order.d_first)) 83 | print(pdist2_slow(X)) 84 | print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X))) 85 | -------------------------------------------------------------------------------- /code/utils/create_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from subprocess import call 4 | import os 5 | import shutil 6 | 7 | 8 | ModelNet40_Path = "/home/rajat/Desktop/research-project/work/python2/ModelNet40" 9 | 10 | 11 | def create_data(datapath, dir_type): 12 | 13 | 14 | if os.path.isdir("../preprocessed_{0}".format(dir_type)): shutil.rmtree("../preprocessed_{0}".format(dir_type)) 15 | os.mkdir("../preprocessed_{0}".format(dir_type)) 16 | 17 | count_off = 0 18 | for _dir in os.listdir(datapath): 19 | if os.path.isdir(os.path.join(datapath, _dir)): 20 | off_dir = os.path.join(datapath, _dir, dir_type) 21 | for off_file in os.listdir(off_dir): 22 | if off_file.endswith(".off"): 23 | op_off_file = off_file.split(".")[0] 24 | if not os.path.exists(("../preprocessed_{0}/{1}_2500.pcd").format(dir_type, op_off_file)): 25 | # convert obj to ply 26 | call(["meshlab.meshlabserver", "-i", os.path.join(off_dir, off_file), "-o", ("../preprocessed_{0}/{1}.ply").format(dir_type, op_off_file)]) 27 | 28 | # create ip (25000 sample), target for each step (1, 2, 3) 29 | # step 1 tgt = 10000 30 | # step 2 tgt = 40000 31 | # step 3 tgt = 160000 32 | # step 4 tgt = 640000 33 | call(["pcl_mesh_sampling", "-n_samples", "2500", "-leaf_size", "0.001", "-no_vis_result", ("../preprocessed_{0}/{1}.ply").format(dir_type, op_off_file), ("../preprocessed_{0}/{1}_2500.pcd").format(dir_type, op_off_file)]) 34 | call(["pcl_mesh_sampling", "-n_samples", "10000", "-leaf_size", "0.001", "-no_vis_result", ("../preprocessed_{0}/{1}.ply").format(dir_type, op_off_file), ("../preprocessed_{0}/{1}_10000.pcd").format(dir_type, op_off_file)]) 35 | #call(["pcl_mesh_sampling", "-n_samples", "40000", "-leaf_size", "0.001", "-no_vis_result", ("../preprocessed_{0}/{1}.ply").format(dir_type, op_off_file), ("../preprocessed_{0}/{1}_40000.pcd").format(dir_type, op_off_file)]) 36 | #call(["pcl_mesh_sampling", "-n_samples", "160000", "-leaf_size", "0.001", "-no_vis_result", ("../preprocessed_{0}/{1}.ply").format(dir_type, op_off_file), ("../preprocessed_{0}/{1}_160000.pcd").format(dir_type, op_off_file)]) 37 | #call(["pcl_mesh_sampling", "-n_samples", "640000", "-leaf_size", "0.001", "-no_vis_result", ("../preprocessed_{0}/{1}.ply").format(dir_type, op_off_file), ("../preprocessed_{0}/{1}_640000.pcd").format(dir_type, op_off_file)]) 38 | count_off += 1 39 | 40 | 41 | print(count_off) 42 | 43 | 44 | create_data(ModelNet40_Path, "test") -------------------------------------------------------------------------------- /code/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from model import PointCloudNet 7 | from code.utils import fp_sampling, knn_patch, helper_function 8 | import os 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--num_points', default=1024, type=int, 13 | help='Number of points per patch') 14 | parser.add_argument('--patch_num_ratio', default=4, type=int, 15 | help='Number of points per patch') 16 | parser.add_argument('--trained_model', type=str, 17 | help='Trained model directory') 18 | parser.add_argument('--test_file', type=str, 19 | help='XYZ file for testing') 20 | FLAGS = parser.parse_args() 21 | 22 | 23 | if not os.path.exists("../results"): 24 | os.mkdir("../results") 25 | 26 | NUM_POINTS = FLAGS.num_points 27 | PATCH_NUM_RATIO = FLAGS.patch_num_ratio 28 | TRAINED_MODEL = FLAGS.trained_model 29 | TEST_FILE = FLAGS.test_file 30 | f_name = TEST_FILE.split("/")[-1] 31 | 32 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 33 | 34 | #normaliaze data and extract patches 35 | pc = torch.tensor(np.loadtxt(TEST_FILE)).float().to(device) 36 | num_patches = int(pc.shape[0] / NUM_POINTS * PATCH_NUM_RATIO) 37 | fps_idx = fp_sampling.furthest_point_sample(torch.unsqueeze(pc[:, 0:3], dim=0).contiguous(), num_patches) 38 | patches = torch.tensor(knn_patch.extract_knn_patch(pc[torch.squeeze(fps_idx, dim=0).cpu().numpy(), 0:3].cpu().numpy(), pc.cpu().numpy(), NUM_POINTS)).to(device) 39 | print(patches.shape) 40 | 41 | centroid = torch.mean(patches[:, :, 0:3], dim=1, keepdim=True) 42 | patches[:, :, 0:3] = patches[:, :, 0:3] - centroid 43 | furthest_distance = torch.max(torch.sqrt(torch.sum(patches[:, :, 0:3] ** 2, dim=-1)), dim=1,keepdim=True).values 44 | patches[:, :, 0:3] = patches[:, :, 0:3] / torch.unsqueeze(furthest_distance, dim=-1) 45 | 46 | 47 | # read best epoch from trained model 48 | trained_model_state = open("{0}/state.txt".format(TRAINED_MODEL), "r") 49 | 50 | best_epoch, read_min_loss = helper_function.get_best_epoch(trained_model_state) 51 | print(best_epoch, read_min_loss) 52 | print("Best epoch (i.e., minimum loss) for {0}".format(read_min_loss)) 53 | 54 | #initialize model 55 | net = PointCloudNet(3, 6, True, NUM_POINTS).to(device) 56 | 57 | model = torch.load("{0}/epoch_{1}.pt".format(TRAINED_MODEL, best_epoch)) 58 | net.load_state_dict(model["model_state_dict"]) 59 | net.eval() 60 | 61 | 62 | up_patches = net(patches) 63 | 64 | #denormalize and merge patches 65 | up_patches[:, :, 0:3] = up_patches[:, :, 0:3] * torch.unsqueeze(furthest_distance, dim=-1) + centroid 66 | up_points = torch.cat([p for p in up_patches], dim=0) 67 | fps_idx = fp_sampling.furthest_point_sample(torch.unsqueeze(up_points[:, 0:3], dim=0).contiguous(), pc.shape[0] * 4) 68 | up_points = up_points[torch.squeeze(fps_idx, dim=0).cpu().numpy(), :].detach().cpu().numpy() 69 | np.savetxt("../results/{0}".format(f_name), up_points, fmt='%.6f', delimiter=" ", newline="\n") 70 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.type().is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.type().is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data(), 32 | idx.data(), output.data()); 33 | } else { 34 | AT_CHECK(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.type().is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.type().is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data(), 58 | idx.data(), output.data()); 59 | } else { 60 | AT_CHECK(false, "CPU not supported"); 61 | } 62 | 63 | return output; 64 | } 65 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 66 | CHECK_CONTIGUOUS(points); 67 | CHECK_IS_FLOAT(points); 68 | 69 | at::Tensor output = 70 | torch::zeros({points.size(0), nsamples}, 71 | at::device(points.device()).dtype(at::ScalarType::Int)); 72 | 73 | at::Tensor tmp = 74 | torch::full({points.size(0), points.size(1)}, 1e10, 75 | at::device(points.device()).dtype(at::ScalarType::Float)); 76 | 77 | if (points.type().is_cuda()) { 78 | furthest_point_sampling_kernel_wrapper( 79 | points.size(0), points.size(1), nsamples, points.data(), 80 | tmp.data(), output.data()); 81 | } else { 82 | AT_CHECK(false, "CPU not supported"); 83 | } 84 | 85 | return output; 86 | } 87 | -------------------------------------------------------------------------------- /sampling/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | 6 | at::Tensor furthest_sampling_cuda_forward( 7 | int b, int n, int m, 8 | at::Tensor input, 9 | at::Tensor temp, 10 | at::Tensor idx); 11 | 12 | at::Tensor gather_points_cuda_forward(int b, int c, int n, int npoints, 13 | at::Tensor points, at::Tensor idx, 14 | at::Tensor out); 15 | 16 | at::Tensor gather_points_cuda_backward(int b, int c, int n, int npoints, 17 | at::Tensor grad_out, at::Tensor idx, at::Tensor grad_points); 18 | 19 | // C++ interface 20 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 21 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 22 | #define CHECK_INPUT(x) \ 23 | CHECK_CUDA(x); \ 24 | CHECK_CONTIGUOUS(x) 25 | 26 | at::Tensor furthest_sampling_forward( 27 | int b, int n, int m, 28 | at::Tensor input, 29 | at::Tensor temp, 30 | at::Tensor idx) 31 | { 32 | CHECK_INPUT(input); 33 | CHECK_INPUT(temp); 34 | return furthest_sampling_cuda_forward(b, n, m, input, temp, idx); 35 | } 36 | 37 | at::Tensor gather_points_forward(int b, int c, int n, int npoints, 38 | at::Tensor points_tensor, 39 | at::Tensor idx_tensor, 40 | at::Tensor out_tensor) 41 | { 42 | CHECK_INPUT(points_tensor); 43 | CHECK_INPUT(idx_tensor); 44 | return gather_points_cuda_forward(b, c, n, npoints, points_tensor, idx_tensor, out_tensor); 45 | } 46 | 47 | at::Tensor gather_points_backward(int b, int c, int n, int npoints, 48 | at::Tensor grad_out_tensor, 49 | at::Tensor idx_tensor, 50 | at::Tensor grad_points_tensor) 51 | { 52 | return gather_points_cuda_backward(b, c, n, npoints, grad_out_tensor, idx_tensor, grad_points_tensor); 53 | } 54 | 55 | at::Tensor ball_query_cuda_forward(int b, int n, int m, float radius, 56 | int nsample, at::Tensor new_xyz, 57 | at::Tensor xyz, at::Tensor out_idx); 58 | 59 | at::Tensor ball_query_forward(at::Tensor query, at::Tensor xyz, const float radius, 60 | const int nsample) 61 | { 62 | CHECK_INPUT(query); 63 | CHECK_INPUT(xyz); 64 | 65 | at::Tensor idx = 66 | torch::zeros({query.size(0), query.size(1), nsample}, 67 | at::device(query.device()).dtype(at::ScalarType::Int)); 68 | 69 | if (query.type().is_cuda()) 70 | { 71 | ball_query_cuda_forward(xyz.size(0), xyz.size(1), query.size(1), 72 | radius, nsample, query, 73 | xyz, idx); 74 | } 75 | else 76 | { 77 | AT_CHECK(false, "CPU not supported"); 78 | } 79 | 80 | return idx; 81 | } 82 | 83 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 84 | { 85 | m.def("furthest_sampling", &furthest_sampling_forward, "furthest point sampling (no gradient)"); 86 | m.def("gather_forward", &gather_points_forward, "gather npoints points along an axis"); 87 | m.def("gather_backward", &gather_points_backward, "gather npoints points along an axis backward"); 88 | m.def("ball_query", &ball_query_forward, "ball query"); 89 | } -------------------------------------------------------------------------------- /code/utils/write_hdf5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import h5py 5 | import numpy as np 6 | 7 | 8 | 9 | def read_off_file_points(f_name): 10 | file = open(f_name, "r") 11 | lines = file.readlines() 12 | 13 | lines = np.array([[float(j) for j in i.split(" ")] for i in lines[11:]]) 14 | 15 | #print(lines[0], lines[-1]) 16 | return lines.transpose() 17 | 18 | 19 | def valid_point_cloud(_dir, f_prefix): 20 | 21 | for step in [2500, 10000, 40000, 160000, 640000]: 22 | file = open(_dir + f_prefix + "_{0}.pcd".format(step), "r") 23 | point_count = int(file.readlines()[9].split(" ")[1]) 24 | if point_count != step: 25 | print(f_prefix) 26 | return False 27 | 28 | return True 29 | 30 | #with h5py.File("../data_2500_test.hdf5", "w") as hf_2500, \ 31 | with h5py.File("../data_10000_test.hdf5", "w") as hf_10000:#, \ 32 | #h5py.File("../data_40000_test.hdf5", "w") as hf_40000, \ 33 | #h5py.File("../data_160000_test.hdf5", "w") as hf_160000: 34 | 35 | #s = 2468 # dataset size (test) 36 | s = 9843 # dataset size (train) 37 | 38 | # create data set for each sampled points 39 | #trn_set_2500 = hf_2500.create_dataset("train_set", [s, 3, 2500], dtype="f") 40 | #tgt_set_2500 = hf_2500.create_dataset("tgt_set", [s, 3, 10000], dtype="f") 41 | #trn_set_2500 = hf_2500["train_set"] 42 | #tgt_set_2500 = hf_2500["tgt_set"] 43 | trn_set_10000 = hf_10000.create_dataset("train_set", [s, 3, 10000], dtype="f") 44 | tgt_set_10000 = hf_10000.create_dataset("tgt_set", [s, 3, 40000], dtype="f") 45 | #trn_set_10000 = hf_10000["train_set"] 46 | #tgt_set_10000 = hf_10000["tgt_set"] 47 | #trn_set_40000 = hf_40000.create_dataset("train_set", [s, 3, 40000], dtype="f") 48 | #tgt_set_40000 = hf_40000.create_dataset("tgt_set", [s, 3, 160000], dtype="f") 49 | #trn_set_40000 = hf_40000["train_set"] 50 | #tgt_set_40000 = hf_40000["tgt_set"] 51 | #trn_set_160000 = hf_160000.create_dataset("train_set", [s, 3, 160000], dtype="f") 52 | #tgt_set_160000 = hf_160000.create_dataset("tgt_set", [s, 3, 640000], dtype="f") 53 | 54 | 55 | c = 0 56 | 57 | temp = [] 58 | 59 | for pcd_files in os.listdir("../preprocessed_train"): 60 | f = pcd_files.split("_") 61 | #print(f) 62 | f_prefix = "_".join(f[0:-1]) 63 | if f_prefix not in temp: 64 | if valid_point_cloud("../preprocessed_train/", f_prefix): 65 | #print("Adding dataset for {0}".format(f_prefix)) 66 | #trn_set_2500[c] = read_off_file_points("../preprocessed_test/" + "_".join([f_prefix, "2500.pcd"])) 67 | #tgt_set_2500[c] = read_off_file_points("../preprocessed_test/" + "_".join([f_prefix, "10000.pcd"])) 68 | tgt_set_10000[c] = read_off_file_points("../preprocessed_train/" + "_".join([f_prefix, "40000.pcd"])) 69 | #tgt_set_40000[c] = read_off_file_points("../preprocessed_test/" + "_".join([f_prefix, "160000.pcd"])) 70 | #tgt_set_160000[c] = read_off_file_points("../preprocessed_test/" + "_".join([f_prefix, "640000.pcd"])) 71 | c += 1 72 | temp.append(f_prefix) 73 | 74 | print(c) -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.utils.cpp_extension import load 5 | 6 | cd = load(name="cd", 7 | sources=["/home/user/point-normals-upsampling/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp", 8 | "/home/user/point-normals-upsampling/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu"]) 9 | 10 | class ChamferDistanceFunction(torch.autograd.Function): 11 | @staticmethod 12 | def forward(ctx, xyz1, xyz2): 13 | batchsize, n, _ = xyz1.size() 14 | _, m, _ = xyz2.size() 15 | xyz1 = xyz1.contiguous() 16 | xyz2 = xyz2.contiguous() 17 | dist1 = torch.zeros(batchsize, n) 18 | dist2 = torch.zeros(batchsize, m) 19 | 20 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 21 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 22 | 23 | if not xyz1.is_cuda: 24 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 25 | else: 26 | dist1 = dist1.cuda() 27 | dist2 = dist2.cuda() 28 | idx1 = idx1.cuda() 29 | idx2 = idx2.cuda() 30 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 31 | 32 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 33 | 34 | return dist1, dist2 35 | 36 | @staticmethod 37 | def backward(ctx, graddist1, graddist2): 38 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 39 | 40 | graddist1 = graddist1.contiguous() 41 | graddist2 = graddist2.contiguous() 42 | 43 | gradxyz1 = torch.zeros(xyz1.size()) 44 | gradxyz2 = torch.zeros(xyz2.size()) 45 | 46 | if not graddist1.is_cuda: 47 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 48 | else: 49 | gradxyz1 = gradxyz1.cuda() 50 | gradxyz2 = gradxyz2.cuda() 51 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 52 | 53 | return gradxyz1, gradxyz2 54 | 55 | 56 | class ChamferDistance(torch.nn.Module): 57 | def forward(self, xyz1, xyz2): 58 | return ChamferDistanceFunction.apply(xyz1, xyz2) 59 | 60 | 61 | class ChamferDistance_PPU(torch.nn.Module): 62 | def __init__(self, threshold=None, forward_weight=1.0): 63 | super(ChamferDistance_PPU, self).__init__() 64 | self.__threshold = threshold 65 | self.forward_weight = forward_weight 66 | 67 | def forward(self, xyz1, xyz2): 68 | pred2gt, gt2pred = ChamferDistanceFunction.apply(xyz1, xyz2) 69 | 70 | if self.__threshold is not None: 71 | threshold = self.__threshold 72 | forward_threshold = torch.mean( 73 | pred2gt, dim=1, keepdim=True) * threshold 74 | backward_threshold = torch.mean( 75 | gt2pred, dim=1, keepdim=True) * threshold 76 | # only care about distance within threshold (ignore strong outliers) 77 | pred2gt = torch.where( 78 | pred2gt < forward_threshold, pred2gt, torch.zeros_like(pred2gt)) 79 | gt2pred = torch.where( 80 | gt2pred < backward_threshold, gt2pred, torch.zeros_like(gt2pred)) 81 | 82 | # pred2gt is for each element in gt, the closest distance to this element 83 | pred2gt = torch.mean(pred2gt, dim=1) 84 | gt2pred = torch.mean(gt2pred, dim=1) 85 | return pred2gt * self.forward_weight, gt2pred 86 | #CD_dist = self.forward_weight * pred2gt + gt2pred 87 | # CD_dist_norm = CD_dist/radius 88 | #cd_loss = torch.mean(CD_dist) 89 | #return cd_loss 90 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.type().is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.type().is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data(), knows.data(), 34 | dist2.data(), idx.data()); 35 | } else { 36 | AT_CHECK(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.type().is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.type().is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data(), idx.data(), weight.data(), 64 | output.data()); 65 | } else { 66 | AT_CHECK(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.type().is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.type().is_cuda()) { 90 | three_interpolate_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data(), idx.data(), weight.data(), 93 | output.data()); 94 | } else { 95 | AT_CHECK(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import argparse 5 | import torch 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | import sys 9 | print(sys.path) 10 | from dataset import PointCLoudDataset 11 | from code.loss import knn_loss, l2_normal_loss, ChamferDistance 12 | from code.model import PointCloudNet 13 | from code.utils.helper_function import get_current_state 14 | from code.utils import data_provider 15 | 16 | 17 | 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--num_points', default=1024, type=int, 21 | help='Number of points per patch') 22 | parser.add_argument('--checkpoint_path', default="..", type=str, 23 | help='Folder path to save checkpoint after each epoch') 24 | parser.add_argument('--batch_size', default=20, type=int, 25 | help='Batch Size') 26 | parser.add_argument('--epochs', default=500, type=int, 27 | help='Number of epochs') 28 | parser.add_argument('--lr', default=5e-4, type=float, 29 | help='Learning Rate') 30 | parser.add_argument('--weight_decay', default=1e-5, type=float, 31 | help='Weight Decay') 32 | parser.add_argument('--add_noise', default=True, type=bool, 33 | help='Add Gaussian Noise') 34 | parser.add_argument('--h5_data_file', type=str, 35 | help='Training h5 file path') 36 | FLAGS = parser.parse_args() 37 | 38 | 39 | NUM_POINTS = FLAGS.num_points 40 | CHECKPOINT_PATH = FLAGS.checkpoint_path 41 | BATCH_SIZE = FLAGS.batch_size 42 | EPOCHS = FLAGS.epochs 43 | LR = FLAGS.lr 44 | WD = FLAGS.weight_decay 45 | ADD_NOISE = FLAGS.add_noise 46 | H5_FILE = FLAGS.h5_data_file 47 | 48 | print(FLAGS) 49 | 50 | 51 | 52 | 53 | # use gpu if available 54 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 55 | 56 | # create directory to save checkpoints 57 | checkpoint = os.path.join(CHECKPOINT_PATH, "checkpoint") 58 | if not os.path.exists(checkpoint): 59 | os.mkdir(checkpoint) 60 | 61 | 62 | 63 | # initialize chamfer loss function 64 | chamfer_dist = ChamferDistance() 65 | 66 | # initialize upsampling network 67 | net = PointCloudNet(3, 6, True, NUM_POINTS).to(device) 68 | 69 | #define optimizer 70 | optimizer = optim.Adam(net.parameters(), lr=LR, weight_decay=WD) 71 | 72 | 73 | # read and get last state of check point if already training started 74 | state = open(os.path.join(checkpoint, "state.txt"), "a+") 75 | start_epoch = get_current_state(state) 76 | 77 | if start_epoch == -1: 78 | start_epoch = 0 79 | elif start_epoch == EPOCHS: 80 | print("Final epoch {0} already Trained.".format(start_epoch)) 81 | else: 82 | # get last state 83 | model = torch.load("{0}/epoch_{1}.pt".format(checkpoint, start_epoch)) 84 | net.load_state_dict(model["model_state_dict"]) 85 | optimizer.load_state_dict(model["optimizer_state_dict"]) 86 | 87 | print("Loaded Checkpoint ::: last trained epoch epoch_{0} with loss {1}".format(start_epoch, model["loss"])) 88 | 89 | 90 | print("TRAININNG STARTED") 91 | print("Starting from epoch {0}".format(start_epoch + 1)) 92 | 93 | 94 | pcDataset = PointCLoudDataset(H5_FILE, transform=True) 95 | trainloader = DataLoader(pcDataset, batch_size=BATCH_SIZE, shuffle=True) 96 | 97 | trainloader_len = len(trainloader) 98 | print(trainloader_len) 99 | 100 | # start training 101 | for epoch in range(start_epoch, EPOCHS): 102 | 103 | running_loss = 0.0 104 | for i, data in enumerate(trainloader): 105 | 106 | inputs, targets = data_provider.data_augmentation(data[0].numpy(), data[1].numpy()) 107 | if ADD_NOISE: 108 | inputs = data_provider.add_noise(inputs) 109 | inputs, targets = torch.from_numpy(inputs).float().to(device), torch.from_numpy(targets).float().to(device) 110 | 111 | # make gradient zeros 112 | optimizer.zero_grad() 113 | 114 | # forward + backward + optimize 115 | predicted_outputs = net(inputs) 116 | dist1, dist2 = chamfer_dist(targets[:, :, 0:3], predicted_outputs[:, :, 0:3]) 117 | n_loss = l2_normal_loss(targets, predicted_outputs, device) 118 | cosine_normal_loss, normal_neighbor_loss, point_neighbor_loss = knn_loss(predicted_outputs, 15, 15, device) 119 | loss = (torch.mean(dist1)) + (torch.mean(dist2)) + (0.1 * point_neighbor_loss) + (0.05 * n_loss) + (0.0001 * cosine_normal_loss) + (0.0001 * normal_neighbor_loss) 120 | 121 | loss.backward() 122 | optimizer.step() 123 | 124 | running_loss += loss.item() 125 | 126 | if i % 50 == 49: 127 | print("EPOCH {0}, BATCH {1}, LOSS {2}".format(epoch + 1, i + 1, loss.item())) 128 | 129 | print("EPOCH {0}, LOSS {1}".format(epoch + 1, running_loss / trainloader_len)) 130 | 131 | torch.save({ 132 | 'epoch': epoch + 1, 133 | 'model_state_dict': net.state_dict(), 134 | 'optimizer_state_dict': optimizer.state_dict(), 135 | 'loss': running_loss / trainloader_len 136 | }, "{0}/epoch_{1}.pt".format(checkpoint, epoch + 1)) 137 | 138 | state.write("EPOCH {0}, TRAINING_LOSS {1}\n".format(epoch + 1, running_loss / trainloader_len)) 139 | 140 | -------------------------------------------------------------------------------- /code/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | from pyTorchChamferDistance.chamfer_distance import ChamferDistance 4 | 5 | def knns_dist(xyz1, xyz2, k, device): 6 | """ 7 | Parameters 8 | ---------- 9 | samples: Number of points in xyz1 10 | xyz1: B * N * 6 11 | xyz2: B * N * 6 12 | k: number of points in xyz2 which are least distant to xyz1 13 | 14 | Returns 15 | ---------- 16 | k number of points in xyz2 which are least distant to xyz1 17 | """ 18 | samples = xyz1.shape[1] 19 | xyz1_xyz1 = torch.bmm(xyz1, xyz1.transpose(2, 1)) 20 | xyz2_xyz2 = torch.bmm(xyz2, xyz2.transpose(2, 1)) 21 | xyz1_xyz2 = torch.bmm(xyz1, xyz2.transpose(2, 1)) 22 | diag_ind_x = torch.arange(0, samples).to(device) 23 | rx = xyz1_xyz1[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(xyz1_xyz2.transpose(2,1)) 24 | ry = xyz2_xyz2[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(xyz1_xyz2) 25 | pair_wise_loss = rx.transpose(2,1) + ry - 2 * xyz1_xyz2 26 | 27 | top_min_k = torch.topk(pair_wise_loss, k, dim=2, largest=False) 28 | 29 | return top_min_k 30 | 31 | 32 | def l2_normal_loss(xyz1, xyz2, device): 33 | """ 34 | Parameters 35 | ---------- 36 | xyz1: B * N * 6 37 | xyz2: B * N * 6 38 | 39 | Returns 40 | ---------- 41 | l2 normal loss for points which are closer in points1 & points2 42 | """ 43 | batch = xyz1.shape[0] 44 | num_points = xyz1.shape[1] 45 | channels = xyz1.shape[2] 46 | 47 | # get indices of points1, points2 which have minimum distgance 48 | get_knn = knns_dist(xyz1[:, :, 0:3], xyz2[:, :, 0:3], k=1, device=device) 49 | 50 | k_indices = get_knn.indices 51 | k_values = get_knn.values 52 | 53 | k_points = torch.gather(xyz2.unsqueeze(1).expand(-1, xyz1.size(1), -1, -1), 54 | 2, 55 | k_indices.unsqueeze(-1).expand(-1, -1, -1, xyz2.size(-1))) 56 | 57 | #dist = torch.mean(torch.sum((points1.view(batch, num_points, 1, channels)[:, :, :, 0:3] - k_points[:, :, :, 0:3]) ** 2, dim=-1)) 58 | normal_loss = torch.mean(torch.sum((xyz1.view(batch, num_points, 1, channels)[:, :, :, 3:6] - k_points[:, :, :, 3:6]) ** 2, dim=-1)) 59 | return normal_loss #, dist 60 | 61 | def knn_loss(xyz, k_point, k_normal, device): 62 | """ 63 | Parameters 64 | ---------- 65 | points: B * N * 6 66 | k_point: number of neighbour for point regularization 67 | k_normal: number of neighbour for normal regularization 68 | 69 | Returns 70 | ---------- 71 | cosine_normal_loss 72 | normal_neighbor_loss 73 | point_neighbor_loss 74 | """ 75 | k = max(k_point, k_normal) 76 | k = k + 1 # a point also includes itself in knn search 77 | batch = xyz.shape[0] 78 | num_points = xyz.shape[1] 79 | channels = xyz.shape[2] 80 | 81 | get_knn = knns_dist(xyz[:, :, 0:3], xyz[:, :, 0:3], k, device) 82 | 83 | k_indices = get_knn.indices 84 | 85 | kv = get_knn.values 86 | kp = torch.gather(xyz.unsqueeze(1).expand(-1, xyz.size(1), -1, -1), 87 | 2, 88 | k_indices.unsqueeze(-1).expand(-1, -1, -1, xyz.size(-1))) 89 | 90 | #print(kp.shape) 91 | #print(kv.shape) 92 | #print(kp[:, :, 0, :].view(batch, num_points, 1, channels)[:, :, :, 0:3].shape) 93 | p_dist = kp[:, :, 0, :].view(batch, num_points, 1, channels)[:, :, :, 0:3] - kp[:, :, 0:k_point+1, 0:3] 94 | # remove first column of each row as it is the same point from where min distance is calculated 95 | p_dist = p_dist[:, :, 1:, :] 96 | point_neighbor_loss = torch.mean(torch.sum(p_dist ** 2, dim=-1)) 97 | #print(p_dist) 98 | #print(point_neighbor_loss) 99 | 100 | n_dist = kp[:, :, 0, :].view(batch, num_points, 1, channels)[:, :, :, 3:6] - kp[:, :, 0:k_normal+1, 3:6] 101 | # remove first column of each row as it is the same point from where min distance is calculated 102 | n_dist = n_dist[:, :, 1:, :] 103 | #print(n_dist) 104 | normal_neighbor_loss = torch.mean(torch.sum(n_dist ** 2, dim=-1)) 105 | 106 | #print(normal_neighbor_loss) 107 | 108 | dot_product = f.normalize(p_dist, p=2, dim=-1) * f.normalize(kp[:, :, 0, :].view(batch, num_points, 1, channels)[:, :, :, 3:6], p=2, dim=-1) 109 | #print(dot_product) 110 | cosine_normal_loss = torch.mean(torch.abs(torch.sum(dot_product, dim=-1))) 111 | #print(normal_loss) 112 | return cosine_normal_loss, normal_neighbor_loss, point_neighbor_loss 113 | 114 | 115 | if __name__ == "__main__": 116 | 117 | a = torch.tensor( 118 | [ 119 | [[1, 2, 3, 0.1, 0.2, 0.3], [3, 2, 1, 0.3, 0.2, 0.1], [2, 4, 2, 0.2, 0.4, 0.2], [5, 1, 2, 0.5, 0.1, 0.2]], 120 | [[5, 3, 2, 0.5, 0.3, 0.2], [0, 1, 0, 0.0, 0.1, 0.0], [3, 0, 6, 0.3, 0.0, 0.6], [3, 2, 1, 0.3, 0.2, 0.1]] 121 | ], 122 | dtype=torch.float 123 | ) 124 | b = torch.tensor( 125 | [ 126 | [[1.1, 2.1, 3.1, 0.1, 0.2, 0.3], [3, 2, 1, 0.3, 0.2, 0.1], [2, 4, 2, 0.2, 0.4, 0.2], [15, 11, 12, 0.5, 0.1, 0.2]], 127 | [[5.1, 3.1, 2.1, 0.5, 0.3, 0.2], [10, 11, 10, 0.0, 0.1, 0.0], [3, 0, 6, 0.3, 0.0, 0.6], [3, 2, 1, 0.3, 0.2, 0.1]] 128 | ], 129 | dtype=torch.float 130 | ) 131 | nl = knn_loss(a, 2, 2, "cuda:0") 132 | ecdn, ecdv = l2_normal_loss(a, b, device="cuda:0") 133 | print(nl) 134 | #print(ecdn, ecdv) 135 | #print(a, b) 136 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | import torch.nn as nn 10 | 11 | from pointnet2.utils.pointnet2_modules import PointnetFPModule, PointnetSAModuleMSG 12 | 13 | class Upsample(nn.Module): 14 | def __init__(self): 15 | super(Upsample, self).__init__() 16 | 17 | def forward(self, x): 18 | return x.view(x.shape[0], int(x.shape[1] / 2), x.shape[2] * 2) 19 | 20 | 21 | class PointCloudNet(nn.Module): 22 | r""" 23 | PointNet2 as base net with multi-scale grouping 24 | Point Cloud Upsample Network 25 | 26 | Parameters 27 | ---------- 28 | input_channels: int = 6 29 | Number of input channels in the feature descriptor for each point. If the point cloud is Nx9, this 30 | value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors 31 | output_channels: int = 6 32 | Number of output channels. 33 | num_points: int 34 | Number of points in point cloud 35 | use_xyz: bool = True 36 | Whether or not to use the xyz position of a point as a feature 37 | """ 38 | 39 | def __init__(self, input_channels=6, output_channels=6, use_xyz=True, num_points=1024): 40 | super(PointCloudNet, self).__init__() 41 | print(num_points) 42 | self.GLOBAL_module = nn.Sequential( 43 | nn.Conv1d(in_channels=input_channels + 3, out_channels=32, kernel_size=1), 44 | nn.BatchNorm1d(32), 45 | nn.ReLU(), 46 | nn.Conv1d(in_channels=32, out_channels=64, kernel_size=1), 47 | nn.BatchNorm1d(64), 48 | nn.ReLU(), 49 | nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1), 50 | nn.BatchNorm1d(64), 51 | nn.ReLU(), 52 | nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1), 53 | nn.BatchNorm1d(128), 54 | nn.ReLU(), 55 | nn.Conv1d(in_channels=128, out_channels=256, kernel_size=1), 56 | nn.BatchNorm1d(256), 57 | nn.ReLU(), 58 | nn.Conv1d(in_channels=256, out_channels=512, kernel_size=1), 59 | nn.BatchNorm1d(512), 60 | nn.ReLU(), 61 | ) 62 | 63 | num_points = int(num_points) 64 | self.SA_modules = nn.ModuleList() 65 | c_in = input_channels 66 | self.SA_modules.append( 67 | PointnetSAModuleMSG( 68 | npoint=num_points, 69 | radii=[0.1,], 70 | nsamples=[32,], 71 | mlps=[[c_in, 32, 32, 64]], 72 | use_xyz=use_xyz, 73 | ) 74 | ) 75 | c_out_0 = 64 76 | num_points = int(num_points / 4) 77 | c_in = c_out_0 78 | self.SA_modules.append( 79 | PointnetSAModuleMSG( 80 | npoint=num_points, 81 | radii=[0.2,], 82 | nsamples=[32,], 83 | mlps=[[c_in, 64, 64, 128]], 84 | use_xyz=use_xyz, 85 | ) 86 | ) 87 | c_out_1 = 128 88 | num_points = int(num_points / 4) 89 | c_in = c_out_1 90 | self.SA_modules.append( 91 | PointnetSAModuleMSG( 92 | npoint=num_points, 93 | radii=[0.3,], 94 | nsamples=[32,], 95 | mlps=[[c_in, 128, 128, 256]], 96 | use_xyz=use_xyz, 97 | ) 98 | ) 99 | c_out_2 = 256 100 | num_points = int(num_points / 4) 101 | c_in = c_out_2 102 | self.SA_modules.append( 103 | PointnetSAModuleMSG( 104 | npoint=num_points, 105 | radii=[0.4], 106 | nsamples=[32], 107 | mlps=[[c_in, 256, 256, 512]], 108 | use_xyz=use_xyz, 109 | ) 110 | ) 111 | c_out_3 = 512 112 | 113 | self.FP_modules = nn.ModuleList() 114 | self.FP_modules.append(PointnetFPModule(mlp=[c_out_3 + c_out_2, 128, 256])) 115 | self.FP_modules.append(PointnetFPModule(mlp=[256 + c_out_1, 128, 128])) 116 | self.FP_modules.append(PointnetFPModule(mlp=[128 + c_out_0, 128, 128])) 117 | self.FP_modules.append(PointnetFPModule(mlp=[128 + input_channels + 3, 128, 256])) 118 | 119 | self.UPSAMPLING_module = nn.Sequential( 120 | # in_channels = 512 (global_channels) + (256) (Local Features) + 6 (xyz + input_channels) 121 | nn.Conv1d(in_channels=512 + 256 + 6, out_channels=512, kernel_size=1), 122 | nn.BatchNorm1d(512), 123 | nn.ReLU(), 124 | nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1), 125 | nn.BatchNorm1d(256), 126 | nn.ReLU(), 127 | nn.Conv1d(in_channels=256, out_channels=128, kernel_size=1), 128 | nn.BatchNorm1d(128), 129 | nn.ReLU(), 130 | # Upsampling by factor 2 131 | Upsample(), 132 | nn.Conv1d(in_channels=64, out_channels=64, kernel_size=1), 133 | nn.BatchNorm1d(64), 134 | nn.ReLU(), 135 | nn.Conv1d(in_channels=64, out_channels=32, kernel_size=1), 136 | nn.BatchNorm1d(32), 137 | nn.ReLU(), 138 | # Upsampling by factor 2 139 | Upsample(), 140 | nn.Conv1d(in_channels=16, out_channels=16, kernel_size=1), 141 | nn.BatchNorm1d(16), 142 | nn.ReLU(), 143 | nn.Conv1d(in_channels=16, out_channels=output_channels, kernel_size=1), 144 | ) 145 | 146 | 147 | def _break_up_pc(self, pc): 148 | xyz = pc[..., 0:3].contiguous() 149 | features = pc[..., 3:].transpose(1, 2).contiguous() 150 | 151 | return xyz, features 152 | 153 | def forward(self, pointcloud): 154 | # type: (Pointnet2MSG, torch.cuda.FloatTensor) -> pt_utils.Seq 155 | r""" 156 | Forward pass of the network 157 | 158 | Parameters 159 | ---------- 160 | pointcloud: Variable(torch.cuda.FloatTensor) 161 | (B, N, 3 + input_channels) tensor 162 | Point cloud to run predicts on 163 | Each point in the point-cloud MUST 164 | be formated as (x, y, z, features...) 165 | """ 166 | #print(pointcloud.shape) 167 | num_points = pointcloud.shape[1] 168 | 169 | g_features = nn.MaxPool1d(num_points)(self.GLOBAL_module(pointcloud.permute(0, 2, 1))) 170 | #print("Global Features Shape, ", g_features.shape) 171 | 172 | xyz, features = self._break_up_pc(pointcloud) 173 | l0_xyz, l0_features = xyz, features 174 | ip_features = torch.cat((l0_xyz.permute(0, 2, 1), l0_features), dim=1) 175 | 176 | l1_xyz, l1_features = self.SA_modules[0](l0_xyz, l0_features) 177 | l2_xyz, l2_features = self.SA_modules[1](l1_xyz, l1_features) 178 | l3_xyz, l3_features = self.SA_modules[2](l2_xyz, l2_features) 179 | l4_xyz, l4_features = self.SA_modules[3](l3_xyz, l3_features) 180 | 181 | l3_features = self.FP_modules[0](l3_xyz, l4_xyz, l3_features, l4_features) 182 | l2_features = self.FP_modules[1](l2_xyz, l3_xyz, l2_features, l3_features) 183 | l1_features = self.FP_modules[2](l1_xyz, l2_xyz, l1_features, l2_features) 184 | l0_features = self.FP_modules[3](l0_xyz, l1_xyz, ip_features, l1_features) 185 | #print("Local Features Shape, ", l0_features.shape) 186 | 187 | #c_features = torch.cat([ip_features, l0_features], dim=1) 188 | c_features = torch.cat([ip_features, l0_features, g_features.repeat(1, 1, num_points)], dim=1) 189 | #print("Concat Features Shape, ", c_features.shape) 190 | 191 | return self.UPSAMPLING_module(c_features).permute(0, 2, 1) 192 | 193 | -------------------------------------------------------------------------------- /pointnet2/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2/utils/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from pointnet2.utils import pytorch_utils as pt_utils 12 | 13 | from pointnet2.utils import pointnet2_utils 14 | 15 | if False: 16 | # Workaround for type hints without depending on the `typing` module 17 | from typing import * 18 | 19 | 20 | class _PointnetSAModuleBase(nn.Module): 21 | def __init__(self): 22 | super(_PointnetSAModuleBase, self).__init__() 23 | self.npoint = None 24 | self.groupers = None 25 | self.mlps = None 26 | 27 | def forward(self, xyz, features=None): 28 | # type: (_PointnetSAModuleBase, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 29 | r""" 30 | Parameters 31 | ---------- 32 | xyz : torch.Tensor 33 | (B, N, 3) tensor of the xyz coordinates of the features 34 | features : torch.Tensor 35 | (B, N, C) tensor of the descriptors of the the features 36 | 37 | Returns 38 | ------- 39 | new_xyz : torch.Tensor 40 | (B, npoint, 3) tensor of the new features' xyz 41 | new_features : torch.Tensor 42 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 43 | """ 44 | 45 | new_features_list = [] 46 | 47 | xyz_flipped = xyz.transpose(1, 2).contiguous() 48 | new_xyz = ( 49 | pointnet2_utils.gather_operation( 50 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 51 | ) 52 | .transpose(1, 2) 53 | .contiguous() 54 | if self.npoint is not None 55 | else None 56 | ) 57 | 58 | for i in range(len(self.groupers)): 59 | new_features = self.groupers[i]( 60 | xyz, new_xyz, features 61 | ) # (B, C, npoint, nsample) 62 | 63 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 64 | new_features = F.max_pool2d( 65 | new_features, kernel_size=[1, new_features.size(3)] 66 | ) # (B, mlp[-1], npoint, 1) 67 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 68 | 69 | new_features_list.append(new_features) 70 | 71 | return new_xyz, torch.cat(new_features_list, dim=1) 72 | 73 | 74 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 75 | r"""Pointnet set abstrction layer with multiscale grouping 76 | 77 | Parameters 78 | ---------- 79 | npoint : int 80 | Number of features 81 | radii : list of float32 82 | list of radii to group with 83 | nsamples : list of int32 84 | Number of samples in each ball query 85 | mlps : list of list of int32 86 | Spec of the pointnet before the global max_pool for each scale 87 | bn : bool 88 | Use batchnorm 89 | """ 90 | 91 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 92 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 93 | super(PointnetSAModuleMSG, self).__init__() 94 | 95 | assert len(radii) == len(nsamples) == len(mlps) 96 | 97 | self.npoint = npoint 98 | self.groupers = nn.ModuleList() 99 | self.mlps = nn.ModuleList() 100 | for i in range(len(radii)): 101 | radius = radii[i] 102 | nsample = nsamples[i] 103 | self.groupers.append( 104 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 105 | if npoint is not None 106 | else pointnet2_utils.GroupAll(use_xyz) 107 | ) 108 | mlp_spec = mlps[i] 109 | if use_xyz: 110 | mlp_spec[0] += 3 111 | 112 | self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn)) 113 | 114 | 115 | class PointnetSAModule(PointnetSAModuleMSG): 116 | r"""Pointnet set abstrction layer 117 | 118 | Parameters 119 | ---------- 120 | npoint : int 121 | Number of features 122 | radius : float 123 | Radius of ball 124 | nsample : int 125 | Number of samples in the ball query 126 | mlp : list 127 | Spec of the pointnet before the global max_pool 128 | bn : bool 129 | Use batchnorm 130 | """ 131 | 132 | def __init__( 133 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 134 | ): 135 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 136 | super(PointnetSAModule, self).__init__( 137 | mlps=[mlp], 138 | npoint=npoint, 139 | radii=[radius], 140 | nsamples=[nsample], 141 | bn=bn, 142 | use_xyz=use_xyz, 143 | ) 144 | 145 | 146 | class PointnetFPModule(nn.Module): 147 | r"""Propigates the features of one set to another 148 | 149 | Parameters 150 | ---------- 151 | mlp : list 152 | Pointnet module parameters 153 | bn : bool 154 | Use batchnorm 155 | """ 156 | 157 | def __init__(self, mlp, bn=True): 158 | # type: (PointnetFPModule, List[int], bool) -> None 159 | super(PointnetFPModule, self).__init__() 160 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 161 | 162 | def forward(self, unknown, known, unknow_feats, known_feats): 163 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 164 | r""" 165 | Parameters 166 | ---------- 167 | unknown : torch.Tensor 168 | (B, n, 3) tensor of the xyz positions of the unknown features 169 | known : torch.Tensor 170 | (B, m, 3) tensor of the xyz positions of the known features 171 | unknow_feats : torch.Tensor 172 | (B, C1, n) tensor of the features to be propigated to 173 | known_feats : torch.Tensor 174 | (B, C2, m) tensor of features to be propigated 175 | 176 | Returns 177 | ------- 178 | new_features : torch.Tensor 179 | (B, mlp[-1], n) tensor of the features of the unknown features 180 | """ 181 | 182 | if known is not None: 183 | dist, idx = pointnet2_utils.three_nn(unknown, known) 184 | dist_recip = 1.0 / (dist + 1e-8) 185 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 186 | weight = dist_recip / norm 187 | 188 | interpolated_feats = pointnet2_utils.three_interpolate( 189 | known_feats, idx, weight 190 | ) 191 | else: 192 | interpolated_feats = known_feats.expand( 193 | *(known_feats.size()[0:2] + [unknown.size(1)]) 194 | ) 195 | 196 | if unknow_feats is not None: 197 | new_features = torch.cat( 198 | [interpolated_feats, unknow_feats], dim=1 199 | ) # (B, C2 + C1, n) 200 | else: 201 | new_features = interpolated_feats 202 | 203 | new_features = new_features.unsqueeze(-1) 204 | new_features = self.mlp(new_features) 205 | 206 | return new_features.squeeze(-1) 207 | 208 | 209 | if __name__ == "__main__": 210 | from torch.autograd import Variable 211 | 212 | torch.manual_seed(1) 213 | torch.cuda.manual_seed_all(1) 214 | xyz = Variable(torch.randn(2, 9, 3).cuda(), requires_grad=True) 215 | xyz_feats = Variable(torch.randn(2, 9, 6).cuda(), requires_grad=True) 216 | 217 | test_module = PointnetSAModuleMSG( 218 | npoint=2, radii=[5.0, 10.0], nsamples=[6, 3], mlps=[[9, 3], [9, 6]] 219 | ) 220 | test_module.cuda() 221 | print(test_module(xyz, xyz_feats)) 222 | 223 | # test_module = PointnetFPModule(mlp=[6, 6]) 224 | # test_module.cuda() 225 | # from torch.autograd import gradcheck 226 | # inputs = (xyz, xyz, None, xyz_feats) 227 | # test = gradcheck(test_module, inputs, eps=1e-6, atol=1e-4) 228 | # print(test) 229 | 230 | for _ in range(1): 231 | _, new_features = test_module(xyz, xyz_feats) 232 | new_features.backward(torch.cuda.FloatTensor(*new_features.size()).fill_(1)) 233 | print(new_features) 234 | print(xyz.grad) 235 | -------------------------------------------------------------------------------- /code/utils/data_provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import time 4 | import os 5 | 6 | def normalize_point_cloud(input): 7 | if len(input.shape)==2: 8 | axis = 0 9 | elif len(input.shape)==3: 10 | axis = 1 11 | centroid = np.mean(input, axis=axis, keepdims=True) 12 | input = input - centroid 13 | furthest_distance = np.amax(np.sqrt(np.sum(input ** 2, axis=-1)),axis=axis,keepdims=True) 14 | input = input / furthest_distance 15 | return input, centroid,furthest_distance 16 | 17 | def load_patch_data(h5_filename='../h5_data/Patches_noHole_and_collected.h5', skip_rate = 1,num_point=2048, use_randominput=True, norm=False): 18 | if use_randominput: 19 | print("use randominput, input h5 file is:", h5_filename) 20 | f = h5py.File(h5_filename) 21 | input = f['poisson_4096'][:] 22 | gt = f['poisson_4096'][:] 23 | else: 24 | print("Do not randominput, input h5 file is:",h5_filename) 25 | f = h5py.File(h5_filename) 26 | gt = f['poisson_4096'][:] 27 | input = f['montecarlo_1024'][:] 28 | 29 | name = f['name'][:] 30 | assert len(input) == len(gt) 31 | 32 | if norm: 33 | print("Normalization the data") 34 | data_radius = np.ones(shape=(len(input))) 35 | centroid = np.mean(gt[:,:,0:3], axis=1, keepdims=True) 36 | gt[:,:,0:3] = gt[:,:,0:3] - centroid 37 | furthest_distance = np.amax(np.sqrt(np.sum(gt[:,:,0:3] ** 2, axis=-1)),axis=1,keepdims=True) 38 | gt[:, :, 0:3] = gt[:,:,0:3] / np.expand_dims(furthest_distance,axis=-1) 39 | input[:, :, 0:3] = input[:, :, 0:3] - centroid 40 | input[:, :, 0:3] = input[:, :, 0:3] / np.expand_dims(furthest_distance,axis=-1) 41 | else: 42 | print("Do not normalization the data") 43 | centroid = np.mean(gt[:, :, 0:3], axis=1, keepdims=True) 44 | furthest_distance = np.amax(np.sqrt(np.sum((gt[:, :, 0:3] - centroid) ** 2, axis=-1)), axis=1, keepdims=True) 45 | data_radius = furthest_distance[0,:] 46 | 47 | input = input[::skip_rate] 48 | gt = gt[::skip_rate] 49 | data_radius = data_radius[::skip_rate] 50 | name = name[::skip_rate] 51 | 52 | object_name = list(set([item.split('/')[-1].split('_')[0] for item in name])) 53 | object_name.sort() 54 | print("load object names {}".format(object_name)) 55 | print("total %d samples" % (len(input))) 56 | return input, gt, data_radius, name 57 | 58 | 59 | def rotate_point_cloud_and_gt(batch_data,batch_gt=None): 60 | """ Randomly rotate the point clouds to augument the dataset 61 | rotation is per shape based along up direction 62 | Input: 63 | BxNx3 array, original batch of point clouds 64 | Return: 65 | BxNx3 array, rotated batch of point clouds 66 | """ 67 | for k in range(batch_data.shape[0]): 68 | angles = np.random.uniform(size=(3)) * 2 * np.pi 69 | Rx = np.array([[1, 0, 0], 70 | [0, np.cos(angles[0]), -np.sin(angles[0])], 71 | [0, np.sin(angles[0]), np.cos(angles[0])]]) 72 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])], 73 | [0, 1, 0], 74 | [-np.sin(angles[1]), 0, np.cos(angles[1])]]) 75 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0], 76 | [np.sin(angles[2]), np.cos(angles[2]), 0], 77 | [0, 0, 1]]) 78 | rotation_matrix = np.dot(Rz, np.dot(Ry, Rx)) 79 | 80 | # rotation_angle = np.random.uniform(size=(3)) * 2 * np.pi 81 | # cosval = np.cos(rotation_angle) 82 | # sinval = np.sin(rotation_angle) 83 | # rotation_matrix = np.array([[cosval, 0, sinval], 84 | # [0, 1, 0], 85 | # [-sinval, 0, cosval]]) 86 | 87 | batch_data[k, ..., 0:3] = np.dot(batch_data[k, ..., 0:3].reshape((-1, 3)), rotation_matrix) 88 | if batch_data.shape[-1]>3: 89 | batch_data[k, ..., 3:] = np.dot(batch_data[k, ..., 3:].reshape((-1, 3)), rotation_matrix) 90 | 91 | if batch_gt is not None: 92 | batch_gt[k, ..., 0:3] = np.dot(batch_gt[k, ..., 0:3].reshape((-1, 3)), rotation_matrix) 93 | if batch_gt.shape[-1] > 3: 94 | batch_gt[k, ..., 3:] = np.dot(batch_gt[k, ..., 3:].reshape((-1, 3)), rotation_matrix) 95 | 96 | return batch_data,batch_gt 97 | 98 | 99 | def shift_point_cloud_and_gt(batch_data, batch_gt = None, shift_range=0.3): 100 | """ Randomly shift point cloud. Shift is per point cloud. 101 | Input: 102 | BxNx3 array, original batch of point clouds 103 | Return: 104 | BxNx3 array, shifted batch of point clouds 105 | """ 106 | B, N, C = batch_data.shape 107 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 108 | for batch_index in range(B): 109 | batch_data[batch_index,:,0:3] += shifts[batch_index,0:3] 110 | 111 | if batch_gt is not None: 112 | for batch_index in range(B): 113 | batch_gt[batch_index, :, 0:3] += shifts[batch_index, 0:3] 114 | 115 | return batch_data,batch_gt 116 | 117 | 118 | def random_scale_point_cloud_and_gt(batch_data, batch_gt = None, scale_low=0.5, scale_high=2): 119 | """ Randomly scale the point cloud. Scale is per point cloud. 120 | Input: 121 | BxNx3 array, original batch of point clouds 122 | Return: 123 | BxNx3 array, scaled batch of point clouds 124 | """ 125 | B, N, C = batch_data.shape 126 | scales = np.random.uniform(scale_low, scale_high, B) 127 | for batch_index in range(B): 128 | batch_data[batch_index,:,0:3] *= scales[batch_index] 129 | 130 | if batch_gt is not None: 131 | for batch_index in range(B): 132 | batch_gt[batch_index, :, 0:3] *= scales[batch_index] 133 | 134 | return batch_data,batch_gt,scales 135 | 136 | 137 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.03, angle_clip=0.09): 138 | """ Randomly perturb the point clouds by small rotations 139 | Input: 140 | BxNx3 array, original batch of point clouds 141 | Return: 142 | BxNx3 array, rotated batch of point clouds 143 | """ 144 | for k in range(batch_data.shape[0]): 145 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 146 | Rx = np.array([[1,0,0], 147 | [0,np.cos(angles[0]),-np.sin(angles[0])], 148 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 149 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 150 | [0,1,0], 151 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 152 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 153 | [np.sin(angles[2]),np.cos(angles[2]),0], 154 | [0,0,1]]) 155 | R = np.dot(Rz, np.dot(Ry,Rx)) 156 | batch_data[k, ...,0:3] = np.dot(batch_data[k, ...,0:3].reshape((-1, 3)), R) 157 | if batch_data.shape[-1]>3: 158 | batch_data[k, ..., 3:] = np.dot(batch_data[k, ..., 3:].reshape((-1, 3)), R) 159 | 160 | return batch_data 161 | 162 | 163 | def jitter_perturbation_point_cloud(batch_data, sigma=0.005, clip=0.02): 164 | """ Randomly jitter points. jittering is per point. 165 | Input: 166 | BxNx3 array, original batch of point clouds 167 | Return: 168 | BxNx3 array, jittered batch of point clouds 169 | """ 170 | B, N, C = batch_data.shape 171 | assert(clip > 0) 172 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 173 | jittered_data[:,:,3:] = 0 174 | jittered_data += batch_data 175 | return jittered_data 176 | 177 | 178 | def save_pl(path, pl): 179 | if not os.path.exists(os.path.split(path)[0]): 180 | os.makedirs(os.path.split(path)[0]) 181 | myfile = file(path, "w") 182 | point_num = pl.shape[0] 183 | for j in range(point_num): 184 | if len(pl[j])==3: 185 | print >> myfile, "%f %f %f" % (pl[j,0],pl[j,1],pl[j,2]) 186 | elif len(pl[j])==6: 187 | print >> myfile, "%f %f %f %f %f %f" % (pl[j, 0], pl[j, 1], pl[j, 2],pl[j, 3],pl[j, 4],pl[j, 5]) 188 | # print >> myfile, "%f %f %f %f %f %f %f" % ( 189 | # pl[j, 0], pl[j, 1], pl[j, 2], pl[j, 3], pl[j, 4], pl[j, 5], pl[j, 2]) 190 | elif len(pl[j])==7: 191 | print >> myfile, "%f %f %f %f %f %f %f" % (pl[j, 0], pl[j, 1], pl[j, 2],pl[j, 3],pl[j, 4],pl[j, 5],pl[j, 2]) 192 | myfile.close() 193 | if np.random.rand()>1.9: 194 | show3d.showpoints(pl[:, 0:3]) 195 | 196 | 197 | def nonuniform_sampling(num = 4096, sample_num = 1024): 198 | sample = set() 199 | loc = np.random.rand()*0.8+0.1 200 | while(len(sample)=num: 203 | continue 204 | sample.add(a) 205 | return list(sample) 206 | 207 | def add_noise(ip): 208 | if np.random.rand() > 0.5: 209 | #print("Adding shift noise") 210 | ip = jitter_perturbation_point_cloud(ip, sigma=0.025,clip=0.05) 211 | if np.random.rand() > 0.5: 212 | #print("Adding rotation noise") 213 | ip = rotate_perturbation_point_cloud(ip, angle_sigma=0.03, angle_clip=0.09) 214 | #ip = jitter_perturbation_point_cloud(ip, sigma=0.025,clip=0.05) 215 | return ip 216 | 217 | def data_augmentation(ip, gt): 218 | #np.savetxt("temp/test_input_wo_pp.xyz", ip[0]) 219 | #np.savetxt("temp/test_gt_wo_pp.xyz", gt[0]) 220 | ip, gt = rotate_point_cloud_and_gt(ip, gt) 221 | ip, gt, _ = random_scale_point_cloud_and_gt(ip, gt, scale_low=0.9, scale_high=1.1) 222 | ip, gt = shift_point_cloud_and_gt(ip, gt, shift_range=0.1) 223 | 224 | #np.savetxt("temp/test_input_w_pp.xyz", ip[0]) 225 | #np.savetxt("temp/test_gt_w_pp.xyz", gt[0]) 226 | return ip, gt 227 | -------------------------------------------------------------------------------- /sampling/sampling_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include "cuda_utils.h" 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | #include "cuda_utils.h" 15 | 16 | 17 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 18 | int idx1, int idx2) { 19 | 20 | const float v1 = dists[idx1], v2 = dists[idx2]; 21 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 22 | dists[idx1] = max(v1, v2); 23 | dists_i[idx1] = v2 > v1 ? i2 : i1; 24 | } 25 | 26 | // input: points(b, c, n) idx(b, m) 27 | // output: out(b, c, m) 28 | template 29 | __global__ void gather_points_forward_kernel(int b, int c, int n, int m, 30 | const scalar_t *__restrict__ points, 31 | const int *__restrict__ idx, 32 | scalar_t *__restrict__ out) { 33 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 34 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 35 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 36 | int a = idx[i * m + j]; 37 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 38 | } 39 | } 40 | } 41 | } 42 | 43 | at::Tensor gather_points_cuda_forward(int b, int c, int n, int npoints, 44 | at::Tensor points, at::Tensor idx, 45 | at::Tensor out) { 46 | 47 | cudaError_t err; 48 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(points.type(), "gather_points_cuda_forward", ([&] { 49 | gather_points_forward_kernel<<>>( 50 | b, c, n, npoints, 51 | points.data(), 52 | idx.data(), 53 | out.data()); 54 | })); 55 | 56 | err = cudaGetLastError(); 57 | if (cudaSuccess != err) { 58 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 59 | exit(-1); 60 | } 61 | return out; 62 | } 63 | 64 | // input: grad_out(b, c, m) idx(b, m) 65 | // output: grad_points(b, c, n) 66 | template 67 | __global__ void gather_points_backward_kernel(int b, int c, int n, int m, 68 | scalar_t *__restrict__ grad_out, 69 | const int *__restrict__ idx, 70 | scalar_t *__restrict__ grad_points) { 71 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 72 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 73 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 74 | int a = idx[i * m + j]; 75 | atomicAdd(grad_points + (i * c + l) * n + a, 76 | grad_out[(i * c + l) * m + j]); 77 | } 78 | } 79 | } 80 | } 81 | 82 | 83 | at::Tensor gather_points_cuda_backward(int b, int c, int n, int npoints, 84 | at::Tensor grad_out, at::Tensor idx, at::Tensor grad_points) { 85 | cudaError_t err; 86 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.type(), "gather_points_cuda_backward", ([&] { 87 | gather_points_backward_kernel<<>>( 88 | b, c, n, npoints, 89 | grad_out.data(), 90 | idx.data(), 91 | grad_points.data()); 92 | })); 93 | 94 | err = cudaGetLastError(); 95 | if (cudaSuccess != err) { 96 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 97 | exit(-1); 98 | } 99 | return grad_points; 100 | } 101 | 102 | 103 | template 104 | __global__ void furthest_point_sampling_forward_kernel(int b, int n, int m, 105 | const float * __restrict__ input, float * __restrict__ temp, int * __restrict__ idx) { 106 | // temp: (nxb) the closest distance from each of the n points to the existing set 107 | if (m <= 0) return; 108 | __shared__ float dists[block_size]; 109 | __shared__ int dists_i[block_size]; 110 | const unsigned int buffer_size = block_size; 111 | __shared__ float buf[block_size*3]; 112 | for (int i=blockIdx.x; ibest){ 148 | best=d2; 149 | besti=k; 150 | } 151 | } 152 | dists[threadIdx.x]=best; 153 | dists_i[threadIdx.x]=besti; 154 | // u from 0~log2(block_size) 155 | for (int u=0;(1<>(u+1))){ 160 | int i1=(threadIdx.x*2)<<<>>( 184 | b, n, m, input.data(), 185 | temp.data(), 186 | idx.data()); 187 | break; 188 | case 256: 189 | furthest_point_sampling_forward_kernel<256><<>>( 190 | b, n, m, 191 | input.data(), 192 | temp.data(), 193 | idx.data()); 194 | break; 195 | case 128: 196 | furthest_point_sampling_forward_kernel<128><<>>( 197 | b, n, m, 198 | input.data(), 199 | temp.data(), 200 | idx.data()); 201 | break; 202 | case 64: 203 | furthest_point_sampling_forward_kernel<64><<>>( 204 | b, n, m, 205 | input.data(), 206 | temp.data(), 207 | idx.data()); 208 | break; 209 | case 32: 210 | furthest_point_sampling_forward_kernel<32><<>>( 211 | b, n, m, 212 | input.data(), 213 | temp.data(), 214 | idx.data()); 215 | break; 216 | case 16: 217 | furthest_point_sampling_forward_kernel<16><<>>( 218 | b, n, m, 219 | input.data(), 220 | temp.data(), 221 | idx.data()); 222 | break; 223 | case 8: 224 | furthest_point_sampling_forward_kernel<8><<>>( 225 | b, n, m, 226 | input.data(), 227 | temp.data(), 228 | idx.data()); 229 | break; 230 | case 4: 231 | furthest_point_sampling_forward_kernel<4><<>>( 232 | b, n, m, 233 | input.data(), 234 | temp.data(), 235 | idx.data()); 236 | break; 237 | case 2: 238 | furthest_point_sampling_forward_kernel<2><<>>( 239 | b, n, m, 240 | input.data(), 241 | temp.data(), 242 | idx.data()); 243 | break; 244 | case 1: 245 | furthest_point_sampling_forward_kernel<1><<>>( 246 | b, n, m, 247 | input.data(), 248 | temp.data(), 249 | idx.data()); 250 | break; 251 | default: 252 | furthest_point_sampling_forward_kernel<512><<>>( 253 | b, n, m, 254 | input.data(), 255 | temp.data(), 256 | idx.data()); 257 | } 258 | 259 | cudaError_t err = cudaGetLastError(); 260 | if (cudaSuccess != err) { 261 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 262 | exit(-1); 263 | } 264 | return idx; 265 | } 266 | 267 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 268 | // output: idx(b, m, nsample) 269 | template 270 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 271 | int nsample, 272 | const scalar_t *__restrict__ new_xyz, 273 | const scalar_t *__restrict__ xyz, 274 | int *__restrict__ idx) { 275 | int batch_index = blockIdx.x; 276 | xyz += batch_index * n * 3; 277 | new_xyz += batch_index * m * 3; 278 | idx += m * nsample * batch_index; 279 | 280 | int index = threadIdx.x; 281 | int stride = blockDim.x; 282 | 283 | float radius2 = radius * radius; 284 | for (int j = index; j < m; j += stride) { 285 | scalar_t new_x = new_xyz[j * 3 + 0]; 286 | scalar_t new_y = new_xyz[j * 3 + 1]; 287 | scalar_t new_z = new_xyz[j * 3 + 2]; 288 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 289 | scalar_t x = xyz[k * 3 + 0]; 290 | scalar_t y = xyz[k * 3 + 1]; 291 | scalar_t z = xyz[k * 3 + 2]; 292 | scalar_t d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 293 | (new_z - z) * (new_z - z); 294 | if (d2 < radius2) { 295 | if (cnt == 0) { 296 | for (int l = 0; l < nsample; ++l) { 297 | idx[j * nsample + l] = k; 298 | } 299 | } 300 | idx[j * nsample + cnt] = k; 301 | ++cnt; 302 | } 303 | } 304 | } 305 | } 306 | 307 | at::Tensor ball_query_cuda_forward(int b, int n, int m, float radius, 308 | int nsample, at::Tensor query, 309 | at::Tensor xyz, at::Tensor idx) { 310 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 311 | AT_DISPATCH_FLOATING_TYPES(xyz.type(), "query_ball_point_kernel", ([&]() { 312 | query_ball_point_kernel<<>>(b, n, m, radius, nsample, 313 | query.data(), xyz.data(), idx.data()); 314 | })); 315 | CUDA_CHECK_ERRORS(); 316 | return idx; 317 | } -------------------------------------------------------------------------------- /pointnet2/utils/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import torch 9 | from torch.autograd import Function 10 | import torch.nn as nn 11 | from pointnet2.utils import pytorch_utils as pt_utils 12 | import sys 13 | 14 | try: 15 | import builtins 16 | except: 17 | import __builtin__ as builtins 18 | 19 | try: 20 | import pointnet2._ext as _ext 21 | except ImportError: 22 | if not getattr(builtins, "__POINTNET2_SETUP__", False): 23 | raise ImportError( 24 | "Could not import _ext module.\n" 25 | "Please see the setup instructions in the README: " 26 | "https://github.com/erikwijmans/Pointnet2_PyTorch/blob/master/README.rst" 27 | ) 28 | 29 | if False: 30 | # Workaround for type hints without depending on the `typing` module 31 | from typing import * 32 | 33 | 34 | class RandomDropout(nn.Module): 35 | def __init__(self, p=0.5, inplace=False): 36 | super(RandomDropout, self).__init__() 37 | self.p = p 38 | self.inplace = inplace 39 | 40 | def forward(self, X): 41 | theta = torch.Tensor(1).uniform_(0, self.p)[0] 42 | return pt_utils.feature_dropout_no_scaling(X, theta, self.train, self.inplace) 43 | 44 | 45 | class FurthestPointSampling(Function): 46 | @staticmethod 47 | def forward(ctx, xyz, npoint): 48 | # type: (Any, torch.Tensor, int) -> torch.Tensor 49 | r""" 50 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 51 | minimum distance 52 | 53 | Parameters 54 | ---------- 55 | xyz : torch.Tensor 56 | (B, N, 3) tensor where N > npoint 57 | npoint : int32 58 | number of features in the sampled set 59 | 60 | Returns 61 | ------- 62 | torch.Tensor 63 | (B, npoint) tensor containing the set 64 | """ 65 | return _ext.furthest_point_sampling(xyz, npoint) 66 | 67 | @staticmethod 68 | def backward(xyz, a=None): 69 | return None, None 70 | 71 | 72 | furthest_point_sample = FurthestPointSampling.apply 73 | 74 | 75 | class GatherOperation(Function): 76 | @staticmethod 77 | def forward(ctx, features, idx): 78 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 79 | r""" 80 | 81 | Parameters 82 | ---------- 83 | features : torch.Tensor 84 | (B, C, N) tensor 85 | 86 | idx : torch.Tensor 87 | (B, npoint) tensor of the features to gather 88 | 89 | Returns 90 | ------- 91 | torch.Tensor 92 | (B, C, npoint) tensor 93 | """ 94 | 95 | _, C, N = features.size() 96 | 97 | ctx.for_backwards = (idx, C, N) 98 | 99 | return _ext.gather_points(features, idx) 100 | 101 | @staticmethod 102 | def backward(ctx, grad_out): 103 | idx, C, N = ctx.for_backwards 104 | 105 | grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N) 106 | return grad_features, None 107 | 108 | 109 | gather_operation = GatherOperation.apply 110 | 111 | 112 | class ThreeNN(Function): 113 | @staticmethod 114 | def forward(ctx, unknown, known): 115 | # type: (Any, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor] 116 | r""" 117 | Find the three nearest neighbors of unknown in known 118 | Parameters 119 | ---------- 120 | unknown : torch.Tensor 121 | (B, n, 3) tensor of known features 122 | known : torch.Tensor 123 | (B, m, 3) tensor of unknown features 124 | 125 | Returns 126 | ------- 127 | dist : torch.Tensor 128 | (B, n, 3) l2 distance to the three nearest neighbors 129 | idx : torch.Tensor 130 | (B, n, 3) index of 3 nearest neighbors 131 | """ 132 | dist2, idx = _ext.three_nn(unknown, known) 133 | 134 | return torch.sqrt(dist2), idx 135 | 136 | @staticmethod 137 | def backward(ctx, a=None, b=None): 138 | return None, None 139 | 140 | 141 | three_nn = ThreeNN.apply 142 | 143 | 144 | class ThreeInterpolate(Function): 145 | @staticmethod 146 | def forward(ctx, features, idx, weight): 147 | # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor 148 | r""" 149 | Performs weight linear interpolation on 3 features 150 | Parameters 151 | ---------- 152 | features : torch.Tensor 153 | (B, c, m) Features descriptors to be interpolated from 154 | idx : torch.Tensor 155 | (B, n, 3) three nearest neighbors of the target features in features 156 | weight : torch.Tensor 157 | (B, n, 3) weights 158 | 159 | Returns 160 | ------- 161 | torch.Tensor 162 | (B, c, n) tensor of the interpolated features 163 | """ 164 | B, c, m = features.size() 165 | n = idx.size(1) 166 | 167 | ctx.three_interpolate_for_backward = (idx, weight, m) 168 | 169 | return _ext.three_interpolate(features, idx, weight) 170 | 171 | @staticmethod 172 | def backward(ctx, grad_out): 173 | # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] 174 | r""" 175 | Parameters 176 | ---------- 177 | grad_out : torch.Tensor 178 | (B, c, n) tensor with gradients of ouputs 179 | 180 | Returns 181 | ------- 182 | grad_features : torch.Tensor 183 | (B, c, m) tensor with gradients of features 184 | 185 | None 186 | 187 | None 188 | """ 189 | idx, weight, m = ctx.three_interpolate_for_backward 190 | 191 | grad_features = _ext.three_interpolate_grad( 192 | grad_out.contiguous(), idx, weight, m 193 | ) 194 | 195 | return grad_features, None, None 196 | 197 | 198 | three_interpolate = ThreeInterpolate.apply 199 | 200 | 201 | class GroupingOperation(Function): 202 | @staticmethod 203 | def forward(ctx, features, idx): 204 | # type: (Any, torch.Tensor, torch.Tensor) -> torch.Tensor 205 | r""" 206 | 207 | Parameters 208 | ---------- 209 | features : torch.Tensor 210 | (B, C, N) tensor of features to group 211 | idx : torch.Tensor 212 | (B, npoint, nsample) tensor containing the indicies of features to group with 213 | 214 | Returns 215 | ------- 216 | torch.Tensor 217 | (B, C, npoint, nsample) tensor 218 | """ 219 | B, nfeatures, nsample = idx.size() 220 | _, C, N = features.size() 221 | 222 | ctx.for_backwards = (idx, N) 223 | 224 | return _ext.group_points(features, idx) 225 | 226 | @staticmethod 227 | def backward(ctx, grad_out): 228 | # type: (Any, torch.tensor) -> Tuple[torch.Tensor, torch.Tensor] 229 | r""" 230 | 231 | Parameters 232 | ---------- 233 | grad_out : torch.Tensor 234 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 235 | 236 | Returns 237 | ------- 238 | torch.Tensor 239 | (B, C, N) gradient of the features 240 | None 241 | """ 242 | idx, N = ctx.for_backwards 243 | 244 | grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N) 245 | 246 | return grad_features, None 247 | 248 | 249 | grouping_operation = GroupingOperation.apply 250 | 251 | 252 | class BallQuery(Function): 253 | @staticmethod 254 | def forward(ctx, radius, nsample, xyz, new_xyz): 255 | # type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor 256 | r""" 257 | 258 | Parameters 259 | ---------- 260 | radius : float 261 | radius of the balls 262 | nsample : int 263 | maximum number of features in the balls 264 | xyz : torch.Tensor 265 | (B, N, 3) xyz coordinates of the features 266 | new_xyz : torch.Tensor 267 | (B, npoint, 3) centers of the ball query 268 | 269 | Returns 270 | ------- 271 | torch.Tensor 272 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 273 | """ 274 | return _ext.ball_query(new_xyz, xyz, radius, nsample) 275 | 276 | @staticmethod 277 | def backward(ctx, a=None): 278 | return None, None, None, None 279 | 280 | 281 | ball_query = BallQuery.apply 282 | 283 | 284 | class QueryAndGroup(nn.Module): 285 | r""" 286 | Groups with a ball query of radius 287 | 288 | Parameters 289 | --------- 290 | radius : float32 291 | Radius of ball 292 | nsample : int32 293 | Maximum number of features to gather in the ball 294 | """ 295 | 296 | def __init__(self, radius, nsample, use_xyz=True): 297 | # type: (QueryAndGroup, float, int, bool) -> None 298 | super(QueryAndGroup, self).__init__() 299 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 300 | 301 | def forward(self, xyz, new_xyz, features=None): 302 | # type: (QueryAndGroup, torch.Tensor. torch.Tensor, torch.Tensor) -> Tuple[Torch.Tensor] 303 | r""" 304 | Parameters 305 | ---------- 306 | xyz : torch.Tensor 307 | xyz coordinates of the features (B, N, 3) 308 | new_xyz : torch.Tensor 309 | centriods (B, npoint, 3) 310 | features : torch.Tensor 311 | Descriptors of the features (B, C, N) 312 | 313 | Returns 314 | ------- 315 | new_features : torch.Tensor 316 | (B, 3 + C, npoint, nsample) tensor 317 | """ 318 | 319 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 320 | xyz_trans = xyz.transpose(1, 2).contiguous() 321 | grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) 322 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 323 | 324 | if features is not None: 325 | grouped_features = grouping_operation(features, idx) 326 | if self.use_xyz: 327 | new_features = torch.cat( 328 | [grouped_xyz, grouped_features], dim=1 329 | ) # (B, C + 3, npoint, nsample) 330 | else: 331 | new_features = grouped_features 332 | else: 333 | assert ( 334 | self.use_xyz 335 | ), "Cannot have not features and not use xyz as a feature!" 336 | new_features = grouped_xyz 337 | 338 | return new_features 339 | 340 | 341 | class GroupAll(nn.Module): 342 | r""" 343 | Groups all features 344 | 345 | Parameters 346 | --------- 347 | """ 348 | 349 | def __init__(self, use_xyz=True): 350 | # type: (GroupAll, bool) -> None 351 | super(GroupAll, self).__init__() 352 | self.use_xyz = use_xyz 353 | 354 | def forward(self, xyz, new_xyz, features=None): 355 | # type: (GroupAll, torch.Tensor, torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor] 356 | r""" 357 | Parameters 358 | ---------- 359 | xyz : torch.Tensor 360 | xyz coordinates of the features (B, N, 3) 361 | new_xyz : torch.Tensor 362 | Ignored 363 | features : torch.Tensor 364 | Descriptors of the features (B, C, N) 365 | 366 | Returns 367 | ------- 368 | new_features : torch.Tensor 369 | (B, C + 3, 1, N) tensor 370 | """ 371 | 372 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 373 | if features is not None: 374 | grouped_features = features.unsqueeze(2) 375 | if self.use_xyz: 376 | new_features = torch.cat( 377 | [grouped_xyz, grouped_features], dim=1 378 | ) # (B, 3 + C, 1, N) 379 | else: 380 | new_features = grouped_features 381 | else: 382 | new_features = grouped_xyz 383 | 384 | return new_features 385 | -------------------------------------------------------------------------------- /pointnet2/utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import ( 2 | division, 3 | absolute_import, 4 | with_statement, 5 | print_function, 6 | unicode_literals, 7 | ) 8 | import os 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd.function import InplaceFunction 12 | from itertools import repeat 13 | import numpy as np 14 | import shutil 15 | import tqdm 16 | from scipy.stats import t as student_t 17 | import statistics as stats 18 | 19 | 20 | if False: 21 | # Workaround for type hints without depending on the `typing` module 22 | from typing import * 23 | 24 | 25 | class SharedMLP(nn.Sequential): 26 | def __init__( 27 | self, 28 | args, 29 | bn=False, 30 | activation=nn.ReLU(inplace=True), 31 | preact=False, 32 | first=False, 33 | name="", 34 | ): 35 | # type: (SharedMLP, List[int], bool, Any, bool, bool, AnyStr) -> None 36 | super(SharedMLP, self).__init__() 37 | 38 | for i in range(len(args) - 1): 39 | self.add_module( 40 | name + "layer{}".format(i), 41 | Conv2d( 42 | args[i], 43 | args[i + 1], 44 | bn=(not first or not preact or (i != 0)) and bn, 45 | activation=activation 46 | if (not first or not preact or (i != 0)) 47 | else None, 48 | preact=preact, 49 | ), 50 | ) 51 | 52 | 53 | class _BNBase(nn.Sequential): 54 | def __init__(self, in_size, batch_norm=None, name=""): 55 | super(_BNBase, self).__init__() 56 | self.add_module(name + "bn", batch_norm(in_size)) 57 | 58 | nn.init.constant_(self[0].weight, 1.0) 59 | nn.init.constant_(self[0].bias, 0) 60 | 61 | 62 | class BatchNorm1d(_BNBase): 63 | def __init__(self, in_size, name=""): 64 | # type: (BatchNorm1d, int, AnyStr) -> None 65 | super(BatchNorm1d, self).__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 66 | 67 | 68 | class BatchNorm2d(_BNBase): 69 | def __init__(self, in_size, name=""): 70 | # type: (BatchNorm2d, int, AnyStr) -> None 71 | super(BatchNorm2d, self).__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 72 | 73 | 74 | class BatchNorm3d(_BNBase): 75 | def __init__(self, in_size, name=""): 76 | # type: (BatchNorm3d, int, AnyStr) -> None 77 | super(BatchNorm3d, self).__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 78 | 79 | 80 | class _ConvBase(nn.Sequential): 81 | def __init__( 82 | self, 83 | in_size, 84 | out_size, 85 | kernel_size, 86 | stride, 87 | padding, 88 | dilation, 89 | activation, 90 | bn, 91 | init, 92 | conv=None, 93 | norm_layer=None, 94 | bias=True, 95 | preact=False, 96 | name="", 97 | ): 98 | super(_ConvBase, self).__init__() 99 | 100 | bias = bias and (not bn) 101 | conv_unit = conv( 102 | in_size, 103 | out_size, 104 | kernel_size=kernel_size, 105 | stride=stride, 106 | padding=padding, 107 | dilation=dilation, 108 | bias=bias, 109 | ) 110 | init(conv_unit.weight) 111 | if bias: 112 | nn.init.constant_(conv_unit.bias, 0) 113 | 114 | if bn: 115 | if not preact: 116 | bn_unit = norm_layer(out_size) 117 | else: 118 | bn_unit = norm_layer(in_size) 119 | 120 | if preact: 121 | if bn: 122 | self.add_module(name + "normlayer", bn_unit) 123 | 124 | if activation is not None: 125 | self.add_module(name + "activation", activation) 126 | 127 | self.add_module(name + "conv", conv_unit) 128 | 129 | if not preact: 130 | if bn: 131 | self.add_module(name + "normlayer", bn_unit) 132 | 133 | if activation is not None: 134 | self.add_module(name + "activation", activation) 135 | 136 | 137 | class Conv1d(_ConvBase): 138 | def __init__( 139 | self, 140 | in_size, 141 | out_size, 142 | kernel_size=1, 143 | stride=1, 144 | padding=0, 145 | dilation=1, 146 | activation=nn.ReLU(inplace=True), 147 | bn=False, 148 | init=nn.init.kaiming_normal_, 149 | bias=True, 150 | preact=False, 151 | name="", 152 | norm_layer=BatchNorm1d, 153 | ): 154 | # type: (Conv1d, int, int, int, int, int, int, Any, bool, Any, bool, bool, AnyStr, _BNBase) -> None 155 | super(Conv1d, self).__init__( 156 | in_size, 157 | out_size, 158 | kernel_size, 159 | stride, 160 | padding, 161 | dilation, 162 | activation, 163 | bn, 164 | init, 165 | conv=nn.Conv1d, 166 | norm_layer=norm_layer, 167 | bias=bias, 168 | preact=preact, 169 | name=name, 170 | ) 171 | 172 | 173 | class Conv2d(_ConvBase): 174 | def __init__( 175 | self, 176 | in_size, 177 | out_size, 178 | kernel_size=(1, 1), 179 | stride=(1, 1), 180 | padding=(0, 0), 181 | dilation=(1, 1), 182 | activation=nn.ReLU(inplace=True), 183 | bn=False, 184 | init=nn.init.kaiming_normal_, 185 | bias=True, 186 | preact=False, 187 | name="", 188 | norm_layer=BatchNorm2d, 189 | ): 190 | # type: (Conv2d, int, int, Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[int, int], Any, bool, Any, bool, bool, AnyStr, _BNBase) -> None 191 | super(Conv2d, self).__init__( 192 | in_size, 193 | out_size, 194 | kernel_size, 195 | stride, 196 | padding, 197 | dilation, 198 | activation, 199 | bn, 200 | init, 201 | conv=nn.Conv2d, 202 | norm_layer=norm_layer, 203 | bias=bias, 204 | preact=preact, 205 | name=name, 206 | ) 207 | 208 | 209 | class Conv3d(_ConvBase): 210 | def __init__( 211 | self, 212 | in_size, 213 | out_size, 214 | kernel_size=(1, 1, 1), 215 | stride=(1, 1, 1), 216 | padding=(0, 0, 0), 217 | dilation=(1, 1, 1), 218 | activation=nn.ReLU(inplace=True), 219 | bn=False, 220 | init=nn.init.kaiming_normal_, 221 | bias=True, 222 | preact=False, 223 | name="", 224 | norm_layer=BatchNorm3d, 225 | ): 226 | # type: (Conv3d, int, int, Tuple[int, int, int], Tuple[int, int, int], Tuple[int, int, int], Tuple[int, int, int], Any, bool, Any, bool, bool, AnyStr, _BNBase) -> None 227 | super(Conv3d, self).__init__( 228 | in_size, 229 | out_size, 230 | kernel_size, 231 | stride, 232 | padding, 233 | dilation, 234 | activation, 235 | bn, 236 | init, 237 | conv=nn.Conv3d, 238 | norm_layer=norm_layer, 239 | bias=bias, 240 | preact=preact, 241 | name=name, 242 | ) 243 | 244 | 245 | class FC(nn.Sequential): 246 | def __init__( 247 | self, 248 | in_size, 249 | out_size, 250 | activation=nn.ReLU(inplace=True), 251 | bn=False, 252 | init=None, 253 | preact=False, 254 | name="", 255 | ): 256 | # type: (FC, int, int, Any, bool, Any, bool, AnyStr) -> None 257 | super(FC, self).__init__() 258 | 259 | fc = nn.Linear(in_size, out_size, bias=not bn) 260 | if init is not None: 261 | init(fc.weight) 262 | if not bn: 263 | nn.init.constant_(fc.bias, 0) 264 | 265 | if preact: 266 | if bn: 267 | self.add_module(name + "bn", BatchNorm1d(in_size)) 268 | 269 | if activation is not None: 270 | self.add_module(name + "activation", activation) 271 | 272 | self.add_module(name + "fc", fc) 273 | 274 | if not preact: 275 | if bn: 276 | self.add_module(name + "bn", BatchNorm1d(out_size)) 277 | 278 | if activation is not None: 279 | self.add_module(name + "activation", activation) 280 | 281 | 282 | class _DropoutNoScaling(InplaceFunction): 283 | @staticmethod 284 | def _make_noise(input): 285 | return input.new().resize_as_(input) 286 | 287 | @staticmethod 288 | def symbolic(g, input, p=0.5, train=False, inplace=False): 289 | if inplace: 290 | return None 291 | n = g.appendNode( 292 | g.create("Dropout", [input]).f_("ratio", p).i_("is_test", not train) 293 | ) 294 | real = g.appendNode(g.createSelect(n, 0)) 295 | g.appendNode(g.createSelect(n, 1)) 296 | return real 297 | 298 | @classmethod 299 | def forward(cls, ctx, input, p=0.5, train=False, inplace=False): 300 | if p < 0 or p > 1: 301 | raise ValueError( 302 | "dropout probability has to be between 0 and 1, " "but got {}".format(p) 303 | ) 304 | ctx.p = p 305 | ctx.train = train 306 | ctx.inplace = inplace 307 | 308 | if ctx.inplace: 309 | ctx.mark_dirty(input) 310 | output = input 311 | else: 312 | output = input.clone() 313 | 314 | if ctx.p > 0 and ctx.train: 315 | ctx.noise = cls._make_noise(input) 316 | if ctx.p == 1: 317 | ctx.noise.fill_(0) 318 | else: 319 | ctx.noise.bernoulli_(1 - ctx.p) 320 | ctx.noise = ctx.noise.expand_as(input) 321 | output.mul_(ctx.noise) 322 | 323 | return output 324 | 325 | @staticmethod 326 | def backward(ctx, grad_output): 327 | if ctx.p > 0 and ctx.train: 328 | return grad_output.mul(ctx.noise), None, None, None 329 | else: 330 | return grad_output, None, None, None 331 | 332 | 333 | dropout_no_scaling = _DropoutNoScaling.apply 334 | 335 | 336 | class _FeatureDropoutNoScaling(_DropoutNoScaling): 337 | @staticmethod 338 | def symbolic(input, p=0.5, train=False, inplace=False): 339 | return None 340 | 341 | @staticmethod 342 | def _make_noise(input): 343 | return input.new().resize_( 344 | input.size(0), input.size(1), *repeat(1, input.dim() - 2) 345 | ) 346 | 347 | 348 | feature_dropout_no_scaling = _FeatureDropoutNoScaling.apply 349 | 350 | 351 | def group_model_params(model, **kwargs): 352 | # type: (nn.Module, ...) -> List[Dict] 353 | decay_group = [] 354 | no_decay_group = [] 355 | 356 | for name, param in model.named_parameters(): 357 | if name.find("normlayer") != -1 or name.find("bias") != -1: 358 | no_decay_group.append(param) 359 | else: 360 | decay_group.append(param) 361 | 362 | assert len(list(model.parameters())) == len(decay_group) + len(no_decay_group) 363 | 364 | return [ 365 | dict(params=decay_group, **kwargs), 366 | dict(params=no_decay_group, weight_decay=0.0, **kwargs), 367 | ] 368 | 369 | 370 | def checkpoint_state(model=None, optimizer=None, best_prec=None, epoch=None, it=None): 371 | optim_state = optimizer.state_dict() if optimizer is not None else None 372 | if model is not None: 373 | if isinstance(model, torch.nn.DataParallel): 374 | model_state = model.module.state_dict() 375 | else: 376 | model_state = model.state_dict() 377 | else: 378 | model_state = None 379 | 380 | return { 381 | "epoch": epoch, 382 | "it": it, 383 | "best_prec": best_prec, 384 | "model_state": model_state, 385 | "optimizer_state": optim_state, 386 | } 387 | 388 | 389 | def save_checkpoint(state, is_best, filename="checkpoint", bestname="model_best"): 390 | filename = "{}.pth.tar".format(filename) 391 | torch.save(state, filename) 392 | if is_best: 393 | shutil.copyfile(filename, "{}.pth.tar".format(bestname)) 394 | 395 | 396 | def load_checkpoint(model=None, optimizer=None, filename="checkpoint"): 397 | filename = "{}.pth.tar".format(filename) 398 | 399 | if os.path.isfile(filename): 400 | print("==> Loading from checkpoint '{}'".format(filename)) 401 | checkpoint = torch.load(filename) 402 | epoch = checkpoint["epoch"] 403 | it = checkpoint.get("it", 0.0) 404 | best_prec = checkpoint["best_prec"] 405 | if model is not None and checkpoint["model_state"] is not None: 406 | model.load_state_dict(checkpoint["model_state"]) 407 | if optimizer is not None and checkpoint["optimizer_state"] is not None: 408 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 409 | print("==> Done") 410 | return it, epoch, best_prec 411 | else: 412 | print("==> Checkpoint '{}' not found".format(filename)) 413 | return None 414 | 415 | 416 | def variable_size_collate(pad_val=0, use_shared_memory=True): 417 | import collections 418 | 419 | _numpy_type_map = { 420 | "float64": torch.DoubleTensor, 421 | "float32": torch.FloatTensor, 422 | "float16": torch.HalfTensor, 423 | "int64": torch.LongTensor, 424 | "int32": torch.IntTensor, 425 | "int16": torch.ShortTensor, 426 | "int8": torch.CharTensor, 427 | "uint8": torch.ByteTensor, 428 | } 429 | 430 | def wrapped(batch): 431 | "Puts each data field into a tensor with outer dimension batch size" 432 | 433 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 434 | elem_type = type(batch[0]) 435 | if torch.is_tensor(batch[0]): 436 | max_len = 0 437 | for b in batch: 438 | max_len = max(max_len, b.size(0)) 439 | 440 | numel = sum([int(b.numel() / b.size(0) * max_len) for b in batch]) 441 | if use_shared_memory: 442 | # If we're in a background process, concatenate directly into a 443 | # shared memory tensor to avoid an extra copy 444 | storage = batch[0].storage()._new_shared(numel) 445 | out = batch[0].new(storage) 446 | else: 447 | out = batch[0].new(numel) 448 | 449 | out = out.view( 450 | len(batch), 451 | max_len, 452 | *[batch[0].size(i) for i in range(1, batch[0].dim())] 453 | ) 454 | out.fill_(pad_val) 455 | for i in range(len(batch)): 456 | out[i, 0 : batch[i].size(0)] = batch[i] 457 | 458 | return out 459 | elif ( 460 | elem_type.__module__ == "numpy" 461 | and elem_type.__name__ != "str_" 462 | and elem_type.__name__ != "string_" 463 | ): 464 | elem = batch[0] 465 | if elem_type.__name__ == "ndarray": 466 | # array of string classes and object 467 | if re.search("[SaUO]", elem.dtype.str) is not None: 468 | raise TypeError(error_msg.format(elem.dtype)) 469 | 470 | return wrapped([torch.from_numpy(b) for b in batch]) 471 | if elem.shape == (): # scalars 472 | py_type = float if elem.dtype.name.startswith("float") else int 473 | return _numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 474 | elif isinstance(batch[0], int): 475 | return torch.LongTensor(batch) 476 | elif isinstance(batch[0], float): 477 | return torch.DoubleTensor(batch) 478 | elif isinstance(batch[0], collections.Mapping): 479 | return {key: wrapped([d[key] for d in batch]) for key in batch[0]} 480 | elif isinstance(batch[0], collections.Sequence): 481 | transposed = zip(*batch) 482 | return [wrapped(samples) for samples in transposed] 483 | 484 | raise TypeError((error_msg.format(type(batch[0])))) 485 | 486 | return wrapped 487 | 488 | 489 | class TrainValSplitter: 490 | r""" 491 | Creates a training and validation split to be used as the sampler in a pytorch DataLoader 492 | Parameters 493 | --------- 494 | numel : int 495 | Number of elements in the entire training dataset 496 | percent_train : float 497 | Percentage of data in the training split 498 | shuffled : bool 499 | Whether or not shuffle which data goes to which split 500 | """ 501 | 502 | def __init__(self, numel, percent_train, shuffled=False): 503 | # type: (TrainValSplitter, int, float, bool) -> None 504 | indicies = np.array([i for i in range(numel)]) 505 | if shuffled: 506 | np.random.shuffle(indicies) 507 | 508 | self.train = torch.utils.data.sampler.SubsetRandomSampler( 509 | indicies[0 : int(percent_train * numel)] 510 | ) 511 | self.val = torch.utils.data.sampler.SubsetRandomSampler( 512 | indicies[int(percent_train * numel) : -1] 513 | ) 514 | 515 | 516 | class CrossValSplitter: 517 | r""" 518 | Class that creates cross validation splits. The train and val splits can be used in pytorch DataLoaders. The splits can be updated 519 | by calling next(self) or using a loop: 520 | for _ in self: 521 | .... 522 | Parameters 523 | --------- 524 | numel : int 525 | Number of elements in the training set 526 | k_folds : int 527 | Number of folds 528 | shuffled : bool 529 | Whether or not to shuffle which data goes in which fold 530 | """ 531 | 532 | def __init__(self, numel, k_folds, shuffled=False): 533 | # type: (CrossValSplitter, int, int, bool) -> None 534 | inidicies = np.array([i for i in range(numel)]) 535 | if shuffled: 536 | np.random.shuffle(inidicies) 537 | 538 | self.folds = np.array(np.array_split(inidicies, k_folds), dtype=object) 539 | self.current_v_ind = -1 540 | 541 | self.val = torch.utils.data.sampler.SubsetRandomSampler(self.folds[0]) 542 | self.train = torch.utils.data.sampler.SubsetRandomSampler( 543 | np.concatenate(self.folds[1:], axis=0) 544 | ) 545 | 546 | self.metrics = {} 547 | 548 | def __iter__(self): 549 | self.current_v_ind = -1 550 | return self 551 | 552 | def __len__(self): 553 | return len(self.folds) 554 | 555 | def __getitem__(self, idx): 556 | assert idx >= 0 and idx < len(self) 557 | self.val.inidicies = self.folds[idx] 558 | self.train.inidicies = np.concatenate( 559 | self.folds[np.arange(len(self)) != idx], axis=0 560 | ) 561 | 562 | def __next__(self): 563 | self.current_v_ind += 1 564 | if self.current_v_ind >= len(self): 565 | raise StopIteration 566 | 567 | self[self.current_v_ind] 568 | 569 | def update_metrics(self, to_post): 570 | # type: (CrossValSplitter, dict) -> None 571 | for k, v in to_post.items(): 572 | if k in self.metrics: 573 | self.metrics[k].append(v) 574 | else: 575 | self.metrics[k] = [v] 576 | 577 | def print_metrics(self): 578 | for name, samples in self.metrics.items(): 579 | xbar = stats.mean(samples) 580 | sx = stats.stdev(samples, xbar) 581 | tstar = student_t.ppf(1.0 - 0.025, len(samples) - 1) 582 | margin_of_error = tstar * sx / sqrt(len(samples)) 583 | print("{}: {} +/- {}".format(name, xbar, margin_of_error)) 584 | 585 | 586 | def set_bn_momentum_default(bn_momentum): 587 | def fn(m): 588 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 589 | m.momentum = bn_momentum 590 | 591 | return fn 592 | 593 | 594 | class BNMomentumScheduler(object): 595 | def __init__(self, model, bn_lambda, last_epoch=-1, setter=set_bn_momentum_default): 596 | if not isinstance(model, nn.Module): 597 | raise RuntimeError( 598 | "Class '{}' is not a PyTorch nn Module".format(type(model).__name__) 599 | ) 600 | 601 | self.model = model 602 | self.setter = setter 603 | self.lmbd = bn_lambda 604 | 605 | self.step(last_epoch + 1) 606 | self.last_epoch = last_epoch 607 | 608 | def step(self, epoch=None): 609 | if epoch is None: 610 | epoch = self.last_epoch + 1 611 | 612 | self.last_epoch = epoch 613 | self.model.apply(self.setter(self.lmbd(epoch))) 614 | 615 | 616 | class Trainer(object): 617 | r""" 618 | Reasonably generic trainer for pytorch models 619 | 620 | Parameters 621 | ---------- 622 | model : pytorch model 623 | Model to be trained 624 | model_fn : function (model, inputs, labels) -> preds, loss, accuracy 625 | optimizer : torch.optim 626 | Optimizer for model 627 | checkpoint_name : str 628 | Name of file to save checkpoints to 629 | best_name : str 630 | Name of file to save best model to 631 | lr_scheduler : torch.optim.lr_scheduler 632 | Learning rate scheduler. .step() will be called at the start of every epoch 633 | bnm_scheduler : BNMomentumScheduler 634 | Batchnorm momentum scheduler. .step() will be called at the start of every epoch 635 | eval_frequency : int 636 | How often to run an eval 637 | log_name : str 638 | Name of file to output tensorboard_logger to 639 | """ 640 | 641 | def __init__( 642 | self, 643 | model, 644 | model_fn, 645 | optimizer, 646 | checkpoint_name="ckpt", 647 | best_name="best", 648 | lr_scheduler=None, 649 | bnm_scheduler=None, 650 | eval_frequency=-1, 651 | viz=None, 652 | ): 653 | self.model, self.model_fn, self.optimizer, self.lr_scheduler, self.bnm_scheduler = ( 654 | model, 655 | model_fn, 656 | optimizer, 657 | lr_scheduler, 658 | bnm_scheduler, 659 | ) 660 | 661 | self.checkpoint_name, self.best_name = checkpoint_name, best_name 662 | self.eval_frequency = eval_frequency 663 | 664 | self.training_best, self.eval_best = {}, {} 665 | self.viz = viz 666 | 667 | @staticmethod 668 | def _decode_value(v): 669 | if isinstance(v[0], float): 670 | return np.mean(v) 671 | elif isinstance(v[0], tuple): 672 | if len(v[0]) == 3: 673 | num = [l[0] for l in v] 674 | denom = [l[1] for l in v] 675 | w = v[0][2] 676 | else: 677 | num = [l[0] for l in v] 678 | denom = [l[1] for l in v] 679 | w = None 680 | 681 | return np.average( 682 | np.sum(num, axis=0) / (np.sum(denom, axis=0) + 1e-6), weights=w 683 | ) 684 | else: 685 | raise AssertionError("Unknown type: {}".format(type(v))) 686 | 687 | def _train_it(self, it, batch): 688 | self.model.train() 689 | 690 | if self.lr_scheduler is not None: 691 | self.lr_scheduler.step(it) 692 | 693 | if self.bnm_scheduler is not None: 694 | self.bnm_scheduler.step(it) 695 | 696 | self.optimizer.zero_grad() 697 | _, loss, eval_res = self.model_fn(self.model, batch) 698 | 699 | loss.backward() 700 | self.optimizer.step() 701 | 702 | return eval_res 703 | 704 | def eval_epoch(self, d_loader): 705 | self.model.eval() 706 | 707 | eval_dict = {} 708 | total_loss = 0.0 709 | count = 1.0 710 | for i, data in tqdm.tqdm( 711 | enumerate(d_loader, 0), total=len(d_loader), leave=False, desc="val" 712 | ): 713 | self.optimizer.zero_grad() 714 | 715 | _, loss, eval_res = self.model_fn(self.model, data, eval=True) 716 | 717 | total_loss += loss.item() 718 | count += 1 719 | for k, v in eval_res.items(): 720 | if v is not None: 721 | eval_dict[k] = eval_dict.get(k, []) + [v] 722 | 723 | return total_loss / count, eval_dict 724 | 725 | def train( 726 | self, 727 | start_it, 728 | start_epoch, 729 | n_epochs, 730 | train_loader, 731 | test_loader=None, 732 | best_loss=0.0, 733 | ): 734 | r""" 735 | Call to begin training the model 736 | 737 | Parameters 738 | ---------- 739 | start_epoch : int 740 | Epoch to start at 741 | n_epochs : int 742 | Number of epochs to train for 743 | test_loader : torch.utils.data.DataLoader 744 | DataLoader of the test_data 745 | train_loader : torch.utils.data.DataLoader 746 | DataLoader of training data 747 | best_loss : float 748 | Testing loss of the best model 749 | """ 750 | 751 | eval_frequency = ( 752 | self.eval_frequency if self.eval_frequency > 0 else len(train_loader) 753 | ) 754 | 755 | it = start_it 756 | with tqdm.trange(start_epoch, n_epochs + 1, desc="epochs") as tbar, tqdm.tqdm( 757 | total=eval_frequency, leave=False, desc="train" 758 | ) as pbar: 759 | 760 | for epoch in tbar: 761 | for batch in train_loader: 762 | res = self._train_it(it, batch) 763 | it += 1 764 | 765 | pbar.update() 766 | pbar.set_postfix(dict(total_it=it)) 767 | tbar.refresh() 768 | 769 | if self.viz is not None: 770 | self.viz.update("train", it, res) 771 | 772 | if (it % eval_frequency) == 0: 773 | pbar.close() 774 | 775 | if test_loader is not None: 776 | val_loss, res = self.eval_epoch(test_loader) 777 | 778 | if self.viz is not None: 779 | self.viz.update("val", it, res) 780 | 781 | is_best = val_loss < best_loss 782 | best_loss = min(best_loss, val_loss) 783 | save_checkpoint( 784 | checkpoint_state( 785 | self.model, self.optimizer, val_loss, epoch, it 786 | ), 787 | is_best, 788 | filename=self.checkpoint_name, 789 | bestname=self.best_name, 790 | ) 791 | 792 | pbar = tqdm.tqdm( 793 | total=eval_frequency, leave=False, desc="train" 794 | ) 795 | pbar.set_postfix(dict(total_it=it)) 796 | 797 | self.viz.flush() 798 | 799 | return best_loss 800 | --------------------------------------------------------------------------------