├── .idea ├── .gitignore ├── CEMNet.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── batch_svd ├── README.md ├── __init__.py ├── setup.py ├── tests │ └── tests.py ├── torch_batch_svd.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt └── torch_batch_svd │ ├── __init__.py │ ├── batch_svd.py │ ├── csrc │ ├── bindings.cpp │ └── torch_batch_svd.cpp │ └── include │ ├── torch_batch_svd.h │ └── utils.h ├── cemnet_lib ├── __init__.py ├── cemnet_lib.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── functions │ ├── __init__.py │ └── distances.py ├── setup.py └── src │ ├── cemnet_lib_api.cpp │ ├── closest_point │ ├── closest_point_cuda.cpp │ ├── closest_point_cuda_kernel.cu │ ├── closest_point_cuda_kernel.h │ └── test.cpp │ ├── cuda_utils.h │ ├── distances │ ├── distances_cuda.cpp │ ├── distances_cuda_kernel.cu │ └── distances_cuda_kernel.h │ ├── ops │ ├── ops_cuda.cpp │ ├── ops_cuda_kernel.cu │ └── ops_cuda_kernel.h │ ├── ops_saved │ ├── ops_cuda.cpp │ ├── ops_cuda_kernel.cu │ └── ops_cuda_kernel.h │ └── prcem_lib_api.cpp ├── cems ├── __pycache__ │ ├── base_cem.cpython-36.pyc │ └── guided_cem.cpython-36.pyc ├── base_cem.py └── guided_cem.py ├── datasets ├── __pycache__ │ ├── base_dataset.cpython-36.pyc │ ├── dataset.cpython-36.pyc │ └── get_dataset.cpython-36.pyc ├── base_dataset.py ├── dataset.py ├── get_dataset.py ├── process_dataset.py ├── se_math │ ├── __init__.py │ ├── invmat.py │ ├── mesh.py │ ├── se3.py │ ├── sinc.py │ ├── so3.py │ └── transforms.py └── utils │ ├── commons.py │ ├── db_icl_nuim.py │ ├── gen_normal.py │ └── npmat2euler.py ├── modules ├── __pycache__ │ ├── commons.cpython-36.pyc │ ├── dcp_net.cpython-36.pyc │ ├── dgcnn.cpython-36.pyc │ └── sparsemax.cpython-36.pyc ├── commons.py ├── dcp_net.py ├── dgcnn.py └── sparsemax.py ├── results ├── icl_nuim_n768_unseen0_noise0_seed123 │ └── model.pth ├── modelnet40_n768_unseen0_noise0_seed123 │ └── model.pth ├── modelnet40_n768_unseen0_noise1_seed123 │ └── model.pth ├── modelnet40_n768_unseen1_noise0_seed123 │ └── model.pth └── scene7_n768_unseen0_noise0_seed123 │ └── model.pth ├── run.sh ├── test_model.py ├── train_model.py └── utils ├── __pycache__ ├── attr_dict.cpython-36.pyc ├── batch_icp.cpython-36.pyc ├── commons.cpython-36.pyc ├── euler2mat.cpython-36.pyc ├── losses.cpython-36.pyc ├── mat2euler.cpython-36.pyc ├── options.cpython-36.pyc ├── recorder.cpython-36.pyc ├── test.cpython-36.pyc └── transform_pc.cpython-36.pyc ├── attr_dict.py ├── batch_icp.py ├── commons.py ├── euler2mat.py ├── losses.py ├── mat2euler.py ├── options.py ├── recorder.py ├── test.py └── transform_pc.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/CEMNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 30 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jiang-HB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sampling Network Guided Cross-Entropy Method for Unsupervised Point Cloud Registration (ICCV2021) 2 | 3 | PyTorch implementation of CEMNet for ICCV'2021 paper ["Sampling Network Guided Cross-Entropy Method for Unsupervised Point Cloud Registration"](https://openaccess.thecvf.com/content/ICCV2021/papers/Jiang_Sampling_Network_Guided_Cross-Entropy_Method_for_Unsupervised_Point_Cloud_Registration_ICCV_2021_paper.pdf), by *Haobo Jiang, Yaqi Shen, Jin Xie, Jun Li, Jianjun Qian, and Jian Yang* from PCA Lab, Nanjing University of Science and Technology, China. 4 | 5 | This paper focuses on unsupervised deep learning for 3D point clouds registration. If you find this project useful, please cite: 6 | 7 | ```bash 8 | @inproceedings{jiang2021sampling, 9 | title={{S}ampling {N}etwork {G}uided {C}ross-{E}ntropy {M}ethod for {U}nsupervised {P}oint {C}loud {R}egistration}, 10 | author={Jiang, Haobo and Shen, Yaqi and Xie, Jin and Li, Jun and Qian, Jianjun and Yang, Jian}, 11 | booktitle={ICCV}, 12 | year={2021} 13 | } 14 | ``` 15 | 16 | ## Introduction 17 | 18 | In this paper, by modeling the point cloud registration task as a Markov decision process, we propose an end-to-end deep model embedded with the cross-entropy method (CEM) for unsupervised 3D registration. 19 | Our model consists of a sampling network module and a differentiable CEM module. In our sampling network module, given a pair of point clouds, the sampling network learns a prior sampling distribution over the transformation space. The learned sampling distribution can be used as a "good" initialization of the differentiable CEM module. In our differentiable CEM module, we first propose a maximum consensus criterion based alignment metric as the reward function for the point cloud registration task. Based on the reward function, for each state, we then construct a fused score function to evaluate the sampled transformations, where we weight the current and future rewards of the transformations. Particularly, the future rewards of the sampled transforms are obtained by performing the iterative closest point (ICP) algorithm on the transformed state. Extensive experimental results demonstrate the good registration performance of our method on benchmark datasets. 20 | 21 | ## Requirements 22 | 23 | Before running our code, you need to install the `cemlib` and `batch_svd` libraries via: 24 | 25 | ```bash 26 | bash run.sh 27 | ``` 28 | (If you meet something error when install `batch_svd`, please refer to [torch-batch-svd](https://github.com/KinglittleQ/torch-batch-svd).) 29 | 30 | ## Dataset Preprocessing 31 | 32 | We generated the used dataset files `modelnet40_normal_n2048.pth` , `7scene_normal_n2048.pth` and `icl_nuim_normal_n2048.pth` by preprocessing the raw point clouds of *ModelNet40*, *7Scene* and *ICL-NUIM* , and uploaded them to [GoogleDisk](https://drive.google.com/drive/folders/1ne9naYI8M8v4Lv0L9AcQm60Jqb8ciQ6t?usp=sharing). Also, you can generate them by yourself via: 33 | 34 | ```bash 35 | cd datasets 36 | python3 process_dataset.py 37 | ``` 38 | 39 | After that, you need modify the dataset paths in `utils/options.py`. 40 | 41 | ## Pretrained Model 42 | 43 | We uploaded the pretrained models as below: 44 | 45 | *ModelNe40*: `results/modelnet40_n768_unseen0_noise0_seed123/model.pth`, 46 | 47 | *7Scene* : `results/scene7_n768_unseen0_noise0_seed123/model.pth`, 48 | 49 | *ICL-NUIM*: `results/icl_nuim_n768_unseen0_noise0_seed123/model.pth`. 50 | 51 | ## Acknowledgments 52 | We thank the authors of 53 | - [DCP](https://github.com/WangYueFt/dcp) 54 | - [torch-batch-svd](https://github.com/KinglittleQ/torch-batch-svd) 55 | 56 | for open sourcing their methods. 57 | -------------------------------------------------------------------------------- /batch_svd/README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Batch SVD 2 | 3 | ## 1) Introduction 4 | 5 | A batch version of SVD in Pytorch implemented using cuSolver 6 | including forward and backward function. 7 | In terms of speed, it is superior to that of `torch.svd`. 8 | 9 | ``` python 10 | import torch 11 | from torch_batch_svd import svd 12 | 13 | A = torch.rand(1000000, 3, 3).cuda() 14 | u, s, v = svd(A) 15 | u, s, v = torch.svd(A) # probably you should take a coffee break here 16 | ``` 17 | 18 | The catch here is that it only works for matrices whose row and column are smaller than `32`. 19 | Other than that, `torch_batch_svd.svd` can be a drop-in for the native one. 20 | 21 | The forward function is modified from [ShigekiKarita/pytorch-cusolver](https://github.com/ShigekiKarita/pytorch-cusolver) and I fixed several bugs of it. The backward function is adapted from pytorch official [svd backward function](https://github.com/pytorch/pytorch/blob/b0545aa85f7302be5b9baf8320398981365f003d/tools/autograd/templates/Functions.cpp#L1476). I converted it to a batch version. 22 | 23 | **NOTE**: `batch_svd` supports all `torch.half`, `torch.float` and `torch.double` tensors now. 24 | 25 | **NOTE**: SVD for `torch.half` is performed by casting to `torch.float` 26 | as there is no CuSolver implementation for `c10::half`. 27 | 28 | **NOTE**: Sometimes, tests will fail for `torch.double` tensor due to numerical imprecision. 29 | 30 | ## 2) Requirements 31 | 32 | - Pytorch >= 1.0 33 | 34 | > diag_embed() is used in torch_batch_svd.cpp at the backward function. Pytorch with version lower than 1.0 does not contains diag_embed(). If you want to use it in lower version pytorch, you can replace diag_embed() by some existing function. 35 | 36 | - CUDA 9.0/10.2 (should work with 10.0/10.1 too) 37 | 38 | - Tested in Pytorch 1.4 & 1.5, with CUDA 10.2 39 | 40 | ## 3) Install 41 | 42 | Set environment variables 43 | 44 | ``` shell 45 | export CUDA_HOME=/your/cuda/home/directory/ 46 | export LIBRARY_PATH=$LIBRARY_PATH:/your/cuda/lib64/ (optional) 47 | ``` 48 | 49 | Run `setup.py` 50 | 51 | ``` shell 52 | python setup.py install 53 | ``` 54 | 55 | Run `test.py` 56 | 57 | ```shell 58 | cd tests 59 | python -m pytest test.py 60 | ``` 61 | 62 | ## 4) Differences between `torch.svd()` 63 | 64 | - The sign of column vectors at U and V may be different from `torch.svd()`. 65 | 66 | - `batch_svd()`is much more faster than `torch.svd()` using loop. 67 | 68 | ## 5) Example 69 | 70 | See `test.py` and [introduction](#1-introduction). 71 | -------------------------------------------------------------------------------- /batch_svd/__init__.py: -------------------------------------------------------------------------------- 1 | from svd.batch_svd import batch_svd 2 | -------------------------------------------------------------------------------- /batch_svd/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | import os 4 | import glob 5 | 6 | libname = "torch_batch_svd" 7 | ext_src = glob.glob(os.path.join(libname, 'csrc/*.cpp')) 8 | print(ext_src) 9 | 10 | setup(name=libname, 11 | packages=find_packages(exclude=('tests', 'build', 'csrc', 'include', 'torch_batch_svd.egg-info')), 12 | ext_modules=[CUDAExtension( 13 | # libname + '._c', 14 | "torch_batch_svd_cuda", 15 | sources=ext_src, 16 | libraries=["cusolver", "cublas"], 17 | extra_compile_args={'cxx': ['-O2', '-I{}'.format('{}/include'.format(libname))], 18 | 'nvcc': ['-O2']} 19 | )], 20 | cmdclass={'build_ext': BuildExtension} 21 | # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)} 22 | ) 23 | -------------------------------------------------------------------------------- /batch_svd/tests/tests.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import testing 3 | 4 | from torch_batch_svd import svd 5 | 6 | 7 | def test_float(): 8 | torch.manual_seed(0) 9 | a = torch.randn(1000000, 9, 3).cuda() 10 | b = a.clone() 11 | a.requires_grad = True 12 | b.requires_grad = True 13 | 14 | U, S, V = svd(a) 15 | loss = U.sum() + S.sum() + V.sum() 16 | loss.backward() 17 | 18 | u, s, v = torch.svd(b[0], some=True, compute_uv=True) 19 | loss0 = u.sum() + s.sum() + v.sum() 20 | loss0.backward() 21 | 22 | testing.assert_allclose(U[0].abs(), u.abs()) # eigenvectors are only precise up to sign 23 | testing.assert_allclose(S[0].abs(), s.abs()) 24 | testing.assert_allclose(V[0].abs(), v.abs()) 25 | testing.assert_allclose(a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1))) 26 | 27 | 28 | def test_double(): 29 | torch.manual_seed(0) 30 | a = torch.randn(10, 9, 3).cuda().double() 31 | b = a.clone() 32 | a.requires_grad = True 33 | b.requires_grad = True 34 | 35 | U, S, V = svd(a) 36 | loss = U.sum() + S.sum() + V.sum() 37 | loss.backward() 38 | 39 | u, s, v = torch.svd(b[0], some=True, compute_uv=True) 40 | loss0 = u.sum() + s.sum() + v.sum() 41 | loss0.backward() 42 | 43 | assert U.dtype == torch.double 44 | assert S.dtype == torch.double 45 | assert V.dtype == torch.double 46 | assert a.grad.dtype == torch.double 47 | testing.assert_allclose(U[0].abs(), u.abs()) # eigenvectors are only precise up to sign 48 | testing.assert_allclose(S[0].abs(), s.abs()) 49 | testing.assert_allclose(V[0].abs(), v.abs()) 50 | testing.assert_allclose(a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1))) 51 | 52 | 53 | def test_half(): 54 | torch.manual_seed(0) 55 | a = torch.randn(10, 9, 3).cuda().half() 56 | b = a.clone() 57 | a.requires_grad = True 58 | b.requires_grad = True 59 | 60 | U, S, V = svd(a) 61 | loss = U.sum() + S.sum() + V.sum() 62 | loss.backward() 63 | 64 | assert U.dtype == torch.half 65 | assert S.dtype == torch.half 66 | assert V.dtype == torch.half 67 | assert a.grad.dtype == torch.half 68 | testing.assert_allclose(a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1))) 69 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: torch-batch-svd 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | torch_batch_svd/__init__.py 4 | torch_batch_svd/batch_svd.py 5 | torch_batch_svd.egg-info/PKG-INFO 6 | torch_batch_svd.egg-info/SOURCES.txt 7 | torch_batch_svd.egg-info/dependency_links.txt 8 | torch_batch_svd.egg-info/top_level.txt 9 | torch_batch_svd/csrc/bindings.cpp 10 | torch_batch_svd/csrc/torch_batch_svd.cpp -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | torch_batch_svd 2 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_svd import svd 2 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd/batch_svd.py: -------------------------------------------------------------------------------- 1 | import torch, torch_batch_svd_cuda 2 | 3 | # from . import _c 4 | 5 | 6 | 7 | class BatchSVDFunction(torch.autograd.Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input: torch.Tensor, some=True, compute_uv=True, out=None): 11 | """ 12 | This function returns `(U, S, V)` 13 | which is the singular value decomposition 14 | of a input real matrix or batches of real matrices `input` 15 | 16 | :param ctx: 17 | :param input: 18 | :param out: 19 | :return: 20 | """ 21 | assert input.shape[-1] < 32 and input.shape[-2] < 32, \ 22 | 'This implementation only supports matrices having dims smaller than 32' 23 | 24 | is_double = True if input.dtype == torch.double else False 25 | if input.dtype == torch.half: 26 | input = input.float() 27 | ctx.is_half = True 28 | else: 29 | ctx.is_half = False 30 | 31 | if out is None: 32 | b, m, n = input.shape 33 | U = torch.empty(b, m, m, dtype=input.dtype).to(input.device) 34 | S = torch.empty(b, min(m, n), dtype=input.dtype).to(input.device) 35 | V = torch.empty(b, n, n, dtype=input.dtype).to(input.device) 36 | else: 37 | U, S, V = out 38 | 39 | torch_batch_svd_cuda.batch_svd_forward(input, U, S, V, True, 1e-7, 100, is_double) 40 | U.transpose_(1, 2) 41 | V.transpose_(1, 2) 42 | if ctx.is_half: 43 | U, S, V = U.half(), S.half(), V.half() 44 | 45 | k = S.size(1) 46 | U_reduced: torch.Tensor = U[:, :, :k] 47 | V_reduced: torch.Tensor = V[:, :, :k] 48 | ctx.save_for_backward(input, U_reduced, S, V_reduced) 49 | 50 | if not compute_uv: 51 | U = torch.zeros(b, m, m, dtype=S.dtype).to(input.device) 52 | V = torch.zeros(b, m, m, dtype=S.dtype).to(input.device) 53 | return U, S, V 54 | 55 | return (U_reduced, S, V_reduced) if some else (U, S, V) 56 | 57 | @staticmethod 58 | def backward(ctx, grad_u: torch.Tensor, grad_s: torch.Tensor, grad_v: torch.Tensor): 59 | A, U, S, V = ctx.saved_tensors 60 | if ctx.is_half: 61 | grad_u, grad_s, grad_v = grad_u.float(), grad_s.float(), grad_v.float() 62 | 63 | grad_out: torch.Tensor = torch_batch_svd_cuda.batch_svd_backward( 64 | [grad_u, grad_s, grad_v], 65 | A, True, True, U.to(A.dtype), S.to(A.dtype), V.to(A.dtype) 66 | ) 67 | if ctx.is_half: 68 | grad_out = grad_out.half() 69 | 70 | return grad_out 71 | 72 | 73 | svd = BatchSVDFunction.apply 74 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd/csrc/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "torch_batch_svd.h" 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("batch_svd_forward", &batch_svd_forward, 5 | "cusolver based batch svd forward"); 6 | m.def("batch_svd_backward", &batch_svd_backward, 7 | "batch svd backward"); 8 | } 9 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd/csrc/torch_batch_svd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "torch_batch_svd.h" 7 | #include "utils.h" 8 | 9 | // solve U S V = svd(A) a.k.a. syevj, where A (b, m, n), U (b, m, m), S (b, min(m, n)), V (b, n, n) 10 | // see also https://docs.nvidia.com/cuda/cusolver/index.html#batchgesvdj-example1 11 | void batch_svd_forward(at::Tensor a, at::Tensor U, at::Tensor s, at::Tensor V, bool is_sort, double tol, int max_sweeps, bool is_double) 12 | { 13 | CHECK_CUDA(a); 14 | CHECK_CUDA(U); 15 | CHECK_CUDA(s); 16 | CHECK_CUDA(V); 17 | CHECK_IS_FLOAT(a); 18 | 19 | auto handle_ptr = unique_allocate(cusolverDnCreate, cusolverDnDestroy); 20 | const auto A = a.contiguous().clone().transpose(1, 2).contiguous().transpose(1, 2); // important 21 | const auto batch_size = A.size(0); 22 | const auto m = A.size(1); 23 | TORCH_CHECK(m <= 32, "matrix row should be <= 32"); 24 | const auto n = A.size(2); 25 | TORCH_CHECK(n <= 32, "matrix col should be <= 32"); 26 | const auto lda = m; 27 | const auto minmn = std::min(m, n); 28 | const auto ldu = m; 29 | const auto ldv = n; 30 | 31 | auto params = unique_allocate(cusolverDnCreateGesvdjInfo, cusolverDnDestroyGesvdjInfo); 32 | auto status = cusolverDnXgesvdjSetTolerance(params.get(), tol); 33 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status); 34 | status = cusolverDnXgesvdjSetMaxSweeps(params.get(), max_sweeps); 35 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status); 36 | status = cusolverDnXgesvdjSetSortEig(params.get(), is_sort); 37 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status); 38 | 39 | auto jobz = CUSOLVER_EIG_MODE_VECTOR; // compute eigenvalues and eigenvectors 40 | int lwork; 41 | auto info_ptr = unique_cuda_ptr(batch_size); 42 | 43 | if (is_double) { 44 | const auto d_A = A.data(); 45 | auto d_s = s.data(); 46 | const auto d_U = U.data(); 47 | const auto d_V = V.data(); 48 | 49 | auto status_buffer = cusolverDnDgesvdjBatched_bufferSize( 50 | handle_ptr.get(), 51 | jobz, 52 | m, 53 | n, 54 | d_A, 55 | lda, 56 | d_s, 57 | d_U, 58 | ldu, 59 | d_V, 60 | ldv, 61 | &lwork, 62 | params.get(), 63 | batch_size); 64 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status_buffer); 65 | auto work_ptr = unique_cuda_ptr(lwork); 66 | 67 | status = cusolverDnDgesvdjBatched( 68 | handle_ptr.get(), 69 | jobz, 70 | m, 71 | n, 72 | d_A, 73 | lda, 74 | d_s, 75 | d_U, 76 | ldu, 77 | d_V, 78 | ldv, 79 | work_ptr.get(), 80 | lwork, 81 | info_ptr.get(), 82 | params.get(), 83 | batch_size 84 | ); 85 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status); 86 | } 87 | else { 88 | const auto d_A = A.data(); 89 | auto d_s = s.data(); 90 | const auto d_U = U.data(); 91 | const auto d_V = V.data(); 92 | 93 | auto status_buffer = cusolverDnSgesvdjBatched_bufferSize( 94 | handle_ptr.get(), 95 | jobz, 96 | m, 97 | n, 98 | d_A, 99 | lda, 100 | d_s, 101 | d_U, 102 | ldu, 103 | d_V, 104 | ldv, 105 | &lwork, 106 | params.get(), 107 | batch_size); 108 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status_buffer); 109 | auto work_ptr = unique_cuda_ptr(lwork); 110 | 111 | status = cusolverDnSgesvdjBatched( 112 | handle_ptr.get(), 113 | jobz, 114 | m, 115 | n, 116 | d_A, 117 | lda, 118 | d_s, 119 | d_U, 120 | ldu, 121 | d_V, 122 | ldv, 123 | work_ptr.get(), 124 | lwork, 125 | info_ptr.get(), 126 | params.get(), 127 | batch_size 128 | ); 129 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status); 130 | } 131 | 132 | std::vector hinfo(batch_size); 133 | auto status_memcpy = cudaMemcpy(hinfo.data(), info_ptr.get(), sizeof(int) * batch_size, cudaMemcpyDeviceToHost); 134 | TORCH_CHECK(cudaSuccess == status_memcpy); 135 | 136 | for(int i = 0 ; i < batch_size; ++i) 137 | { 138 | if ( 0 == hinfo[i] ) 139 | { 140 | continue; 141 | } 142 | else if ( 0 > hinfo[i] ) 143 | { 144 | std::cout << "Error: " << -hinfo[i] << "-th parameter is wrong" << std::endl; 145 | TORCH_CHECK(false); 146 | } 147 | else 148 | { 149 | std::cout << "WARNING: matrix " << i << ", info = " << hinfo[i] << ": Jacobi method does not converge" << std::endl; 150 | } 151 | } 152 | } 153 | 154 | 155 | 156 | // https://j-towns.github.io/papers/svd-derivative.pdf 157 | // 158 | // This makes no assumption on the signs of sigma. 159 | at::Tensor batch_svd_backward(const std::vector &grads, const at::Tensor& self, 160 | bool some, bool compute_uv, const at::Tensor& raw_u, const at::Tensor& sigma, const at::Tensor& raw_v) { 161 | TORCH_CHECK(compute_uv, 162 | "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ", 163 | "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)"); 164 | 165 | // A [b, m, n] 166 | // auto b = self.size(0); 167 | auto m = self.size(1); 168 | auto n = self.size(2); 169 | auto k = sigma.size(1); 170 | auto gsigma = grads[1]; 171 | 172 | auto u = raw_u; 173 | auto v = raw_v; 174 | auto gu = grads[0]; 175 | auto gv = grads[2]; 176 | 177 | if (!some) { 178 | // We ignore the free subspace here because possible base vectors cancel 179 | // each other, e.g., both -v and +v are valid base for a dimension. 180 | // Don't assume behavior of any particular implementation of svd. 181 | u = raw_u.narrow(2, 0, k); 182 | v = raw_v.narrow(2, 0, k); 183 | if (gu.defined()) { 184 | gu = gu.narrow(2, 0, k); 185 | } 186 | if (gv.defined()) { 187 | gv = gv.narrow(2, 0, k); 188 | } 189 | } 190 | auto vt = v.transpose(1, 2); 191 | 192 | at::Tensor sigma_term; 193 | if (gsigma.defined()) { 194 | sigma_term = u.bmm(gsigma.diag_embed()).bmm(vt); 195 | } else { 196 | sigma_term = at::zeros({1}, self.options()).expand_as(self); 197 | } 198 | // in case that there are no gu and gv, we can avoid the series of kernel 199 | // calls below 200 | if (!gv.defined() && !gu.defined()) { 201 | return sigma_term; 202 | } 203 | 204 | auto ut = u.transpose(1, 2); 205 | auto im = at::eye(m, self.options()); // work if broadcast 206 | auto in = at::eye(n, self.options()); 207 | auto sigma_mat = sigma.diag_embed(); 208 | auto sigma_mat_inv = sigma.pow(-1).diag_embed(); 209 | auto sigma_expanded_sq = sigma.pow(2).unsqueeze(1).expand_as(sigma_mat); 210 | auto F = sigma_expanded_sq - sigma_expanded_sq.transpose(1, 2); 211 | // The following two lines invert values of F, and fills the diagonal with 0s. 212 | // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf 213 | // first to prevent nan from appearing in backward of this function. 214 | F.diagonal(0, -2, -1).fill_(INFINITY); 215 | F = F.pow(-1); 216 | 217 | at::Tensor u_term, v_term; 218 | 219 | if (gu.defined()) { 220 | u_term = u.bmm(F.mul(ut.bmm(gu) - gu.transpose(1, 2).bmm(u))).bmm(sigma_mat); 221 | if (m > k) { 222 | u_term = u_term + (im - u.bmm(ut)).bmm(gu).bmm(sigma_mat_inv); 223 | } 224 | u_term = u_term.bmm(vt); 225 | } else { 226 | u_term = at::zeros({1}, self.options()).expand_as(self); 227 | } 228 | 229 | if (gv.defined()) { 230 | auto gvt = gv.transpose(1, 2); 231 | v_term = sigma_mat.bmm(F.mul(vt.bmm(gv) - gvt.bmm(v))).bmm(vt); 232 | if (n > k) { 233 | v_term = v_term + sigma_mat_inv.bmm(gvt.bmm(in - v.bmm(vt))); 234 | } 235 | v_term = u.bmm(v_term); 236 | } else { 237 | v_term = at::zeros({1}, self.options()).expand_as(self); 238 | } 239 | 240 | return u_term + sigma_term + v_term; 241 | } 242 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd/include/torch_batch_svd.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | void batch_svd_forward(at::Tensor a, at::Tensor U, at::Tensor s, at::Tensor V, 6 | bool is_sort, double tol=1e-7, int max_sweeps=100, bool is_double=false); 7 | at::Tensor batch_svd_backward(const std::vector &grads, const at::Tensor& self, 8 | bool some, bool compute_uv, const at::Tensor& raw_u, const at::Tensor& sigma, const at::Tensor& raw_v); 9 | -------------------------------------------------------------------------------- /batch_svd/torch_batch_svd/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #define CHECK_CUDA(x) \ 11 | do { \ 12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_CONTIGUOUS(x) \ 16 | do { \ 17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \ 18 | } while (0) 19 | 20 | #define CHECK_IS_FLOAT(x) \ 21 | do { \ 22 | TORCH_CHECK((x.scalar_type() == at::ScalarType::Float) || (x.scalar_type() == at::ScalarType::Half) || (x.scalar_type() == at::ScalarType::Double), \ 23 | #x " must be a double, float or half tensor"); \ 24 | } while (0) 25 | 26 | 27 | template // , class A = Status(*)(P), class D = Status(*)(T)> 28 | std::unique_ptr unique_allocate(Status(allocator)(T**), Status(deleter)(T*)) 29 | { 30 | T* ptr; 31 | auto stat = allocator(&ptr); 32 | TORCH_CHECK(stat == success); 33 | return {ptr, deleter}; 34 | } 35 | 36 | template 37 | std::unique_ptr unique_cuda_ptr(size_t len) { 38 | T* ptr; 39 | auto stat = cudaMalloc(&ptr, sizeof(T) * len); 40 | TORCH_CHECK(stat == cudaSuccess); 41 | return {ptr, cudaFree}; 42 | } 43 | -------------------------------------------------------------------------------- /cemnet_lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import cd_distance, mc_distance, gm_distance, closest_point 2 | -------------------------------------------------------------------------------- /cemnet_lib/cemnet_lib.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: cemnet-lib 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /cemnet_lib/cemnet_lib.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | cemnet_lib.egg-info/PKG-INFO 3 | cemnet_lib.egg-info/SOURCES.txt 4 | cemnet_lib.egg-info/dependency_links.txt 5 | cemnet_lib.egg-info/top_level.txt 6 | src/cemnet_lib_api.cpp 7 | src/ops/ops_cuda.cpp 8 | src/ops/ops_cuda_kernel.cu -------------------------------------------------------------------------------- /cemnet_lib/cemnet_lib.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cemnet_lib/cemnet_lib.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | cemnet_lib_cuda 2 | -------------------------------------------------------------------------------- /cemnet_lib/functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .distances import mc_distance, cd_distance, gm_distance, closest_point -------------------------------------------------------------------------------- /cemnet_lib/functions/distances.py: -------------------------------------------------------------------------------- 1 | import torch, cemnet_lib_cuda 2 | from torch.autograd import Function 3 | 4 | # closest point 5 | class ClosestPoint(Function): 6 | @staticmethod 7 | def forward(ctx, srcs, tgt): 8 | """ 9 | input: srcs: (B, 3, N), tgt: (3, N) 10 | return: closest_points: (B, 3, N) 11 | """ 12 | closest_points = torch.zeros_like(srcs).to(srcs.device) 13 | cemnet_lib_cuda.closest_point_cuda(srcs, tgt, closest_points) 14 | return closest_points 15 | 16 | @staticmethod 17 | def backward(srcs=None, tgt=None, closest_idxs=None): 18 | return None, None 19 | 20 | closest_point = ClosestPoint.apply 21 | 22 | # Chamfer distance 23 | class CD(Function): 24 | @staticmethod 25 | def forward(ctx, srcs, tgt): 26 | """ 27 | srcs: (B, 3, N) 28 | tgt: (3, N) 29 | return: distances: (B) 30 | """ 31 | distances = torch.zeros(len(srcs), srcs.size(2), 2).to(srcs.device) 32 | cemnet_lib_cuda.cd_distance_cuda(srcs, tgt, distances) 33 | distances = distances.sum(2).mean(1) 34 | return distances 35 | 36 | @staticmethod 37 | def backward(srcs=None, tgt=None, distances=None, r=None): 38 | return None, None 39 | 40 | cd_distance = CD.apply 41 | 42 | # Geman-McClure estimator based distance 43 | class GM(Function): 44 | @staticmethod 45 | def forward(ctx, srcs, tgt, mu): 46 | """ 47 | srcs: (B, 3, N) 48 | tgt: (3, N) 49 | return: distances: (B) 50 | """ 51 | distances = torch.zeros(len(srcs), srcs.size(2), 2).to(srcs.device) 52 | cemnet_lib_cuda.cd_distance_cuda(srcs, tgt, distances) # (B, N, 2) 53 | distances = ((distances * mu) / (distances + mu)).sum(2).mean(1) 54 | return distances 55 | 56 | @staticmethod 57 | def backward(srcs=None, tgt=None, distances=None, r=None): 58 | return None, None 59 | 60 | gm_distance = GM.apply 61 | 62 | # Maximum consensus based distance 63 | class MC(Function): 64 | @staticmethod 65 | def forward(ctx, srcs, tgt, epsilon): 66 | """ 67 | srcs: (B, 3, N) 68 | tgt: (3, N) 69 | epsilon: float 70 | return: distances: (B) 71 | """ 72 | distances = torch.zeros(len(srcs), srcs.size(2), 2).to(srcs.device) 73 | min_idxs = torch.zeros(len(srcs), srcs.size(2), 2).type(torch.int).to(srcs.device) 74 | cemnet_lib_cuda.mc_distance_cuda(srcs, tgt, distances, epsilon, min_idxs) 75 | distances = 2.0 - distances.sum(2).mean(1) 76 | return distances 77 | 78 | @staticmethod 79 | def backward(srcs=None, tgt=None, distances=None, r=None, is_min_idxs=None): 80 | return None, None 81 | 82 | mc_distance = MC.apply -------------------------------------------------------------------------------- /cemnet_lib/setup.py: -------------------------------------------------------------------------------- 1 | #python3 setup.py install 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | libname = "cemnet_lib" 6 | 7 | setup( 8 | name='cemnet_lib', 9 | ext_modules=[ 10 | CUDAExtension("cemnet_lib_cuda", [ 11 | "src/cemnet_lib_api.cpp", 12 | "src/ops/ops_cuda.cpp", 13 | "src/ops/ops_cuda_kernel.cu", 14 | ], 15 | extra_compile_args={'cxx': ['-O2', '-I{}'.format('{}/include'.format(libname))],'nvcc': ['-O2']}) 16 | ], 17 | cmdclass={'build_ext': BuildExtension} 18 | ) -------------------------------------------------------------------------------- /cemnet_lib/src/cemnet_lib_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ops/ops_cuda_kernel.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("closest_point_cuda", &closest_point_cuda, "closest_point_cuda"); // name in python, cpp function address, docs 8 | m.def("cd_distance_cuda", &cd_distance_cuda, "cd_distance_cuda"); 9 | // m.def("iou_distance_cuda", &iou_distance_cuda, "iou_distance_cuda"); 10 | m.def("mc_distance_cuda", &mc_distance_cuda, "mc_distance_cuda"); 11 | // m.def("cycle_distance_cuda", &cycle_distance_cuda, "cycle_distance_cuda"); 12 | } -------------------------------------------------------------------------------- /cemnet_lib/src/closest_point/closest_point_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "closest_point_cuda_kernel.h" 6 | 7 | extern THCState *state; 8 | 9 | void closest_point_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor closest_idxs_tensor){ 10 | int b = srcs_tensor.size(0); 11 | int c = srcs_tensor.size(1); 12 | int n = srcs_tensor.size(2); 13 | const float *srcs = srcs_tensor.data(); 14 | const float *tgt = tgt_tensor.data(); 15 | int *closest_idxs = closest_idxs_tensor.data(); 16 | closest_point_cuda_launcher(b, c, n, closest_idxs, srcs, tgt); 17 | } 18 | 19 | 20 | 21 | 22 | //#include 23 | //#include 24 | //#include 25 | //#include 26 | //#include "sampling_cuda_kernel.h" 27 | // 28 | //extern THCState *state; 29 | // 30 | //void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) 31 | //{ 32 | // const float *points = points_tensor.data(); 33 | // const int *idx = idx_tensor.data(); 34 | // float *out = out_tensor.data(); 35 | // gathering_forward_cuda_launcher(b, c, n, m, points, idx, out); 36 | //} 37 | // 38 | //void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) 39 | //{ 40 | // 41 | // const float *grad_out = grad_out_tensor.data(); 42 | // const int *idx = idx_tensor.data(); 43 | // float *grad_points = grad_points_tensor.data(); 44 | // gathering_backward_cuda_launcher(b, c, n, m, grad_out, idx, grad_points); 45 | //} 46 | // 47 | //void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) 48 | //{ 49 | // const float *points = points_tensor.data(); 50 | // float *temp = temp_tensor.data(); 51 | // int *idx = idx_tensor.data(); 52 | // furthestsampling_cuda_launcher(b, n, m, points, temp, idx); 53 | //} 54 | -------------------------------------------------------------------------------- /cemnet_lib/src/closest_point/closest_point_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "closest_point_cuda_kernel.h" 3 | #include 4 | 5 | // input: srcs(B, 3, n) tgt(3, n) 6 | // output: out(b, n) 7 | //__global__ void closest_point_cuda_kernel(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt) 8 | //{ 9 | // int n_block = gridDim.x; 10 | // int n_thread = blockDim.x; 11 | // int idx_block = blockIdx.x; 12 | // int idx_thread = threadIdx.x; 13 | // 14 | // if (idx_block < n && idx_thread < b) 15 | // { 16 | // float min_val = 1000.0; 17 | // int idx = -1; 18 | // for (int k = 0; k < n; k += 1) 19 | // { 20 | // float val = 0.0; 21 | // for (int i = 0; i < c; i += 1) 22 | // { 23 | // val += pow(srcs[idx_thread * c * n + i * n + idx_block] - tgt[i * n + k], 2.0); 24 | // } 25 | // if (val < min_val) 26 | // { 27 | // min_val = val; 28 | // idx = k; 29 | // } 30 | // } 31 | // closest_idxs[idx_thread * n + idx_block] = idx; 32 | // } 33 | //} 34 | 35 | //void closest_point_cuda_launcher(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt) 36 | //{ 37 | // dim3 grid(n, 1, 1); 38 | // dim3 block(b, 1, 1); 39 | // closest_point_cuda_kernel<<>>(b, c, n, closest_idxs, srcs, tgt); 40 | //} 41 | 42 | 43 | // input: srcs(B, 3, n) tgt(3, n) 44 | // output: out(b, n) 45 | __global__ void closest_point_cuda_kernel(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt) 46 | { 47 | int idx_block = blockIdx.x; 48 | int idx_thread = threadIdx.x; 49 | 50 | if (idx_block < b && idx_thread < n) 51 | { 52 | float min_val = 1000.0; 53 | int idx = -1; 54 | for (int k = 0; k < n; k += 1) 55 | { 56 | float val = 0.0; 57 | for (int i = 0; i < c; i += 1) 58 | { 59 | val += pow(srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k], 2.0); 60 | } 61 | if (val < min_val) 62 | { 63 | min_val = val; 64 | idx = k; 65 | } 66 | } 67 | closest_idxs[idx_block * n + idx_thread] = idx; 68 | } 69 | } 70 | 71 | void closest_point_cuda_launcher(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt) 72 | { 73 | dim3 grid(b, 1, 1); 74 | dim3 block(n, 1, 1); 75 | closest_point_cuda_kernel<<>>(b, c, n, closest_idxs, srcs, tgt); 76 | } -------------------------------------------------------------------------------- /cemnet_lib/src/closest_point/closest_point_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_CUDA_KERNEL 2 | #define _SAMPLING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void closest_point_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor closest_idxs); 8 | 9 | #ifdef __cplusplus 10 | extern "C" { 11 | #endif 12 | 13 | void closest_point_cuda_launcher(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | #endif 19 | 20 | 21 | //void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor); 22 | //void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor); 23 | //void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor); 24 | // 25 | //#ifdef __cplusplus 26 | //extern "C" { 27 | //#endif 28 | // 29 | //void gathering_forward_cuda_launcher(int b, int c, int n, int m, const float *points, const int *idx, float *out); 30 | //void gathering_backward_cuda_launcher(int b, int c, int n, int m, const float *grad_out, const int *idx, float *grad_points); 31 | //void furthestsampling_cuda_launcher(int b, int n, int m, const float *dataset, float *temp, int *idxs); 32 | // 33 | //#ifdef __cplusplus 34 | //} 35 | //#endif 36 | //#endif 37 | -------------------------------------------------------------------------------- /cemnet_lib/src/closest_point/test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | int main() 7 | { 8 | cout << "Hello, world, from Visual C++!" << endl; 9 | } -------------------------------------------------------------------------------- /cemnet_lib/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 1500 7 | 8 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 9 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x) 11 | 12 | #define THREADS_PER_BLOCK 800 13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 18 | } 19 | 20 | inline dim3 opt_block_config(int x, int y) { 21 | // const int x_threads = opt_n_threads(x); 22 | // const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 23 | // dim3 block_config(x_threads, y_threads, 1); 24 | // return block_config; 25 | dim3 block_config(TOTAL_THREADS, 1, 1); 26 | return block_config; 27 | } 28 | 29 | #endif 30 | -------------------------------------------------------------------------------- /cemnet_lib/src/distances/distances_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "distances_cuda_kernel.h" 6 | 7 | extern THCState *state; 8 | 9 | void cd_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor) 10 | { 11 | int b = srcs_tensor.size(0); 12 | int c = srcs_tensor.size(1); 13 | int n = srcs_tensor.size(2); 14 | const float *srcs = srcs_tensor.data(); 15 | const float *tgts = tgt_tensor.data(); 16 | float *distances = distances_tensor.data(); 17 | cd_distance_cuda_launcher(b, c, n, srcs, tgts, distances); 18 | } 19 | 20 | void iou_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, float r) 21 | { 22 | int b = srcs_tensor.size(0); 23 | int c = srcs_tensor.size(1); 24 | int n = srcs_tensor.size(2); 25 | const float *srcs = srcs_tensor.data(); 26 | const float *tgt = tgt_tensor.data(); 27 | float *distances = distances_tensor.data(); 28 | iou_distance_cuda_launcher(b, c, n, r, srcs, tgt, distances); 29 | } -------------------------------------------------------------------------------- /cemnet_lib/src/distances/distances_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "distances_cuda_kernel.h" 3 | #include 4 | 5 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 6 | // output: None 7 | __global__ void iou_distance_cuda_kernel(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances) 8 | { 9 | int idx_block = blockIdx.x; // b_idx 10 | int idx_thread = threadIdx.x; // n_idx 11 | 12 | if (idx_block < b && idx_thread < n) 13 | { 14 | float min_distance1, min_distance2 = 0.0, 0.0; 15 | for (int i = 0; i < n; i += 1) 16 | { 17 | float distance1, distance2 = 0.0, 0.0; 18 | for (int j = 0; j < c; j += 1) 19 | { 20 | distance1 += pow(srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i], 2.0); 21 | distance2 += pow(srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread], 2.0); 22 | } 23 | if (distance1 < min_distance1) 24 | { 25 | min_distance1 = distance1; 26 | } 27 | if (distance2 < min_distance2) 28 | { 29 | min_distance2 = distance2; 30 | } 31 | } 32 | if (min_distance1 <= r) 33 | { 34 | distances[idx_block] += (1.0 - min_distance1 / r); 35 | } 36 | if (min_distance2 <= r) 37 | { 38 | distances[idx_block] += (1.0 - min_distance2 / r); 39 | } 40 | } 41 | } 42 | 43 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances) 44 | { 45 | dim3 grid(b, 1, 1); 46 | dim3 block(n, 1, 1); 47 | iou_distance_cuda_kernel<<>>(b, c, n, r, srcs, tgt, distances); 48 | } 49 | 50 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 51 | // output: None 52 | __global__ void cd_distance_cuda_kernel(int b, int c, int n, const float *srcs, const float *tgt, float * distances) 53 | { 54 | int idx_block = blockIdx.x; // b_idx 55 | int idx_thread = threadIdx.x; // n_idx 56 | 57 | if (idx_block < b && idx_thread < n) 58 | { 59 | float min_distance1, min_distance2 = 0.0, 0.0; 60 | for (int i = 0; i < n; i += 1) 61 | { 62 | float distance1, distance2 = 0.0, 0.0; 63 | for (int j = 0; j < c; j += 1) 64 | { 65 | distance1 += pow(srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i], 2.0); 66 | distance2 += pow(srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread], 2.0); 67 | } 68 | if (distance1 < min_distance1) 69 | { 70 | min_distance1 = distance1; 71 | } 72 | if (distance2 < min_distance2) 73 | { 74 | min_distance2 = distance2; 75 | } 76 | } 77 | distances[idx_block] += (min_distance1 + min_distance2); 78 | } 79 | } 80 | 81 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances) 82 | { 83 | dim3 grid(b, 1, 1); 84 | dim3 block(n, 1, 1); 85 | cd_distance_cuda_kernel<<>>(b, c, n, srcs, tgt, distances); 86 | } -------------------------------------------------------------------------------- /cemnet_lib/src/distances/distances_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_CUDA_KERNEL 2 | #define _SAMPLING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void cd_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances); 8 | void iou_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, float r); 9 | 10 | #ifdef __cplusplus 11 | extern "C" { 12 | #endif 13 | 14 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances); 15 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances); 16 | 17 | #ifdef __cplusplus 18 | } 19 | #endif 20 | #endif -------------------------------------------------------------------------------- /cemnet_lib/src/ops/ops_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "ops_cuda_kernel.h" 6 | 7 | extern THCState *state; 8 | 9 | void closest_point_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor closest_points_tensor){ 10 | int b = srcs_tensor.size(0); 11 | int c = srcs_tensor.size(1); 12 | int n = srcs_tensor.size(2); 13 | const float *srcs = srcs_tensor.data(); 14 | const float *tgt = tgt_tensor.data(); 15 | float *closest_points = closest_points_tensor.data(); 16 | closest_point_cuda_launcher(b, c, n, closest_points, srcs, tgt); 17 | } 18 | 19 | void cd_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor) 20 | { 21 | int b = srcs_tensor.size(0); 22 | int c = srcs_tensor.size(1); 23 | int n = srcs_tensor.size(2); 24 | const float *srcs = srcs_tensor.data(); 25 | const float *tgts = tgt_tensor.data(); 26 | float *distances = distances_tensor.data(); 27 | cd_distance_cuda_launcher(b, c, n, srcs, tgts, distances); 28 | } 29 | 30 | void mc_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, float r, at::Tensor min_idxs_tensor) 31 | { 32 | int b = srcs_tensor.size(0); 33 | int c = srcs_tensor.size(1); 34 | int n = srcs_tensor.size(2); 35 | const float *srcs = srcs_tensor.data(); 36 | const float *tgt = tgt_tensor.data(); 37 | float *distances = distances_tensor.data(); 38 | int *min_idxs = min_idxs_tensor.data(); 39 | mc_distance_cuda_launcher(b, c, n, r, srcs, tgt, distances, min_idxs); 40 | } 41 | 42 | void cycle_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, int N, at::Tensor min_idxs_tensor, at::Tensor n_distances_tensor, at::Tensor n_idxs_tensor) 43 | { 44 | int b = srcs_tensor.size(0); 45 | int c = srcs_tensor.size(1); 46 | int n = srcs_tensor.size(2); 47 | const float *srcs = srcs_tensor.data(); 48 | const float *tgt = tgt_tensor.data(); 49 | float *distances = distances_tensor.data(); 50 | int *min_idxs = min_idxs_tensor.data(); 51 | float *n_distances = n_distances_tensor.data(); 52 | int *n_idxs = n_idxs_tensor.data(); 53 | cycle_distance_cuda_launcher(b, c, n, N, srcs, tgt, distances, min_idxs, n_distances, n_idxs); 54 | } -------------------------------------------------------------------------------- /cemnet_lib/src/ops/ops_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "ops_cuda_kernel.h" 3 | #include "math.h" 4 | 5 | // input: srcs(B, 3, n) tgt(3, n) 6 | // output: out(b, n) 7 | __global__ void closest_point_cuda_kernel(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt) 8 | { 9 | int idx_block = blockIdx.x; 10 | int idx_thread = threadIdx.x; 11 | 12 | if (idx_block < b && idx_thread < n) 13 | { 14 | float min_val = 1000.0; 15 | int idx = -1; 16 | for (int k = 0; k < n; k += 1) 17 | { 18 | float val = 0.0; 19 | for (int i = 0; i < c; i += 1) 20 | { 21 | val += (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]) * (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]); 22 | } 23 | if (val < min_val) 24 | { 25 | min_val = val; 26 | idx = k; 27 | } 28 | } 29 | closest_points[idx_block * n * c + idx_thread] = tgt[idx]; 30 | closest_points[idx_block * n * c + n + idx_thread] = tgt[n + idx]; 31 | closest_points[idx_block * n * c + 2 * n + idx_thread] = tgt[2 * n + idx]; 32 | } 33 | } 34 | 35 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt) 36 | { 37 | dim3 grid(b, 1, 1); 38 | dim3 block(n, 1, 1); 39 | closest_point_cuda_kernel<<>>(b, c, n, closest_points, srcs, tgt); 40 | } 41 | 42 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 43 | // output: None 44 | __global__ void mc_distance_cuda_kernel(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances, int *min_idxs) 45 | { 46 | int idx_block = blockIdx.x; // b_idx 47 | int idx_thread = threadIdx.x; // n_idx 48 | float rr = r * r; 49 | if (idx_block < b && idx_thread < n) 50 | { 51 | float min_distance1 = 1000.0; 52 | float min_distance2 = 1000.0; 53 | int min_idx1 = -1; 54 | int min_idx2 = -1; 55 | for (int i = 0; i < n; i += 1) 56 | { 57 | float distance1 = 0.0; 58 | float distance2 = 0.0; 59 | for (int j = 0; j < c; j += 1) 60 | { 61 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]); 62 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]); 63 | } 64 | if (distance1 < min_distance1) 65 | { 66 | min_distance1 = distance1; 67 | min_idx1 = i; 68 | } 69 | if (distance2 < min_distance2) 70 | { 71 | min_distance2 = distance2; 72 | min_idx2 = i; 73 | } 74 | } 75 | if (min_distance1 <= rr) 76 | { 77 | distances[idx_block * n * 2 + idx_thread * 2] = (1.0 - sqrt((float)(min_distance1)) / r); 78 | min_idxs[idx_block * n * 2 + idx_thread * 2] = min_idx1; 79 | } 80 | if (min_distance2 <= rr) 81 | { 82 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = (1.0 - sqrt((float)(min_distance2)) / r); 83 | min_idxs[idx_block * n * 2 + idx_thread * 2 + 1] = min_idx2; 84 | } 85 | } 86 | } 87 | 88 | void mc_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances, int *min_idxs) 89 | { 90 | dim3 grid(b, 1, 1); 91 | dim3 block(n, 1, 1); 92 | mc_distance_cuda_kernel<<>>(b, c, n, r, srcs, tgt, distances, min_idxs); 93 | } 94 | 95 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 96 | // output: None 97 | __global__ void cd_distance_cuda_kernel(int b, int c, int n, const float *srcs, const float *tgt, float * distances) 98 | { 99 | int idx_block = blockIdx.x; // b_idx 100 | int idx_thread = threadIdx.x; // n_idx 101 | 102 | if (idx_block < b && idx_thread < n) 103 | { 104 | float min_distance1 = 1000.0; 105 | float min_distance2 = 1000.0; 106 | for (int i = 0; i < n; i += 1) 107 | { 108 | float distance1 = 0.0; 109 | float distance2 = 0.0; 110 | for (int j = 0; j < c; j += 1) 111 | { 112 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]); 113 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]); 114 | } 115 | if (distance1 < min_distance1) 116 | { 117 | min_distance1 = distance1; 118 | } 119 | if (distance2 < min_distance2) 120 | { 121 | min_distance2 = distance2; 122 | } 123 | } 124 | distances[idx_block * n * 2 + idx_thread * 2] = min_distance1; 125 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = min_distance2; 126 | } 127 | } 128 | 129 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances) 130 | { 131 | dim3 grid(b, 1, 1); 132 | dim3 block(n, 1, 1); 133 | cd_distance_cuda_kernel<<>>(b, c, n, srcs, tgt, distances); 134 | } 135 | 136 | 137 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 138 | // output: None 139 | __global__ void cycle_distance_cuda_kernel(int b, int c, int n, int N, const float *srcs, const float *tgt, float *distances, int *min_idxs, float * n_distances, int * n_idxs) 140 | { 141 | int idx_block = blockIdx.x; // b_idx 142 | int idx_thread = threadIdx.x; // n_idx 143 | if (idx_block < b && idx_thread < n) 144 | { 145 | 146 | for (int i = 0; i < N; i += 1) 147 | { 148 | n_distances[i] = 100.0; 149 | n_idxs[i] = -1; 150 | } 151 | 152 | // cal xy_idxs and xy_distance 153 | float min_distance1 = 1000.0; 154 | int min_idx1 = -1; 155 | for (int i = 0; i < n; i += 1) 156 | { 157 | float distance1 = 0.0; 158 | for (int j = 0; j < c; j += 1) 159 | { 160 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]); 161 | } 162 | if (distance1 < min_distance1) 163 | { 164 | min_distance1 = distance1; 165 | min_idx1 = i; 166 | } 167 | } 168 | 169 | // cal yx_n_idxs and yx_n_distance 170 | int cnt = 0; 171 | // for (int i = 0; i < n; i += 1) 172 | // { 173 | // if (cnt >= N) 174 | // { 175 | // break; 176 | // } 177 | // float distance2 = 0.0; 178 | // for (int j = 0; j < c; j += 1) 179 | // { 180 | // distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + min_idx1]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + min_idx1]); 181 | // } 182 | // if (distance2 < distance1) 183 | // { 184 | // cnt += 1; 185 | // } 186 | // } 187 | 188 | if (cnt < N) 189 | { 190 | distances[idx_block * n + idx_thread] = - min_distance1; 191 | min_idxs[idx_block * n + idx_thread] = min_idx1; 192 | } 193 | } 194 | } 195 | 196 | void cycle_distance_cuda_launcher(int b, int c, int n, int N, const float *srcs, const float *tgt, float * distances, int *min_idxs, float * n_distances, int * n_idxs) 197 | { 198 | dim3 grid(b, 1, 1); 199 | dim3 block(n, 1, 1); 200 | cycle_distance_cuda_kernel<<>>(b, c, n, N, srcs, tgt, distances, min_idxs, n_distances, n_idxs); 201 | } -------------------------------------------------------------------------------- /cemnet_lib/src/ops/ops_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_CUDA_KERNEL 2 | #define _SAMPLING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void closest_point_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor closest_points); 8 | void cd_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances); 9 | void mc_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, float r, at::Tensor min_idxs); 10 | void cycle_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, int N, at::Tensor min_idxs, at::Tensor n_distances, at::Tensor n_idxs); 11 | 12 | #ifdef __cplusplus 13 | extern "C" { 14 | #endif 15 | 16 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt); 17 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances); 18 | void mc_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances, int *min_idxs); 19 | void cycle_distance_cuda_launcher(int b, int c, int n, int N, const float *srcs, const float *tgt, float * distances, int *min_idxs, float * n_distances, int * n_idxs); 20 | 21 | #ifdef __cplusplus 22 | } 23 | #endif 24 | #endif -------------------------------------------------------------------------------- /cemnet_lib/src/ops_saved/ops_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "ops_cuda_kernel.h" 6 | 7 | extern THCState *state; 8 | 9 | void closest_point_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor closest_points_tensor){ 10 | int b = srcs_tensor.size(0); 11 | int c = srcs_tensor.size(1); 12 | int n = srcs_tensor.size(2); 13 | const float *srcs = srcs_tensor.data(); 14 | const float *tgt = tgt_tensor.data(); 15 | float *closest_points = closest_points_tensor.data(); 16 | closest_point_cuda_launcher(b, c, n, closest_points, srcs, tgt); 17 | } 18 | 19 | void cd_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor) 20 | { 21 | int b = srcs_tensor.size(0); 22 | int c = srcs_tensor.size(1); 23 | int n = srcs_tensor.size(2); 24 | const float *srcs = srcs_tensor.data(); 25 | const float *tgts = tgt_tensor.data(); 26 | float *distances = distances_tensor.data(); 27 | cd_distance_cuda_launcher(b, c, n, srcs, tgts, distances); 28 | } 29 | 30 | void iou_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, float r) 31 | { 32 | int b = srcs_tensor.size(0); 33 | int c = srcs_tensor.size(1); 34 | int n = srcs_tensor.size(2); 35 | const float *srcs = srcs_tensor.data(); 36 | const float *tgt = tgt_tensor.data(); 37 | float *distances = distances_tensor.data(); 38 | iou_distance_cuda_launcher(b, c, n, r, srcs, tgt, distances); 39 | } -------------------------------------------------------------------------------- /cemnet_lib/src/ops_saved/ops_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include "../cuda_utils.h" 2 | #include "ops_cuda_kernel.h" 3 | #include "math.h" 4 | 5 | // input: srcs(B, 3, n) tgt(3, n) 6 | // output: out(b, n) 7 | __global__ void closest_point_cuda_kernel(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt) 8 | { 9 | int idx_block = blockIdx.x; 10 | int idx_thread = threadIdx.x; 11 | 12 | if (idx_block < b && idx_thread < n) 13 | { 14 | float min_val = 1000.0; 15 | int idx = -1; 16 | for (int k = 0; k < n; k += 1) 17 | { 18 | float val = 0.0; 19 | for (int i = 0; i < c; i += 1) 20 | { 21 | val += (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]) * (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]); 22 | } 23 | if (val < min_val) 24 | { 25 | min_val = val; 26 | idx = k; 27 | } 28 | } 29 | closest_points[idx_block * n * c + idx_thread] = tgt[idx]; 30 | closest_points[idx_block * n * c + n + idx_thread] = tgt[n + idx]; 31 | closest_points[idx_block * n * c + 2 * n + idx_thread] = tgt[2 * n + idx]; 32 | } 33 | } 34 | 35 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt) 36 | { 37 | dim3 grid(b, 1, 1); 38 | dim3 block(n, 1, 1); 39 | closest_point_cuda_kernel<<>>(b, c, n, closest_points, srcs, tgt); 40 | } 41 | 42 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 43 | // output: None 44 | __global__ void iou_distance_cuda_kernel(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances) 45 | { 46 | int idx_block = blockIdx.x; // b_idx 47 | int idx_thread = threadIdx.x; // n_idx 48 | float rr = r * r; 49 | if (idx_block < b && idx_thread < n) 50 | { 51 | float min_distance1 = 1000.0; 52 | float min_distance2 = 1000.0; 53 | for (int i = 0; i < n; i += 1) 54 | { 55 | float distance1 = 0.0; 56 | float distance2 = 0.0; 57 | for (int j = 0; j < c; j += 1) 58 | { 59 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]); 60 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]); 61 | } 62 | if (distance1 < min_distance1) 63 | { 64 | min_distance1 = distance1; 65 | } 66 | if (distance2 < min_distance2) 67 | { 68 | min_distance2 = distance2; 69 | } 70 | } 71 | if (min_distance1 <= rr) 72 | { 73 | distances[idx_block * n * 2 + idx_thread * 2] = (1.0 - sqrt((float)(min_distance1)) / r); 74 | } 75 | if (min_distance2 <= rr) 76 | { 77 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = (1.0 - sqrt((float)(min_distance2)) / r); 78 | } 79 | } 80 | } 81 | 82 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances) 83 | { 84 | dim3 grid(b, 1, 1); 85 | dim3 block(n, 1, 1); 86 | iou_distance_cuda_kernel<<>>(b, c, n, r, srcs, tgt, distances); 87 | } 88 | 89 | // input: srcs(b, 3, n) tgt(3, n) distances(b) 90 | // output: None 91 | __global__ void cd_distance_cuda_kernel(int b, int c, int n, const float *srcs, const float *tgt, float * distances) 92 | { 93 | int idx_block = blockIdx.x; // b_idx 94 | int idx_thread = threadIdx.x; // n_idx 95 | 96 | if (idx_block < b && idx_thread < n) 97 | { 98 | float min_distance1 = 1000.0; 99 | float min_distance2 = 1000.0; 100 | for (int i = 0; i < n; i += 1) 101 | { 102 | float distance1 = 0.0; 103 | float distance2 = 0.0; 104 | for (int j = 0; j < c; j += 1) 105 | { 106 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]); 107 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]); 108 | } 109 | if (distance1 < min_distance1) 110 | { 111 | min_distance1 = distance1; 112 | } 113 | if (distance2 < min_distance2) 114 | { 115 | min_distance2 = distance2; 116 | } 117 | } 118 | distances[idx_block * n * 2 + idx_thread * 2] = min_distance1; 119 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = min_distance2; 120 | } 121 | } 122 | 123 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances) 124 | { 125 | dim3 grid(b, 1, 1); 126 | dim3 block(n, 1, 1); 127 | cd_distance_cuda_kernel<<>>(b, c, n, srcs, tgt, distances); 128 | } -------------------------------------------------------------------------------- /cemnet_lib/src/ops_saved/ops_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_CUDA_KERNEL 2 | #define _SAMPLING_CUDA_KERNEL 3 | #include 4 | #include 5 | #include 6 | 7 | void closest_point_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor closest_points); 8 | void cd_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances); 9 | void iou_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, float r); 10 | 11 | #ifdef __cplusplus 12 | extern "C" { 13 | #endif 14 | 15 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt); 16 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances); 17 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances); 18 | 19 | #ifdef __cplusplus 20 | } 21 | #endif 22 | #endif -------------------------------------------------------------------------------- /cemnet_lib/src/prcem_lib_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ops/ops_cuda_kernel.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("closest_point_cuda", &closest_point_cuda, "closest_point_cuda"); // name in python, cpp function address, docs 8 | m.def("cd_distance_cuda", &cd_distance_cuda, "cd_distance_cuda"); 9 | // m.def("iou_distance_cuda", &iou_distance_cuda, "iou_distance_cuda"); 10 | m.def("mc_distance_cuda", &mc_distance_cuda, "mc_distance_cuda"); 11 | // m.def("cycle_distance_cuda", &cycle_distance_cuda, "cycle_distance_cuda"); 12 | } -------------------------------------------------------------------------------- /cems/__pycache__/base_cem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/cems/__pycache__/base_cem.cpython-36.pyc -------------------------------------------------------------------------------- /cems/__pycache__/guided_cem.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/cems/__pycache__/guided_cem.cpython-36.pyc -------------------------------------------------------------------------------- /cems/base_cem.py: -------------------------------------------------------------------------------- 1 | import torch, cemnet_lib 2 | from utils.euler2mat import euler2mat_torch 3 | from utils.commons import stack_transforms_seq 4 | from utils.transform_pc import transform_pc_torch 5 | from utils.batch_icp import batch_icp 6 | 7 | class BaseCEM(torch.nn.Module): 8 | def __init__(self, opts): 9 | super(BaseCEM, self).__init__() 10 | self.opts = opts 11 | self.n_iters = opts.cem.n_iters 12 | self.n_candidates = opts.cem.n_candidates[0] 13 | self.planning_horizon = opts.cem.planning_horizon 14 | self.n_elites = opts.cem.n_elites 15 | self.r_size, self.t_size, self.action_size = 3, 3, 6 16 | self.min_r, self.max_r = opts.cem.r_range 17 | self.min_t, self.max_t = opts.cem.t_range[opts.db_nm] 18 | self.r_init_sigma, self.t_init_sigma = opts.cem.init_sigma[opts.db_nm] 19 | self.recorder = opts.recorder 20 | self.device = opts.device 21 | 22 | def init_distrib(self): 23 | # mus 24 | mus_r = torch.zeros([self.planning_horizon, self.batch_size, 1, self.r_size]) + (self.min_r + self.max_r) / 2. # (H, B, 1, 3) 25 | mus_t = torch.zeros([self.planning_horizon, self.batch_size, 1, self.t_size]) + (self.min_t + self.max_t) / 2. # (H, B, 1, 3) 26 | self.mus = torch.cat([mus_r, mus_t], 3).to(self.device) # (H, B, 1, 6) 27 | init_r_t = stack_transforms_seq(self.mus.squeeze(2)) # (B, 6) 28 | self.r_init = init_r_t[:, :3] # (B, 3) 29 | self.t_init = init_r_t[:, 3:] # (B, 3) 30 | # sigmas 31 | self.sigmas = torch.ones_like(self.mus).to(self.device) # (H, B, 1, 6) 32 | self.sigmas[:, :, :, :3] *= self.r_init_sigma # (H, B, 1, 3) 33 | self.sigmas[:, :, :, 3:] *= self.t_init_sigma # (H, B, 1, 3) 34 | self.mus_init = self.mus 35 | self.sigmas_init = self.sigmas 36 | 37 | def sample_candidates(self): 38 | self.candidates = self.mus + self.sigmas * torch.randn(self.planning_horizon, self.batch_size, self.n_candidates, self.action_size).to(self.device) # (H, B, C, 6) 39 | 40 | def perform_candidates(self): 41 | stack_candidates = stack_transforms_seq(self.candidates) # (B * C, 6) 42 | Rs = euler2mat_torch(stack_candidates[:, :self.r_size], seq="zyx") # (B * C, 3, 3) 43 | ts = stack_candidates[:, self.r_size:] # (B * C, 3) 44 | transformed_srcs_repeat = transform_pc_torch(self.srcs_repeat, Rs, ts).reshape(self.batch_size, self.n_candidates, 3, -1) # (B, C, 3, N) 45 | return transformed_srcs_repeat 46 | 47 | def distance(self, srcs, tgt): 48 | if self.opts.cem.metric_type[0] == "MC": 49 | return cemnet_lib.mc_distance(srcs, tgt, self.opts.cem.metric_type[1]) 50 | if self.opts.cem.metric_type[0] == "CD": 51 | return cemnet_lib.cd_distance(srcs, tgt) 52 | if self.opts.cem.metric_type[0] == "GM": 53 | return cemnet_lib.gm_distance(srcs, tgt, self.opts.cem.metric_type[1]) * 100 54 | 55 | def evaluate_candidates(self): 56 | transformed_srcs_repeat = self.perform_candidates() # (B, C, 3, N) 57 | ## current reward 58 | self.alignment_errors = torch.zeros(self.batch_size, self.n_candidates).to(self.device) # (B, C) 59 | for k, (transformed_src_repeat, tgt) in enumerate(zip(transformed_srcs_repeat, self.tgts)): 60 | self.alignment_errors[k] = self.distance(transformed_src_repeat, tgt).detach() 61 | ## future potential 62 | if self.opts.cem.is_fused_reward[0] and self.k < self.opts.cem.is_fused_reward[2]: 63 | alpha = self.opts.cem.is_fused_reward[1] 64 | for k, (transformed_src_repeat, tgt) in enumerate(zip(transformed_srcs_repeat, self.tgts)): 65 | Rs_icp, ts_icp = batch_icp(self.opts, transformed_src_repeat, tgt) # (C, 3, 3), (C, 3) 66 | transformed_src_repeat_icp = transform_pc_torch(transformed_src_repeat, Rs_icp, ts_icp) # (C, 3, N) 67 | potential = self.distance(transformed_src_repeat_icp, tgt).detach() 68 | self.alignment_errors[k] = alpha * self.alignment_errors[k] + (1- alpha) * potential 69 | self.alignment_errors = self.alignment_errors.detach() 70 | return self.alignment_errors 71 | 72 | def update_distrib(self): 73 | self.mus = self.elites.mean(dim=2, keepdim=True) # mus: [H, B, 1, 6] 74 | self.sigmas = self.elites.std(dim=2, unbiased=False, keepdim=True) # sigmas: [H, B, 1, 6] 75 | 76 | def elite_selection(self): 77 | self.candidates = self.candidates.reshape(self.planning_horizon, self.batch_size * self.n_candidates, self.action_size) # (H, B * C, 6) 78 | self.alignment_errors = self.evaluate_candidates() # (B, C) 79 | self.elite_errors, self.elite_idxs = self.alignment_errors.topk(self.n_elites, dim=1, largest=False, sorted=False) # (B, K) 80 | self.elite_idxs += self.n_candidates * torch.arange(0, self.batch_size, dtype=torch.int64).reshape(self.batch_size, 1).to(self.device) # topk: (B, K) 81 | self.elite_idxs = self.elite_idxs.view(-1) # (B * K, ) 82 | self.elites = self.candidates[:, self.elite_idxs].reshape(self.planning_horizon, self.batch_size, self.n_elites, self.action_size) # (H, B, K, 6) 83 | return self.elites 84 | 85 | def forward(self, srcs, tgts): 86 | """ 87 | :param ids: ids 88 | :param srcs: (B, 3, N), torch.FloatTensor 89 | :param tgts: (B, 3, N), torch.FloatTensor 90 | :return: None 91 | """ 92 | self.batch_size, self.n_points = srcs.size(0), srcs.size(2) 93 | self.srcs, self.tgts = srcs.to(self.device), tgts.to(self.device) # (B, 3, N) 94 | self.srcs_repeat = self.srcs.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(self.batch_size * self.n_candidates, 95 | 3, self.n_points) # (B * C, 3, N) 96 | self.init_distrib() # (H, B, 1, 6) 97 | self.k = 0 98 | for iter_idx in range(self.n_iters): 99 | 100 | # 1. sample candidates 101 | self.sample_candidates() 102 | 103 | # 2. elite selection 104 | self.elite_selection() 105 | 106 | # 3. update distribution 107 | self.update_distrib() 108 | 109 | self.k += 1 110 | 111 | actions = stack_transforms_seq(self.mus.squeeze(2)) # (B, 6) 112 | return {"r": actions[:, :3], "t": actions[:, 3:], "r_init": self.r_init, "t_init": self.t_init} -------------------------------------------------------------------------------- /cems/guided_cem.py: -------------------------------------------------------------------------------- 1 | from cems.base_cem import BaseCEM 2 | from modules.dcp_net import DCPNet 3 | from modules.sparsemax import Sparsemax 4 | from utils.commons import stack_transforms_seq 5 | import pdb, torch 6 | 7 | class VanillaCEM(BaseCEM): 8 | 9 | def __init__(self, opts): 10 | super(VanillaCEM, self).__init__(opts) 11 | self.opts = opts 12 | 13 | class GuidedCEM(BaseCEM): 14 | 15 | def __init__(self, opts): 16 | super(GuidedCEM, self).__init__(opts) 17 | self.coarse_policy = DCPNet(opts) 18 | self.top_k = Sparsemax(dim=1) 19 | self.opts = opts 20 | 21 | def add_exploration_noise(self, mus, sigmas): 22 | mus_ = torch.zeros_like(mus).to(self.device) 23 | sigmas_ = torch.ones_like(mus_).to(self.device) # (B, 6) 24 | sigmas_[:, :3] *= self.r_init_sigma 25 | sigmas_[:, 3:] *= self.t_init_sigma 26 | r = self.opts.exploration_weight 27 | mus = (1 - r) * mus + r * mus_ 28 | sigmas = (1 - r) * sigmas + r * sigmas_ 29 | return mus, sigmas 30 | 31 | def init_distrib(self): 32 | mus, sigmas = self.coarse_policy(self.srcs, self.tgts, is_sigma=True) # (H, B, 6) 33 | if not self.opts.is_train: 34 | mus, sigmas = self.add_exploration_noise(mus, sigmas) 35 | self.mus, self.sigmas = mus.unsqueeze(2), sigmas.unsqueeze(2) # (H, B, 1, 6) 36 | self.mus_init, self.sigmas_init = self.mus, self.sigmas 37 | init_r_t = stack_transforms_seq(self.mus.squeeze(2)).clone() # (B, 6) 38 | self.r_init = init_r_t[:, :3] # (B, 3) 39 | self.t_init = init_r_t[:, 3:] # (B, 3) 40 | 41 | def elite_selection(self): 42 | super().elite_selection() 43 | self.sparsemax_probs = self.top_k(-self.alignment_errors).reshape(1, self.batch_size, self.n_candidates, 1) # (1, B, C, 1) 44 | self.elites = self.candidates.reshape(self.planning_horizon, self.batch_size, self.n_candidates, self.action_size) * self.sparsemax_probs # (H, B, C, 6) 45 | 46 | def update_distrib(self): 47 | self.mus = self.elites.sum(dim=2, keepdim=True) # (H, B, 1, 6) 48 | self.sigmas = (((self.candidates.reshape(self.planning_horizon, self.batch_size, self.n_candidates, self.action_size) - 49 | self.mus) ** 2) * self.sparsemax_probs).sum(dim=2, keepdim=True).sqrt() # (H, B, 1, 6) 50 | 51 | def load_model(self, model_path): 52 | self.load_state_dict(torch.load(model_path)) 53 | return self -------------------------------------------------------------------------------- /datasets/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/datasets/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/datasets/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/get_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/datasets/__pycache__/get_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from scipy.spatial.transform import Rotation 3 | from sklearn.neighbors import NearestNeighbors 4 | from scipy.spatial.distance import minkowski 5 | import numpy as np, pdb, pickle 6 | 7 | def load_data(path): 8 | file = open(path, "rb") 9 | data = pickle.load(file) 10 | file.close() 11 | return data 12 | 13 | class BaseDataset(Dataset): 14 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1): 15 | self.opts = opts 16 | self.cls_idx = cls_idx 17 | self.n_points = opts.db.n_points 18 | self.n_sub_points = opts.db.n_sub_points 19 | self.partition = partition 20 | self.gaussian_noise = opts.db.gaussian_noise 21 | self.unseen = opts.db.unseen 22 | self.factor = opts.db.factor 23 | self.is_normal = is_normal 24 | self.pcs = None 25 | 26 | def load_data(self, opts, partition): 27 | db_path = opts.db.path 28 | db = load_data(db_path)[partition] 29 | if opts.infos.db_nm == "scene7": 30 | db["normal_pcs"] = db["normal_pcs"].transpose(0, 2, 1) 31 | pcs = db["normal_pcs"][:, :, :3] 32 | lbs = db["lbs"] 33 | if self.cls_idx != -1: 34 | pcs = pcs[lbs == self.cls_idx] 35 | lbs = lbs[lbs == self.cls_idx] 36 | return pcs, lbs 37 | 38 | def jitter_pointcloud(self, pc, sigma=0.01, clip=0.05): 39 | pc += np.clip(sigma * np.random.randn(*pc.shape), -1 * clip, clip) 40 | return pc 41 | 42 | def farthest_subsample_points(self, pc1, pc2, n_sub_points=768): 43 | pc1, pc2 = pc1.T, pc2.T 44 | nbrs1 = NearestNeighbors(n_neighbors=n_sub_points, algorithm='auto', 45 | metric=lambda x, y: minkowski(x, y)).fit(pc1) 46 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1]) 47 | idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((n_sub_points,)) 48 | nbrs2 = NearestNeighbors(n_neighbors=n_sub_points, algorithm='auto', 49 | metric=lambda x, y: minkowski(x, y)).fit(pc2) 50 | random_p2 = random_p1 51 | idx2 = nbrs2.kneighbors(random_p2, return_distance=False).reshape((n_sub_points,)) 52 | return pc1[idx1, :].T, pc2[idx2, :].T 53 | 54 | def __getitem__(self, item): 55 | pc = self.pcs[item][:self.opts.db.n_points] # (N, 3) 56 | if self.partition != 'train': 57 | np.random.seed(item) 58 | 59 | angle_x = np.random.uniform(0., np.pi / self.factor) 60 | angle_y = np.random.uniform(0., np.pi / self.factor) 61 | angle_z = np.random.uniform(0., np.pi / self.factor) 62 | t_lb = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5)]) 63 | 64 | pc1 = pc.T # (3, N) 65 | r_lb = np.array([angle_z, angle_y, angle_x]) 66 | pc2 = Rotation.from_euler('zyx', r_lb).apply(pc1.T).T + np.expand_dims(t_lb, axis=1) 67 | 68 | pc1 = np.random.permutation(pc1.T).T 69 | pc2 = np.random.permutation(pc2.T).T 70 | 71 | if self.gaussian_noise: 72 | pc1 = self.jitter_pointcloud(pc1) 73 | pc2 = self.jitter_pointcloud(pc2) 74 | 75 | if self.n_points != self.n_sub_points: 76 | pc1, pc2 = self.farthest_subsample_points(pc1, pc2, n_sub_points = self.n_sub_points) 77 | 78 | return pc1.astype('float32'), pc2.astype('float32'), r_lb.astype('float32'), t_lb.astype('float32') 79 | 80 | def __len__(self): 81 | return len(self.pcs) -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from datasets.base_dataset import BaseDataset 2 | from utils.commons import load_data 3 | 4 | class DB_ModelNet40(BaseDataset): 5 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1): 6 | super(DB_ModelNet40, self).__init__(opts, partition, is_normal, cls_idx) 7 | self.pcs, self.lbs = self.load_data(opts, partition) # (B, N, 3), (B, ) 8 | 9 | def load_data(self, opts, partition): 10 | db_path = opts.db.path 11 | db = load_data(db_path)[partition] 12 | pcs = db["normal_pcs"][:, :, :3] 13 | lbs = db["lbs"] 14 | if self.unseen: 15 | if self.partition == 'test': 16 | pcs = pcs[lbs >= 20] 17 | lbs = lbs[lbs >= 20] 18 | elif self.partition == 'train': 19 | pcs = pcs[lbs < 20] 20 | lbs = lbs[lbs < 20] 21 | if self.cls_idx != -1: 22 | pcs = pcs[lbs == self.cls_idx] 23 | lbs = lbs[lbs == self.cls_idx] 24 | return pcs, lbs 25 | 26 | class DB_7Scene(BaseDataset): 27 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1): 28 | super(DB_7Scene, self).__init__(opts, partition, is_normal, cls_idx) 29 | self.pcs, self.lbs = self.load_data(opts, partition) # (B, N, 3), (B, ) 30 | 31 | def load_data(self, opts, partition): 32 | db_path = opts.db.path 33 | db = load_data(db_path)[partition] 34 | db["normal_pcs"] = db["normal_pcs"].transpose(0, 2, 1) 35 | pcs = db["normal_pcs"][:, :, :3] 36 | lbs = db["lbs"] 37 | if self.cls_idx != -1: 38 | pcs = pcs[lbs == self.cls_idx] 39 | lbs = lbs[lbs == self.cls_idx] 40 | return pcs, lbs 41 | 42 | class DB_ICL_NUIM(BaseDataset): 43 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1): 44 | super(DB_ICL_NUIM, self).__init__(opts, partition, is_normal, cls_idx) 45 | self.pcs, self.lbs = self.load_data(opts, partition) # (B, N, 3), (B, ) 46 | 47 | def load_data(self, opts, partition): 48 | db_path = opts.db.path 49 | db = load_data(db_path)[partition] 50 | db["normal_pcs"] = db["normal_pcs"].transpose(0, 2, 1) 51 | pcs = db["normal_pcs"][:, :, :3] 52 | lbs = db["lbs"] 53 | if self.cls_idx != -1: 54 | pcs = pcs[lbs == self.cls_idx] 55 | lbs = lbs[lbs == self.cls_idx] 56 | return pcs, lbs -------------------------------------------------------------------------------- /datasets/get_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from datasets.dataset import DB_ModelNet40, DB_7Scene, DB_ICL_NUIM 3 | 4 | def get_dataset(opts, db_nm, partition, batch_size, shuffle, drop_last, is_normal=False, cls_idx=-1, n_cores=1): 5 | loader, db = None, None 6 | if db_nm == "modelnet40": 7 | db = DB_ModelNet40(opts, partition, is_normal, cls_idx=cls_idx) 8 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores) 9 | if db_nm == "scene7": 10 | db = DB_7Scene(opts, partition, is_normal, cls_idx=cls_idx) 11 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores) 12 | if db_nm == "icl_nuim": 13 | db = DB_ICL_NUIM(opts, partition, is_normal, cls_idx=cls_idx) 14 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores) 15 | return loader, db 16 | -------------------------------------------------------------------------------- /datasets/process_dataset.py: -------------------------------------------------------------------------------- 1 | import h5py, sys, os, glob, numpy as np, pdb 2 | sys.path.insert(0, "..") 3 | from datasets.utils.gen_normal import gen_normal 4 | from datasets.utils.commons import save_data 5 | import datasets.se_math.mesh as mesh, datasets.se_math.transforms as transforms, torchvision, torch 6 | 7 | # ============ ModelNet40 ============== 8 | def process_modelnet40(): 9 | db_dir = "/test/datasets/registration/modelnet40" 10 | res = {"train": {"normal_pcs": [], "lbs": []}, 11 | "test": {"normal_pcs": [], "lbs": []}} 12 | for partition in ["train", "test"]: 13 | for db_path in glob.glob(os.path.join(db_dir, "ply_data_%s*.h5" % partition)): 14 | f = h5py.File(db_path) 15 | normal_pcs_batch = np.concatenate([f['data'][:], f['normal'][:]], axis=2) # (B, N, 6) 16 | lbs_batch = f['label'][:].astype('int64') # (B, 1) 17 | f.close() 18 | res[partition]["normal_pcs"].append(normal_pcs_batch) 19 | res[partition]["lbs"].append(lbs_batch) 20 | print("--- %s, finished. ---" % (db_path)) 21 | res[partition]["normal_pcs"] = np.concatenate(res[partition]["normal_pcs"], 0) 22 | res[partition]["lbs"] = np.concatenate(res[partition]["lbs"], 0).reshape(-1) 23 | 24 | save_data(os.path.join(db_dir, "modelnet40_normal_n2048.pth"), res) 25 | 26 | # ============ 7Scene ============== 27 | def get_cls(cls_path): 28 | cls = [line.rstrip('\n') for line in open(cls_path)] 29 | cls.sort() 30 | cls_to_lb = {cls[i]: i for i in range(len(cls))} 31 | return cls, cls_to_lb 32 | 33 | def process_7scene(): 34 | db_dir = "/test/datasets/registration/7scene" 35 | res = {"train": {"normal_pcs": [], "lbs": []}, 36 | "test": {"normal_pcs": [], "lbs": []}} 37 | transform = torchvision.transforms.Compose([transforms.Mesh2Points(), 38 | transforms.OnUnitCube(), 39 | transforms.Resampler(2048)]) 40 | for partition in ["train", "test"]: 41 | pc_dir = os.path.join(db_dir, "7scene") 42 | cls_path = os.path.join(db_dir, "categories/7scene_%s.txt" % (partition)) 43 | cls_nms, nms_to_lbs = get_cls(cls_path) 44 | for cls_nm in cls_nms: 45 | pc_paths = os.path.join(pc_dir, cls_nm, "*.ply") 46 | lb = nms_to_lbs[cls_nm] 47 | for pc_path in glob.glob(pc_paths): 48 | pc = mesh.plyread(pc_path) 49 | pc = transform(pc).numpy().transpose([1, 0]) # [3, 2048] 50 | normal_pc = gen_normal(pc[None, :]) 51 | res[partition]["normal_pcs"].append(normal_pc) 52 | res[partition]["lbs"].append(lb) 53 | print("--- %s, finished. ---" % (pc_path)) 54 | res[partition]["normal_pcs"] = np.concatenate(res[partition]["normal_pcs"], 0) 55 | res[partition]["lbs"] = np.asarray(res[partition]["lbs"]) 56 | save_data(os.path.join(db_dir, "7scene_normal_n2048.pth"), res) 57 | 58 | # ============ ICL-NUIM ============== 59 | def process_icl_nuim(): 60 | db_dir = "/test/datasets/registration/icl_nuim" 61 | res = {"train": {"normal_pcs": [], "lbs": []}, 62 | "test": {"normal_pcs": [], "lbs": []}} 63 | transform = torchvision.transforms.Compose([transforms.OnUnitCube(), 64 | transforms.Resampler(2048)]) 65 | for partition in ["train", "test"]: 66 | db_path = os.path.join(db_dir, "icl_nuim_%s.h5" % partition) 67 | if partition == "train": 68 | f = h5py.File(db_path, "r") 69 | pcs = f['points'][...] 70 | for idx, pc in enumerate(pcs): 71 | pc = transform(torch.FloatTensor(pc)).numpy().transpose([1, 0]) 72 | normal_pc = gen_normal(pc[None, :]) 73 | res[partition]["normal_pcs"].append(normal_pc) 74 | res[partition]["lbs"].append(idx) 75 | elif partition == "test": 76 | f = h5py.File(db_path, "r") 77 | pcs = f['source'][...] 78 | # tgt_pcs = f['target'][...] 79 | # transforms = f['transform'][...] 80 | for idx in range(len(pcs)): 81 | pc = pcs[idx] 82 | pc = transform(torch.FloatTensor(pc)).numpy().transpose([1, 0]) 83 | normal_pc = gen_normal(pc[None, :]) 84 | res[partition]["normal_pcs"].append(normal_pc) 85 | res[partition]["lbs"].append(idx) 86 | res[partition]["normal_pcs"] = np.concatenate(res[partition]["normal_pcs"], 0) 87 | res[partition]["lbs"] = np.asarray(res[partition]["lbs"]) 88 | 89 | save_data(os.path.join(db_dir, "ic_nuim_normal_n2048.pth"), res) 90 | 91 | if __name__ == '__main__': 92 | process_modelnet40() 93 | # process_7scene() 94 | # process_icl_nuim() -------------------------------------------------------------------------------- /datasets/se_math/__init__.py: -------------------------------------------------------------------------------- 1 | from . import invmat, se3, sinc, so3, mesh, transforms -------------------------------------------------------------------------------- /datasets/se_math/invmat.py: -------------------------------------------------------------------------------- 1 | """ inverse matrix """ 2 | 3 | import torch 4 | 5 | 6 | def batch_inverse(x): 7 | """ M(n) -> M(n); x -> x^-1 """ 8 | batch_size, h, w = x.size() 9 | assert h == w 10 | y = torch.zeros_like(x) 11 | for i in range(batch_size): 12 | y[i, :, :] = x[i, :, :].inverse() 13 | return y 14 | 15 | 16 | def batch_inverse_dx(y): 17 | """ backward """ 18 | # Let y(x) = x^-1. 19 | # compute dy 20 | # dy = dy(j,k) 21 | # = - y(j,m) * dx(m,n) * y(n,k) 22 | # = - y(j,m) * y(n,k) * dx(m,n) 23 | # therefore, 24 | # dy(j,k)/dx(m,n) = - y(j,m) * y(n,k) 25 | batch_size, h, w = y.size() 26 | assert h == w 27 | # compute dy(j,k,m,n) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k) 28 | # = - (y(j,:))' * y'(k,:) 29 | yl = y.repeat(1, 1, h).view(batch_size * h * h, h, 1) 30 | yr = y.transpose(1, 2).repeat(1, h, 1).view(batch_size * h * h, 1, h) 31 | dy = - yl.bmm(yr).view(batch_size, h, h, h, h) 32 | 33 | # compute dy(m,n,j,k) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k) 34 | # = - (y'(m,:))' * y(n,:) 35 | # yl = y.transpose(1, 2).repeat(1, 1, h).view(batch_size*h*h, h, 1) 36 | # yr = y.repeat(1, h, 1).view(batch_size*h*h, 1, h) 37 | # dy = - yl.bmm(yr).view(batch_size, h, h, h, h) 38 | 39 | return dy 40 | 41 | 42 | def batch_pinv_dx(x): 43 | """ returns y = (x'*x)^-1 * x' and dy/dx. """ 44 | # y = (x'*x)^-1 * x' 45 | # = s^-1 * x' 46 | # = b * x' 47 | # d{y(j,k)}/d{x(m,n)} 48 | # = d{b(j,i) * x(k,i)}/d{x(m,n)} 49 | # = d{b(j,i)}/d{x(m,n)} * x(k,i) + b(j,i) * d{x(k,i)}/d{x(m,n)} 50 | # d{b(j,i)}/d{x(m,n)} 51 | # = d{b(j,i)}/d{s(p,q)} * d{s(p,q)}/d{x(m,n)} 52 | # = -b(j,p)*b(q,i) * d{s(p,q)}/d{x(m,n)} 53 | # d{s(p,q)}/d{x(m,n)} 54 | # = d{x(t,p)*x(t,q)}/d{x(m,n)} 55 | # = d{x(t,p)}/d{x(m,n)} * x(t,q) + x(t,p) * d{x(t,q)}/d{x(m,n)} 56 | batch_size, h, w = x.size() 57 | xt = x.transpose(1, 2) 58 | s = xt.bmm(x) 59 | b = batch_inverse(s) 60 | y = b.bmm(xt) 61 | 62 | # dx/dx 63 | ex = torch.eye(h * w).to(x).unsqueeze(0).view(1, h, w, h, w) 64 | # ds/dx = dx(t,_)/dx * x(t,_) + x(t,_) * dx(t,_)/dx 65 | ex1 = ex.view(1, h, w * h * w) # [t, p*m*n] 66 | dx1 = x.transpose(1, 2).matmul(ex1).view(batch_size, w, w, h, w) # [q, p,m,n] 67 | ds_dx = dx1.transpose(1, 2) + dx1 # [p, q, m, n] 68 | # db/ds 69 | db_ds = batch_inverse_dx(b) # [j, i, p, q] 70 | # db/dx = db/d{s(p,q)} * d{s(p,q)}/dx 71 | db1 = db_ds.view(batch_size, w * w, w * w).bmm(ds_dx.view(batch_size, w * w, h * w)) 72 | db_dx = db1.view(batch_size, w, w, h, w) # [j, i, m, n] 73 | # dy/dx = db(_,i)/dx * x(_,i) + b(_,i) * dx(_,i)/dx 74 | dy1 = db_dx.transpose(1, 2).contiguous().view(batch_size, w, w * h * w) 75 | dy1 = x.matmul(dy1).view(batch_size, h, w, h, w) # [k, j, m, n] 76 | ext = ex.transpose(1, 2).contiguous().view(1, w, h * h * w) 77 | dy2 = b.matmul(ext).view(batch_size, w, h, h, w) # [j, k, m, n] 78 | dy_dx = dy1.transpose(1, 2) + dy2 79 | 80 | return y, dy_dx 81 | 82 | 83 | class InvMatrix(torch.autograd.Function): 84 | """ M(n) -> M(n); x -> x^-1. 85 | """ 86 | 87 | @staticmethod 88 | def forward(ctx, x): 89 | y = batch_inverse(x) 90 | ctx.save_for_backward(y) 91 | return y 92 | 93 | @staticmethod 94 | def backward(ctx, grad_output): 95 | y, = ctx.saved_tensors # v0.4 96 | # y, = ctx.saved_variables # v0.3.1 97 | batch_size, h, w = y.size() 98 | assert h == w 99 | 100 | # Let y(x) = x^-1 and assume any function f(y(x)). 101 | # compute df/dx(m,n)... 102 | # df/dx(m,n) = df/dy(j,k) * dy(j,k)/dx(m,n) 103 | # well, df/dy is 'grad_output' 104 | # and so we will return 'grad_input = df/dy(j,k) * dy(j,k)/dx(m,n)' 105 | 106 | dy = batch_inverse_dx(y) # dy(j,k,m,n) = dy(j,k)/dx(m,n) 107 | go = grad_output.contiguous().view(batch_size, 1, h * h) # [1, (j*k)] 108 | ym = dy.view(batch_size, h * h, h * h) # [(j*k), (m*n)] 109 | r = go.bmm(ym) # [1, (m*n)] 110 | grad_input = r.view(batch_size, h, h) # [m, n] 111 | 112 | return grad_input 113 | 114 | 115 | if __name__ == '__main__': 116 | def test(): 117 | x = torch.randn(2, 3, 2) 118 | x_val = x.requires_grad_() 119 | 120 | s_val = x_val.transpose(1, 2).bmm(x_val) 121 | s_inv = InvMatrix.apply(s_val) 122 | y_val = s_inv.bmm(x_val.transpose(1, 2)) 123 | y_val.sum().backward() 124 | t1 = x_val.grad 125 | 126 | y, dy_dx = batch_pinv_dx(x) 127 | t2 = dy_dx.sum(1).sum(1) 128 | 129 | print(t1) 130 | print(t2) 131 | print(t1 - t2) 132 | 133 | 134 | test() 135 | 136 | # EOF 137 | -------------------------------------------------------------------------------- /datasets/se_math/mesh.py: -------------------------------------------------------------------------------- 1 | """ 3-d mesh reader """ 2 | import os 3 | import copy 4 | import numpy 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 7 | # import matplotlib.pyplot 8 | 9 | # used to read ply files 10 | from plyfile import PlyData 11 | 12 | 13 | class Mesh: 14 | def __init__(self): 15 | self._vertices = [] # array-like (N, D) 16 | self._faces = [] # array-like (M, K) 17 | self._edges = [] # array-like (L, 2) 18 | 19 | def clone(self): 20 | other = copy.deepcopy(self) 21 | return other 22 | 23 | def clear(self): 24 | for key in self.__dict__: 25 | self.__dict__[key] = [] 26 | 27 | def add_attr(self, name): 28 | self.__dict__[name] = [] 29 | 30 | @property 31 | def vertex_array(self): 32 | return numpy.array(self._vertices) 33 | 34 | @property 35 | def vertex_list(self): 36 | return list(map(tuple, self._vertices)) 37 | 38 | @staticmethod 39 | def faces2polygons(faces, vertices): 40 | p = list(map(lambda face: \ 41 | list(map(lambda vidx: vertices[vidx], face)), faces)) 42 | return p 43 | 44 | @property 45 | def polygon_list(self): 46 | p = Mesh.faces2polygons(self._faces, self._vertices) 47 | return p 48 | 49 | # def plot(self, fig=None, ax=None, *args, **kwargs): 50 | # p = self.polygon_list 51 | # v = self.vertex_array 52 | # if fig is None: 53 | # fig = matplotlib.pyplot.gcf() 54 | # if ax is None: 55 | # ax = Axes3D(fig) 56 | # if p: 57 | # ax.add_collection3d(Poly3DCollection(p)) 58 | # if v.shape: 59 | # ax.scatter(v[:, 0], v[:, 1], v[:, 2], *args, **kwargs) 60 | # ax.set_xlabel('X') 61 | # ax.set_ylabel('Y') 62 | # ax.set_zlabel('Z') 63 | # return fig, ax 64 | 65 | def on_unit_sphere(self, zero_mean=False): 66 | # radius == 1 67 | v = self.vertex_array # (N, D) 68 | if zero_mean: 69 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3) 70 | v[:, 0:3] = v[:, 0:3] - a 71 | n = numpy.linalg.norm(v[:, 0:3], axis=1) # (N,) 72 | m = numpy.max(n) # scalar 73 | v[:, 0:3] = v[:, 0:3] / m 74 | self._vertices = v 75 | return self 76 | 77 | def on_unit_cube(self, zero_mean=False): 78 | # volume == 1 79 | v = self.vertex_array # (N, D) 80 | if zero_mean: 81 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3) 82 | v[:, 0:3] = v[:, 0:3] - a 83 | m = numpy.max(numpy.abs(v)) # scalar 84 | v[:, 0:3] = v[:, 0:3] / (m * 2) 85 | self._vertices = v 86 | return self 87 | 88 | def rot_x(self): 89 | # camera local (up: +Y, front: -Z) -> model local (up: +Z, front: +Y). 90 | v = self.vertex_array 91 | t = numpy.copy(v[:, 1]) 92 | v[:, 1] = -numpy.copy(v[:, 2]) 93 | v[:, 2] = t 94 | self._vertices = list(map(tuple, v)) 95 | return self 96 | 97 | def rot_zc(self): 98 | # R = [0, -1; 99 | # 1, 0] 100 | v = self.vertex_array 101 | x = numpy.copy(v[:, 0]) 102 | y = numpy.copy(v[:, 1]) 103 | v[:, 0] = -y 104 | v[:, 1] = x 105 | self._vertices = list(map(tuple, v)) 106 | return self 107 | 108 | 109 | def offread(filepath, points_only=True): 110 | """ read Geomview OFF file. """ 111 | with open(filepath, 'r') as fin: 112 | mesh, fixme = _load_off(fin, points_only) 113 | if fixme: 114 | _fix_modelnet_broken_off(filepath) 115 | return mesh 116 | 117 | 118 | def _load_off(fin, points_only): 119 | """ read Geomview OFF file. """ 120 | mesh = Mesh() 121 | 122 | fixme = False 123 | sig = fin.readline().strip() 124 | if sig == 'OFF': 125 | line = fin.readline().strip() 126 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')]) 127 | elif sig[0:3] == 'OFF': # ...broken data in ModelNet (missing '\n')... 128 | line = sig[3:] 129 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')]) 130 | fixme = True 131 | else: 132 | raise RuntimeError('unknown format') 133 | 134 | for v in range(num_verts): 135 | vp = tuple(float(s) for s in fin.readline().strip().split(' ')) 136 | mesh._vertices.append(vp) 137 | 138 | if points_only: 139 | return mesh, fixme 140 | 141 | for f in range(num_faces): 142 | fc = tuple([int(s) for s in fin.readline().strip().split(' ')][1:]) 143 | mesh._faces.append(fc) 144 | 145 | return mesh, fixme 146 | 147 | 148 | def _fix_modelnet_broken_off(filepath): 149 | oldfile = '{}.orig'.format(filepath) 150 | os.rename(filepath, oldfile) 151 | with open(oldfile, 'r') as fin: 152 | with open(filepath, 'w') as fout: 153 | sig = fin.readline().strip() 154 | line = sig[3:] 155 | print('OFF', file=fout) 156 | print(line, file=fout) 157 | for line in fin: 158 | print(line.strip(), file=fout) 159 | 160 | 161 | def objread(filepath, points_only=True): 162 | """Loads a Wavefront OBJ file. """ 163 | _vertices = [] 164 | _normals = [] 165 | _texcoords = [] 166 | _faces = [] 167 | _mtl_name = None 168 | 169 | material = None 170 | for line in open(filepath, "r"): 171 | if line.startswith('#'): continue 172 | values = line.split() 173 | if not values: continue 174 | if values[0] == 'v': 175 | v = tuple(map(float, values[1:4])) 176 | _vertices.append(v) 177 | elif values[0] == 'vn': 178 | v = tuple(map(float, values[1:4])) 179 | _normals.append(v) 180 | elif values[0] == 'vt': 181 | _texcoords.append(tuple(map(float, values[1:3]))) 182 | elif values[0] in ('usemtl', 'usemat'): 183 | material = values[1] 184 | elif values[0] == 'mtllib': 185 | _mtl_name = values[1] 186 | elif values[0] == 'f': 187 | face_ = [] 188 | texcoords_ = [] 189 | norms_ = [] 190 | for v in values[1:]: 191 | w = v.split('/') 192 | face_.append(int(w[0]) - 1) 193 | if len(w) >= 2 and len(w[1]) > 0: 194 | texcoords_.append(int(w[1]) - 1) 195 | else: 196 | texcoords_.append(-1) 197 | if len(w) >= 3 and len(w[2]) > 0: 198 | norms_.append(int(w[2]) - 1) 199 | else: 200 | norms_.append(-1) 201 | # _faces.append((face_, norms_, texcoords_, material)) 202 | _faces.append(face_) 203 | 204 | mesh = Mesh() 205 | mesh._vertices = _vertices 206 | if points_only: 207 | return mesh 208 | 209 | mesh._faces = _faces 210 | 211 | return mesh 212 | 213 | 214 | def plyread(filepath, points_only=True): 215 | # read binary ply file and return [x, y, z] array 216 | data = PlyData.read(filepath) 217 | vertex = data['vertex'] 218 | 219 | (x, y, z) = (vertex[t] for t in ('x', 'y', 'z')) 220 | num_verts = len(x) 221 | 222 | mesh = Mesh() 223 | 224 | for v in range(num_verts): 225 | vp = tuple(float(s) for s in [x[v], y[v], z[v]]) 226 | mesh._vertices.append(vp) 227 | 228 | return mesh 229 | 230 | 231 | if __name__ == '__main__': 232 | def test1(): 233 | mesh = objread('model_normalized.obj', points_only=False) 234 | # mesh.on_unit_sphere() 235 | mesh.rot_x() 236 | mesh.plot(c='m') 237 | # matplotlib.pyplot.show() 238 | 239 | 240 | def test2(): 241 | mesh = plyread('1.ply', points_only=True) 242 | # mesh.on_unit_sphere() 243 | mesh.rot_x() 244 | mesh.plot(c='m') 245 | # matplotlib.pyplot.show() 246 | 247 | 248 | test2() 249 | 250 | # EOF 251 | -------------------------------------------------------------------------------- /datasets/se_math/se3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rigid body transfomation group and corresponding Lie algebra. """ 2 | import torch 3 | from .sinc import sinc1, sinc2, sinc3 4 | from . import so3 5 | 6 | 7 | def twist_prod(x, y): 8 | x_ = x.view(-1, 6) 9 | y_ = y.view(-1, 6) 10 | 11 | xw, xv = x_[:, 0:3], x_[:, 3:6] 12 | yw, yv = y_[:, 0:3], y_[:, 3:6] 13 | 14 | zw = so3.cross_prod(xw, yw) 15 | zv = so3.cross_prod(xw, yv) + so3.cross_prod(xv, yw) 16 | 17 | z = torch.cat((zw, zv), dim=1) 18 | 19 | return z.view_as(x) 20 | 21 | 22 | def liebracket(x, y): 23 | return twist_prod(x, y) 24 | 25 | 26 | def mat(x): 27 | # size: [*, 6] -> [*, 4, 4] 28 | x_ = x.view(-1, 6) 29 | w1, w2, w3 = x_[:, 0], x_[:, 1], x_[:, 2] 30 | v1, v2, v3 = x_[:, 3], x_[:, 4], x_[:, 5] 31 | O = torch.zeros_like(w1) 32 | 33 | X = torch.stack(( 34 | torch.stack((O, -w3, w2, v1), dim=1), 35 | torch.stack((w3, O, -w1, v2), dim=1), 36 | torch.stack((-w2, w1, O, v3), dim=1), 37 | torch.stack((O, O, O, O), dim=1)), dim=1) 38 | return X.view(*(x.size()[0:-1]), 4, 4) 39 | 40 | 41 | def vec(X): 42 | X_ = X.view(-1, 4, 4) 43 | w1, w2, w3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] 44 | v1, v2, v3 = X_[:, 0, 3], X_[:, 1, 3], X_[:, 2, 3] 45 | x = torch.stack((w1, w2, w3, v1, v2, v3), dim=1) 46 | return x.view(*X.size()[0:-2], 6) 47 | 48 | 49 | def genvec(): 50 | return torch.eye(6) 51 | 52 | 53 | def genmat(): 54 | return mat(genvec()) 55 | 56 | 57 | def exp(x): 58 | x_ = x.view(-1, 6) 59 | w, v = x_[:, 0:3], x_[:, 3:6] 60 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 61 | W = so3.mat(w) 62 | S = W.bmm(W) 63 | I = torch.eye(3).to(w) 64 | 65 | # Rodrigues' rotation formula. 66 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 67 | # = eye(3) + sinc1(t)*W + sinc2(t)*S 68 | R = I + sinc1(t) * W + sinc2(t) * S 69 | 70 | # V = sinc1(t)*eye(3) + sinc2(t)*W + sinc3(t)*(w*w') 71 | # = eye(3) + sinc2(t)*W + sinc3(t)*S 72 | V = I + sinc2(t) * W + sinc3(t) * S 73 | 74 | p = V.bmm(v.contiguous().view(-1, 3, 1)) 75 | 76 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(x_.size(0), 1, 1).to(x) 77 | Rp = torch.cat((R, p), dim=2) 78 | g = torch.cat((Rp, z), dim=1) 79 | 80 | return g.view(*(x.size()[0:-1]), 4, 4) 81 | 82 | 83 | def inverse(g): 84 | g_ = g.view(-1, 4, 4) 85 | R = g_[:, 0:3, 0:3] 86 | p = g_[:, 0:3, 3] 87 | Q = R.transpose(1, 2) 88 | q = -Q.matmul(p.unsqueeze(-1)) 89 | 90 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(g_.size(0), 1, 1).to(g) 91 | Qq = torch.cat((Q, q), dim=2) 92 | ig = torch.cat((Qq, z), dim=1) 93 | 94 | return ig.view(*(g.size()[0:-2]), 4, 4) 95 | 96 | 97 | def log(g): 98 | g_ = g.view(-1, 4, 4) 99 | R = g_[:, 0:3, 0:3] 100 | p = g_[:, 0:3, 3] 101 | 102 | w = so3.log(R) 103 | H = so3.inv_vecs_Xg_ig(w) 104 | v = H.bmm(p.contiguous().view(-1, 3, 1)).view(-1, 3) 105 | 106 | x = torch.cat((w, v), dim=1) 107 | return x.view(*(g.size()[0:-2]), 6) 108 | 109 | 110 | def transform(g, a): 111 | # g : SE(3), * x 4 x 4 112 | # a : R^3, * x 3[x N] 113 | g_ = g.view(-1, 4, 4) 114 | R = g_[:, 0:3, 0:3].contiguous().view(*(g.size()[0:-2]), 3, 3) 115 | p = g_[:, 0:3, 3].contiguous().view(*(g.size()[0:-2]), 3) 116 | if len(g.size()) == len(a.size()): 117 | b = R.matmul(a) + p.unsqueeze(-1) 118 | else: 119 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p 120 | return b 121 | 122 | 123 | def group_prod(g, h): 124 | # g, h : SE(3) 125 | g1 = g.matmul(h) 126 | return g1 127 | 128 | 129 | class ExpMap(torch.autograd.Function): 130 | """ Exp: se(3) -> SE(3) 131 | """ 132 | 133 | @staticmethod 134 | def forward(ctx, x): 135 | """ Exp: R^6 -> M(4), 136 | size: [B, 6] -> [B, 4, 4], 137 | or [B, 1, 6] -> [B, 1, 4, 4] 138 | """ 139 | ctx.save_for_backward(x) 140 | g = exp(x) 141 | return g 142 | 143 | @staticmethod 144 | def backward(ctx, grad_output): 145 | x, = ctx.saved_tensors 146 | g = exp(x) 147 | gen_k = genmat().to(x) 148 | 149 | # Let z = f(g) = f(exp(x)) 150 | # dz = df/dgij * dgij/dxk * dxk 151 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk 152 | # = df/dgij * [gen_k*g]_ij * dxk 153 | 154 | dg = gen_k.matmul(g.view(-1, 1, 4, 4)) 155 | # (k, i, j) 156 | dg = dg.to(grad_output) 157 | 158 | go = grad_output.contiguous().view(-1, 1, 4, 4) 159 | dd = go * dg 160 | grad_input = dd.sum(-1).sum(-1) 161 | 162 | return grad_input 163 | 164 | 165 | Exp = ExpMap.apply 166 | 167 | # EOF 168 | -------------------------------------------------------------------------------- /datasets/se_math/sinc.py: -------------------------------------------------------------------------------- 1 | """ sinc(t) := sin(t) / t """ 2 | import torch 3 | from torch import sin, cos 4 | 5 | 6 | def sinc1(t): 7 | """ sinc1: t -> sin(t)/t """ 8 | e = 0.01 9 | r = torch.zeros_like(t) 10 | a = torch.abs(t) 11 | 12 | s = a < e 13 | c = (s == 0) 14 | t2 = t[s] ** 2 15 | r[s] = 1 - t2 / 6 * (1 - t2 / 20 * (1 - t2 / 42)) # Taylor series O(t^8) 16 | r[c] = sin(t[c]) / t[c] 17 | 18 | return r 19 | 20 | 21 | def sinc1_dt(t): 22 | """ d/dt(sinc1) """ 23 | e = 0.01 24 | r = torch.zeros_like(t) 25 | a = torch.abs(t) 26 | 27 | s = a < e 28 | c = (s == 0) 29 | t2 = t ** 2 30 | r[s] = -t[s] / 3 * (1 - t2[s] / 10 * (1 - t2[s] / 28 * (1 - t2[s] / 54))) # Taylor series O(t^8) 31 | r[c] = cos(t[c]) / t[c] - sin(t[c]) / t2[c] 32 | 33 | return r 34 | 35 | 36 | def sinc1_dt_rt(t): 37 | """ d/dt(sinc1) / t """ 38 | e = 0.01 39 | r = torch.zeros_like(t) 40 | a = torch.abs(t) 41 | 42 | s = a < e 43 | c = (s == 0) 44 | t2 = t ** 2 45 | r[s] = -1 / 3 * (1 - t2[s] / 10 * (1 - t2[s] / 28 * (1 - t2[s] / 54))) # Taylor series O(t^8) 46 | r[c] = (cos(t[c]) / t[c] - sin(t[c]) / t2[c]) / t[c] 47 | 48 | return r 49 | 50 | 51 | def rsinc1(t): 52 | """ rsinc1: t -> t/sinc1(t) """ 53 | e = 0.01 54 | r = torch.zeros_like(t) 55 | a = torch.abs(t) 56 | 57 | s = a < e 58 | c = (s == 0) 59 | t2 = t[s] ** 2 60 | r[s] = (((31 * t2) / 42 + 7) * t2 / 60 + 1) * t2 / 6 + 1 # Taylor series O(t^8) 61 | r[c] = t[c] / sin(t[c]) 62 | 63 | return r 64 | 65 | 66 | def rsinc1_dt(t): 67 | """ d/dt(rsinc1) """ 68 | e = 0.01 69 | r = torch.zeros_like(t) 70 | a = torch.abs(t) 71 | 72 | s = a < e 73 | c = (s == 0) 74 | t2 = t[s] ** 2 75 | r[s] = ((((127 * t2) / 30 + 31) * t2 / 28 + 7) * t2 / 30 + 1) * t[s] / 3 # Taylor series O(t^8) 76 | r[c] = 1 / sin(t[c]) - (t[c] * cos(t[c])) / (sin(t[c]) * sin(t[c])) 77 | 78 | return r 79 | 80 | 81 | def rsinc1_dt_csc(t): 82 | """ d/dt(rsinc1) / sin(t) """ 83 | e = 0.01 84 | r = torch.zeros_like(t) 85 | a = torch.abs(t) 86 | 87 | s = a < e 88 | c = (s == 0) 89 | t2 = t[s] ** 2 90 | r[s] = t2 * (t2 * ((4 * t2) / 675 + 2 / 63) + 2 / 15) + 1 / 3 # Taylor series O(t^8) 91 | r[c] = (1 / sin(t[c]) - (t[c] * cos(t[c])) / (sin(t[c]) * sin(t[c]))) / sin(t[c]) 92 | 93 | return r 94 | 95 | 96 | def sinc2(t): 97 | """ sinc2: t -> (1 - cos(t)) / (t**2) """ 98 | e = 0.01 99 | r = torch.zeros_like(t) 100 | a = torch.abs(t) 101 | 102 | s = a < e 103 | c = (s == 0) 104 | t2 = t ** 2 105 | r[s] = 1 / 2 * (1 - t2[s] / 12 * (1 - t2[s] / 30 * (1 - t2[s] / 56))) # Taylor series O(t^8) 106 | r[c] = (1 - cos(t[c])) / t2[c] 107 | 108 | return r 109 | 110 | 111 | def sinc2_dt(t): 112 | """ d/dt(sinc2) """ 113 | e = 0.01 114 | r = torch.zeros_like(t) 115 | a = torch.abs(t) 116 | 117 | s = a < e 118 | c = (s == 0) 119 | t2 = t ** 2 120 | r[s] = -t[s] / 12 * (1 - t2[s] / 5 * (1.0 / 3 - t2[s] / 56 * (1.0 / 2 - t2[s] / 135))) # Taylor series O(t^8) 121 | r[c] = sin(t[c]) / t2[c] - 2 * (1 - cos(t[c])) / (t2[c] * t[c]) 122 | 123 | return r 124 | 125 | 126 | def sinc3(t): 127 | """ sinc3: t -> (t - sin(t)) / (t**3) """ 128 | e = 0.01 129 | r = torch.zeros_like(t) 130 | a = torch.abs(t) 131 | 132 | s = a < e 133 | c = (s == 0) 134 | t2 = t[s] ** 2 135 | r[s] = 1 / 6 * (1 - t2 / 20 * (1 - t2 / 42 * (1 - t2 / 72))) # Taylor series O(t^8) 136 | r[c] = (t[c] - sin(t[c])) / (t[c] ** 3) 137 | 138 | return r 139 | 140 | 141 | def sinc3_dt(t): 142 | """ d/dt(sinc3) """ 143 | e = 0.01 144 | r = torch.zeros_like(t) 145 | a = torch.abs(t) 146 | 147 | s = a < e 148 | c = (s == 0) 149 | t2 = t[s] ** 2 150 | r[s] = -t[s] / 60 * (1 - t2 / 21 * (1 - t2 / 24 * (1.0 / 2 - t2 / 165))) # Taylor series O(t^8) 151 | r[c] = (3 * sin(t[c]) - t[c] * (cos(t[c]) + 2)) / (t[c] ** 4) 152 | 153 | return r 154 | 155 | 156 | def sinc4(t): 157 | """ sinc4: t -> 1/t^2 * (1/2 - sinc2(t)) 158 | = 1/t^2 * (1/2 - (1 - cos(t))/t^2) 159 | """ 160 | e = 0.01 161 | r = torch.zeros_like(t) 162 | a = torch.abs(t) 163 | 164 | s = a < e 165 | c = (s == 0) 166 | t2 = t ** 2 167 | r[s] = 1 / 24 * (1 - t2 / 30 * (1 - t2 / 56 * (1 - t2 / 90))) # Taylor series O(t^8) 168 | r[c] = (0.5 - (1 - cos(t)) / t2) / t2 169 | 170 | 171 | class Sinc1_autograd(torch.autograd.Function): 172 | @staticmethod 173 | def forward(ctx, theta): 174 | ctx.save_for_backward(theta) 175 | return sinc1(theta) 176 | 177 | @staticmethod 178 | def backward(ctx, grad_output): 179 | theta, = ctx.saved_tensors 180 | grad_theta = None 181 | if ctx.needs_input_grad[0]: 182 | grad_theta = grad_output * sinc1_dt(theta).to(grad_output) 183 | return grad_theta 184 | 185 | 186 | Sinc1 = Sinc1_autograd.apply 187 | 188 | 189 | class RSinc1_autograd(torch.autograd.Function): 190 | @staticmethod 191 | def forward(ctx, theta): 192 | ctx.save_for_backward(theta) 193 | return rsinc1(theta) 194 | 195 | @staticmethod 196 | def backward(ctx, grad_output): 197 | theta, = ctx.saved_tensors 198 | grad_theta = None 199 | if ctx.needs_input_grad[0]: 200 | grad_theta = grad_output * rsinc1_dt(theta).to(grad_output) 201 | return grad_theta 202 | 203 | 204 | RSinc1 = RSinc1_autograd.apply 205 | 206 | 207 | class Sinc2_autograd(torch.autograd.Function): 208 | @staticmethod 209 | def forward(ctx, theta): 210 | ctx.save_for_backward(theta) 211 | return sinc2(theta) 212 | 213 | @staticmethod 214 | def backward(ctx, grad_output): 215 | theta, = ctx.saved_tensors 216 | grad_theta = None 217 | if ctx.needs_input_grad[0]: 218 | grad_theta = grad_output * sinc2_dt(theta).to(grad_output) 219 | return grad_theta 220 | 221 | 222 | Sinc2 = Sinc2_autograd.apply 223 | 224 | 225 | class Sinc3_autograd(torch.autograd.Function): 226 | @staticmethod 227 | def forward(ctx, theta): 228 | ctx.save_for_backward(theta) 229 | return sinc3(theta) 230 | 231 | @staticmethod 232 | def backward(ctx, grad_output): 233 | theta, = ctx.saved_tensors 234 | grad_theta = None 235 | if ctx.needs_input_grad[0]: 236 | grad_theta = grad_output * sinc3_dt(theta).to(grad_output) 237 | return grad_theta 238 | 239 | 240 | Sinc3 = Sinc3_autograd.apply 241 | 242 | # EOF 243 | -------------------------------------------------------------------------------- /datasets/se_math/so3.py: -------------------------------------------------------------------------------- 1 | """ 3-d rotation group and corresponding Lie algebra """ 2 | import torch 3 | from . import sinc 4 | from .sinc import sinc1, sinc2, sinc3 5 | 6 | 7 | def cross_prod(x, y): 8 | z = torch.cross(x.view(-1, 3), y.view(-1, 3), dim=1).view_as(x) 9 | return z 10 | 11 | 12 | def liebracket(x, y): 13 | return cross_prod(x, y) 14 | 15 | 16 | def mat(x): 17 | # size: [*, 3] -> [*, 3, 3] 18 | x_ = x.view(-1, 3) 19 | x1, x2, x3 = x_[:, 0], x_[:, 1], x_[:, 2] 20 | O = torch.zeros_like(x1) 21 | 22 | X = torch.stack(( 23 | torch.stack((O, -x3, x2), dim=1), 24 | torch.stack((x3, O, -x1), dim=1), 25 | torch.stack((-x2, x1, O), dim=1)), dim=1) 26 | return X.view(*(x.size()[0:-1]), 3, 3) 27 | 28 | 29 | def vec(X): 30 | X_ = X.view(-1, 3, 3) 31 | x1, x2, x3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0] 32 | x = torch.stack((x1, x2, x3), dim=1) 33 | return x.view(*X.size()[0:-2], 3) 34 | 35 | 36 | def genvec(): 37 | return torch.eye(3) 38 | 39 | 40 | def genmat(): 41 | return mat(genvec()) 42 | 43 | 44 | def RodriguesRotation(x): 45 | # for autograd 46 | w = x.view(-1, 3) 47 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 48 | W = mat(w) 49 | S = W.bmm(W) 50 | I = torch.eye(3).to(w) 51 | 52 | # Rodrigues' rotation formula. 53 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 54 | # R = eye(3) + sinc1(t)*W + sinc2(t)*S 55 | 56 | R = I + sinc.Sinc1(t) * W + sinc.Sinc2(t) * S 57 | 58 | return R.view(*(x.size()[0:-1]), 3, 3) 59 | 60 | 61 | def exp(x): 62 | w = x.view(-1, 3) 63 | t = w.norm(p=2, dim=1).view(-1, 1, 1) 64 | W = mat(w) 65 | S = W.bmm(W) 66 | I = torch.eye(3).to(w) 67 | 68 | # Rodrigues' rotation formula. 69 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w'); 70 | # R = eye(3) + sinc1(t)*W + sinc2(t)*S 71 | 72 | R = I + sinc1(t) * W + sinc2(t) * S 73 | 74 | return R.view(*(x.size()[0:-1]), 3, 3) 75 | 76 | 77 | def inverse(g): 78 | R = g.view(-1, 3, 3) 79 | Rt = R.transpose(1, 2) 80 | return Rt.view_as(g) 81 | 82 | 83 | def btrace(X): 84 | # batch-trace: [B, N, N] -> [B] 85 | n = X.size(-1) 86 | X_ = X.view(-1, n, n) 87 | tr = torch.zeros(X_.size(0)).to(X) 88 | for i in range(tr.size(0)): 89 | m = X_[i, :, :] 90 | tr[i] = torch.trace(m) 91 | return tr.view(*(X.size()[0:-2])) 92 | 93 | 94 | def log(g): 95 | eps = 1.0e-7 96 | R = g.view(-1, 3, 3) 97 | tr = btrace(R) 98 | c = (tr - 1) / 2 99 | t = torch.acos(c) 100 | sc = sinc1(t) 101 | idx0 = (torch.abs(sc) <= eps) 102 | idx1 = (torch.abs(sc) > eps) 103 | sc = sc.view(-1, 1, 1) 104 | 105 | X = torch.zeros_like(R) 106 | if idx1.any(): 107 | X[idx1] = (R[idx1] - R[idx1].transpose(1, 2)) / (2 * sc[idx1]) 108 | 109 | if idx0.any(): 110 | # t[idx0] == math.pi 111 | t2 = t[idx0] ** 2 112 | A = (R[idx0] + torch.eye(3).type_as(R).unsqueeze(0)) * t2.view(-1, 1, 1) / 2 113 | aw1 = torch.sqrt(A[:, 0, 0]) 114 | aw2 = torch.sqrt(A[:, 1, 1]) 115 | aw3 = torch.sqrt(A[:, 2, 2]) 116 | sgn_3 = torch.sign(A[:, 0, 2]) 117 | sgn_3[sgn_3 == 0] = 1 118 | sgn_23 = torch.sign(A[:, 1, 2]) 119 | sgn_23[sgn_23 == 0] = 1 120 | sgn_2 = sgn_23 * sgn_3 121 | w1 = aw1 122 | w2 = aw2 * sgn_2 123 | w3 = aw3 * sgn_3 124 | w = torch.stack((w1, w2, w3), dim=-1) 125 | W = mat(w) 126 | X[idx0] = W 127 | 128 | x = vec(X.view_as(g)) 129 | return x 130 | 131 | 132 | def transform(g, a): 133 | # g in SO(3): * x 3 x 3 134 | # a in R^3: * x 3[x N] 135 | if len(g.size()) == len(a.size()): 136 | b = g.matmul(a) 137 | else: 138 | b = g.matmul(a.unsqueeze(-1)).squeeze(-1) 139 | return b 140 | 141 | 142 | def group_prod(g, h): 143 | # g, h : SO(3) 144 | g1 = g.matmul(h) 145 | return g1 146 | 147 | 148 | def vecs_Xg_ig(x): 149 | """ Vi = vec(dg/dxi * inv(g)), where g = exp(x) 150 | (== [Ad(exp(x))] * vecs_ig_Xg(x)) 151 | """ 152 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) 153 | X = mat(x) 154 | S = X.bmm(X) 155 | # B = x.view(-1,3,1).bmm(x.view(-1,1,3)) # B = x*x' 156 | I = torch.eye(3).to(X) 157 | 158 | # V = sinc1(t)*eye(3) + sinc2(t)*X + sinc3(t)*B 159 | # V = eye(3) + sinc2(t)*X + sinc3(t)*S 160 | 161 | V = I + sinc2(t) * X + sinc3(t) * S 162 | 163 | return V.view(*(x.size()[0:-1]), 3, 3) 164 | 165 | 166 | def inv_vecs_Xg_ig(x): 167 | """ H = inv(vecs_Xg_ig(x)) """ 168 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1) 169 | X = mat(x) 170 | S = X.bmm(X) 171 | I = torch.eye(3).to(x) 172 | 173 | e = 0.01 174 | eta = torch.zeros_like(t) 175 | s = (t < e) 176 | c = (s == 0) 177 | t2 = t[s] ** 2 178 | eta[s] = ((t2 / 40 + 1) * t2 / 42 + 1) * t2 / 720 + 1 / 12 # O(t**8) 179 | eta[c] = (1 - (t[c] / 2) / torch.tan(t[c] / 2)) / (t[c] ** 2) 180 | 181 | H = I - 1 / 2 * X + eta * S 182 | return H.view(*(x.size()[0:-1]), 3, 3) 183 | 184 | 185 | class ExpMap(torch.autograd.Function): 186 | """ Exp: so(3) -> SO(3) 187 | """ 188 | 189 | @staticmethod 190 | def forward(ctx, x): 191 | """ Exp: R^3 -> M(3), 192 | size: [B, 3] -> [B, 3, 3], 193 | or [B, 1, 3] -> [B, 1, 3, 3] 194 | """ 195 | ctx.save_for_backward(x) 196 | g = exp(x) 197 | return g 198 | 199 | @staticmethod 200 | def backward(ctx, grad_output): 201 | x, = ctx.saved_tensors 202 | g = exp(x) 203 | gen_k = genmat().to(x) 204 | # gen_1 = gen_k[0, :, :] 205 | # gen_2 = gen_k[1, :, :] 206 | # gen_3 = gen_k[2, :, :] 207 | 208 | # Let z = f(g) = f(exp(x)) 209 | # dz = df/dgij * dgij/dxk * dxk 210 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk 211 | # = df/dgij * [gen_k*g]_ij * dxk 212 | 213 | dg = gen_k.matmul(g.view(-1, 1, 3, 3)) 214 | # (k, i, j) 215 | dg = dg.to(grad_output) 216 | 217 | go = grad_output.contiguous().view(-1, 1, 3, 3) 218 | dd = go * dg 219 | grad_input = dd.sum(-1).sum(-1) 220 | 221 | return grad_input 222 | 223 | 224 | Exp = ExpMap.apply 225 | 226 | # EOF 227 | -------------------------------------------------------------------------------- /datasets/se_math/transforms.py: -------------------------------------------------------------------------------- 1 | """ gives some transform methods for 3d points """ 2 | import math 3 | 4 | import torch 5 | import torch.utils.data 6 | 7 | from . import so3 8 | from . import se3 9 | 10 | 11 | class Mesh2Points: 12 | def __init__(self): 13 | pass 14 | 15 | def __call__(self, mesh): 16 | mesh = mesh.clone() 17 | v = mesh.vertex_array 18 | return torch.from_numpy(v).type(dtype=torch.float) 19 | 20 | 21 | class OnUnitSphere: 22 | def __init__(self, zero_mean=False): 23 | self.zero_mean = zero_mean 24 | 25 | def __call__(self, tensor): 26 | if self.zero_mean: 27 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D] 28 | v = tensor - m 29 | else: 30 | v = tensor 31 | nn = v.norm(p=2, dim=1) # [N, D] -> [N] 32 | nmax = torch.max(nn) 33 | return v / nmax 34 | 35 | 36 | class OnUnitCube: 37 | def __init__(self): 38 | pass 39 | 40 | def method1(self, tensor): 41 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D] 42 | v = tensor - m 43 | s = torch.max(v.abs()) 44 | v = v / s * 0.5 45 | return v 46 | 47 | def method2(self, tensor): 48 | c = torch.max(tensor, dim=0)[0] - torch.min(tensor, dim=0)[0] # [N, D] -> [D] 49 | s = torch.max(c) # -> scalar 50 | v = tensor / s 51 | return v - v.mean(dim=0, keepdim=True) 52 | 53 | def __call__(self, tensor): 54 | # return self.method1(tensor) 55 | return self.method2(tensor) 56 | 57 | 58 | class Resampler: 59 | """ [N, D] -> [M, D] """ 60 | 61 | def __init__(self, num): 62 | self.num = num 63 | 64 | def __call__(self, tensor): 65 | num_points, dim_p = tensor.size() 66 | out = torch.zeros(self.num, dim_p).to(tensor) 67 | 68 | selected = 0 69 | while selected < self.num: 70 | remainder = self.num - selected 71 | idx = torch.randperm(num_points) 72 | sel = min(remainder, num_points) 73 | val = tensor[idx[:sel]] 74 | out[selected:(selected + sel)] = val 75 | selected += sel 76 | return out 77 | 78 | 79 | class RandomTranslate: 80 | def __init__(self, mag=None, randomly=True): 81 | self.mag = 1.0 if mag is None else mag 82 | self.randomly = randomly 83 | self.igt = None 84 | 85 | def __call__(self, tensor): 86 | # tensor: [N, 3] 87 | amp = torch.rand(1) if self.randomly else 1.0 88 | t = torch.randn(1, 3).to(tensor) 89 | t = t / t.norm(p=2, dim=1, keepdim=True) * amp * self.mag 90 | 91 | g = torch.eye(4).to(tensor) 92 | g[0:3, 3] = t[0, :] 93 | self.igt = g # [4, 4] 94 | 95 | p1 = tensor + t 96 | return p1 97 | 98 | 99 | class RandomRotator: 100 | def __init__(self, mag=None, randomly=True): 101 | self.mag = math.pi if mag is None else mag 102 | self.randomly = randomly 103 | self.igt = None 104 | 105 | def __call__(self, tensor): 106 | # tensor: [N, 3] 107 | amp = torch.rand(1) if self.randomly else 1.0 108 | w = torch.randn(1, 3) 109 | w = w / w.norm(p=2, dim=1, keepdim=True) * amp * self.mag 110 | 111 | g = so3.exp(w).to(tensor) # [1, 3, 3] 112 | self.igt = g.squeeze(0) # [3, 3] 113 | 114 | p1 = so3.transform(g, tensor) # [1, 3, 3] x [N, 3] -> [N, 3] 115 | return p1 116 | 117 | 118 | class RandomRotatorZ: 119 | def __init__(self): 120 | self.mag = 2 * math.pi 121 | 122 | def __call__(self, tensor): 123 | # tensor: [N, 3] 124 | w = torch.Tensor([0, 0, 1]).view(1, 3) * torch.rand(1) * self.mag 125 | 126 | g = so3.exp(w).to(tensor) # [1, 3, 3] 127 | 128 | p1 = so3.transform(g, tensor) 129 | return p1 130 | 131 | 132 | class RandomJitter: 133 | """ generate perturbations """ 134 | 135 | def __init__(self, scale=0.01, clip=0.05): 136 | self.scale = scale 137 | self.clip = clip 138 | self.e = None 139 | 140 | def jitter(self, tensor): 141 | noise = torch.zeros_like(tensor).to(tensor) # [N, 3] 142 | noise.normal_(0, self.scale) 143 | noise.clamp_(-self.clip, self.clip) 144 | self.e = noise 145 | return tensor.add(noise) 146 | 147 | def __call__(self, tensor): 148 | return self.jitter(tensor) 149 | 150 | 151 | class RandomTransformSE3: 152 | """ rigid motion """ 153 | 154 | def __init__(self, mag=1, mag_randomly=False): 155 | self.mag = mag 156 | self.randomly = mag_randomly 157 | 158 | self.gt = None 159 | self.igt = None 160 | 161 | def generate_transform(self): 162 | # return: a twist-vector 163 | amp = self.mag 164 | if self.randomly: 165 | amp = torch.rand(1, 1) * self.mag 166 | x = torch.randn(1, 6) 167 | x = x / x.norm(p=2, dim=1, keepdim=True) * amp 168 | 169 | '''a = torch.rand(3) 170 | a = a * math.pi 171 | b = torch.zeros(1, 6) 172 | b[:, 0:3] = a 173 | x = x+b 174 | ''' 175 | return x # [1, 6] 176 | 177 | def apply_transform(self, p0, x): 178 | # p0: [N, 3] 179 | # x: [1, 6] 180 | g = se3.exp(x).to(p0) # [1, 4, 4] 181 | gt = se3.exp(-x).to(p0) # [1, 4, 4] 182 | 183 | p1 = se3.transform(g, p0) 184 | self.gt = gt.squeeze(0) # gt: p1 -> p0 185 | self.igt = g.squeeze(0) # igt: p0 -> p1 186 | return p1 187 | 188 | def transform(self, tensor): 189 | x = self.generate_transform() 190 | return self.apply_transform(tensor, x) 191 | 192 | def __call__(self, tensor): 193 | return self.transform(tensor) 194 | 195 | # EOF 196 | -------------------------------------------------------------------------------- /datasets/utils/commons.py: -------------------------------------------------------------------------------- 1 | import pickle, matplotlib.pyplot as plt, numpy as np 2 | 3 | def load_data(path): 4 | file = open(path, "rb") 5 | data = pickle.load(file) 6 | file.close() 7 | return data 8 | 9 | def save_data(path, data): 10 | file = open(path, "wb") 11 | pickle.dump(data, file) 12 | file.close() 13 | 14 | def line_plot(xs, ys, title, path): 15 | plt.figure() 16 | plt.plot(xs, ys) 17 | plt.title(title) 18 | plt.savefig(path) 19 | plt.close() 20 | 21 | def stack_action_seqs(action_seqs): 22 | """ 23 | :param action_seqs: (H, B, D) 24 | :return: actions (B, D) 25 | """ 26 | return action_seqs.sum(0) 27 | 28 | def cal_errors_np(rs_pred, ts_pred, rs_lb, ts_lb, is_degree=False): 29 | """ 30 | :param rs_pred: (B, 3) 31 | :param ts_pred: (B, 3) 32 | :param rs_lb: (B, 3) 33 | :param ts_lb: (B, 3) 34 | :param is_degree: bool 35 | :return: dict key: val (B, ) 36 | """ 37 | if not is_degree: 38 | rs_pred = np.degrees(rs_pred) 39 | rs_lb = np.degrees(rs_lb) 40 | rs_mse = np.mean((rs_pred - rs_lb) ** 2, 1) # (B, ) 41 | ts_mse = np.mean((ts_pred - ts_lb) ** 2, 1) 42 | rs_rmse = np.sqrt(rs_mse) 43 | ts_rmse = np.sqrt(ts_mse) 44 | rs_mae = np.mean(np.abs(rs_pred - rs_lb), 1) 45 | ts_mae = np.mean(np.abs(ts_pred - ts_lb), 1) 46 | 47 | return {"rs_mse": rs_mse, "ts_mse": ts_mse, 48 | "rs_rmse": rs_rmse, "ts_rmse": ts_rmse, 49 | "rs_mae": rs_mae, "ts_mae": ts_mae} 50 | -------------------------------------------------------------------------------- /datasets/utils/db_icl_nuim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) 2020 NVIDIA 3 | Author: Wentao Yuan 4 | ''' 5 | 6 | import h5py 7 | import numpy as np 8 | import os 9 | import torch 10 | from scipy.spatial import cKDTree 11 | from torch.utils.data import Dataset 12 | from sklearn.neighbors import NearestNeighbors 13 | from scipy.spatial.distance import minkowski 14 | import open3d 15 | def jitter_pcd(pcd, sigma=0.01, clip=0.05): 16 | pcd += np.clip(sigma * np.random.randn(*pcd.shape), -1 * clip, clip) 17 | return pcd 18 | 19 | def random_pose(max_angle, max_trans): 20 | R = random_rotation(max_angle) 21 | t = random_translation(max_trans) 22 | return np.concatenate([np.concatenate([R, t], 1), [[0, 0, 0, 1]]], 0) 23 | 24 | 25 | def random_rotation(max_angle): 26 | axis = np.random.randn(3) 27 | axis /= np.linalg.norm(axis) 28 | angle = np.random.rand() * max_angle 29 | A = np.array([[0, -axis[2], axis[1]], 30 | [axis[2], 0, -axis[0]], 31 | [-axis[1], axis[0], 0]]) 32 | R = np.eye(3) + np.sin(angle) * A + (1 - np.cos(angle)) * np.dot(A, A) 33 | return R 34 | 35 | 36 | def random_translation(max_dist): 37 | t = np.random.randn(3) 38 | t /= np.linalg.norm(t) 39 | t *= np.random.rand() * max_dist 40 | return np.expand_dims(t, 1) 41 | 42 | def farthest_subsample_points(pointcloud1, pointcloud2, num_subsampled_points=768): 43 | num_points = pointcloud1.shape[0] 44 | nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto', 45 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud1) 46 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1]) 47 | idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,)) 48 | nbrs2 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto', 49 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud2) 50 | random_p2 = random_p1 #np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 2, -2]) 51 | idx2 = nbrs2.kneighbors(random_p2, return_distance=False).reshape((num_subsampled_points,)) 52 | return pointcloud1[idx1, :], pointcloud2[idx2, :] 53 | 54 | class TestData(Dataset): 55 | def __init__(self, path): 56 | super(TestData, self).__init__() 57 | with h5py.File(path, 'r') as f: 58 | self.source = f['source'][...] 59 | self.target = f['target'][...] 60 | self.transform = f['transform'][...] 61 | self.n_points = 1024 62 | self.max_angle = 45 / 180 * np.pi 63 | self.max_trans = 1.0 64 | self.noisy = False 65 | self.subsampled = True 66 | self.num_subsampled_points = 768 67 | 68 | def __getitem__(self, index): 69 | np.random.seed(index) 70 | pcd1 = self.source[index][:self.n_points] 71 | pcd2 = self.target[index][:self.n_points] 72 | transform = self.transform[index] 73 | pcd1 = pcd1 @ transform[:3, :3].T + transform[:3, 3] 74 | transform = random_pose(self.max_angle, self.max_trans) 75 | pose1 = random_pose(np.pi, self.max_trans) 76 | pose2 = transform @ pose1 77 | pcd1 = pcd1 @ pose1[:3, :3].T + pose1[:3, 3] 78 | pcd2 = pcd2 @ pose2[:3, :3].T + pose2[:3, 3] 79 | R_ab = transform[:3, :3] 80 | translation_ab = transform[:3, 3] 81 | 82 | if self.subsampled: 83 | pcd1, pcd2 = farthest_subsample_points(pcd1, pcd2,num_subsampled_points = self.num_subsampled_points) 84 | 85 | # return pcd1.T.astype('float32'), pcd2.T.astype('float32'), R_ab.astype('float32'), \ 86 | # translation_ab.astype('float32') 87 | pcd_src = open3d.geometry.PointCloud() 88 | pcd_src.points = open3d.utility.Vector3dVector(pcd1) 89 | open3d.geometry.estimate_normals(pcd_src) 90 | # print(np.asarray(pcd_src.points).shape) 91 | # print(np.asarray(pcd_src.normals).shape) 92 | # exit() 93 | pcd_tgt = open3d.geometry.PointCloud() 94 | pcd_tgt.points = open3d.utility.Vector3dVector(pcd2) 95 | open3d.geometry.estimate_normals(pcd_tgt) 96 | sample = {'idx': np.array(index, dtype=np.int32),'transform_gt':np.eye(4).astype('float32')} 97 | sample['points_src'] = np.concatenate([np.asarray(pcd_src.points),np.asarray(pcd_src.normals)],axis=1).astype('float32') 98 | sample['points_ref'] = np.concatenate([np.asarray(pcd_tgt.points),np.asarray(pcd_tgt.normals)],axis=1).astype('float32') 99 | sample['transform_gt'][:3, :3] = R_ab.astype('float32') 100 | sample['transform_gt'][:3, 3] = translation_ab.astype('float32') 101 | return sample 102 | 103 | def __len__(self): 104 | return self.transform.shape[0] 105 | 106 | 107 | class TrainData(Dataset): 108 | def __init__(self, path): 109 | super(TrainData, self).__init__() 110 | with h5py.File(path, 'r') as f: 111 | self.points = f['points'][...] 112 | self.n_points = 1024 113 | self.max_angle = 45 / 180 * np.pi 114 | self.max_trans = 1.0 115 | self.noisy = False 116 | self.subsampled = True 117 | self.num_subsampled_points = 768 118 | 119 | def __getitem__(self, index): 120 | 121 | pcd1 = self.points[index][:self.n_points] 122 | pcd2 = self.points[index][:self.n_points] 123 | transform = random_pose(self.max_angle, self.max_trans) 124 | pose1 = random_pose(np.pi, self.max_trans) 125 | pose2 = transform @ pose1 126 | pcd1 = pcd1 @ pose1[:3, :3].T + pose1[:3, 3] 127 | pcd2 = pcd2 @ pose2[:3, :3].T + pose2[:3, 3] 128 | R_ab = transform[:3, :3] 129 | translation_ab = transform[:3, 3] 130 | if self.subsampled: 131 | pcd1, pcd2 = farthest_subsample_points(pcd1, pcd2,num_subsampled_points = self.num_subsampled_points) 132 | 133 | # return pcd1.T.astype('float32'), pcd2.T.astype('float32'), R_ab.astype('float32'), \ 134 | # translation_ab.astype('float32') 135 | pcd_src = open3d.geometry.PointCloud() 136 | pcd_src.points = open3d.utility.Vector3dVector(pcd1) 137 | open3d.geometry.estimate_normals(pcd_src) 138 | # print(np.asarray(pcd_src.points).shape) 139 | # print(np.asarray(pcd_src.normals).shape) 140 | # exit() 141 | pcd_tgt = open3d.geometry.PointCloud() 142 | pcd_tgt.points = open3d.utility.Vector3dVector(pcd2) 143 | open3d.geometry.estimate_normals(pcd_tgt) 144 | sample = {'idx': np.array(index, dtype=np.int32),'transform_gt':np.eye(4).astype('float32')} 145 | sample['points_src'] = np.concatenate([np.asarray(pcd_src.points),np.asarray(pcd_src.normals)],axis=1).astype('float32') 146 | sample['points_ref'] = np.concatenate([np.asarray(pcd_tgt.points),np.asarray(pcd_tgt.normals)],axis=1).astype('float32') 147 | sample['transform_gt'][:3, :3] = R_ab.astype('float32') 148 | sample['transform_gt'][:3, 3] = translation_ab.astype('float32') 149 | return sample 150 | 151 | def __len__(self): 152 | return self.points.shape[0] 153 | 154 | -------------------------------------------------------------------------------- /datasets/utils/gen_normal.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d, numpy as np, pdb 2 | 3 | def gen_normal(pcs): 4 | """ 5 | :param pcs: shape (B, 3, N), np.array 6 | :return: shape (B, 6, N) 7 | """ 8 | normal_pcs = np.zeros([len(pcs), 6, pcs.shape[2]]) 9 | for idx, pc in enumerate(pcs): 10 | _pc = o3d.geometry.PointCloud() 11 | _pc.points = o3d.utility.Vector3dVector(pc.transpose([1, 0])) 12 | o3d.geometry.estimate_normals(_pc, search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)) 13 | normal_pc = np.concatenate([np.asarray(_pc.points), np.asarray(_pc.normals)], 1).transpose([1, 0]) 14 | normal_pcs[idx] = normal_pc 15 | return normal_pcs 16 | 17 | if __name__ == '__main__': 18 | pcs = np.zeros([5, 3, 1024]) 19 | gen_normal(pcs) -------------------------------------------------------------------------------- /datasets/utils/npmat2euler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.transform import Rotation 3 | 4 | def npmat2euler(mats, seq='zyx', is_degrees=True): 5 | eulers = [] 6 | for i in range(mats.shape[0]): 7 | r = Rotation.from_dcm(mats[i]) 8 | eulers.append(r.as_euler(seq, degrees=is_degrees)) 9 | return np.asarray(eulers, dtype='float32') -------------------------------------------------------------------------------- /modules/__pycache__/commons.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/commons.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/dcp_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/dcp_net.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/dgcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/dgcnn.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/sparsemax.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/sparsemax.cpython-36.pyc -------------------------------------------------------------------------------- /modules/commons.py: -------------------------------------------------------------------------------- 1 | import copy, torch.nn as nn, torch, math, torch.nn.functional as F 2 | 3 | def knn(x, k): 4 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x) 5 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 6 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous() 7 | 8 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 9 | return idx 10 | 11 | def get_graph_feature(x, k=20): 12 | # x = x.squeeze() 13 | idx = knn(x, k=k) # (batch_size, num_points, k) 14 | batch_size, num_points, _ = idx.size() 15 | device = torch.device('cuda') 16 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 17 | idx = idx + idx_base 18 | idx = idx.view(-1) 19 | _, num_dims, _ = x.size() 20 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 21 | feature = x.view(batch_size * num_points, -1)[idx, :] 22 | feature = feature.view(batch_size, num_points, k, num_dims)#[b,n,k,3],knn 23 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)#[b,n,k,3],central points 24 | feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2) 25 | return feature 26 | 27 | def get_graph_feature2(x, x_features, k=20): 28 | idx = knn(x, k=k) # (batch_size, num_points, k) 29 | batch_size, num_points, _ = idx.size() 30 | 31 | device = torch.device('cuda') 32 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 33 | idx = idx + idx_base 34 | idx = idx.view(-1) 35 | _, num_dims, _ = x.size() 36 | _,num_dims_features,_ = x_features.size() 37 | 38 | x_features = x_features.transpose(2,1).contiguous()#[b,n,c] 39 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 40 | feature = x.view(batch_size * num_points, -1)[idx, :] 41 | x_feature_k = x_features.view(batch_size * num_points, -1)[idx, :] 42 | feature = feature.view(batch_size, num_points, k, num_dims)#[b,n,k,3],knn 43 | x_feature_k = x_feature_k.view(batch_size,num_points,k,num_dims_features).permute(0,3,1,2)#[b,n,k,c],knn 44 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)#[b,n,k,3],central points 45 | feature = (feature - x).permute(0,3,1,2) #[b,n,k,3] 46 | 47 | return feature,x_feature_k 48 | 49 | def pairwise_distance_batch(x,y): 50 | 51 | xx = torch.sum(torch.mul(x,x), 1, keepdim = True)#[b,1,n] 52 | yy = torch.sum(torch.mul(y,y),1, keepdim = True) #[b,1,n] 53 | inner = -2 * torch.matmul(x.transpose(2,1),y) #[b,n,n] 54 | 55 | pair_distance = xx.transpose(2,1) + inner + yy #[b,n,n] 56 | device = torch.device('cuda') 57 | # print("3") 58 | zeros_matrix = torch.zeros_like(pair_distance,device = device) 59 | pair_distance_square = torch.where(pair_distance > 0.0,pair_distance,zeros_matrix) 60 | error_mask = torch.le(pair_distance_square,0.0) 61 | # print("4") 62 | pair_distances = torch.sqrt(pair_distance_square + error_mask.float()*1e-16) 63 | pair_distances = torch.mul(pair_distances,(1.0-error_mask.float())) 64 | # print("5") 65 | return pair_distances 66 | 67 | def clones(module, N): 68 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 69 | 70 | class LayerNorm(nn.Module): 71 | def __init__(self, features, eps=1e-6): 72 | super(LayerNorm, self).__init__() 73 | self.a_2 = nn.Parameter(torch.ones(features)) 74 | self.b_2 = nn.Parameter(torch.zeros(features)) 75 | self.eps = eps 76 | 77 | def forward(self, x): 78 | mean = x.mean(-1, keepdim=True) 79 | std = x.std(-1, keepdim=True) 80 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 81 | 82 | class Identity(nn.Module): 83 | def __init__(self): 84 | super(Identity, self).__init__() 85 | 86 | class SublayerConnection(nn.Module): 87 | def __init__(self, size, dropout=None): 88 | super(SublayerConnection, self).__init__() 89 | self.norm = LayerNorm(size) 90 | 91 | def forward(self, x, sublayer): 92 | return x + sublayer(self.norm(x)) 93 | 94 | def attention(query, key, value, mask=None, dropout=None): 95 | d_k = query.size(-1) 96 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / math.sqrt(d_k) 97 | if mask is not None: 98 | scores = scores.masked_fill(mask == 0, -1e9) 99 | p_attn = F.softmax(scores, dim=-1) 100 | return torch.matmul(p_attn, value), p_attn 101 | 102 | class MultiHeadedAttention(nn.Module): 103 | def __init__(self, h, d_model, dropout=0.1): 104 | "Take in model size and number of heads." 105 | super(MultiHeadedAttention, self).__init__() 106 | assert d_model % h == 0 107 | # We assume d_v always equals d_k 108 | self.d_k = d_model // h 109 | self.h = h 110 | self.linears = clones(nn.Linear(d_model, d_model), 4) 111 | self.attn = None 112 | self.dropout = None 113 | 114 | def forward(self, query, key, value, mask=None): 115 | "Implements Figure 2" 116 | if mask is not None: 117 | # Same mask applied to all h heads. 118 | mask = mask.unsqueeze(1) 119 | nbatches = query.size(0) 120 | 121 | # 1) Do all the linear projections in batch from d_model => h x d_k 122 | query, key, value = \ 123 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous() 124 | for l, x in zip(self.linears, (query, key, value))] 125 | 126 | # 2) Apply attention on all the projected vectors in batch. 127 | x, self.attn = attention(query, key, value, mask=mask, 128 | dropout=self.dropout) 129 | 130 | # 3) "Concat" using a view and apply a final linear. 131 | x = x.transpose(1, 2).contiguous() \ 132 | .view(nbatches, -1, self.h * self.d_k) 133 | return self.linears[-1](x) 134 | 135 | class PositionwiseFeedForward(nn.Module): 136 | "Implements FFN equation." 137 | 138 | def __init__(self, d_model, d_ff, dropout=0.1): 139 | super(PositionwiseFeedForward, self).__init__() 140 | self.w_1 = nn.Linear(d_model, d_ff) 141 | self.norm = nn.Sequential() # nn.BatchNorm1d(d_ff) 142 | self.w_2 = nn.Linear(d_ff, d_model) 143 | self.dropout = None 144 | 145 | def forward(self, x): 146 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous()) -------------------------------------------------------------------------------- /modules/dcp_net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn, torch 2 | from modules.dgcnn import DGCNN 3 | from utils.mat2euler import mat2euler 4 | 5 | def pairwise_distance_batch(x,y): 6 | xx = torch.sum(torch.mul(x,x), 1, keepdim = True)#[b,1,n] 7 | yy = torch.sum(torch.mul(y,y),1, keepdim = True) #[b,1,n] 8 | inner = -2*torch.matmul(x.transpose(2,1),y) #[b,n,n] 9 | pair_distance = xx.transpose(2,1) + inner + yy #[b,n,n] 10 | device = torch.device('cuda') 11 | zeros_matrix = torch.zeros_like(pair_distance,device = device) 12 | pair_distance_square = torch.where(pair_distance > 0.0,pair_distance,zeros_matrix) 13 | error_mask = torch.le(pair_distance_square,0.0) 14 | pair_distances = torch.sqrt(pair_distance_square + error_mask.float()*1e-16) 15 | pair_distances = torch.mul(pair_distances,(1.0-error_mask.float())) 16 | return pair_distances 17 | 18 | class DCPNet(nn.Module): 19 | def __init__(self, opts): 20 | super(DCPNet, self).__init__() 21 | self.emb_nn = DGCNN(emb_dims = opts.emb_dims) 22 | self.emb_dims = opts.emb_dims 23 | self.reflect = nn.Parameter(torch.eye(3), requires_grad=False) 24 | self.reflect[2, 2] = -1 25 | 26 | self.opts = opts 27 | self.planning_horizon = self.opts.cem.planning_horizon 28 | self.nn = nn.Sequential(nn.Linear(self.emb_dims * 2, self.emb_dims // 4), 29 | nn.BatchNorm1d(self.emb_dims // 4), 30 | nn.LeakyReLU(), 31 | nn.Linear(self.emb_dims // 4, self.emb_dims // 8), 32 | nn.BatchNorm1d(self.emb_dims // 8), 33 | nn.LeakyReLU(), 34 | # nn.Linear(self.emb_dims // 4, self.emb_dims // 8), 35 | # nn.BatchNorm1d(self.emb_dims // 8), 36 | # nn.LeakyReLU(), 37 | nn.Linear(self.emb_dims // 8, 6)) 38 | 39 | def forward(self, srcs, tgts, is_sigma=False): 40 | batch_size = len(srcs) 41 | srcs_emb = self.emb_nn(srcs) # 3, 512, 1024 42 | tgts_emb = self.emb_nn(tgts) 43 | scores = -pairwise_distance_batch(srcs_emb, tgts_emb) 44 | scores = torch.softmax(scores, dim=2) 45 | src_corr = torch.matmul(tgts, scores.transpose(2, 1).contiguous()) 46 | src_centered = srcs - srcs.mean(dim=2, keepdim=True) 47 | src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True) 48 | H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()) 49 | U, S, V = [], [], [] 50 | R = [] 51 | 52 | for i in range(srcs.size(0)): 53 | u, s, v = torch.svd(H[i]) 54 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) 55 | r_det = torch.det(r) 56 | if r_det < 0: 57 | u, s, v = torch.svd(H[i]) 58 | v = torch.matmul(v, self.reflect) 59 | r = torch.matmul(v, u.transpose(1, 0).contiguous()) 60 | R.append(r) 61 | U.append(u) 62 | S.append(s) 63 | V.append(v) 64 | R = torch.stack(R, dim=0) 65 | t = (torch.matmul(-R, srcs.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)).reshape(batch_size, -1) 66 | r = mat2euler(R) 67 | if is_sigma: 68 | sg = torch.cat([srcs_emb, tgts_emb], 1) 69 | sigmas = self.nn(sg.max(dim=-1)[0]) 70 | sigmas[:, :3] = torch.nn.Sigmoid()(sigmas[:, :3]) * 1.0 71 | sigmas[:, 3:] = torch.nn.Sigmoid()(sigmas[:, 3:]) * 1.0 72 | return torch.cat([r, t], 1).unsqueeze(0), sigmas.unsqueeze(0) 73 | else: 74 | return {"r": r, "t": t} -------------------------------------------------------------------------------- /modules/dgcnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn, torch.nn.functional as F, torch 2 | from modules.commons import get_graph_feature 3 | 4 | class DGCNN(nn.Module): 5 | def __init__(self, emb_dims=512): 6 | super(DGCNN, self).__init__() 7 | self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False) 8 | self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False) 9 | self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False) 10 | self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False) 11 | self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(64) 13 | self.bn2 = nn.BatchNorm2d(64) 14 | self.bn3 = nn.BatchNorm2d(128) 15 | self.bn4 = nn.BatchNorm2d(256) 16 | self.bn5 = nn.BatchNorm2d(emb_dims) 17 | 18 | def forward(self, x): 19 | batch_size, num_dims, num_points = x.size() 20 | x = get_graph_feature(x) 21 | x = F.relu(self.bn1(self.conv1(x))) 22 | x1 = x.max(dim=-1, keepdim=True)[0] 23 | x = F.relu(self.bn2(self.conv2(x))) 24 | x2 = x.max(dim=-1, keepdim=True)[0] 25 | x = F.relu(self.bn3(self.conv3(x))) 26 | x3 = x.max(dim=-1, keepdim=True)[0] 27 | x = F.relu(self.bn4(self.conv4(x))) 28 | x4 = x.max(dim=-1, keepdim=True)[0] 29 | x = torch.cat((x1, x2, x3, x4), dim=1) 30 | x = F.relu(self.bn5(self.conv5(x))).view(batch_size, -1, num_points) 31 | return x -------------------------------------------------------------------------------- /modules/sparsemax.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn 2 | 3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | class Sparsemax(nn.Module): 6 | """Sparsemax function.""" 7 | 8 | def __init__(self, dim=None): 9 | """Initialize sparsemax activation 10 | 11 | Args: 12 | dim (int, optional): The dimension over which to apply the sparsemax function. 13 | """ 14 | super(Sparsemax, self).__init__() 15 | 16 | self.dim = -1 if dim is None else dim 17 | 18 | def forward(self, input): 19 | """Forward function. 20 | 21 | Args: 22 | input (torch.Tensor): Input tensor. First dimension should be the batch size 23 | 24 | Returns: 25 | torch.Tensor: [batch_size x number_of_logits] Output tensor 26 | 27 | """ 28 | # Sparsemax currently only handles 2-dim tensors, 29 | # so we reshape to a convenient shape and reshape back after sparsemax 30 | input = input.transpose(0, self.dim) 31 | original_size = input.size() 32 | input = input.reshape(input.size(0), -1) 33 | input = input.transpose(0, 1) 34 | dim = 1 35 | 36 | number_of_logits = input.size(dim) 37 | 38 | # Translate input by max for numerical stability 39 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) 40 | 41 | # Sort input in descending order. 42 | # (NOTE: Can be replaced with linear time selection method described here: 43 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html) 44 | zs = torch.sort(input=input, dim=dim, descending=True)[0] 45 | range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1) 46 | range = range.expand_as(zs) 47 | 48 | # Determine sparsity of projection 49 | bound = 1 + range * zs 50 | cumulative_sum_zs = torch.cumsum(zs, dim) 51 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) 52 | k = torch.max(is_gt * range, dim, keepdim=True)[0] 53 | 54 | # Compute threshold function 55 | zs_sparse = is_gt * zs 56 | 57 | # Compute taus 58 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k 59 | taus = taus.expand_as(input) 60 | 61 | # Sparsemax 62 | self.output = torch.max(torch.zeros_like(input), input - taus) 63 | 64 | # Reshape back to original shape 65 | output = self.output 66 | output = output.transpose(0, 1) 67 | output = output.reshape(original_size) 68 | output = output.transpose(0, self.dim) 69 | 70 | return output 71 | 72 | def backward(self, grad_output): 73 | """Backward function.""" 74 | dim = 1 75 | nonzeros = torch.ne(self.output, 0) 76 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) 77 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) 78 | 79 | return self.grad_input 80 | 81 | if __name__ == '__main__': 82 | import torch 83 | 84 | sparsemax = Sparsemax(dim=1) 85 | softmax = torch.nn.Softmax(dim=1) 86 | 87 | logits = torch.rand(10, 5).cuda() 88 | print("\nLogits") 89 | print(logits) 90 | 91 | # softmax_probs = softmax(logits) 92 | # print("\nSoftmax probabilities") 93 | # print(softmax_probs) 94 | 95 | sparsemax_probs = sparsemax(logits) 96 | print("\nSparsemax probabilities") 97 | print(sparsemax_probs) -------------------------------------------------------------------------------- /results/icl_nuim_n768_unseen0_noise0_seed123/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/icl_nuim_n768_unseen0_noise0_seed123/model.pth -------------------------------------------------------------------------------- /results/modelnet40_n768_unseen0_noise0_seed123/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/modelnet40_n768_unseen0_noise0_seed123/model.pth -------------------------------------------------------------------------------- /results/modelnet40_n768_unseen0_noise1_seed123/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/modelnet40_n768_unseen0_noise1_seed123/model.pth -------------------------------------------------------------------------------- /results/modelnet40_n768_unseen1_noise0_seed123/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/modelnet40_n768_unseen1_noise0_seed123/model.pth -------------------------------------------------------------------------------- /results/scene7_n768_unseen0_noise0_seed123/model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/scene7_n768_unseen0_noise0_seed123/model.pth -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | cd cemnet_lib 2 | python3 setup.py install 3 | cd .. 4 | 5 | cd batch_svd 6 | python3 setup.py install 7 | cd .. -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch, time 2 | from datasets.get_dataset import get_dataset 3 | from utils.options import opts 4 | from utils.recorder import Recorder 5 | from utils.attr_dict import AttrDict 6 | from cems.guided_cem import GuidedCEM 7 | from tqdm import tqdm 8 | 9 | np.random.seed(opts.seed) 10 | torch.manual_seed(opts.seed) 11 | torch.cuda.manual_seed_all(opts.seed) 12 | torch.backends.cudnn.enabled = True 13 | torch.backends.cudnn.benchmark = True 14 | 15 | opts.db_nm = "scene7" 16 | opts.db = AttrDict( 17 | modelnet40 = AttrDict( 18 | path = "/test/datasets/registration/modelnet40/modelnet40_normal_n2048.pth", 19 | cls_idx = -1, # None_class: -1, airplane: 0, car: 7, chair: 8, table: 33, lamp: 19 20 | is_neg_angle = False, 21 | unseen = False, 22 | gaussian_noise = False, 23 | n_points = 1024, 24 | n_sub_points = 768, 25 | factor = 4 26 | ), 27 | scene7 = AttrDict( 28 | path = "/test/datasets/registration/7scene/7scene_normal_n2048.pth", 29 | cls_idx = -1, 30 | is_neg_angle = False, 31 | unseen = False, 32 | gaussian_noise = False, 33 | n_points = 1024, 34 | n_sub_points = 768, 35 | factor = 4 36 | ), 37 | icl_nuim = AttrDict( 38 | path = "/test/datasets/registration/icl_nuim/icl_nuim_normal_n2048.pth", 39 | cls_idx = -1, 40 | is_neg_angle = False, 41 | unseen = False, 42 | gaussian_noise = False, 43 | n_points = 1024, 44 | n_sub_points = 768, 45 | factor = 4 46 | ) 47 | ) 48 | opts.db = opts.db[opts.db_nm] 49 | 50 | def init_opts(opts): 51 | opts.is_train = False 52 | opts.cem.metric_type = [["MC", 0.1], ["CD"], ["GM", 0.01]][0] 53 | opts.cem.n_candidates = [1000, {"minibatch": 1000}] 54 | opts.cem.is_fused_reward = [True, 0.5, 3] 55 | opts.exploration_weight = 0.5 56 | opts.cem.n_iters = 10 57 | opts.is_debug = True 58 | opts.cem.init_sigma = AttrDict( 59 | modelnet40 = [1.0, 0.5], 60 | scene7 = [1.0, 0.5], 61 | icl_nuim = [1.0, 0.5] 62 | ) 63 | 64 | # 1. ModelNet40 - Unseen Object 65 | # opts.model_path = "./results/modelnet40_n768_unseen0_noise0_seed123/model.pth" 66 | 67 | # 2. ModelNet40 - Unseen Catergory 68 | # opts.model_path = "./results/modelnet40_n768_unseen1_noise0_seed123/model.pth" 69 | 70 | # 3. ModelNet40 - Noise 71 | # opts.model_path = "./results/modelnet40_n768_unseen0_noise1_seed123/model.pth" 72 | 73 | # 4. 7Scene 74 | opts.model_path = "./results/scene7_n768_unseen0_noise0_seed123/model.pth" 75 | 76 | # 5. ICL-NUIM 77 | # opts.model_path = "./results/icl_nuim_n768_unseen0_noise0_seed123/model.pth" 78 | 79 | return opts 80 | 81 | def test(opts, model, test_loader): 82 | rcd_times, n_cnt = 0, 0 83 | with torch.no_grad(): 84 | r_mses, t_mses, r_maes, t_maes = [], [], [], [] 85 | for srcs, tgts, rs_lb, ts_lb in tqdm(test_loader): 86 | srcs, tgts = srcs.cuda(), tgts.cuda() 87 | t1 = time.time() 88 | results = model(srcs, tgts) 89 | rcd_times += time.time() - t1 90 | n_cnt += len(srcs) 91 | rs_pred, ts_pred = results["r"], results["t"] 92 | r_mses.append(np.mean((np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())) ** 2, 1)) 93 | r_maes.append(np.mean(np.abs(np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())), 1)) 94 | t_mses.append(np.mean((ts_pred.cpu().numpy() - ts_lb.numpy()) ** 2, 1)) 95 | t_maes.append(np.mean(np.abs(ts_pred.cpu().numpy() - ts_lb.numpy()), 1)) 96 | 97 | r_mse = np.mean(np.concatenate(r_mses, 0)).item() 98 | t_mse = np.mean(np.concatenate(t_mses, 0)).item() 99 | r_mae = np.mean(np.concatenate(r_maes, 0)).item() 100 | t_mae = np.mean(np.concatenate(t_maes, 0)).item() 101 | 102 | print("--- Test: r_mse: %.8f, t_mse: %.8f, r_rmse: %.8f, t_rmse: %.8f, r_mae: %.8f, t_mae: %.8f, time: %.8f ---" % ( 103 | r_mse, t_mse, np.sqrt(r_mse), np.sqrt(t_mse), r_mae, t_mae, rcd_times / n_cnt)) 104 | 105 | def model_test(opts): 106 | np.random.seed(opts.seed) 107 | torch.manual_seed(opts.seed) 108 | torch.cuda.manual_seed_all(opts.seed) 109 | torch.backends.cudnn.enabled = True 110 | torch.backends.cudnn.benchmark = True 111 | 112 | ## initial setting 113 | opts.recorder = Recorder(opts) 114 | model = GuidedCEM(opts).to(opts.device).load_model(opts.model_path) 115 | 116 | test_loader, db1 = get_dataset(opts, db_nm=opts.db_nm, partition="test", is_normal=False, batch_size=opts.batch_size, shuffle=False, drop_last=False, cls_idx=opts.db.cls_idx) 117 | ## testing 118 | test(opts, model, test_loader) 119 | 120 | if __name__ == '__main__': 121 | opts = init_opts(opts) 122 | model_test(opts) -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch, os, pdb 2 | from utils.options import opts 3 | from utils.transform_pc import transform_pc_torch 4 | from utils.euler2mat import euler2mat_torch 5 | from utils.recorder import Recorder 6 | from utils.test import test 7 | from utils.losses import CDLoss, GMLoss 8 | from datasets.get_dataset import get_dataset 9 | from cems.guided_cem import GuidedCEM 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | from tqdm import tqdm 12 | 13 | np.random.seed(opts.seed) 14 | torch.manual_seed(opts.seed) 15 | torch.cuda.manual_seed_all(opts.seed) 16 | torch.backends.cudnn.enabled = True 17 | torch.backends.cudnn.benchmark = True 18 | 19 | def init_opts(opts): 20 | opts.is_debug = True 21 | opts.is_train = True 22 | opts.loss_type = [["GM", 0.01], ["CD", -1.0]][0] 23 | opts.cem.metric_type = [["MC", 0.1], ["CD"], ["GM", 0.01]][0] 24 | opts.cem.n_candidates = [1000, {"minibatch": 1000}] 25 | opts.cem.n_iters = 10 26 | opts.cem.is_fused_reward = [False, -1.0, 0] 27 | opts.results_dir = "./results/%s_n%d_unseen%d_noise%d_seed%s_v0" % ( 28 | opts.db_nm, opts.db.n_sub_points, opts.db.unseen, opts.db.gaussian_noise, opts.seed) 29 | if not opts.is_debug: 30 | os.makedirs(opts.results_dir, exist_ok=True) 31 | return opts 32 | 33 | def main(opts): 34 | ## initial setting 35 | opts = init_opts(opts) 36 | opts.recorder = Recorder(opts) 37 | model = GuidedCEM(opts).to(opts.device) 38 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 39 | scheduler = MultiStepLR(optimizer, milestones=[35, 100, 150], gamma=0.7) 40 | if opts.loss_type[0] == "CD": 41 | loss_func = CDLoss(opts) 42 | elif opts.loss_type[0] == "GM": 43 | loss_func = GMLoss(opts) 44 | train_loader, _ = get_dataset(opts, db_nm=opts.db_nm, partition="train", is_normal=False, batch_size=opts.batch_size, shuffle=True, drop_last=False) 45 | test_loader, _ = get_dataset(opts, db_nm=opts.db_nm, partition="test", is_normal=False, batch_size=opts.batch_size, shuffle=False, drop_last=False) 46 | 47 | ## training 48 | print(opts.results_dir) 49 | for epoch in range(opts.n_epochs): 50 | scheduler.step() 51 | ## train 52 | losses = [] 53 | for srcs, tgts, rs_lb, ts_lb in tqdm(train_loader): 54 | srcs, tgts, rs_lb, ts_lb = [x.to(opts.device) for x in [srcs, tgts, rs_lb, ts_lb]] 55 | results = model(srcs, tgts) 56 | rs_pred, ts_pred, rs_prior, ts_prior = results["r"], results["t"], results["r_init"], results["t_init"] 57 | transform_srcs_pred = transform_pc_torch(srcs, euler2mat_torch(rs_pred), ts_pred) 58 | transform_srcs_prior = transform_pc_torch(srcs, euler2mat_torch(rs_prior), ts_prior) 59 | loss = loss_func(transform_srcs_pred, tgts) + loss_func(transform_srcs_prior, tgts) 60 | if torch.isnan(loss): 61 | print("None, skip") 62 | continue 63 | optimizer.zero_grad() 64 | loss.backward() 65 | optimizer.step() 66 | losses.append(loss.item()) 67 | print("Epoch[%d], losses: %.9f, %s." % (epoch, np.mean(losses), opts.results_dir)) 68 | 69 | ## test 70 | test(opts, model, test_loader, epoch) 71 | 72 | if __name__ == '__main__': 73 | main(opts) 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /utils/__pycache__/attr_dict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/attr_dict.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/batch_icp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/batch_icp.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/commons.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/commons.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/euler2mat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/euler2mat.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/mat2euler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/mat2euler.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/options.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/recorder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/recorder.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/transform_pc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/transform_pc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | class AttrDict(dict): 2 | 3 | def __init__(self, *args, **kwargs): 4 | super(AttrDict, self).__init__(*args, **kwargs) 5 | 6 | def __getattr__(self, key): 7 | if key.startswith('__'): 8 | raise AttributeError 9 | return self.get(key, None) 10 | 11 | def __setattr__(self, key, value): 12 | if key.startswith('__'): 13 | raise AttributeError("Cannot set magic attribute '{}'".format(key)) 14 | self[key] = value -------------------------------------------------------------------------------- /utils/batch_icp.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | from utils.transform_pc import transform_pc_torch 3 | from torch_batch_svd import svd 4 | from cemnet_lib.functions import closest_point 5 | 6 | def one_step(srcs, tgt, Rs, ts): 7 | 8 | xs_mean = srcs.mean(2, keepdim=True) # [B, 3, 1] 9 | xs_centered = srcs - xs_mean # [B, 3, N] 10 | ys = closest_point(transform_pc_torch(srcs, Rs, ts), tgt) 11 | 12 | ys_mean = ys.mean(2, keepdim=True) 13 | ys_centered = ys - ys_mean 14 | 15 | 16 | H = torch.matmul(xs_centered, ys_centered.transpose(2, 1).contiguous()) 17 | u, _, v = svd(H) 18 | Rs = torch.matmul(v, u.transpose(2, 1)).contiguous() 19 | r_det = torch.det(Rs) 20 | # Rs[:, 2, 2] = r_det 21 | diag = torch.eye(3, 3).unsqueeze(0).repeat(len(Rs), 1, 1).to(srcs.device) 22 | diag[:, 2, 2] = r_det 23 | Rs = torch.matmul(torch.matmul(v, diag), u.transpose(2, 1)).contiguous() 24 | 25 | # idxs = torch.where(r_det < 0)[0] 26 | # Rs[idxs, :, 2] *= -1 27 | # Rs[idxs, 2, :] *= -1 28 | ts = torch.matmul(- Rs, xs_mean) + ys_mean 29 | 30 | return Rs, ts.squeeze(-1) 31 | 32 | def batch_icp(opts, srcs, tgt, Rs=None, ts=None, is_path=False): 33 | 34 | # srcs(b, c, n) tgt(c, n) 35 | if Rs is None: 36 | Rs = torch.eye(3).unsqueeze(0).repeat(len(srcs), 1, 1).to(srcs.device) 37 | if ts is None: 38 | ts = torch.zeros(len(srcs), 3).to(srcs.device) 39 | paths = [[Rs, ts]] 40 | for i in range(3): 41 | Rs, ts = one_step(srcs, tgt, Rs, ts) 42 | if is_path: 43 | paths.append([Rs, ts]) 44 | if is_path: 45 | return Rs, ts, paths 46 | else: 47 | return Rs, ts -------------------------------------------------------------------------------- /utils/commons.py: -------------------------------------------------------------------------------- 1 | import pickle, numpy as np, matplotlib, pdb, torch, os 2 | matplotlib.use("Agg") 3 | import matplotlib.pyplot as plt 4 | from utils.mat2euler import mat2euler_torch 5 | from utils.euler2mat import euler2mat_torch 6 | 7 | def load_data(path): 8 | file = open(path, "rb") 9 | data = pickle.load(file) 10 | file.close() 11 | return data 12 | 13 | def save_data(path, data): 14 | file = open(path, "wb") 15 | pickle.dump(data, file) 16 | file.close() 17 | 18 | def chunker_list(seq, size): 19 | return [seq[pos: pos + size] for pos in range(0, len(seq), size)] 20 | 21 | def chunker_num(num, size): 22 | return [list(range(num))[pos: pos + size] for pos in range(0, num, size)] 23 | 24 | def stack_action_seqs(action_seqs): 25 | """ 26 | :param action_seqs: (H, B, D) 27 | :return: actions (B, D) 28 | """ 29 | return action_seqs.sum(0) 30 | 31 | def line_plot(xs, ys, title, path): 32 | plt.figure() 33 | plt.plot(xs, ys) 34 | plt.title(title) 35 | plt.savefig(path) 36 | plt.close() 37 | 38 | def shuffle_along_axis(a, axis): 39 | idx = np.random.rand(*a.shape).argsort(axis=axis) 40 | return np.take_along_axis(a,idx,axis=axis) 41 | 42 | def cal_errors_np(rs_pred, ts_pred, rs_lb, ts_lb, is_degree=False): 43 | """ 44 | :param rs_pred: (B, 3) 45 | :param ts_pred: (B, 3) 46 | :param rs_lb: (B, 3) 47 | :param ts_lb: (B, 3) 48 | :param is_degree: bool 49 | :return: dict key: val (B, ) 50 | """ 51 | if not is_degree: 52 | rs_pred = np.degrees(rs_pred) 53 | rs_lb = np.degrees(rs_lb) 54 | rs_mse = np.mean((rs_pred - rs_lb) ** 2, 1) # (B, ) 55 | ts_mse = np.mean((ts_pred - ts_lb) ** 2, 1) 56 | rs_rmse = np.sqrt(rs_mse) 57 | ts_rmse = np.sqrt(ts_mse) 58 | rs_mae = np.mean(np.abs(rs_pred - rs_lb), 1) 59 | ts_mae = np.mean(np.abs(ts_pred - ts_lb), 1) 60 | 61 | return {"rs_mse": rs_mse, "ts_mse": ts_mse, 62 | "rs_rmse": rs_rmse, "ts_rmse": ts_rmse, 63 | "rs_mae": rs_mae, "ts_mae": ts_mae} 64 | 65 | def plot_pc(pcs, save_path): 66 | # pcs = [[[nm, pc], [nm, pc]], [[nm, pc]]] pc (3, N) 67 | N = len(pcs) 68 | n_col = 3 69 | n_row = np.ceil(N / n_col) 70 | plt.figure(figsize=(n_col * 4, n_row * 4)) 71 | colors = [(244/255, 17/255, 10/255), (44/255, 175/255, 53/255), (18/255, 72/255, 148/255), (246/255, 130/255, 11/255)] 72 | for i, _pcs in enumerate(pcs): 73 | ax = plt.subplot(n_row, n_col, i + 1, projection='3d') 74 | for j, (lb, pc) in enumerate(_pcs): 75 | ax.scatter(pc[0], pc[1], pc[2], color=colors[j], marker='.', label=lb) 76 | ax.legend(fontsize=12, frameon=True) 77 | plt.savefig(save_path) 78 | # plt.close() 79 | 80 | def stack_transforms(transforms1, transforms2): 81 | """ 82 | :param transforms1: (B, 6), tensor 83 | :return: transforms2 (B, 6), tensor 84 | """ 85 | rs1, ts1 = transforms1[:, :3], transforms1[:, 3:] # (B, 3) 86 | rs2, ts2 = transforms2[:, :3], transforms2[:, 3:] # (B, 3) 87 | Rs1 = euler2mat_torch(rs1) # (B, 3, 3) 88 | Rs2 = euler2mat_torch(rs2) # (B, 3, 3) 89 | Rs = torch.matmul(Rs2, Rs1) # (B, 3, 3) 90 | ts = (torch.matmul(Rs2, ts1.unsqueeze(2)) + ts2.unsqueeze(2)).squeeze(2) # (B, 3) 91 | rs = mat2euler_torch(Rs, is_degrees=False) # (B, 3) 92 | return torch.cat([rs, ts], 1) # (B, 6) 93 | 94 | def stack_transforms_seq(transforms): 95 | """ 96 | :param transforms: (L, B, 6), tensor 97 | """ 98 | L = len(transforms) 99 | if L == 1: 100 | actions = transforms[0, :, :] 101 | else: 102 | for l in range(L - 1): 103 | if l == 0: 104 | actions = stack_transforms(transforms[l, :, :], transforms[l + 1, :, :]) 105 | else: 106 | actions = stack_transforms(actions, transforms[l + 1, :, :]) 107 | 108 | return actions # (B, 6) 109 | -------------------------------------------------------------------------------- /utils/euler2mat.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | 3 | def euler2mat_np(rs, seq="zyx"): 4 | assert seq == "zyx", "Invalid euler seq." 5 | rs_z, rs_y, rs_x = rs[:, [0]], rs[:, [1]], rs[:, [2]] 6 | sinx, siny, sinz = np.sin(rs_x), np.sin(rs_y), np.sin(rs_z) 7 | cosx, cosy, cosz = np.cos(rs_x), np.cos(rs_y), np.cos(rs_z) 8 | R = np.concatenate([cosy * cosz, - cosy * sinz, siny, 9 | sinx * siny * cosz + cosx * sinz, - sinx * siny * sinz + cosx * cosz, - sinx * cosy, 10 | - cosx * siny * cosz + sinx * sinz, cosx * siny * sinz + sinx * cosz, cosx * cosy], 11 | 1).reshape([-1, 3, 3]) 12 | return R 13 | 14 | def euler2mat_torch(rs, seq="zyx"): 15 | assert seq == "zyx", "Invalid euler seq." 16 | rs_z, rs_y, rs_x = rs[:, [0]], rs[:, [1]], rs[:, [2]] 17 | sinx, siny, sinz = torch.sin(rs_x), torch.sin(rs_y), torch.sin(rs_z) 18 | cosx, cosy, cosz = torch.cos(rs_x), torch.cos(rs_y), torch.cos(rs_z) 19 | R = torch.cat([cosy * cosz, - cosy * sinz, siny, 20 | sinx * siny * cosz + cosx * sinz, - sinx * siny * sinz + cosx * cosz, - sinx * cosy, 21 | - cosx * siny * cosz + sinx * sinz, cosx * siny * sinz + sinx * cosz, cosx * cosy], 22 | 1).view([-1, 3, 3]) 23 | return R 24 | 25 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch, torch.nn as nn, pdb 2 | 3 | class CDLoss(nn.Module): 4 | def __init__(self, opts): 5 | super(CDLoss, self).__init__() 6 | self.device = opts.device 7 | 8 | def forward(self, srcs, tgts): 9 | P = self.pairwise_distance(srcs, tgts) 10 | return torch.min(P, 1)[0].mean() + torch.min(P, 2)[0].mean() 11 | 12 | def pairwise_distance(self, srcs, tgts): 13 | srcs, tgts = srcs.transpose(2, 1), tgts.transpose(2, 1) 14 | batch_size, n_points_src, _ = srcs.size() 15 | _, n_points_tgt, _ = tgts.size() 16 | srcs_dist = torch.bmm(srcs, srcs.transpose(2, 1)) # (B, n_points_src, n_points_src) 17 | tgts_dist = torch.bmm(tgts, tgts.transpose(2, 1)) # (B, n_points_tgt, n_points_tgt) 18 | srcs_tgts_dist = torch.bmm(srcs, tgts.transpose(2, 1)) # (B, n_points_src, n_points_tgt) 19 | diag_ind_srcs = torch.arange(0, n_points_src).long().to(self.device) 20 | diag_ind_tgts = torch.arange(0, n_points_tgt).long().to(self.device) 21 | rx = srcs_dist[:, diag_ind_srcs, diag_ind_srcs].unsqueeze(1).expand_as(srcs_tgts_dist.transpose(2, 1)) # (B, n_points_tgt, n_points_src) 22 | ry = tgts_dist[:, diag_ind_tgts, diag_ind_tgts].unsqueeze(1).expand_as(srcs_tgts_dist) # (B, n_points_src, n_points_tgt) 23 | P = (rx.transpose(2, 1) + ry - 2 * srcs_tgts_dist) 24 | return P 25 | 26 | class GMLoss(nn.Module): 27 | def __init__(self, opts): 28 | super(GMLoss, self).__init__() 29 | self.device = opts.device 30 | self.opts = opts 31 | 32 | def forward(self, srcs, tgts): 33 | mu = self.opts.loss_type[1] 34 | srcs, tgts = srcs.transpose(2, 1), tgts.transpose(2, 1) 35 | P = torch.norm(srcs[:, :, None, :] - tgts[:, None, :, :], dim=-1, p=2).pow(2.0) 36 | distances = torch.cat([torch.min(P, 1)[0].unsqueeze(-1), torch.min(P, 2)[0].unsqueeze(-1)], -1) 37 | losses = ((mu * distances) / (distances + mu)).sum(2).mean(1).mean() 38 | return losses 39 | 40 | def pairwise_distance(self, srcs, tgts): 41 | batch_size, n_points_src, _ = srcs.size() 42 | _, n_points_tgt, _ = tgts.size() 43 | srcs_dist = torch.bmm(srcs, srcs.transpose(2, 1)) # (B, n_points_src, n_points_src) 44 | tgts_dist = torch.bmm(tgts, tgts.transpose(2, 1)) # (B, n_points_tgt, n_points_tgt) 45 | srcs_tgts_dist = torch.bmm(srcs, tgts.transpose(2, 1)) # (B, n_points_src, n_points_tgt) 46 | diag_ind_srcs = torch.arange(0, n_points_src).long().to(self.device) 47 | diag_ind_tgts = torch.arange(0, n_points_tgt).long().to(self.device) 48 | rx = srcs_dist[:, diag_ind_srcs, diag_ind_srcs].unsqueeze(1).expand_as(srcs_tgts_dist.transpose(2, 1)) # (B, n_points_tgt, n_points_src) 49 | ry = tgts_dist[:, diag_ind_tgts, diag_ind_tgts].unsqueeze(1).expand_as(srcs_tgts_dist) # (B, n_points_src, n_points_tgt) 50 | P = (rx.transpose(2, 1) + ry - 2 * srcs_tgts_dist) 51 | return P -------------------------------------------------------------------------------- /utils/mat2euler.py: -------------------------------------------------------------------------------- 1 | import numpy as np, torch 2 | from scipy.spatial.transform import Rotation 3 | 4 | def mat2euler_np(mats, seq='zyx', is_degrees=True): 5 | eulers = [] 6 | for i in range(mats.shape[0]): 7 | r = Rotation.from_dcm(mats[i]) 8 | eulers.append(r.as_euler(seq, degrees=is_degrees)) 9 | return np.asarray(eulers, dtype='float32') 10 | 11 | def mat2euler_torch(mats, seq='zyx', is_degrees=True): 12 | mats_np = mats.detach().cpu().numpy() 13 | eulers = [] 14 | for i in range(mats_np.shape[0]): 15 | r = Rotation.from_dcm(mats_np[i]) 16 | eulers.append(r.as_euler(seq, degrees=is_degrees)) 17 | return torch.FloatTensor(np.asarray(eulers, dtype='float32')).to(mats.device) 18 | 19 | def mat2euler(rot_mat, seq='xyz'): 20 | """ 21 | convert rotation matrix to euler angle 22 | :param rot_mat: rotation matrix rx*ry*rz [B, 3, 3] 23 | :param seq: seq is xyz(rotate along z first) or zyx 24 | :return: three angles, x, y, z 25 | """ 26 | r11 = rot_mat[:, 0, 0] 27 | r12 = rot_mat[:, 0, 1] 28 | r13 = rot_mat[:, 0, 2] 29 | r21 = rot_mat[:, 1, 0] 30 | r22 = rot_mat[:, 1, 1] 31 | r23 = rot_mat[:, 1, 2] 32 | r31 = rot_mat[:, 2, 0] 33 | r32 = rot_mat[:, 2, 1] 34 | r33 = rot_mat[:, 2, 2] 35 | if seq == 'xyz': 36 | z = torch.atan2(-r12, r11) 37 | y = torch.asin(r13) 38 | x = torch.atan2(-r23, r33) 39 | else: 40 | y = torch.asin(-r31) 41 | x = torch.atan2(r32, r33) 42 | z = torch.atan2(r21, r11) 43 | return torch.stack((z, y, x), dim=1) -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | from utils.attr_dict import AttrDict 2 | import numpy as np, torch 3 | 4 | opts = AttrDict() 5 | ## general setting 6 | opts.db_nm = "scene7" # "modelnet40", "scene7", "icl_nuim 7 | opts.is_debug = False 8 | opts.device=torch.device("cuda") 9 | opts.seed = 123 10 | opts.batch_size = 35 11 | opts.minibatch_size = 35 12 | opts.n_epochs = 30 13 | 14 | ## dataset 15 | opts.db = AttrDict( 16 | modelnet40 = AttrDict( 17 | path = "/test/datasets/registration/modelnet40/modelnet40_normal_n2048.pth", 18 | cls_idx = -1, # None_class: -1, airplane: 0, car: 7, chair: 8, table: 33, lamp: 19 19 | is_neg_angle = False, 20 | unseen = False, 21 | gaussian_noise = True, 22 | n_points = 1024, 23 | n_sub_points = 768, 24 | factor = 4 25 | ), 26 | scene7 = AttrDict( 27 | path = "/test/datasets/registration/7scene/7scene_normal_n2048.pth", 28 | cls_idx = -1, 29 | is_neg_angle = False, 30 | unseen = False, 31 | gaussian_noise = False, 32 | n_points = 1024, 33 | n_sub_points = 768, 34 | factor = 4 35 | ), 36 | icl_nuim = AttrDict( 37 | path = "/test/datasets/registration/icl_nuim/icl_nuim_normal_n2048.pth", 38 | cls_idx = -1, 39 | is_neg_angle = False, 40 | unseen = False, 41 | gaussian_noise = False, 42 | n_points = 1024, 43 | n_sub_points = 768, 44 | factor = 4 45 | ) 46 | ) 47 | opts.db = opts.db[opts.db_nm] 48 | 49 | ## cem module 50 | opts.cem = AttrDict( 51 | metric_type = ["iou", {"epsilon": 0.01}], # "iou", "cd" 52 | n_candidates = [1000, {"minibatch": 1000}], 53 | n_elites = 25, 54 | n_iters = 10, 55 | r_range = [-np.pi, np.pi], 56 | t_range = AttrDict( 57 | modelnet40 = [-1.0, 1.0], 58 | scene7 = [-1.0, 1.0], 59 | icl_nuim = [-1.0, 1.0] 60 | ), 61 | init_sigma = AttrDict( 62 | modelnet40 = [1.0, 0.5], 63 | scene7 = [1.0, 0.5], 64 | icl_nuim = [1.0, 0.5] 65 | ), 66 | planning_horizon = 1, 67 | is_icp_modification = [True, 0.5, 3] 68 | ) 69 | 70 | # network setting 71 | opts.pointer = "identity" # or "transformer", "identity" 72 | opts.head = "svd" # "mlp", "svd 73 | opts.eval = False 74 | opts.emb_nn = "dgcnn" 75 | opts.emb_dims = 512 76 | opts.ff_dim = 1024 77 | opts.n_blocks = 1 78 | opts.n_heads = 4 79 | opts.dropout = 0. 80 | -------------------------------------------------------------------------------- /utils/recorder.py: -------------------------------------------------------------------------------- 1 | import numpy as np, os 2 | from utils.commons import save_data, line_plot 3 | from collections import defaultdict 4 | 5 | class Recorder: 6 | def __init__(self, opts): 7 | self.results_dir = opts.results_dir 8 | self.results = defaultdict(list) 9 | self.opts = opts 10 | 11 | def add_res(self, res): 12 | for key, res in res.items(): 13 | if isinstance(res, list): 14 | self.results[key].extend(res) 15 | else: 16 | self.results[key].append(res) 17 | 18 | def add_reslist(self, reslist): 19 | for key, res in reslist.items(): 20 | self.results[key].append(res) 21 | 22 | def line_plot(self, nms): 23 | for nm in nms: 24 | res = self.results[nm] 25 | title = "%s_seed%d" % (nm, self.opts.seed) 26 | line_plot(np.arange(len(res)), res, title, os.path.join(self.opts.results_dir, "%s.pdf" % title)) 27 | 28 | def save(self, file_nm=None): 29 | if file_nm is not None: 30 | save_data(os.path.join(self.opts.results_dir, "%s_%d.pth" % (file_nm, self.opts.seed)), self.results) 31 | else: 32 | save_data(os.path.join(self.opts.results_dir, "res_seed%d.pth" % (self.opts.seed)), self.results) 33 | -------------------------------------------------------------------------------- /utils/test.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch, numpy as np, os 3 | 4 | def test(opts, model, test_loader, epoch): 5 | cal_score = lambda x: np.mean(np.concatenate(x, 0)).item() 6 | with torch.no_grad(): 7 | rs_mse, ts_mse, rs_mae, ts_mae = [], [], [], [] 8 | rs_prior_mse, ts_prior_mse, rs_prior_mae, ts_prior_mae = [], [], [], [] 9 | for srcs, tgts, rs_lb, ts_lb in tqdm(test_loader): 10 | srcs, tgts = srcs.cuda(), tgts.cuda() 11 | results = model(srcs, tgts) 12 | # rs_pred, ts_pred, rs_prior, ts_prior = results["rs"], results["ts"], results["rs_init"], results["ts_init"] 13 | rs_pred, ts_pred, rs_prior, ts_prior = results["r"], results["t"], results["r_init"], results["t_init"] 14 | rs_mse.append(np.mean((np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())) ** 2, 1)) 15 | rs_mae.append(np.mean(np.abs(np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())), 1)) 16 | ts_mse.append(np.mean((ts_pred.cpu().numpy() - ts_lb.numpy()) ** 2, 1)) 17 | ts_mae.append(np.mean(np.abs(ts_pred.cpu().numpy() - ts_lb.numpy()), 1)) 18 | 19 | rs_prior_mse.append(np.mean((np.degrees(rs_prior.cpu().numpy()) - np.degrees(rs_lb.numpy())) ** 2, 1)) 20 | rs_prior_mae.append(np.mean(np.abs(np.degrees(rs_prior.cpu().numpy()) - np.degrees(rs_lb.numpy())), 1)) 21 | ts_prior_mse.append(np.mean((ts_prior.cpu().numpy() - ts_lb.numpy()) ** 2, 1)) 22 | ts_prior_mae.append(np.mean(np.abs(ts_prior.cpu().numpy() - ts_lb.numpy()), 1)) 23 | 24 | r_mse = cal_score(rs_mse) 25 | t_mse = cal_score(ts_mse) 26 | r_mae = cal_score(rs_mae) 27 | t_mae = cal_score(ts_mae) 28 | r_prior_mse = cal_score(rs_prior_mse) 29 | t_prior_mse = cal_score(ts_prior_mse) 30 | r_prior_mae = cal_score(rs_prior_mae) 31 | t_prior_mae = cal_score(ts_prior_mae) 32 | 33 | if not opts.is_debug: 34 | torch.save(model.state_dict(), os.path.join(opts.results_dir, 'model_epoch%d.pth' % (epoch))) 35 | 36 | print("[%d] Test: r_mse: %.8f, t_mse: %.8f, r_mae: %.8f, t_mae: %.8f" % ( 37 | epoch, r_mse, t_mse, r_mae, t_mae)) 38 | print("[%d] Prior test: r_mse: %.8f, t_mse: %.8f, r_mae: %.8f, t_mae: %.8f" % ( 39 | epoch, r_prior_mse, t_prior_mse, r_prior_mae, t_prior_mae)) -------------------------------------------------------------------------------- /utils/transform_pc.py: -------------------------------------------------------------------------------- 1 | import torch, numpy as np 2 | from utils.euler2mat import euler2mat_torch, euler2mat_np 3 | 4 | def transform_pc_torch(pcs, Rs, ts): 5 | """ 6 | :param pcs: point clouds, (B, 3, N), 7 | :param Rs: rotation matrix, (B, 3, 3) 8 | :param ts: translation vector, (B, 3) 9 | :return: transformed pcs, (B, 3, N) 10 | """ 11 | return torch.matmul(Rs, pcs) + ts.unsqueeze(2) 12 | 13 | def transform_pc_np(pcs, Rs, ts): 14 | """ 15 | :param pcs: point clouds, (B, 3, N), 16 | :param Rs: rotation matrix, (B, 3, 3) 17 | :param ts: translation vector, (B, 3) 18 | :return: transformed pcs, (B, 3, N) 19 | """ 20 | return np.matmul(Rs, pcs) + np.expand_dims(ts, axis=2) 21 | 22 | def transform_pc_action_pytorch(pcs, actions): 23 | """ 24 | :param pcs: point clouds, (B, 3, N), 25 | :param actions: rotation matrix, (B, 6) 26 | """ 27 | Rs = euler2mat_torch(actions[:, :3], seq="zyx") 28 | ts = actions[:, 3:] 29 | return torch.matmul(Rs, pcs) + ts.unsqueeze(2) 30 | 31 | def transform_pc_action_np(pcs, actions): 32 | """ 33 | :param pcs: point clouds, (B, 3, N), 34 | :param actions: rotation matrix, (B, 6) 35 | """ 36 | Rs = euler2mat_np(actions[:, :3], seq="zyx") 37 | ts = actions[:, 3:] 38 | return np.matmul(Rs, pcs) + np.expand_dims(ts, axis=2) 39 | --------------------------------------------------------------------------------