├── .idea
├── .gitignore
├── CEMNet.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── batch_svd
├── README.md
├── __init__.py
├── setup.py
├── tests
│ └── tests.py
├── torch_batch_svd.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
└── torch_batch_svd
│ ├── __init__.py
│ ├── batch_svd.py
│ ├── csrc
│ ├── bindings.cpp
│ └── torch_batch_svd.cpp
│ └── include
│ ├── torch_batch_svd.h
│ └── utils.h
├── cemnet_lib
├── __init__.py
├── cemnet_lib.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
├── functions
│ ├── __init__.py
│ └── distances.py
├── setup.py
└── src
│ ├── cemnet_lib_api.cpp
│ ├── closest_point
│ ├── closest_point_cuda.cpp
│ ├── closest_point_cuda_kernel.cu
│ ├── closest_point_cuda_kernel.h
│ └── test.cpp
│ ├── cuda_utils.h
│ ├── distances
│ ├── distances_cuda.cpp
│ ├── distances_cuda_kernel.cu
│ └── distances_cuda_kernel.h
│ ├── ops
│ ├── ops_cuda.cpp
│ ├── ops_cuda_kernel.cu
│ └── ops_cuda_kernel.h
│ ├── ops_saved
│ ├── ops_cuda.cpp
│ ├── ops_cuda_kernel.cu
│ └── ops_cuda_kernel.h
│ └── prcem_lib_api.cpp
├── cems
├── __pycache__
│ ├── base_cem.cpython-36.pyc
│ └── guided_cem.cpython-36.pyc
├── base_cem.py
└── guided_cem.py
├── datasets
├── __pycache__
│ ├── base_dataset.cpython-36.pyc
│ ├── dataset.cpython-36.pyc
│ └── get_dataset.cpython-36.pyc
├── base_dataset.py
├── dataset.py
├── get_dataset.py
├── process_dataset.py
├── se_math
│ ├── __init__.py
│ ├── invmat.py
│ ├── mesh.py
│ ├── se3.py
│ ├── sinc.py
│ ├── so3.py
│ └── transforms.py
└── utils
│ ├── commons.py
│ ├── db_icl_nuim.py
│ ├── gen_normal.py
│ └── npmat2euler.py
├── modules
├── __pycache__
│ ├── commons.cpython-36.pyc
│ ├── dcp_net.cpython-36.pyc
│ ├── dgcnn.cpython-36.pyc
│ └── sparsemax.cpython-36.pyc
├── commons.py
├── dcp_net.py
├── dgcnn.py
└── sparsemax.py
├── results
├── icl_nuim_n768_unseen0_noise0_seed123
│ └── model.pth
├── modelnet40_n768_unseen0_noise0_seed123
│ └── model.pth
├── modelnet40_n768_unseen0_noise1_seed123
│ └── model.pth
├── modelnet40_n768_unseen1_noise0_seed123
│ └── model.pth
└── scene7_n768_unseen0_noise0_seed123
│ └── model.pth
├── run.sh
├── test_model.py
├── train_model.py
└── utils
├── __pycache__
├── attr_dict.cpython-36.pyc
├── batch_icp.cpython-36.pyc
├── commons.cpython-36.pyc
├── euler2mat.cpython-36.pyc
├── losses.cpython-36.pyc
├── mat2euler.cpython-36.pyc
├── options.cpython-36.pyc
├── recorder.cpython-36.pyc
├── test.cpython-36.pyc
└── transform_pc.cpython-36.pyc
├── attr_dict.py
├── batch_icp.py
├── commons.py
├── euler2mat.py
├── losses.py
├── mat2euler.py
├── options.py
├── recorder.py
├── test.py
└── transform_pc.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/CEMNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Jiang-HB
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sampling Network Guided Cross-Entropy Method for Unsupervised Point Cloud Registration (ICCV2021)
2 |
3 | PyTorch implementation of CEMNet for ICCV'2021 paper ["Sampling Network Guided Cross-Entropy Method for Unsupervised Point Cloud Registration"](https://openaccess.thecvf.com/content/ICCV2021/papers/Jiang_Sampling_Network_Guided_Cross-Entropy_Method_for_Unsupervised_Point_Cloud_Registration_ICCV_2021_paper.pdf), by *Haobo Jiang, Yaqi Shen, Jin Xie, Jun Li, Jianjun Qian, and Jian Yang* from PCA Lab, Nanjing University of Science and Technology, China.
4 |
5 | This paper focuses on unsupervised deep learning for 3D point clouds registration. If you find this project useful, please cite:
6 |
7 | ```bash
8 | @inproceedings{jiang2021sampling,
9 | title={{S}ampling {N}etwork {G}uided {C}ross-{E}ntropy {M}ethod for {U}nsupervised {P}oint {C}loud {R}egistration},
10 | author={Jiang, Haobo and Shen, Yaqi and Xie, Jin and Li, Jun and Qian, Jianjun and Yang, Jian},
11 | booktitle={ICCV},
12 | year={2021}
13 | }
14 | ```
15 |
16 | ## Introduction
17 |
18 | In this paper, by modeling the point cloud registration task as a Markov decision process, we propose an end-to-end deep model embedded with the cross-entropy method (CEM) for unsupervised 3D registration.
19 | Our model consists of a sampling network module and a differentiable CEM module. In our sampling network module, given a pair of point clouds, the sampling network learns a prior sampling distribution over the transformation space. The learned sampling distribution can be used as a "good" initialization of the differentiable CEM module. In our differentiable CEM module, we first propose a maximum consensus criterion based alignment metric as the reward function for the point cloud registration task. Based on the reward function, for each state, we then construct a fused score function to evaluate the sampled transformations, where we weight the current and future rewards of the transformations. Particularly, the future rewards of the sampled transforms are obtained by performing the iterative closest point (ICP) algorithm on the transformed state. Extensive experimental results demonstrate the good registration performance of our method on benchmark datasets.
20 |
21 | ## Requirements
22 |
23 | Before running our code, you need to install the `cemlib` and `batch_svd` libraries via:
24 |
25 | ```bash
26 | bash run.sh
27 | ```
28 | (If you meet something error when install `batch_svd`, please refer to [torch-batch-svd](https://github.com/KinglittleQ/torch-batch-svd).)
29 |
30 | ## Dataset Preprocessing
31 |
32 | We generated the used dataset files `modelnet40_normal_n2048.pth` , `7scene_normal_n2048.pth` and `icl_nuim_normal_n2048.pth` by preprocessing the raw point clouds of *ModelNet40*, *7Scene* and *ICL-NUIM* , and uploaded them to [GoogleDisk](https://drive.google.com/drive/folders/1ne9naYI8M8v4Lv0L9AcQm60Jqb8ciQ6t?usp=sharing). Also, you can generate them by yourself via:
33 |
34 | ```bash
35 | cd datasets
36 | python3 process_dataset.py
37 | ```
38 |
39 | After that, you need modify the dataset paths in `utils/options.py`.
40 |
41 | ## Pretrained Model
42 |
43 | We uploaded the pretrained models as below:
44 |
45 | *ModelNe40*: `results/modelnet40_n768_unseen0_noise0_seed123/model.pth`,
46 |
47 | *7Scene* : `results/scene7_n768_unseen0_noise0_seed123/model.pth`,
48 |
49 | *ICL-NUIM*: `results/icl_nuim_n768_unseen0_noise0_seed123/model.pth`.
50 |
51 | ## Acknowledgments
52 | We thank the authors of
53 | - [DCP](https://github.com/WangYueFt/dcp)
54 | - [torch-batch-svd](https://github.com/KinglittleQ/torch-batch-svd)
55 |
56 | for open sourcing their methods.
57 |
--------------------------------------------------------------------------------
/batch_svd/README.md:
--------------------------------------------------------------------------------
1 | # Pytorch Batch SVD
2 |
3 | ## 1) Introduction
4 |
5 | A batch version of SVD in Pytorch implemented using cuSolver
6 | including forward and backward function.
7 | In terms of speed, it is superior to that of `torch.svd`.
8 |
9 | ``` python
10 | import torch
11 | from torch_batch_svd import svd
12 |
13 | A = torch.rand(1000000, 3, 3).cuda()
14 | u, s, v = svd(A)
15 | u, s, v = torch.svd(A) # probably you should take a coffee break here
16 | ```
17 |
18 | The catch here is that it only works for matrices whose row and column are smaller than `32`.
19 | Other than that, `torch_batch_svd.svd` can be a drop-in for the native one.
20 |
21 | The forward function is modified from [ShigekiKarita/pytorch-cusolver](https://github.com/ShigekiKarita/pytorch-cusolver) and I fixed several bugs of it. The backward function is adapted from pytorch official [svd backward function](https://github.com/pytorch/pytorch/blob/b0545aa85f7302be5b9baf8320398981365f003d/tools/autograd/templates/Functions.cpp#L1476). I converted it to a batch version.
22 |
23 | **NOTE**: `batch_svd` supports all `torch.half`, `torch.float` and `torch.double` tensors now.
24 |
25 | **NOTE**: SVD for `torch.half` is performed by casting to `torch.float`
26 | as there is no CuSolver implementation for `c10::half`.
27 |
28 | **NOTE**: Sometimes, tests will fail for `torch.double` tensor due to numerical imprecision.
29 |
30 | ## 2) Requirements
31 |
32 | - Pytorch >= 1.0
33 |
34 | > diag_embed() is used in torch_batch_svd.cpp at the backward function. Pytorch with version lower than 1.0 does not contains diag_embed(). If you want to use it in lower version pytorch, you can replace diag_embed() by some existing function.
35 |
36 | - CUDA 9.0/10.2 (should work with 10.0/10.1 too)
37 |
38 | - Tested in Pytorch 1.4 & 1.5, with CUDA 10.2
39 |
40 | ## 3) Install
41 |
42 | Set environment variables
43 |
44 | ``` shell
45 | export CUDA_HOME=/your/cuda/home/directory/
46 | export LIBRARY_PATH=$LIBRARY_PATH:/your/cuda/lib64/ (optional)
47 | ```
48 |
49 | Run `setup.py`
50 |
51 | ``` shell
52 | python setup.py install
53 | ```
54 |
55 | Run `test.py`
56 |
57 | ```shell
58 | cd tests
59 | python -m pytest test.py
60 | ```
61 |
62 | ## 4) Differences between `torch.svd()`
63 |
64 | - The sign of column vectors at U and V may be different from `torch.svd()`.
65 |
66 | - `batch_svd()`is much more faster than `torch.svd()` using loop.
67 |
68 | ## 5) Example
69 |
70 | See `test.py` and [introduction](#1-introduction).
71 |
--------------------------------------------------------------------------------
/batch_svd/__init__.py:
--------------------------------------------------------------------------------
1 | from svd.batch_svd import batch_svd
2 |
--------------------------------------------------------------------------------
/batch_svd/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
3 | import os
4 | import glob
5 |
6 | libname = "torch_batch_svd"
7 | ext_src = glob.glob(os.path.join(libname, 'csrc/*.cpp'))
8 | print(ext_src)
9 |
10 | setup(name=libname,
11 | packages=find_packages(exclude=('tests', 'build', 'csrc', 'include', 'torch_batch_svd.egg-info')),
12 | ext_modules=[CUDAExtension(
13 | # libname + '._c',
14 | "torch_batch_svd_cuda",
15 | sources=ext_src,
16 | libraries=["cusolver", "cublas"],
17 | extra_compile_args={'cxx': ['-O2', '-I{}'.format('{}/include'.format(libname))],
18 | 'nvcc': ['-O2']}
19 | )],
20 | cmdclass={'build_ext': BuildExtension}
21 | # cmdclass={'build_ext': BuildExtension.with_options(use_ninja=False)}
22 | )
23 |
--------------------------------------------------------------------------------
/batch_svd/tests/tests.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import testing
3 |
4 | from torch_batch_svd import svd
5 |
6 |
7 | def test_float():
8 | torch.manual_seed(0)
9 | a = torch.randn(1000000, 9, 3).cuda()
10 | b = a.clone()
11 | a.requires_grad = True
12 | b.requires_grad = True
13 |
14 | U, S, V = svd(a)
15 | loss = U.sum() + S.sum() + V.sum()
16 | loss.backward()
17 |
18 | u, s, v = torch.svd(b[0], some=True, compute_uv=True)
19 | loss0 = u.sum() + s.sum() + v.sum()
20 | loss0.backward()
21 |
22 | testing.assert_allclose(U[0].abs(), u.abs()) # eigenvectors are only precise up to sign
23 | testing.assert_allclose(S[0].abs(), s.abs())
24 | testing.assert_allclose(V[0].abs(), v.abs())
25 | testing.assert_allclose(a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1)))
26 |
27 |
28 | def test_double():
29 | torch.manual_seed(0)
30 | a = torch.randn(10, 9, 3).cuda().double()
31 | b = a.clone()
32 | a.requires_grad = True
33 | b.requires_grad = True
34 |
35 | U, S, V = svd(a)
36 | loss = U.sum() + S.sum() + V.sum()
37 | loss.backward()
38 |
39 | u, s, v = torch.svd(b[0], some=True, compute_uv=True)
40 | loss0 = u.sum() + s.sum() + v.sum()
41 | loss0.backward()
42 |
43 | assert U.dtype == torch.double
44 | assert S.dtype == torch.double
45 | assert V.dtype == torch.double
46 | assert a.grad.dtype == torch.double
47 | testing.assert_allclose(U[0].abs(), u.abs()) # eigenvectors are only precise up to sign
48 | testing.assert_allclose(S[0].abs(), s.abs())
49 | testing.assert_allclose(V[0].abs(), v.abs())
50 | testing.assert_allclose(a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1)))
51 |
52 |
53 | def test_half():
54 | torch.manual_seed(0)
55 | a = torch.randn(10, 9, 3).cuda().half()
56 | b = a.clone()
57 | a.requires_grad = True
58 | b.requires_grad = True
59 |
60 | U, S, V = svd(a)
61 | loss = U.sum() + S.sum() + V.sum()
62 | loss.backward()
63 |
64 | assert U.dtype == torch.half
65 | assert S.dtype == torch.half
66 | assert V.dtype == torch.half
67 | assert a.grad.dtype == torch.half
68 | testing.assert_allclose(a, torch.matmul(torch.matmul(U, torch.diag_embed(S)), V.transpose(-2, -1)))
69 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: torch-batch-svd
3 | Version: 0.0.0
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | README.md
2 | setup.py
3 | torch_batch_svd/__init__.py
4 | torch_batch_svd/batch_svd.py
5 | torch_batch_svd.egg-info/PKG-INFO
6 | torch_batch_svd.egg-info/SOURCES.txt
7 | torch_batch_svd.egg-info/dependency_links.txt
8 | torch_batch_svd.egg-info/top_level.txt
9 | torch_batch_svd/csrc/bindings.cpp
10 | torch_batch_svd/csrc/torch_batch_svd.cpp
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | torch_batch_svd
2 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd/__init__.py:
--------------------------------------------------------------------------------
1 | from .batch_svd import svd
2 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd/batch_svd.py:
--------------------------------------------------------------------------------
1 | import torch, torch_batch_svd_cuda
2 |
3 | # from . import _c
4 |
5 |
6 |
7 | class BatchSVDFunction(torch.autograd.Function):
8 |
9 | @staticmethod
10 | def forward(ctx, input: torch.Tensor, some=True, compute_uv=True, out=None):
11 | """
12 | This function returns `(U, S, V)`
13 | which is the singular value decomposition
14 | of a input real matrix or batches of real matrices `input`
15 |
16 | :param ctx:
17 | :param input:
18 | :param out:
19 | :return:
20 | """
21 | assert input.shape[-1] < 32 and input.shape[-2] < 32, \
22 | 'This implementation only supports matrices having dims smaller than 32'
23 |
24 | is_double = True if input.dtype == torch.double else False
25 | if input.dtype == torch.half:
26 | input = input.float()
27 | ctx.is_half = True
28 | else:
29 | ctx.is_half = False
30 |
31 | if out is None:
32 | b, m, n = input.shape
33 | U = torch.empty(b, m, m, dtype=input.dtype).to(input.device)
34 | S = torch.empty(b, min(m, n), dtype=input.dtype).to(input.device)
35 | V = torch.empty(b, n, n, dtype=input.dtype).to(input.device)
36 | else:
37 | U, S, V = out
38 |
39 | torch_batch_svd_cuda.batch_svd_forward(input, U, S, V, True, 1e-7, 100, is_double)
40 | U.transpose_(1, 2)
41 | V.transpose_(1, 2)
42 | if ctx.is_half:
43 | U, S, V = U.half(), S.half(), V.half()
44 |
45 | k = S.size(1)
46 | U_reduced: torch.Tensor = U[:, :, :k]
47 | V_reduced: torch.Tensor = V[:, :, :k]
48 | ctx.save_for_backward(input, U_reduced, S, V_reduced)
49 |
50 | if not compute_uv:
51 | U = torch.zeros(b, m, m, dtype=S.dtype).to(input.device)
52 | V = torch.zeros(b, m, m, dtype=S.dtype).to(input.device)
53 | return U, S, V
54 |
55 | return (U_reduced, S, V_reduced) if some else (U, S, V)
56 |
57 | @staticmethod
58 | def backward(ctx, grad_u: torch.Tensor, grad_s: torch.Tensor, grad_v: torch.Tensor):
59 | A, U, S, V = ctx.saved_tensors
60 | if ctx.is_half:
61 | grad_u, grad_s, grad_v = grad_u.float(), grad_s.float(), grad_v.float()
62 |
63 | grad_out: torch.Tensor = torch_batch_svd_cuda.batch_svd_backward(
64 | [grad_u, grad_s, grad_v],
65 | A, True, True, U.to(A.dtype), S.to(A.dtype), V.to(A.dtype)
66 | )
67 | if ctx.is_half:
68 | grad_out = grad_out.half()
69 |
70 | return grad_out
71 |
72 |
73 | svd = BatchSVDFunction.apply
74 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd/csrc/bindings.cpp:
--------------------------------------------------------------------------------
1 | #include "torch_batch_svd.h"
2 |
3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4 | m.def("batch_svd_forward", &batch_svd_forward,
5 | "cusolver based batch svd forward");
6 | m.def("batch_svd_backward", &batch_svd_backward,
7 | "batch svd backward");
8 | }
9 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd/csrc/torch_batch_svd.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "torch_batch_svd.h"
7 | #include "utils.h"
8 |
9 | // solve U S V = svd(A) a.k.a. syevj, where A (b, m, n), U (b, m, m), S (b, min(m, n)), V (b, n, n)
10 | // see also https://docs.nvidia.com/cuda/cusolver/index.html#batchgesvdj-example1
11 | void batch_svd_forward(at::Tensor a, at::Tensor U, at::Tensor s, at::Tensor V, bool is_sort, double tol, int max_sweeps, bool is_double)
12 | {
13 | CHECK_CUDA(a);
14 | CHECK_CUDA(U);
15 | CHECK_CUDA(s);
16 | CHECK_CUDA(V);
17 | CHECK_IS_FLOAT(a);
18 |
19 | auto handle_ptr = unique_allocate(cusolverDnCreate, cusolverDnDestroy);
20 | const auto A = a.contiguous().clone().transpose(1, 2).contiguous().transpose(1, 2); // important
21 | const auto batch_size = A.size(0);
22 | const auto m = A.size(1);
23 | TORCH_CHECK(m <= 32, "matrix row should be <= 32");
24 | const auto n = A.size(2);
25 | TORCH_CHECK(n <= 32, "matrix col should be <= 32");
26 | const auto lda = m;
27 | const auto minmn = std::min(m, n);
28 | const auto ldu = m;
29 | const auto ldv = n;
30 |
31 | auto params = unique_allocate(cusolverDnCreateGesvdjInfo, cusolverDnDestroyGesvdjInfo);
32 | auto status = cusolverDnXgesvdjSetTolerance(params.get(), tol);
33 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status);
34 | status = cusolverDnXgesvdjSetMaxSweeps(params.get(), max_sweeps);
35 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status);
36 | status = cusolverDnXgesvdjSetSortEig(params.get(), is_sort);
37 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status);
38 |
39 | auto jobz = CUSOLVER_EIG_MODE_VECTOR; // compute eigenvalues and eigenvectors
40 | int lwork;
41 | auto info_ptr = unique_cuda_ptr(batch_size);
42 |
43 | if (is_double) {
44 | const auto d_A = A.data();
45 | auto d_s = s.data();
46 | const auto d_U = U.data();
47 | const auto d_V = V.data();
48 |
49 | auto status_buffer = cusolverDnDgesvdjBatched_bufferSize(
50 | handle_ptr.get(),
51 | jobz,
52 | m,
53 | n,
54 | d_A,
55 | lda,
56 | d_s,
57 | d_U,
58 | ldu,
59 | d_V,
60 | ldv,
61 | &lwork,
62 | params.get(),
63 | batch_size);
64 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status_buffer);
65 | auto work_ptr = unique_cuda_ptr(lwork);
66 |
67 | status = cusolverDnDgesvdjBatched(
68 | handle_ptr.get(),
69 | jobz,
70 | m,
71 | n,
72 | d_A,
73 | lda,
74 | d_s,
75 | d_U,
76 | ldu,
77 | d_V,
78 | ldv,
79 | work_ptr.get(),
80 | lwork,
81 | info_ptr.get(),
82 | params.get(),
83 | batch_size
84 | );
85 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status);
86 | }
87 | else {
88 | const auto d_A = A.data();
89 | auto d_s = s.data();
90 | const auto d_U = U.data();
91 | const auto d_V = V.data();
92 |
93 | auto status_buffer = cusolverDnSgesvdjBatched_bufferSize(
94 | handle_ptr.get(),
95 | jobz,
96 | m,
97 | n,
98 | d_A,
99 | lda,
100 | d_s,
101 | d_U,
102 | ldu,
103 | d_V,
104 | ldv,
105 | &lwork,
106 | params.get(),
107 | batch_size);
108 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status_buffer);
109 | auto work_ptr = unique_cuda_ptr(lwork);
110 |
111 | status = cusolverDnSgesvdjBatched(
112 | handle_ptr.get(),
113 | jobz,
114 | m,
115 | n,
116 | d_A,
117 | lda,
118 | d_s,
119 | d_U,
120 | ldu,
121 | d_V,
122 | ldv,
123 | work_ptr.get(),
124 | lwork,
125 | info_ptr.get(),
126 | params.get(),
127 | batch_size
128 | );
129 | TORCH_CHECK(CUSOLVER_STATUS_SUCCESS == status);
130 | }
131 |
132 | std::vector hinfo(batch_size);
133 | auto status_memcpy = cudaMemcpy(hinfo.data(), info_ptr.get(), sizeof(int) * batch_size, cudaMemcpyDeviceToHost);
134 | TORCH_CHECK(cudaSuccess == status_memcpy);
135 |
136 | for(int i = 0 ; i < batch_size; ++i)
137 | {
138 | if ( 0 == hinfo[i] )
139 | {
140 | continue;
141 | }
142 | else if ( 0 > hinfo[i] )
143 | {
144 | std::cout << "Error: " << -hinfo[i] << "-th parameter is wrong" << std::endl;
145 | TORCH_CHECK(false);
146 | }
147 | else
148 | {
149 | std::cout << "WARNING: matrix " << i << ", info = " << hinfo[i] << ": Jacobi method does not converge" << std::endl;
150 | }
151 | }
152 | }
153 |
154 |
155 |
156 | // https://j-towns.github.io/papers/svd-derivative.pdf
157 | //
158 | // This makes no assumption on the signs of sigma.
159 | at::Tensor batch_svd_backward(const std::vector &grads, const at::Tensor& self,
160 | bool some, bool compute_uv, const at::Tensor& raw_u, const at::Tensor& sigma, const at::Tensor& raw_v) {
161 | TORCH_CHECK(compute_uv,
162 | "svd_backward: Setting compute_uv to false in torch.svd doesn't compute singular matrices, ",
163 | "and hence we cannot compute backward. Please use torch.svd(compute_uv=True)");
164 |
165 | // A [b, m, n]
166 | // auto b = self.size(0);
167 | auto m = self.size(1);
168 | auto n = self.size(2);
169 | auto k = sigma.size(1);
170 | auto gsigma = grads[1];
171 |
172 | auto u = raw_u;
173 | auto v = raw_v;
174 | auto gu = grads[0];
175 | auto gv = grads[2];
176 |
177 | if (!some) {
178 | // We ignore the free subspace here because possible base vectors cancel
179 | // each other, e.g., both -v and +v are valid base for a dimension.
180 | // Don't assume behavior of any particular implementation of svd.
181 | u = raw_u.narrow(2, 0, k);
182 | v = raw_v.narrow(2, 0, k);
183 | if (gu.defined()) {
184 | gu = gu.narrow(2, 0, k);
185 | }
186 | if (gv.defined()) {
187 | gv = gv.narrow(2, 0, k);
188 | }
189 | }
190 | auto vt = v.transpose(1, 2);
191 |
192 | at::Tensor sigma_term;
193 | if (gsigma.defined()) {
194 | sigma_term = u.bmm(gsigma.diag_embed()).bmm(vt);
195 | } else {
196 | sigma_term = at::zeros({1}, self.options()).expand_as(self);
197 | }
198 | // in case that there are no gu and gv, we can avoid the series of kernel
199 | // calls below
200 | if (!gv.defined() && !gu.defined()) {
201 | return sigma_term;
202 | }
203 |
204 | auto ut = u.transpose(1, 2);
205 | auto im = at::eye(m, self.options()); // work if broadcast
206 | auto in = at::eye(n, self.options());
207 | auto sigma_mat = sigma.diag_embed();
208 | auto sigma_mat_inv = sigma.pow(-1).diag_embed();
209 | auto sigma_expanded_sq = sigma.pow(2).unsqueeze(1).expand_as(sigma_mat);
210 | auto F = sigma_expanded_sq - sigma_expanded_sq.transpose(1, 2);
211 | // The following two lines invert values of F, and fills the diagonal with 0s.
212 | // Notice that F currently has 0s on diagonal. So we fill diagonal with +inf
213 | // first to prevent nan from appearing in backward of this function.
214 | F.diagonal(0, -2, -1).fill_(INFINITY);
215 | F = F.pow(-1);
216 |
217 | at::Tensor u_term, v_term;
218 |
219 | if (gu.defined()) {
220 | u_term = u.bmm(F.mul(ut.bmm(gu) - gu.transpose(1, 2).bmm(u))).bmm(sigma_mat);
221 | if (m > k) {
222 | u_term = u_term + (im - u.bmm(ut)).bmm(gu).bmm(sigma_mat_inv);
223 | }
224 | u_term = u_term.bmm(vt);
225 | } else {
226 | u_term = at::zeros({1}, self.options()).expand_as(self);
227 | }
228 |
229 | if (gv.defined()) {
230 | auto gvt = gv.transpose(1, 2);
231 | v_term = sigma_mat.bmm(F.mul(vt.bmm(gv) - gvt.bmm(v))).bmm(vt);
232 | if (n > k) {
233 | v_term = v_term + sigma_mat_inv.bmm(gvt.bmm(in - v.bmm(vt)));
234 | }
235 | v_term = u.bmm(v_term);
236 | } else {
237 | v_term = at::zeros({1}, self.options()).expand_as(self);
238 | }
239 |
240 | return u_term + sigma_term + v_term;
241 | }
242 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd/include/torch_batch_svd.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | void batch_svd_forward(at::Tensor a, at::Tensor U, at::Tensor s, at::Tensor V,
6 | bool is_sort, double tol=1e-7, int max_sweeps=100, bool is_double=false);
7 | at::Tensor batch_svd_backward(const std::vector &grads, const at::Tensor& self,
8 | bool some, bool compute_uv, const at::Tensor& raw_u, const at::Tensor& sigma, const at::Tensor& raw_v);
9 |
--------------------------------------------------------------------------------
/batch_svd/torch_batch_svd/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | #include
6 | #include
7 | #include
8 | #include
9 |
10 | #define CHECK_CUDA(x) \
11 | do { \
12 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor"); \
13 | } while (0)
14 |
15 | #define CHECK_CONTIGUOUS(x) \
16 | do { \
17 | TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor"); \
18 | } while (0)
19 |
20 | #define CHECK_IS_FLOAT(x) \
21 | do { \
22 | TORCH_CHECK((x.scalar_type() == at::ScalarType::Float) || (x.scalar_type() == at::ScalarType::Half) || (x.scalar_type() == at::ScalarType::Double), \
23 | #x " must be a double, float or half tensor"); \
24 | } while (0)
25 |
26 |
27 | template // , class A = Status(*)(P), class D = Status(*)(T)>
28 | std::unique_ptr unique_allocate(Status(allocator)(T**), Status(deleter)(T*))
29 | {
30 | T* ptr;
31 | auto stat = allocator(&ptr);
32 | TORCH_CHECK(stat == success);
33 | return {ptr, deleter};
34 | }
35 |
36 | template
37 | std::unique_ptr unique_cuda_ptr(size_t len) {
38 | T* ptr;
39 | auto stat = cudaMalloc(&ptr, sizeof(T) * len);
40 | TORCH_CHECK(stat == cudaSuccess);
41 | return {ptr, cudaFree};
42 | }
43 |
--------------------------------------------------------------------------------
/cemnet_lib/__init__.py:
--------------------------------------------------------------------------------
1 | from .functions import cd_distance, mc_distance, gm_distance, closest_point
2 |
--------------------------------------------------------------------------------
/cemnet_lib/cemnet_lib.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 1.0
2 | Name: cemnet-lib
3 | Version: 0.0.0
4 | Summary: UNKNOWN
5 | Home-page: UNKNOWN
6 | Author: UNKNOWN
7 | Author-email: UNKNOWN
8 | License: UNKNOWN
9 | Description: UNKNOWN
10 | Platform: UNKNOWN
11 |
--------------------------------------------------------------------------------
/cemnet_lib/cemnet_lib.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | cemnet_lib.egg-info/PKG-INFO
3 | cemnet_lib.egg-info/SOURCES.txt
4 | cemnet_lib.egg-info/dependency_links.txt
5 | cemnet_lib.egg-info/top_level.txt
6 | src/cemnet_lib_api.cpp
7 | src/ops/ops_cuda.cpp
8 | src/ops/ops_cuda_kernel.cu
--------------------------------------------------------------------------------
/cemnet_lib/cemnet_lib.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/cemnet_lib/cemnet_lib.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | cemnet_lib_cuda
2 |
--------------------------------------------------------------------------------
/cemnet_lib/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .distances import mc_distance, cd_distance, gm_distance, closest_point
--------------------------------------------------------------------------------
/cemnet_lib/functions/distances.py:
--------------------------------------------------------------------------------
1 | import torch, cemnet_lib_cuda
2 | from torch.autograd import Function
3 |
4 | # closest point
5 | class ClosestPoint(Function):
6 | @staticmethod
7 | def forward(ctx, srcs, tgt):
8 | """
9 | input: srcs: (B, 3, N), tgt: (3, N)
10 | return: closest_points: (B, 3, N)
11 | """
12 | closest_points = torch.zeros_like(srcs).to(srcs.device)
13 | cemnet_lib_cuda.closest_point_cuda(srcs, tgt, closest_points)
14 | return closest_points
15 |
16 | @staticmethod
17 | def backward(srcs=None, tgt=None, closest_idxs=None):
18 | return None, None
19 |
20 | closest_point = ClosestPoint.apply
21 |
22 | # Chamfer distance
23 | class CD(Function):
24 | @staticmethod
25 | def forward(ctx, srcs, tgt):
26 | """
27 | srcs: (B, 3, N)
28 | tgt: (3, N)
29 | return: distances: (B)
30 | """
31 | distances = torch.zeros(len(srcs), srcs.size(2), 2).to(srcs.device)
32 | cemnet_lib_cuda.cd_distance_cuda(srcs, tgt, distances)
33 | distances = distances.sum(2).mean(1)
34 | return distances
35 |
36 | @staticmethod
37 | def backward(srcs=None, tgt=None, distances=None, r=None):
38 | return None, None
39 |
40 | cd_distance = CD.apply
41 |
42 | # Geman-McClure estimator based distance
43 | class GM(Function):
44 | @staticmethod
45 | def forward(ctx, srcs, tgt, mu):
46 | """
47 | srcs: (B, 3, N)
48 | tgt: (3, N)
49 | return: distances: (B)
50 | """
51 | distances = torch.zeros(len(srcs), srcs.size(2), 2).to(srcs.device)
52 | cemnet_lib_cuda.cd_distance_cuda(srcs, tgt, distances) # (B, N, 2)
53 | distances = ((distances * mu) / (distances + mu)).sum(2).mean(1)
54 | return distances
55 |
56 | @staticmethod
57 | def backward(srcs=None, tgt=None, distances=None, r=None):
58 | return None, None
59 |
60 | gm_distance = GM.apply
61 |
62 | # Maximum consensus based distance
63 | class MC(Function):
64 | @staticmethod
65 | def forward(ctx, srcs, tgt, epsilon):
66 | """
67 | srcs: (B, 3, N)
68 | tgt: (3, N)
69 | epsilon: float
70 | return: distances: (B)
71 | """
72 | distances = torch.zeros(len(srcs), srcs.size(2), 2).to(srcs.device)
73 | min_idxs = torch.zeros(len(srcs), srcs.size(2), 2).type(torch.int).to(srcs.device)
74 | cemnet_lib_cuda.mc_distance_cuda(srcs, tgt, distances, epsilon, min_idxs)
75 | distances = 2.0 - distances.sum(2).mean(1)
76 | return distances
77 |
78 | @staticmethod
79 | def backward(srcs=None, tgt=None, distances=None, r=None, is_min_idxs=None):
80 | return None, None
81 |
82 | mc_distance = MC.apply
--------------------------------------------------------------------------------
/cemnet_lib/setup.py:
--------------------------------------------------------------------------------
1 | #python3 setup.py install
2 | from setuptools import setup
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4 |
5 | libname = "cemnet_lib"
6 |
7 | setup(
8 | name='cemnet_lib',
9 | ext_modules=[
10 | CUDAExtension("cemnet_lib_cuda", [
11 | "src/cemnet_lib_api.cpp",
12 | "src/ops/ops_cuda.cpp",
13 | "src/ops/ops_cuda_kernel.cu",
14 | ],
15 | extra_compile_args={'cxx': ['-O2', '-I{}'.format('{}/include'.format(libname))],'nvcc': ['-O2']})
16 | ],
17 | cmdclass={'build_ext': BuildExtension}
18 | )
--------------------------------------------------------------------------------
/cemnet_lib/src/cemnet_lib_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "ops/ops_cuda_kernel.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("closest_point_cuda", &closest_point_cuda, "closest_point_cuda"); // name in python, cpp function address, docs
8 | m.def("cd_distance_cuda", &cd_distance_cuda, "cd_distance_cuda");
9 | // m.def("iou_distance_cuda", &iou_distance_cuda, "iou_distance_cuda");
10 | m.def("mc_distance_cuda", &mc_distance_cuda, "mc_distance_cuda");
11 | // m.def("cycle_distance_cuda", &cycle_distance_cuda, "cycle_distance_cuda");
12 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/closest_point/closest_point_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "closest_point_cuda_kernel.h"
6 |
7 | extern THCState *state;
8 |
9 | void closest_point_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor closest_idxs_tensor){
10 | int b = srcs_tensor.size(0);
11 | int c = srcs_tensor.size(1);
12 | int n = srcs_tensor.size(2);
13 | const float *srcs = srcs_tensor.data();
14 | const float *tgt = tgt_tensor.data();
15 | int *closest_idxs = closest_idxs_tensor.data();
16 | closest_point_cuda_launcher(b, c, n, closest_idxs, srcs, tgt);
17 | }
18 |
19 |
20 |
21 |
22 | //#include
23 | //#include
24 | //#include
25 | //#include
26 | //#include "sampling_cuda_kernel.h"
27 | //
28 | //extern THCState *state;
29 | //
30 | //void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor)
31 | //{
32 | // const float *points = points_tensor.data();
33 | // const int *idx = idx_tensor.data();
34 | // float *out = out_tensor.data();
35 | // gathering_forward_cuda_launcher(b, c, n, m, points, idx, out);
36 | //}
37 | //
38 | //void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor)
39 | //{
40 | //
41 | // const float *grad_out = grad_out_tensor.data();
42 | // const int *idx = idx_tensor.data();
43 | // float *grad_points = grad_points_tensor.data();
44 | // gathering_backward_cuda_launcher(b, c, n, m, grad_out, idx, grad_points);
45 | //}
46 | //
47 | //void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor)
48 | //{
49 | // const float *points = points_tensor.data();
50 | // float *temp = temp_tensor.data();
51 | // int *idx = idx_tensor.data();
52 | // furthestsampling_cuda_launcher(b, n, m, points, temp, idx);
53 | //}
54 |
--------------------------------------------------------------------------------
/cemnet_lib/src/closest_point/closest_point_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "closest_point_cuda_kernel.h"
3 | #include
4 |
5 | // input: srcs(B, 3, n) tgt(3, n)
6 | // output: out(b, n)
7 | //__global__ void closest_point_cuda_kernel(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt)
8 | //{
9 | // int n_block = gridDim.x;
10 | // int n_thread = blockDim.x;
11 | // int idx_block = blockIdx.x;
12 | // int idx_thread = threadIdx.x;
13 | //
14 | // if (idx_block < n && idx_thread < b)
15 | // {
16 | // float min_val = 1000.0;
17 | // int idx = -1;
18 | // for (int k = 0; k < n; k += 1)
19 | // {
20 | // float val = 0.0;
21 | // for (int i = 0; i < c; i += 1)
22 | // {
23 | // val += pow(srcs[idx_thread * c * n + i * n + idx_block] - tgt[i * n + k], 2.0);
24 | // }
25 | // if (val < min_val)
26 | // {
27 | // min_val = val;
28 | // idx = k;
29 | // }
30 | // }
31 | // closest_idxs[idx_thread * n + idx_block] = idx;
32 | // }
33 | //}
34 |
35 | //void closest_point_cuda_launcher(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt)
36 | //{
37 | // dim3 grid(n, 1, 1);
38 | // dim3 block(b, 1, 1);
39 | // closest_point_cuda_kernel<<>>(b, c, n, closest_idxs, srcs, tgt);
40 | //}
41 |
42 |
43 | // input: srcs(B, 3, n) tgt(3, n)
44 | // output: out(b, n)
45 | __global__ void closest_point_cuda_kernel(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt)
46 | {
47 | int idx_block = blockIdx.x;
48 | int idx_thread = threadIdx.x;
49 |
50 | if (idx_block < b && idx_thread < n)
51 | {
52 | float min_val = 1000.0;
53 | int idx = -1;
54 | for (int k = 0; k < n; k += 1)
55 | {
56 | float val = 0.0;
57 | for (int i = 0; i < c; i += 1)
58 | {
59 | val += pow(srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k], 2.0);
60 | }
61 | if (val < min_val)
62 | {
63 | min_val = val;
64 | idx = k;
65 | }
66 | }
67 | closest_idxs[idx_block * n + idx_thread] = idx;
68 | }
69 | }
70 |
71 | void closest_point_cuda_launcher(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt)
72 | {
73 | dim3 grid(b, 1, 1);
74 | dim3 block(n, 1, 1);
75 | closest_point_cuda_kernel<<>>(b, c, n, closest_idxs, srcs, tgt);
76 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/closest_point/closest_point_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_CUDA_KERNEL
2 | #define _SAMPLING_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void closest_point_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor closest_idxs);
8 |
9 | #ifdef __cplusplus
10 | extern "C" {
11 | #endif
12 |
13 | void closest_point_cuda_launcher(int b, int c, int n, int *closest_idxs, const float *srcs, const float *tgt);
14 |
15 | #ifdef __cplusplus
16 | }
17 | #endif
18 | #endif
19 |
20 |
21 | //void gathering_forward_cuda(int b, int c, int n, int m, at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
22 | //void gathering_backward_cuda(int b, int c, int n, int m, at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
23 | //void furthestsampling_cuda(int b, int n, int m, at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
24 | //
25 | //#ifdef __cplusplus
26 | //extern "C" {
27 | //#endif
28 | //
29 | //void gathering_forward_cuda_launcher(int b, int c, int n, int m, const float *points, const int *idx, float *out);
30 | //void gathering_backward_cuda_launcher(int b, int c, int n, int m, const float *grad_out, const int *idx, float *grad_points);
31 | //void furthestsampling_cuda_launcher(int b, int n, int m, const float *dataset, float *temp, int *idxs);
32 | //
33 | //#ifdef __cplusplus
34 | //}
35 | //#endif
36 | //#endif
37 |
--------------------------------------------------------------------------------
/cemnet_lib/src/closest_point/test.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | using namespace std;
6 | int main()
7 | {
8 | cout << "Hello, world, from Visual C++!" << endl;
9 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef _CUDA_UTILS_H
2 | #define _CUDA_UTILS_H
3 |
4 | #include
5 |
6 | #define TOTAL_THREADS 1500
7 |
8 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
9 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x, " must be contiguous ")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
11 |
12 | #define THREADS_PER_BLOCK 800
13 | #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
14 |
15 | inline int opt_n_threads(int work_size) {
16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0);
17 | return max(min(1 << pow_2, TOTAL_THREADS), 1);
18 | }
19 |
20 | inline dim3 opt_block_config(int x, int y) {
21 | // const int x_threads = opt_n_threads(x);
22 | // const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1);
23 | // dim3 block_config(x_threads, y_threads, 1);
24 | // return block_config;
25 | dim3 block_config(TOTAL_THREADS, 1, 1);
26 | return block_config;
27 | }
28 |
29 | #endif
30 |
--------------------------------------------------------------------------------
/cemnet_lib/src/distances/distances_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "distances_cuda_kernel.h"
6 |
7 | extern THCState *state;
8 |
9 | void cd_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor)
10 | {
11 | int b = srcs_tensor.size(0);
12 | int c = srcs_tensor.size(1);
13 | int n = srcs_tensor.size(2);
14 | const float *srcs = srcs_tensor.data();
15 | const float *tgts = tgt_tensor.data();
16 | float *distances = distances_tensor.data();
17 | cd_distance_cuda_launcher(b, c, n, srcs, tgts, distances);
18 | }
19 |
20 | void iou_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, float r)
21 | {
22 | int b = srcs_tensor.size(0);
23 | int c = srcs_tensor.size(1);
24 | int n = srcs_tensor.size(2);
25 | const float *srcs = srcs_tensor.data();
26 | const float *tgt = tgt_tensor.data();
27 | float *distances = distances_tensor.data();
28 | iou_distance_cuda_launcher(b, c, n, r, srcs, tgt, distances);
29 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/distances/distances_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "distances_cuda_kernel.h"
3 | #include
4 |
5 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
6 | // output: None
7 | __global__ void iou_distance_cuda_kernel(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances)
8 | {
9 | int idx_block = blockIdx.x; // b_idx
10 | int idx_thread = threadIdx.x; // n_idx
11 |
12 | if (idx_block < b && idx_thread < n)
13 | {
14 | float min_distance1, min_distance2 = 0.0, 0.0;
15 | for (int i = 0; i < n; i += 1)
16 | {
17 | float distance1, distance2 = 0.0, 0.0;
18 | for (int j = 0; j < c; j += 1)
19 | {
20 | distance1 += pow(srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i], 2.0);
21 | distance2 += pow(srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread], 2.0);
22 | }
23 | if (distance1 < min_distance1)
24 | {
25 | min_distance1 = distance1;
26 | }
27 | if (distance2 < min_distance2)
28 | {
29 | min_distance2 = distance2;
30 | }
31 | }
32 | if (min_distance1 <= r)
33 | {
34 | distances[idx_block] += (1.0 - min_distance1 / r);
35 | }
36 | if (min_distance2 <= r)
37 | {
38 | distances[idx_block] += (1.0 - min_distance2 / r);
39 | }
40 | }
41 | }
42 |
43 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances)
44 | {
45 | dim3 grid(b, 1, 1);
46 | dim3 block(n, 1, 1);
47 | iou_distance_cuda_kernel<<>>(b, c, n, r, srcs, tgt, distances);
48 | }
49 |
50 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
51 | // output: None
52 | __global__ void cd_distance_cuda_kernel(int b, int c, int n, const float *srcs, const float *tgt, float * distances)
53 | {
54 | int idx_block = blockIdx.x; // b_idx
55 | int idx_thread = threadIdx.x; // n_idx
56 |
57 | if (idx_block < b && idx_thread < n)
58 | {
59 | float min_distance1, min_distance2 = 0.0, 0.0;
60 | for (int i = 0; i < n; i += 1)
61 | {
62 | float distance1, distance2 = 0.0, 0.0;
63 | for (int j = 0; j < c; j += 1)
64 | {
65 | distance1 += pow(srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i], 2.0);
66 | distance2 += pow(srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread], 2.0);
67 | }
68 | if (distance1 < min_distance1)
69 | {
70 | min_distance1 = distance1;
71 | }
72 | if (distance2 < min_distance2)
73 | {
74 | min_distance2 = distance2;
75 | }
76 | }
77 | distances[idx_block] += (min_distance1 + min_distance2);
78 | }
79 | }
80 |
81 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances)
82 | {
83 | dim3 grid(b, 1, 1);
84 | dim3 block(n, 1, 1);
85 | cd_distance_cuda_kernel<<>>(b, c, n, srcs, tgt, distances);
86 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/distances/distances_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_CUDA_KERNEL
2 | #define _SAMPLING_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void cd_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances);
8 | void iou_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, float r);
9 |
10 | #ifdef __cplusplus
11 | extern "C" {
12 | #endif
13 |
14 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances);
15 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances);
16 |
17 | #ifdef __cplusplus
18 | }
19 | #endif
20 | #endif
--------------------------------------------------------------------------------
/cemnet_lib/src/ops/ops_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "ops_cuda_kernel.h"
6 |
7 | extern THCState *state;
8 |
9 | void closest_point_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor closest_points_tensor){
10 | int b = srcs_tensor.size(0);
11 | int c = srcs_tensor.size(1);
12 | int n = srcs_tensor.size(2);
13 | const float *srcs = srcs_tensor.data();
14 | const float *tgt = tgt_tensor.data();
15 | float *closest_points = closest_points_tensor.data();
16 | closest_point_cuda_launcher(b, c, n, closest_points, srcs, tgt);
17 | }
18 |
19 | void cd_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor)
20 | {
21 | int b = srcs_tensor.size(0);
22 | int c = srcs_tensor.size(1);
23 | int n = srcs_tensor.size(2);
24 | const float *srcs = srcs_tensor.data();
25 | const float *tgts = tgt_tensor.data();
26 | float *distances = distances_tensor.data();
27 | cd_distance_cuda_launcher(b, c, n, srcs, tgts, distances);
28 | }
29 |
30 | void mc_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, float r, at::Tensor min_idxs_tensor)
31 | {
32 | int b = srcs_tensor.size(0);
33 | int c = srcs_tensor.size(1);
34 | int n = srcs_tensor.size(2);
35 | const float *srcs = srcs_tensor.data();
36 | const float *tgt = tgt_tensor.data();
37 | float *distances = distances_tensor.data();
38 | int *min_idxs = min_idxs_tensor.data();
39 | mc_distance_cuda_launcher(b, c, n, r, srcs, tgt, distances, min_idxs);
40 | }
41 |
42 | void cycle_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, int N, at::Tensor min_idxs_tensor, at::Tensor n_distances_tensor, at::Tensor n_idxs_tensor)
43 | {
44 | int b = srcs_tensor.size(0);
45 | int c = srcs_tensor.size(1);
46 | int n = srcs_tensor.size(2);
47 | const float *srcs = srcs_tensor.data();
48 | const float *tgt = tgt_tensor.data();
49 | float *distances = distances_tensor.data();
50 | int *min_idxs = min_idxs_tensor.data();
51 | float *n_distances = n_distances_tensor.data();
52 | int *n_idxs = n_idxs_tensor.data();
53 | cycle_distance_cuda_launcher(b, c, n, N, srcs, tgt, distances, min_idxs, n_distances, n_idxs);
54 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/ops/ops_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "ops_cuda_kernel.h"
3 | #include "math.h"
4 |
5 | // input: srcs(B, 3, n) tgt(3, n)
6 | // output: out(b, n)
7 | __global__ void closest_point_cuda_kernel(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt)
8 | {
9 | int idx_block = blockIdx.x;
10 | int idx_thread = threadIdx.x;
11 |
12 | if (idx_block < b && idx_thread < n)
13 | {
14 | float min_val = 1000.0;
15 | int idx = -1;
16 | for (int k = 0; k < n; k += 1)
17 | {
18 | float val = 0.0;
19 | for (int i = 0; i < c; i += 1)
20 | {
21 | val += (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]) * (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]);
22 | }
23 | if (val < min_val)
24 | {
25 | min_val = val;
26 | idx = k;
27 | }
28 | }
29 | closest_points[idx_block * n * c + idx_thread] = tgt[idx];
30 | closest_points[idx_block * n * c + n + idx_thread] = tgt[n + idx];
31 | closest_points[idx_block * n * c + 2 * n + idx_thread] = tgt[2 * n + idx];
32 | }
33 | }
34 |
35 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt)
36 | {
37 | dim3 grid(b, 1, 1);
38 | dim3 block(n, 1, 1);
39 | closest_point_cuda_kernel<<>>(b, c, n, closest_points, srcs, tgt);
40 | }
41 |
42 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
43 | // output: None
44 | __global__ void mc_distance_cuda_kernel(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances, int *min_idxs)
45 | {
46 | int idx_block = blockIdx.x; // b_idx
47 | int idx_thread = threadIdx.x; // n_idx
48 | float rr = r * r;
49 | if (idx_block < b && idx_thread < n)
50 | {
51 | float min_distance1 = 1000.0;
52 | float min_distance2 = 1000.0;
53 | int min_idx1 = -1;
54 | int min_idx2 = -1;
55 | for (int i = 0; i < n; i += 1)
56 | {
57 | float distance1 = 0.0;
58 | float distance2 = 0.0;
59 | for (int j = 0; j < c; j += 1)
60 | {
61 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]);
62 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]);
63 | }
64 | if (distance1 < min_distance1)
65 | {
66 | min_distance1 = distance1;
67 | min_idx1 = i;
68 | }
69 | if (distance2 < min_distance2)
70 | {
71 | min_distance2 = distance2;
72 | min_idx2 = i;
73 | }
74 | }
75 | if (min_distance1 <= rr)
76 | {
77 | distances[idx_block * n * 2 + idx_thread * 2] = (1.0 - sqrt((float)(min_distance1)) / r);
78 | min_idxs[idx_block * n * 2 + idx_thread * 2] = min_idx1;
79 | }
80 | if (min_distance2 <= rr)
81 | {
82 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = (1.0 - sqrt((float)(min_distance2)) / r);
83 | min_idxs[idx_block * n * 2 + idx_thread * 2 + 1] = min_idx2;
84 | }
85 | }
86 | }
87 |
88 | void mc_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances, int *min_idxs)
89 | {
90 | dim3 grid(b, 1, 1);
91 | dim3 block(n, 1, 1);
92 | mc_distance_cuda_kernel<<>>(b, c, n, r, srcs, tgt, distances, min_idxs);
93 | }
94 |
95 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
96 | // output: None
97 | __global__ void cd_distance_cuda_kernel(int b, int c, int n, const float *srcs, const float *tgt, float * distances)
98 | {
99 | int idx_block = blockIdx.x; // b_idx
100 | int idx_thread = threadIdx.x; // n_idx
101 |
102 | if (idx_block < b && idx_thread < n)
103 | {
104 | float min_distance1 = 1000.0;
105 | float min_distance2 = 1000.0;
106 | for (int i = 0; i < n; i += 1)
107 | {
108 | float distance1 = 0.0;
109 | float distance2 = 0.0;
110 | for (int j = 0; j < c; j += 1)
111 | {
112 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]);
113 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]);
114 | }
115 | if (distance1 < min_distance1)
116 | {
117 | min_distance1 = distance1;
118 | }
119 | if (distance2 < min_distance2)
120 | {
121 | min_distance2 = distance2;
122 | }
123 | }
124 | distances[idx_block * n * 2 + idx_thread * 2] = min_distance1;
125 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = min_distance2;
126 | }
127 | }
128 |
129 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances)
130 | {
131 | dim3 grid(b, 1, 1);
132 | dim3 block(n, 1, 1);
133 | cd_distance_cuda_kernel<<>>(b, c, n, srcs, tgt, distances);
134 | }
135 |
136 |
137 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
138 | // output: None
139 | __global__ void cycle_distance_cuda_kernel(int b, int c, int n, int N, const float *srcs, const float *tgt, float *distances, int *min_idxs, float * n_distances, int * n_idxs)
140 | {
141 | int idx_block = blockIdx.x; // b_idx
142 | int idx_thread = threadIdx.x; // n_idx
143 | if (idx_block < b && idx_thread < n)
144 | {
145 |
146 | for (int i = 0; i < N; i += 1)
147 | {
148 | n_distances[i] = 100.0;
149 | n_idxs[i] = -1;
150 | }
151 |
152 | // cal xy_idxs and xy_distance
153 | float min_distance1 = 1000.0;
154 | int min_idx1 = -1;
155 | for (int i = 0; i < n; i += 1)
156 | {
157 | float distance1 = 0.0;
158 | for (int j = 0; j < c; j += 1)
159 | {
160 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]);
161 | }
162 | if (distance1 < min_distance1)
163 | {
164 | min_distance1 = distance1;
165 | min_idx1 = i;
166 | }
167 | }
168 |
169 | // cal yx_n_idxs and yx_n_distance
170 | int cnt = 0;
171 | // for (int i = 0; i < n; i += 1)
172 | // {
173 | // if (cnt >= N)
174 | // {
175 | // break;
176 | // }
177 | // float distance2 = 0.0;
178 | // for (int j = 0; j < c; j += 1)
179 | // {
180 | // distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + min_idx1]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + min_idx1]);
181 | // }
182 | // if (distance2 < distance1)
183 | // {
184 | // cnt += 1;
185 | // }
186 | // }
187 |
188 | if (cnt < N)
189 | {
190 | distances[idx_block * n + idx_thread] = - min_distance1;
191 | min_idxs[idx_block * n + idx_thread] = min_idx1;
192 | }
193 | }
194 | }
195 |
196 | void cycle_distance_cuda_launcher(int b, int c, int n, int N, const float *srcs, const float *tgt, float * distances, int *min_idxs, float * n_distances, int * n_idxs)
197 | {
198 | dim3 grid(b, 1, 1);
199 | dim3 block(n, 1, 1);
200 | cycle_distance_cuda_kernel<<>>(b, c, n, N, srcs, tgt, distances, min_idxs, n_distances, n_idxs);
201 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/ops/ops_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_CUDA_KERNEL
2 | #define _SAMPLING_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void closest_point_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor closest_points);
8 | void cd_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances);
9 | void mc_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, float r, at::Tensor min_idxs);
10 | void cycle_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, int N, at::Tensor min_idxs, at::Tensor n_distances, at::Tensor n_idxs);
11 |
12 | #ifdef __cplusplus
13 | extern "C" {
14 | #endif
15 |
16 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt);
17 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances);
18 | void mc_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances, int *min_idxs);
19 | void cycle_distance_cuda_launcher(int b, int c, int n, int N, const float *srcs, const float *tgt, float * distances, int *min_idxs, float * n_distances, int * n_idxs);
20 |
21 | #ifdef __cplusplus
22 | }
23 | #endif
24 | #endif
--------------------------------------------------------------------------------
/cemnet_lib/src/ops_saved/ops_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "ops_cuda_kernel.h"
6 |
7 | extern THCState *state;
8 |
9 | void closest_point_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor closest_points_tensor){
10 | int b = srcs_tensor.size(0);
11 | int c = srcs_tensor.size(1);
12 | int n = srcs_tensor.size(2);
13 | const float *srcs = srcs_tensor.data();
14 | const float *tgt = tgt_tensor.data();
15 | float *closest_points = closest_points_tensor.data();
16 | closest_point_cuda_launcher(b, c, n, closest_points, srcs, tgt);
17 | }
18 |
19 | void cd_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor)
20 | {
21 | int b = srcs_tensor.size(0);
22 | int c = srcs_tensor.size(1);
23 | int n = srcs_tensor.size(2);
24 | const float *srcs = srcs_tensor.data();
25 | const float *tgts = tgt_tensor.data();
26 | float *distances = distances_tensor.data();
27 | cd_distance_cuda_launcher(b, c, n, srcs, tgts, distances);
28 | }
29 |
30 | void iou_distance_cuda(at::Tensor srcs_tensor, at::Tensor tgt_tensor, at::Tensor distances_tensor, float r)
31 | {
32 | int b = srcs_tensor.size(0);
33 | int c = srcs_tensor.size(1);
34 | int n = srcs_tensor.size(2);
35 | const float *srcs = srcs_tensor.data();
36 | const float *tgt = tgt_tensor.data();
37 | float *distances = distances_tensor.data();
38 | iou_distance_cuda_launcher(b, c, n, r, srcs, tgt, distances);
39 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/ops_saved/ops_cuda_kernel.cu:
--------------------------------------------------------------------------------
1 | #include "../cuda_utils.h"
2 | #include "ops_cuda_kernel.h"
3 | #include "math.h"
4 |
5 | // input: srcs(B, 3, n) tgt(3, n)
6 | // output: out(b, n)
7 | __global__ void closest_point_cuda_kernel(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt)
8 | {
9 | int idx_block = blockIdx.x;
10 | int idx_thread = threadIdx.x;
11 |
12 | if (idx_block < b && idx_thread < n)
13 | {
14 | float min_val = 1000.0;
15 | int idx = -1;
16 | for (int k = 0; k < n; k += 1)
17 | {
18 | float val = 0.0;
19 | for (int i = 0; i < c; i += 1)
20 | {
21 | val += (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]) * (srcs[idx_block * c * n + i * n + idx_thread] - tgt[i * n + k]);
22 | }
23 | if (val < min_val)
24 | {
25 | min_val = val;
26 | idx = k;
27 | }
28 | }
29 | closest_points[idx_block * n * c + idx_thread] = tgt[idx];
30 | closest_points[idx_block * n * c + n + idx_thread] = tgt[n + idx];
31 | closest_points[idx_block * n * c + 2 * n + idx_thread] = tgt[2 * n + idx];
32 | }
33 | }
34 |
35 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt)
36 | {
37 | dim3 grid(b, 1, 1);
38 | dim3 block(n, 1, 1);
39 | closest_point_cuda_kernel<<>>(b, c, n, closest_points, srcs, tgt);
40 | }
41 |
42 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
43 | // output: None
44 | __global__ void iou_distance_cuda_kernel(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances)
45 | {
46 | int idx_block = blockIdx.x; // b_idx
47 | int idx_thread = threadIdx.x; // n_idx
48 | float rr = r * r;
49 | if (idx_block < b && idx_thread < n)
50 | {
51 | float min_distance1 = 1000.0;
52 | float min_distance2 = 1000.0;
53 | for (int i = 0; i < n; i += 1)
54 | {
55 | float distance1 = 0.0;
56 | float distance2 = 0.0;
57 | for (int j = 0; j < c; j += 1)
58 | {
59 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]);
60 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]);
61 | }
62 | if (distance1 < min_distance1)
63 | {
64 | min_distance1 = distance1;
65 | }
66 | if (distance2 < min_distance2)
67 | {
68 | min_distance2 = distance2;
69 | }
70 | }
71 | if (min_distance1 <= rr)
72 | {
73 | distances[idx_block * n * 2 + idx_thread * 2] = (1.0 - sqrt((float)(min_distance1)) / r);
74 | }
75 | if (min_distance2 <= rr)
76 | {
77 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = (1.0 - sqrt((float)(min_distance2)) / r);
78 | }
79 | }
80 | }
81 |
82 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances)
83 | {
84 | dim3 grid(b, 1, 1);
85 | dim3 block(n, 1, 1);
86 | iou_distance_cuda_kernel<<>>(b, c, n, r, srcs, tgt, distances);
87 | }
88 |
89 | // input: srcs(b, 3, n) tgt(3, n) distances(b)
90 | // output: None
91 | __global__ void cd_distance_cuda_kernel(int b, int c, int n, const float *srcs, const float *tgt, float * distances)
92 | {
93 | int idx_block = blockIdx.x; // b_idx
94 | int idx_thread = threadIdx.x; // n_idx
95 |
96 | if (idx_block < b && idx_thread < n)
97 | {
98 | float min_distance1 = 1000.0;
99 | float min_distance2 = 1000.0;
100 | for (int i = 0; i < n; i += 1)
101 | {
102 | float distance1 = 0.0;
103 | float distance2 = 0.0;
104 | for (int j = 0; j < c; j += 1)
105 | {
106 | distance1 += (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]) * (srcs[idx_block * c * n + j * n + idx_thread] - tgt[j * n + i]);
107 | distance2 += (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]) * (srcs[idx_block * c * n + j * n + i] - tgt[j * n + idx_thread]);
108 | }
109 | if (distance1 < min_distance1)
110 | {
111 | min_distance1 = distance1;
112 | }
113 | if (distance2 < min_distance2)
114 | {
115 | min_distance2 = distance2;
116 | }
117 | }
118 | distances[idx_block * n * 2 + idx_thread * 2] = min_distance1;
119 | distances[idx_block * n * 2 + idx_thread * 2 + 1] = min_distance2;
120 | }
121 | }
122 |
123 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances)
124 | {
125 | dim3 grid(b, 1, 1);
126 | dim3 block(n, 1, 1);
127 | cd_distance_cuda_kernel<<>>(b, c, n, srcs, tgt, distances);
128 | }
--------------------------------------------------------------------------------
/cemnet_lib/src/ops_saved/ops_cuda_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _SAMPLING_CUDA_KERNEL
2 | #define _SAMPLING_CUDA_KERNEL
3 | #include
4 | #include
5 | #include
6 |
7 | void closest_point_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor closest_points);
8 | void cd_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances);
9 | void iou_distance_cuda(at::Tensor srcs, at::Tensor tgt, at::Tensor distances, float r);
10 |
11 | #ifdef __cplusplus
12 | extern "C" {
13 | #endif
14 |
15 | void closest_point_cuda_launcher(int b, int c, int n, float *closest_points, const float *srcs, const float *tgt);
16 | void cd_distance_cuda_launcher(int b, int c, int n, const float *srcs, const float *tgt, float * distances);
17 | void iou_distance_cuda_launcher(int b, int c, int n, float r, const float *srcs, const float *tgt, float * distances);
18 |
19 | #ifdef __cplusplus
20 | }
21 | #endif
22 | #endif
--------------------------------------------------------------------------------
/cemnet_lib/src/prcem_lib_api.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "ops/ops_cuda_kernel.h"
5 |
6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7 | m.def("closest_point_cuda", &closest_point_cuda, "closest_point_cuda"); // name in python, cpp function address, docs
8 | m.def("cd_distance_cuda", &cd_distance_cuda, "cd_distance_cuda");
9 | // m.def("iou_distance_cuda", &iou_distance_cuda, "iou_distance_cuda");
10 | m.def("mc_distance_cuda", &mc_distance_cuda, "mc_distance_cuda");
11 | // m.def("cycle_distance_cuda", &cycle_distance_cuda, "cycle_distance_cuda");
12 | }
--------------------------------------------------------------------------------
/cems/__pycache__/base_cem.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/cems/__pycache__/base_cem.cpython-36.pyc
--------------------------------------------------------------------------------
/cems/__pycache__/guided_cem.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/cems/__pycache__/guided_cem.cpython-36.pyc
--------------------------------------------------------------------------------
/cems/base_cem.py:
--------------------------------------------------------------------------------
1 | import torch, cemnet_lib
2 | from utils.euler2mat import euler2mat_torch
3 | from utils.commons import stack_transforms_seq
4 | from utils.transform_pc import transform_pc_torch
5 | from utils.batch_icp import batch_icp
6 |
7 | class BaseCEM(torch.nn.Module):
8 | def __init__(self, opts):
9 | super(BaseCEM, self).__init__()
10 | self.opts = opts
11 | self.n_iters = opts.cem.n_iters
12 | self.n_candidates = opts.cem.n_candidates[0]
13 | self.planning_horizon = opts.cem.planning_horizon
14 | self.n_elites = opts.cem.n_elites
15 | self.r_size, self.t_size, self.action_size = 3, 3, 6
16 | self.min_r, self.max_r = opts.cem.r_range
17 | self.min_t, self.max_t = opts.cem.t_range[opts.db_nm]
18 | self.r_init_sigma, self.t_init_sigma = opts.cem.init_sigma[opts.db_nm]
19 | self.recorder = opts.recorder
20 | self.device = opts.device
21 |
22 | def init_distrib(self):
23 | # mus
24 | mus_r = torch.zeros([self.planning_horizon, self.batch_size, 1, self.r_size]) + (self.min_r + self.max_r) / 2. # (H, B, 1, 3)
25 | mus_t = torch.zeros([self.planning_horizon, self.batch_size, 1, self.t_size]) + (self.min_t + self.max_t) / 2. # (H, B, 1, 3)
26 | self.mus = torch.cat([mus_r, mus_t], 3).to(self.device) # (H, B, 1, 6)
27 | init_r_t = stack_transforms_seq(self.mus.squeeze(2)) # (B, 6)
28 | self.r_init = init_r_t[:, :3] # (B, 3)
29 | self.t_init = init_r_t[:, 3:] # (B, 3)
30 | # sigmas
31 | self.sigmas = torch.ones_like(self.mus).to(self.device) # (H, B, 1, 6)
32 | self.sigmas[:, :, :, :3] *= self.r_init_sigma # (H, B, 1, 3)
33 | self.sigmas[:, :, :, 3:] *= self.t_init_sigma # (H, B, 1, 3)
34 | self.mus_init = self.mus
35 | self.sigmas_init = self.sigmas
36 |
37 | def sample_candidates(self):
38 | self.candidates = self.mus + self.sigmas * torch.randn(self.planning_horizon, self.batch_size, self.n_candidates, self.action_size).to(self.device) # (H, B, C, 6)
39 |
40 | def perform_candidates(self):
41 | stack_candidates = stack_transforms_seq(self.candidates) # (B * C, 6)
42 | Rs = euler2mat_torch(stack_candidates[:, :self.r_size], seq="zyx") # (B * C, 3, 3)
43 | ts = stack_candidates[:, self.r_size:] # (B * C, 3)
44 | transformed_srcs_repeat = transform_pc_torch(self.srcs_repeat, Rs, ts).reshape(self.batch_size, self.n_candidates, 3, -1) # (B, C, 3, N)
45 | return transformed_srcs_repeat
46 |
47 | def distance(self, srcs, tgt):
48 | if self.opts.cem.metric_type[0] == "MC":
49 | return cemnet_lib.mc_distance(srcs, tgt, self.opts.cem.metric_type[1])
50 | if self.opts.cem.metric_type[0] == "CD":
51 | return cemnet_lib.cd_distance(srcs, tgt)
52 | if self.opts.cem.metric_type[0] == "GM":
53 | return cemnet_lib.gm_distance(srcs, tgt, self.opts.cem.metric_type[1]) * 100
54 |
55 | def evaluate_candidates(self):
56 | transformed_srcs_repeat = self.perform_candidates() # (B, C, 3, N)
57 | ## current reward
58 | self.alignment_errors = torch.zeros(self.batch_size, self.n_candidates).to(self.device) # (B, C)
59 | for k, (transformed_src_repeat, tgt) in enumerate(zip(transformed_srcs_repeat, self.tgts)):
60 | self.alignment_errors[k] = self.distance(transformed_src_repeat, tgt).detach()
61 | ## future potential
62 | if self.opts.cem.is_fused_reward[0] and self.k < self.opts.cem.is_fused_reward[2]:
63 | alpha = self.opts.cem.is_fused_reward[1]
64 | for k, (transformed_src_repeat, tgt) in enumerate(zip(transformed_srcs_repeat, self.tgts)):
65 | Rs_icp, ts_icp = batch_icp(self.opts, transformed_src_repeat, tgt) # (C, 3, 3), (C, 3)
66 | transformed_src_repeat_icp = transform_pc_torch(transformed_src_repeat, Rs_icp, ts_icp) # (C, 3, N)
67 | potential = self.distance(transformed_src_repeat_icp, tgt).detach()
68 | self.alignment_errors[k] = alpha * self.alignment_errors[k] + (1- alpha) * potential
69 | self.alignment_errors = self.alignment_errors.detach()
70 | return self.alignment_errors
71 |
72 | def update_distrib(self):
73 | self.mus = self.elites.mean(dim=2, keepdim=True) # mus: [H, B, 1, 6]
74 | self.sigmas = self.elites.std(dim=2, unbiased=False, keepdim=True) # sigmas: [H, B, 1, 6]
75 |
76 | def elite_selection(self):
77 | self.candidates = self.candidates.reshape(self.planning_horizon, self.batch_size * self.n_candidates, self.action_size) # (H, B * C, 6)
78 | self.alignment_errors = self.evaluate_candidates() # (B, C)
79 | self.elite_errors, self.elite_idxs = self.alignment_errors.topk(self.n_elites, dim=1, largest=False, sorted=False) # (B, K)
80 | self.elite_idxs += self.n_candidates * torch.arange(0, self.batch_size, dtype=torch.int64).reshape(self.batch_size, 1).to(self.device) # topk: (B, K)
81 | self.elite_idxs = self.elite_idxs.view(-1) # (B * K, )
82 | self.elites = self.candidates[:, self.elite_idxs].reshape(self.planning_horizon, self.batch_size, self.n_elites, self.action_size) # (H, B, K, 6)
83 | return self.elites
84 |
85 | def forward(self, srcs, tgts):
86 | """
87 | :param ids: ids
88 | :param srcs: (B, 3, N), torch.FloatTensor
89 | :param tgts: (B, 3, N), torch.FloatTensor
90 | :return: None
91 | """
92 | self.batch_size, self.n_points = srcs.size(0), srcs.size(2)
93 | self.srcs, self.tgts = srcs.to(self.device), tgts.to(self.device) # (B, 3, N)
94 | self.srcs_repeat = self.srcs.unsqueeze(1).repeat(1, self.n_candidates, 1, 1).reshape(self.batch_size * self.n_candidates,
95 | 3, self.n_points) # (B * C, 3, N)
96 | self.init_distrib() # (H, B, 1, 6)
97 | self.k = 0
98 | for iter_idx in range(self.n_iters):
99 |
100 | # 1. sample candidates
101 | self.sample_candidates()
102 |
103 | # 2. elite selection
104 | self.elite_selection()
105 |
106 | # 3. update distribution
107 | self.update_distrib()
108 |
109 | self.k += 1
110 |
111 | actions = stack_transforms_seq(self.mus.squeeze(2)) # (B, 6)
112 | return {"r": actions[:, :3], "t": actions[:, 3:], "r_init": self.r_init, "t_init": self.t_init}
--------------------------------------------------------------------------------
/cems/guided_cem.py:
--------------------------------------------------------------------------------
1 | from cems.base_cem import BaseCEM
2 | from modules.dcp_net import DCPNet
3 | from modules.sparsemax import Sparsemax
4 | from utils.commons import stack_transforms_seq
5 | import pdb, torch
6 |
7 | class VanillaCEM(BaseCEM):
8 |
9 | def __init__(self, opts):
10 | super(VanillaCEM, self).__init__(opts)
11 | self.opts = opts
12 |
13 | class GuidedCEM(BaseCEM):
14 |
15 | def __init__(self, opts):
16 | super(GuidedCEM, self).__init__(opts)
17 | self.coarse_policy = DCPNet(opts)
18 | self.top_k = Sparsemax(dim=1)
19 | self.opts = opts
20 |
21 | def add_exploration_noise(self, mus, sigmas):
22 | mus_ = torch.zeros_like(mus).to(self.device)
23 | sigmas_ = torch.ones_like(mus_).to(self.device) # (B, 6)
24 | sigmas_[:, :3] *= self.r_init_sigma
25 | sigmas_[:, 3:] *= self.t_init_sigma
26 | r = self.opts.exploration_weight
27 | mus = (1 - r) * mus + r * mus_
28 | sigmas = (1 - r) * sigmas + r * sigmas_
29 | return mus, sigmas
30 |
31 | def init_distrib(self):
32 | mus, sigmas = self.coarse_policy(self.srcs, self.tgts, is_sigma=True) # (H, B, 6)
33 | if not self.opts.is_train:
34 | mus, sigmas = self.add_exploration_noise(mus, sigmas)
35 | self.mus, self.sigmas = mus.unsqueeze(2), sigmas.unsqueeze(2) # (H, B, 1, 6)
36 | self.mus_init, self.sigmas_init = self.mus, self.sigmas
37 | init_r_t = stack_transforms_seq(self.mus.squeeze(2)).clone() # (B, 6)
38 | self.r_init = init_r_t[:, :3] # (B, 3)
39 | self.t_init = init_r_t[:, 3:] # (B, 3)
40 |
41 | def elite_selection(self):
42 | super().elite_selection()
43 | self.sparsemax_probs = self.top_k(-self.alignment_errors).reshape(1, self.batch_size, self.n_candidates, 1) # (1, B, C, 1)
44 | self.elites = self.candidates.reshape(self.planning_horizon, self.batch_size, self.n_candidates, self.action_size) * self.sparsemax_probs # (H, B, C, 6)
45 |
46 | def update_distrib(self):
47 | self.mus = self.elites.sum(dim=2, keepdim=True) # (H, B, 1, 6)
48 | self.sigmas = (((self.candidates.reshape(self.planning_horizon, self.batch_size, self.n_candidates, self.action_size) -
49 | self.mus) ** 2) * self.sparsemax_probs).sum(dim=2, keepdim=True).sqrt() # (H, B, 1, 6)
50 |
51 | def load_model(self, model_path):
52 | self.load_state_dict(torch.load(model_path))
53 | return self
--------------------------------------------------------------------------------
/datasets/__pycache__/base_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/datasets/__pycache__/base_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/datasets/__pycache__/dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/get_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/datasets/__pycache__/get_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | from scipy.spatial.transform import Rotation
3 | from sklearn.neighbors import NearestNeighbors
4 | from scipy.spatial.distance import minkowski
5 | import numpy as np, pdb, pickle
6 |
7 | def load_data(path):
8 | file = open(path, "rb")
9 | data = pickle.load(file)
10 | file.close()
11 | return data
12 |
13 | class BaseDataset(Dataset):
14 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1):
15 | self.opts = opts
16 | self.cls_idx = cls_idx
17 | self.n_points = opts.db.n_points
18 | self.n_sub_points = opts.db.n_sub_points
19 | self.partition = partition
20 | self.gaussian_noise = opts.db.gaussian_noise
21 | self.unseen = opts.db.unseen
22 | self.factor = opts.db.factor
23 | self.is_normal = is_normal
24 | self.pcs = None
25 |
26 | def load_data(self, opts, partition):
27 | db_path = opts.db.path
28 | db = load_data(db_path)[partition]
29 | if opts.infos.db_nm == "scene7":
30 | db["normal_pcs"] = db["normal_pcs"].transpose(0, 2, 1)
31 | pcs = db["normal_pcs"][:, :, :3]
32 | lbs = db["lbs"]
33 | if self.cls_idx != -1:
34 | pcs = pcs[lbs == self.cls_idx]
35 | lbs = lbs[lbs == self.cls_idx]
36 | return pcs, lbs
37 |
38 | def jitter_pointcloud(self, pc, sigma=0.01, clip=0.05):
39 | pc += np.clip(sigma * np.random.randn(*pc.shape), -1 * clip, clip)
40 | return pc
41 |
42 | def farthest_subsample_points(self, pc1, pc2, n_sub_points=768):
43 | pc1, pc2 = pc1.T, pc2.T
44 | nbrs1 = NearestNeighbors(n_neighbors=n_sub_points, algorithm='auto',
45 | metric=lambda x, y: minkowski(x, y)).fit(pc1)
46 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
47 | idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((n_sub_points,))
48 | nbrs2 = NearestNeighbors(n_neighbors=n_sub_points, algorithm='auto',
49 | metric=lambda x, y: minkowski(x, y)).fit(pc2)
50 | random_p2 = random_p1
51 | idx2 = nbrs2.kneighbors(random_p2, return_distance=False).reshape((n_sub_points,))
52 | return pc1[idx1, :].T, pc2[idx2, :].T
53 |
54 | def __getitem__(self, item):
55 | pc = self.pcs[item][:self.opts.db.n_points] # (N, 3)
56 | if self.partition != 'train':
57 | np.random.seed(item)
58 |
59 | angle_x = np.random.uniform(0., np.pi / self.factor)
60 | angle_y = np.random.uniform(0., np.pi / self.factor)
61 | angle_z = np.random.uniform(0., np.pi / self.factor)
62 | t_lb = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5)])
63 |
64 | pc1 = pc.T # (3, N)
65 | r_lb = np.array([angle_z, angle_y, angle_x])
66 | pc2 = Rotation.from_euler('zyx', r_lb).apply(pc1.T).T + np.expand_dims(t_lb, axis=1)
67 |
68 | pc1 = np.random.permutation(pc1.T).T
69 | pc2 = np.random.permutation(pc2.T).T
70 |
71 | if self.gaussian_noise:
72 | pc1 = self.jitter_pointcloud(pc1)
73 | pc2 = self.jitter_pointcloud(pc2)
74 |
75 | if self.n_points != self.n_sub_points:
76 | pc1, pc2 = self.farthest_subsample_points(pc1, pc2, n_sub_points = self.n_sub_points)
77 |
78 | return pc1.astype('float32'), pc2.astype('float32'), r_lb.astype('float32'), t_lb.astype('float32')
79 |
80 | def __len__(self):
81 | return len(self.pcs)
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | from datasets.base_dataset import BaseDataset
2 | from utils.commons import load_data
3 |
4 | class DB_ModelNet40(BaseDataset):
5 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1):
6 | super(DB_ModelNet40, self).__init__(opts, partition, is_normal, cls_idx)
7 | self.pcs, self.lbs = self.load_data(opts, partition) # (B, N, 3), (B, )
8 |
9 | def load_data(self, opts, partition):
10 | db_path = opts.db.path
11 | db = load_data(db_path)[partition]
12 | pcs = db["normal_pcs"][:, :, :3]
13 | lbs = db["lbs"]
14 | if self.unseen:
15 | if self.partition == 'test':
16 | pcs = pcs[lbs >= 20]
17 | lbs = lbs[lbs >= 20]
18 | elif self.partition == 'train':
19 | pcs = pcs[lbs < 20]
20 | lbs = lbs[lbs < 20]
21 | if self.cls_idx != -1:
22 | pcs = pcs[lbs == self.cls_idx]
23 | lbs = lbs[lbs == self.cls_idx]
24 | return pcs, lbs
25 |
26 | class DB_7Scene(BaseDataset):
27 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1):
28 | super(DB_7Scene, self).__init__(opts, partition, is_normal, cls_idx)
29 | self.pcs, self.lbs = self.load_data(opts, partition) # (B, N, 3), (B, )
30 |
31 | def load_data(self, opts, partition):
32 | db_path = opts.db.path
33 | db = load_data(db_path)[partition]
34 | db["normal_pcs"] = db["normal_pcs"].transpose(0, 2, 1)
35 | pcs = db["normal_pcs"][:, :, :3]
36 | lbs = db["lbs"]
37 | if self.cls_idx != -1:
38 | pcs = pcs[lbs == self.cls_idx]
39 | lbs = lbs[lbs == self.cls_idx]
40 | return pcs, lbs
41 |
42 | class DB_ICL_NUIM(BaseDataset):
43 | def __init__(self, opts, partition, is_normal=False, cls_idx=-1):
44 | super(DB_ICL_NUIM, self).__init__(opts, partition, is_normal, cls_idx)
45 | self.pcs, self.lbs = self.load_data(opts, partition) # (B, N, 3), (B, )
46 |
47 | def load_data(self, opts, partition):
48 | db_path = opts.db.path
49 | db = load_data(db_path)[partition]
50 | db["normal_pcs"] = db["normal_pcs"].transpose(0, 2, 1)
51 | pcs = db["normal_pcs"][:, :, :3]
52 | lbs = db["lbs"]
53 | if self.cls_idx != -1:
54 | pcs = pcs[lbs == self.cls_idx]
55 | lbs = lbs[lbs == self.cls_idx]
56 | return pcs, lbs
--------------------------------------------------------------------------------
/datasets/get_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from datasets.dataset import DB_ModelNet40, DB_7Scene, DB_ICL_NUIM
3 |
4 | def get_dataset(opts, db_nm, partition, batch_size, shuffle, drop_last, is_normal=False, cls_idx=-1, n_cores=1):
5 | loader, db = None, None
6 | if db_nm == "modelnet40":
7 | db = DB_ModelNet40(opts, partition, is_normal, cls_idx=cls_idx)
8 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores)
9 | if db_nm == "scene7":
10 | db = DB_7Scene(opts, partition, is_normal, cls_idx=cls_idx)
11 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores)
12 | if db_nm == "icl_nuim":
13 | db = DB_ICL_NUIM(opts, partition, is_normal, cls_idx=cls_idx)
14 | loader = DataLoader(db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=n_cores)
15 | return loader, db
16 |
--------------------------------------------------------------------------------
/datasets/process_dataset.py:
--------------------------------------------------------------------------------
1 | import h5py, sys, os, glob, numpy as np, pdb
2 | sys.path.insert(0, "..")
3 | from datasets.utils.gen_normal import gen_normal
4 | from datasets.utils.commons import save_data
5 | import datasets.se_math.mesh as mesh, datasets.se_math.transforms as transforms, torchvision, torch
6 |
7 | # ============ ModelNet40 ==============
8 | def process_modelnet40():
9 | db_dir = "/test/datasets/registration/modelnet40"
10 | res = {"train": {"normal_pcs": [], "lbs": []},
11 | "test": {"normal_pcs": [], "lbs": []}}
12 | for partition in ["train", "test"]:
13 | for db_path in glob.glob(os.path.join(db_dir, "ply_data_%s*.h5" % partition)):
14 | f = h5py.File(db_path)
15 | normal_pcs_batch = np.concatenate([f['data'][:], f['normal'][:]], axis=2) # (B, N, 6)
16 | lbs_batch = f['label'][:].astype('int64') # (B, 1)
17 | f.close()
18 | res[partition]["normal_pcs"].append(normal_pcs_batch)
19 | res[partition]["lbs"].append(lbs_batch)
20 | print("--- %s, finished. ---" % (db_path))
21 | res[partition]["normal_pcs"] = np.concatenate(res[partition]["normal_pcs"], 0)
22 | res[partition]["lbs"] = np.concatenate(res[partition]["lbs"], 0).reshape(-1)
23 |
24 | save_data(os.path.join(db_dir, "modelnet40_normal_n2048.pth"), res)
25 |
26 | # ============ 7Scene ==============
27 | def get_cls(cls_path):
28 | cls = [line.rstrip('\n') for line in open(cls_path)]
29 | cls.sort()
30 | cls_to_lb = {cls[i]: i for i in range(len(cls))}
31 | return cls, cls_to_lb
32 |
33 | def process_7scene():
34 | db_dir = "/test/datasets/registration/7scene"
35 | res = {"train": {"normal_pcs": [], "lbs": []},
36 | "test": {"normal_pcs": [], "lbs": []}}
37 | transform = torchvision.transforms.Compose([transforms.Mesh2Points(),
38 | transforms.OnUnitCube(),
39 | transforms.Resampler(2048)])
40 | for partition in ["train", "test"]:
41 | pc_dir = os.path.join(db_dir, "7scene")
42 | cls_path = os.path.join(db_dir, "categories/7scene_%s.txt" % (partition))
43 | cls_nms, nms_to_lbs = get_cls(cls_path)
44 | for cls_nm in cls_nms:
45 | pc_paths = os.path.join(pc_dir, cls_nm, "*.ply")
46 | lb = nms_to_lbs[cls_nm]
47 | for pc_path in glob.glob(pc_paths):
48 | pc = mesh.plyread(pc_path)
49 | pc = transform(pc).numpy().transpose([1, 0]) # [3, 2048]
50 | normal_pc = gen_normal(pc[None, :])
51 | res[partition]["normal_pcs"].append(normal_pc)
52 | res[partition]["lbs"].append(lb)
53 | print("--- %s, finished. ---" % (pc_path))
54 | res[partition]["normal_pcs"] = np.concatenate(res[partition]["normal_pcs"], 0)
55 | res[partition]["lbs"] = np.asarray(res[partition]["lbs"])
56 | save_data(os.path.join(db_dir, "7scene_normal_n2048.pth"), res)
57 |
58 | # ============ ICL-NUIM ==============
59 | def process_icl_nuim():
60 | db_dir = "/test/datasets/registration/icl_nuim"
61 | res = {"train": {"normal_pcs": [], "lbs": []},
62 | "test": {"normal_pcs": [], "lbs": []}}
63 | transform = torchvision.transforms.Compose([transforms.OnUnitCube(),
64 | transforms.Resampler(2048)])
65 | for partition in ["train", "test"]:
66 | db_path = os.path.join(db_dir, "icl_nuim_%s.h5" % partition)
67 | if partition == "train":
68 | f = h5py.File(db_path, "r")
69 | pcs = f['points'][...]
70 | for idx, pc in enumerate(pcs):
71 | pc = transform(torch.FloatTensor(pc)).numpy().transpose([1, 0])
72 | normal_pc = gen_normal(pc[None, :])
73 | res[partition]["normal_pcs"].append(normal_pc)
74 | res[partition]["lbs"].append(idx)
75 | elif partition == "test":
76 | f = h5py.File(db_path, "r")
77 | pcs = f['source'][...]
78 | # tgt_pcs = f['target'][...]
79 | # transforms = f['transform'][...]
80 | for idx in range(len(pcs)):
81 | pc = pcs[idx]
82 | pc = transform(torch.FloatTensor(pc)).numpy().transpose([1, 0])
83 | normal_pc = gen_normal(pc[None, :])
84 | res[partition]["normal_pcs"].append(normal_pc)
85 | res[partition]["lbs"].append(idx)
86 | res[partition]["normal_pcs"] = np.concatenate(res[partition]["normal_pcs"], 0)
87 | res[partition]["lbs"] = np.asarray(res[partition]["lbs"])
88 |
89 | save_data(os.path.join(db_dir, "ic_nuim_normal_n2048.pth"), res)
90 |
91 | if __name__ == '__main__':
92 | process_modelnet40()
93 | # process_7scene()
94 | # process_icl_nuim()
--------------------------------------------------------------------------------
/datasets/se_math/__init__.py:
--------------------------------------------------------------------------------
1 | from . import invmat, se3, sinc, so3, mesh, transforms
--------------------------------------------------------------------------------
/datasets/se_math/invmat.py:
--------------------------------------------------------------------------------
1 | """ inverse matrix """
2 |
3 | import torch
4 |
5 |
6 | def batch_inverse(x):
7 | """ M(n) -> M(n); x -> x^-1 """
8 | batch_size, h, w = x.size()
9 | assert h == w
10 | y = torch.zeros_like(x)
11 | for i in range(batch_size):
12 | y[i, :, :] = x[i, :, :].inverse()
13 | return y
14 |
15 |
16 | def batch_inverse_dx(y):
17 | """ backward """
18 | # Let y(x) = x^-1.
19 | # compute dy
20 | # dy = dy(j,k)
21 | # = - y(j,m) * dx(m,n) * y(n,k)
22 | # = - y(j,m) * y(n,k) * dx(m,n)
23 | # therefore,
24 | # dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
25 | batch_size, h, w = y.size()
26 | assert h == w
27 | # compute dy(j,k,m,n) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
28 | # = - (y(j,:))' * y'(k,:)
29 | yl = y.repeat(1, 1, h).view(batch_size * h * h, h, 1)
30 | yr = y.transpose(1, 2).repeat(1, h, 1).view(batch_size * h * h, 1, h)
31 | dy = - yl.bmm(yr).view(batch_size, h, h, h, h)
32 |
33 | # compute dy(m,n,j,k) = dy(j,k)/dx(m,n) = - y(j,m) * y(n,k)
34 | # = - (y'(m,:))' * y(n,:)
35 | # yl = y.transpose(1, 2).repeat(1, 1, h).view(batch_size*h*h, h, 1)
36 | # yr = y.repeat(1, h, 1).view(batch_size*h*h, 1, h)
37 | # dy = - yl.bmm(yr).view(batch_size, h, h, h, h)
38 |
39 | return dy
40 |
41 |
42 | def batch_pinv_dx(x):
43 | """ returns y = (x'*x)^-1 * x' and dy/dx. """
44 | # y = (x'*x)^-1 * x'
45 | # = s^-1 * x'
46 | # = b * x'
47 | # d{y(j,k)}/d{x(m,n)}
48 | # = d{b(j,i) * x(k,i)}/d{x(m,n)}
49 | # = d{b(j,i)}/d{x(m,n)} * x(k,i) + b(j,i) * d{x(k,i)}/d{x(m,n)}
50 | # d{b(j,i)}/d{x(m,n)}
51 | # = d{b(j,i)}/d{s(p,q)} * d{s(p,q)}/d{x(m,n)}
52 | # = -b(j,p)*b(q,i) * d{s(p,q)}/d{x(m,n)}
53 | # d{s(p,q)}/d{x(m,n)}
54 | # = d{x(t,p)*x(t,q)}/d{x(m,n)}
55 | # = d{x(t,p)}/d{x(m,n)} * x(t,q) + x(t,p) * d{x(t,q)}/d{x(m,n)}
56 | batch_size, h, w = x.size()
57 | xt = x.transpose(1, 2)
58 | s = xt.bmm(x)
59 | b = batch_inverse(s)
60 | y = b.bmm(xt)
61 |
62 | # dx/dx
63 | ex = torch.eye(h * w).to(x).unsqueeze(0).view(1, h, w, h, w)
64 | # ds/dx = dx(t,_)/dx * x(t,_) + x(t,_) * dx(t,_)/dx
65 | ex1 = ex.view(1, h, w * h * w) # [t, p*m*n]
66 | dx1 = x.transpose(1, 2).matmul(ex1).view(batch_size, w, w, h, w) # [q, p,m,n]
67 | ds_dx = dx1.transpose(1, 2) + dx1 # [p, q, m, n]
68 | # db/ds
69 | db_ds = batch_inverse_dx(b) # [j, i, p, q]
70 | # db/dx = db/d{s(p,q)} * d{s(p,q)}/dx
71 | db1 = db_ds.view(batch_size, w * w, w * w).bmm(ds_dx.view(batch_size, w * w, h * w))
72 | db_dx = db1.view(batch_size, w, w, h, w) # [j, i, m, n]
73 | # dy/dx = db(_,i)/dx * x(_,i) + b(_,i) * dx(_,i)/dx
74 | dy1 = db_dx.transpose(1, 2).contiguous().view(batch_size, w, w * h * w)
75 | dy1 = x.matmul(dy1).view(batch_size, h, w, h, w) # [k, j, m, n]
76 | ext = ex.transpose(1, 2).contiguous().view(1, w, h * h * w)
77 | dy2 = b.matmul(ext).view(batch_size, w, h, h, w) # [j, k, m, n]
78 | dy_dx = dy1.transpose(1, 2) + dy2
79 |
80 | return y, dy_dx
81 |
82 |
83 | class InvMatrix(torch.autograd.Function):
84 | """ M(n) -> M(n); x -> x^-1.
85 | """
86 |
87 | @staticmethod
88 | def forward(ctx, x):
89 | y = batch_inverse(x)
90 | ctx.save_for_backward(y)
91 | return y
92 |
93 | @staticmethod
94 | def backward(ctx, grad_output):
95 | y, = ctx.saved_tensors # v0.4
96 | # y, = ctx.saved_variables # v0.3.1
97 | batch_size, h, w = y.size()
98 | assert h == w
99 |
100 | # Let y(x) = x^-1 and assume any function f(y(x)).
101 | # compute df/dx(m,n)...
102 | # df/dx(m,n) = df/dy(j,k) * dy(j,k)/dx(m,n)
103 | # well, df/dy is 'grad_output'
104 | # and so we will return 'grad_input = df/dy(j,k) * dy(j,k)/dx(m,n)'
105 |
106 | dy = batch_inverse_dx(y) # dy(j,k,m,n) = dy(j,k)/dx(m,n)
107 | go = grad_output.contiguous().view(batch_size, 1, h * h) # [1, (j*k)]
108 | ym = dy.view(batch_size, h * h, h * h) # [(j*k), (m*n)]
109 | r = go.bmm(ym) # [1, (m*n)]
110 | grad_input = r.view(batch_size, h, h) # [m, n]
111 |
112 | return grad_input
113 |
114 |
115 | if __name__ == '__main__':
116 | def test():
117 | x = torch.randn(2, 3, 2)
118 | x_val = x.requires_grad_()
119 |
120 | s_val = x_val.transpose(1, 2).bmm(x_val)
121 | s_inv = InvMatrix.apply(s_val)
122 | y_val = s_inv.bmm(x_val.transpose(1, 2))
123 | y_val.sum().backward()
124 | t1 = x_val.grad
125 |
126 | y, dy_dx = batch_pinv_dx(x)
127 | t2 = dy_dx.sum(1).sum(1)
128 |
129 | print(t1)
130 | print(t2)
131 | print(t1 - t2)
132 |
133 |
134 | test()
135 |
136 | # EOF
137 |
--------------------------------------------------------------------------------
/datasets/se_math/mesh.py:
--------------------------------------------------------------------------------
1 | """ 3-d mesh reader """
2 | import os
3 | import copy
4 | import numpy
5 | from mpl_toolkits.mplot3d import Axes3D
6 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection
7 | # import matplotlib.pyplot
8 |
9 | # used to read ply files
10 | from plyfile import PlyData
11 |
12 |
13 | class Mesh:
14 | def __init__(self):
15 | self._vertices = [] # array-like (N, D)
16 | self._faces = [] # array-like (M, K)
17 | self._edges = [] # array-like (L, 2)
18 |
19 | def clone(self):
20 | other = copy.deepcopy(self)
21 | return other
22 |
23 | def clear(self):
24 | for key in self.__dict__:
25 | self.__dict__[key] = []
26 |
27 | def add_attr(self, name):
28 | self.__dict__[name] = []
29 |
30 | @property
31 | def vertex_array(self):
32 | return numpy.array(self._vertices)
33 |
34 | @property
35 | def vertex_list(self):
36 | return list(map(tuple, self._vertices))
37 |
38 | @staticmethod
39 | def faces2polygons(faces, vertices):
40 | p = list(map(lambda face: \
41 | list(map(lambda vidx: vertices[vidx], face)), faces))
42 | return p
43 |
44 | @property
45 | def polygon_list(self):
46 | p = Mesh.faces2polygons(self._faces, self._vertices)
47 | return p
48 |
49 | # def plot(self, fig=None, ax=None, *args, **kwargs):
50 | # p = self.polygon_list
51 | # v = self.vertex_array
52 | # if fig is None:
53 | # fig = matplotlib.pyplot.gcf()
54 | # if ax is None:
55 | # ax = Axes3D(fig)
56 | # if p:
57 | # ax.add_collection3d(Poly3DCollection(p))
58 | # if v.shape:
59 | # ax.scatter(v[:, 0], v[:, 1], v[:, 2], *args, **kwargs)
60 | # ax.set_xlabel('X')
61 | # ax.set_ylabel('Y')
62 | # ax.set_zlabel('Z')
63 | # return fig, ax
64 |
65 | def on_unit_sphere(self, zero_mean=False):
66 | # radius == 1
67 | v = self.vertex_array # (N, D)
68 | if zero_mean:
69 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3)
70 | v[:, 0:3] = v[:, 0:3] - a
71 | n = numpy.linalg.norm(v[:, 0:3], axis=1) # (N,)
72 | m = numpy.max(n) # scalar
73 | v[:, 0:3] = v[:, 0:3] / m
74 | self._vertices = v
75 | return self
76 |
77 | def on_unit_cube(self, zero_mean=False):
78 | # volume == 1
79 | v = self.vertex_array # (N, D)
80 | if zero_mean:
81 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3)
82 | v[:, 0:3] = v[:, 0:3] - a
83 | m = numpy.max(numpy.abs(v)) # scalar
84 | v[:, 0:3] = v[:, 0:3] / (m * 2)
85 | self._vertices = v
86 | return self
87 |
88 | def rot_x(self):
89 | # camera local (up: +Y, front: -Z) -> model local (up: +Z, front: +Y).
90 | v = self.vertex_array
91 | t = numpy.copy(v[:, 1])
92 | v[:, 1] = -numpy.copy(v[:, 2])
93 | v[:, 2] = t
94 | self._vertices = list(map(tuple, v))
95 | return self
96 |
97 | def rot_zc(self):
98 | # R = [0, -1;
99 | # 1, 0]
100 | v = self.vertex_array
101 | x = numpy.copy(v[:, 0])
102 | y = numpy.copy(v[:, 1])
103 | v[:, 0] = -y
104 | v[:, 1] = x
105 | self._vertices = list(map(tuple, v))
106 | return self
107 |
108 |
109 | def offread(filepath, points_only=True):
110 | """ read Geomview OFF file. """
111 | with open(filepath, 'r') as fin:
112 | mesh, fixme = _load_off(fin, points_only)
113 | if fixme:
114 | _fix_modelnet_broken_off(filepath)
115 | return mesh
116 |
117 |
118 | def _load_off(fin, points_only):
119 | """ read Geomview OFF file. """
120 | mesh = Mesh()
121 |
122 | fixme = False
123 | sig = fin.readline().strip()
124 | if sig == 'OFF':
125 | line = fin.readline().strip()
126 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')])
127 | elif sig[0:3] == 'OFF': # ...broken data in ModelNet (missing '\n')...
128 | line = sig[3:]
129 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')])
130 | fixme = True
131 | else:
132 | raise RuntimeError('unknown format')
133 |
134 | for v in range(num_verts):
135 | vp = tuple(float(s) for s in fin.readline().strip().split(' '))
136 | mesh._vertices.append(vp)
137 |
138 | if points_only:
139 | return mesh, fixme
140 |
141 | for f in range(num_faces):
142 | fc = tuple([int(s) for s in fin.readline().strip().split(' ')][1:])
143 | mesh._faces.append(fc)
144 |
145 | return mesh, fixme
146 |
147 |
148 | def _fix_modelnet_broken_off(filepath):
149 | oldfile = '{}.orig'.format(filepath)
150 | os.rename(filepath, oldfile)
151 | with open(oldfile, 'r') as fin:
152 | with open(filepath, 'w') as fout:
153 | sig = fin.readline().strip()
154 | line = sig[3:]
155 | print('OFF', file=fout)
156 | print(line, file=fout)
157 | for line in fin:
158 | print(line.strip(), file=fout)
159 |
160 |
161 | def objread(filepath, points_only=True):
162 | """Loads a Wavefront OBJ file. """
163 | _vertices = []
164 | _normals = []
165 | _texcoords = []
166 | _faces = []
167 | _mtl_name = None
168 |
169 | material = None
170 | for line in open(filepath, "r"):
171 | if line.startswith('#'): continue
172 | values = line.split()
173 | if not values: continue
174 | if values[0] == 'v':
175 | v = tuple(map(float, values[1:4]))
176 | _vertices.append(v)
177 | elif values[0] == 'vn':
178 | v = tuple(map(float, values[1:4]))
179 | _normals.append(v)
180 | elif values[0] == 'vt':
181 | _texcoords.append(tuple(map(float, values[1:3])))
182 | elif values[0] in ('usemtl', 'usemat'):
183 | material = values[1]
184 | elif values[0] == 'mtllib':
185 | _mtl_name = values[1]
186 | elif values[0] == 'f':
187 | face_ = []
188 | texcoords_ = []
189 | norms_ = []
190 | for v in values[1:]:
191 | w = v.split('/')
192 | face_.append(int(w[0]) - 1)
193 | if len(w) >= 2 and len(w[1]) > 0:
194 | texcoords_.append(int(w[1]) - 1)
195 | else:
196 | texcoords_.append(-1)
197 | if len(w) >= 3 and len(w[2]) > 0:
198 | norms_.append(int(w[2]) - 1)
199 | else:
200 | norms_.append(-1)
201 | # _faces.append((face_, norms_, texcoords_, material))
202 | _faces.append(face_)
203 |
204 | mesh = Mesh()
205 | mesh._vertices = _vertices
206 | if points_only:
207 | return mesh
208 |
209 | mesh._faces = _faces
210 |
211 | return mesh
212 |
213 |
214 | def plyread(filepath, points_only=True):
215 | # read binary ply file and return [x, y, z] array
216 | data = PlyData.read(filepath)
217 | vertex = data['vertex']
218 |
219 | (x, y, z) = (vertex[t] for t in ('x', 'y', 'z'))
220 | num_verts = len(x)
221 |
222 | mesh = Mesh()
223 |
224 | for v in range(num_verts):
225 | vp = tuple(float(s) for s in [x[v], y[v], z[v]])
226 | mesh._vertices.append(vp)
227 |
228 | return mesh
229 |
230 |
231 | if __name__ == '__main__':
232 | def test1():
233 | mesh = objread('model_normalized.obj', points_only=False)
234 | # mesh.on_unit_sphere()
235 | mesh.rot_x()
236 | mesh.plot(c='m')
237 | # matplotlib.pyplot.show()
238 |
239 |
240 | def test2():
241 | mesh = plyread('1.ply', points_only=True)
242 | # mesh.on_unit_sphere()
243 | mesh.rot_x()
244 | mesh.plot(c='m')
245 | # matplotlib.pyplot.show()
246 |
247 |
248 | test2()
249 |
250 | # EOF
251 |
--------------------------------------------------------------------------------
/datasets/se_math/se3.py:
--------------------------------------------------------------------------------
1 | """ 3-d rigid body transfomation group and corresponding Lie algebra. """
2 | import torch
3 | from .sinc import sinc1, sinc2, sinc3
4 | from . import so3
5 |
6 |
7 | def twist_prod(x, y):
8 | x_ = x.view(-1, 6)
9 | y_ = y.view(-1, 6)
10 |
11 | xw, xv = x_[:, 0:3], x_[:, 3:6]
12 | yw, yv = y_[:, 0:3], y_[:, 3:6]
13 |
14 | zw = so3.cross_prod(xw, yw)
15 | zv = so3.cross_prod(xw, yv) + so3.cross_prod(xv, yw)
16 |
17 | z = torch.cat((zw, zv), dim=1)
18 |
19 | return z.view_as(x)
20 |
21 |
22 | def liebracket(x, y):
23 | return twist_prod(x, y)
24 |
25 |
26 | def mat(x):
27 | # size: [*, 6] -> [*, 4, 4]
28 | x_ = x.view(-1, 6)
29 | w1, w2, w3 = x_[:, 0], x_[:, 1], x_[:, 2]
30 | v1, v2, v3 = x_[:, 3], x_[:, 4], x_[:, 5]
31 | O = torch.zeros_like(w1)
32 |
33 | X = torch.stack((
34 | torch.stack((O, -w3, w2, v1), dim=1),
35 | torch.stack((w3, O, -w1, v2), dim=1),
36 | torch.stack((-w2, w1, O, v3), dim=1),
37 | torch.stack((O, O, O, O), dim=1)), dim=1)
38 | return X.view(*(x.size()[0:-1]), 4, 4)
39 |
40 |
41 | def vec(X):
42 | X_ = X.view(-1, 4, 4)
43 | w1, w2, w3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0]
44 | v1, v2, v3 = X_[:, 0, 3], X_[:, 1, 3], X_[:, 2, 3]
45 | x = torch.stack((w1, w2, w3, v1, v2, v3), dim=1)
46 | return x.view(*X.size()[0:-2], 6)
47 |
48 |
49 | def genvec():
50 | return torch.eye(6)
51 |
52 |
53 | def genmat():
54 | return mat(genvec())
55 |
56 |
57 | def exp(x):
58 | x_ = x.view(-1, 6)
59 | w, v = x_[:, 0:3], x_[:, 3:6]
60 | t = w.norm(p=2, dim=1).view(-1, 1, 1)
61 | W = so3.mat(w)
62 | S = W.bmm(W)
63 | I = torch.eye(3).to(w)
64 |
65 | # Rodrigues' rotation formula.
66 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w');
67 | # = eye(3) + sinc1(t)*W + sinc2(t)*S
68 | R = I + sinc1(t) * W + sinc2(t) * S
69 |
70 | # V = sinc1(t)*eye(3) + sinc2(t)*W + sinc3(t)*(w*w')
71 | # = eye(3) + sinc2(t)*W + sinc3(t)*S
72 | V = I + sinc2(t) * W + sinc3(t) * S
73 |
74 | p = V.bmm(v.contiguous().view(-1, 3, 1))
75 |
76 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(x_.size(0), 1, 1).to(x)
77 | Rp = torch.cat((R, p), dim=2)
78 | g = torch.cat((Rp, z), dim=1)
79 |
80 | return g.view(*(x.size()[0:-1]), 4, 4)
81 |
82 |
83 | def inverse(g):
84 | g_ = g.view(-1, 4, 4)
85 | R = g_[:, 0:3, 0:3]
86 | p = g_[:, 0:3, 3]
87 | Q = R.transpose(1, 2)
88 | q = -Q.matmul(p.unsqueeze(-1))
89 |
90 | z = torch.Tensor([0, 0, 0, 1]).view(1, 1, 4).repeat(g_.size(0), 1, 1).to(g)
91 | Qq = torch.cat((Q, q), dim=2)
92 | ig = torch.cat((Qq, z), dim=1)
93 |
94 | return ig.view(*(g.size()[0:-2]), 4, 4)
95 |
96 |
97 | def log(g):
98 | g_ = g.view(-1, 4, 4)
99 | R = g_[:, 0:3, 0:3]
100 | p = g_[:, 0:3, 3]
101 |
102 | w = so3.log(R)
103 | H = so3.inv_vecs_Xg_ig(w)
104 | v = H.bmm(p.contiguous().view(-1, 3, 1)).view(-1, 3)
105 |
106 | x = torch.cat((w, v), dim=1)
107 | return x.view(*(g.size()[0:-2]), 6)
108 |
109 |
110 | def transform(g, a):
111 | # g : SE(3), * x 4 x 4
112 | # a : R^3, * x 3[x N]
113 | g_ = g.view(-1, 4, 4)
114 | R = g_[:, 0:3, 0:3].contiguous().view(*(g.size()[0:-2]), 3, 3)
115 | p = g_[:, 0:3, 3].contiguous().view(*(g.size()[0:-2]), 3)
116 | if len(g.size()) == len(a.size()):
117 | b = R.matmul(a) + p.unsqueeze(-1)
118 | else:
119 | b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p
120 | return b
121 |
122 |
123 | def group_prod(g, h):
124 | # g, h : SE(3)
125 | g1 = g.matmul(h)
126 | return g1
127 |
128 |
129 | class ExpMap(torch.autograd.Function):
130 | """ Exp: se(3) -> SE(3)
131 | """
132 |
133 | @staticmethod
134 | def forward(ctx, x):
135 | """ Exp: R^6 -> M(4),
136 | size: [B, 6] -> [B, 4, 4],
137 | or [B, 1, 6] -> [B, 1, 4, 4]
138 | """
139 | ctx.save_for_backward(x)
140 | g = exp(x)
141 | return g
142 |
143 | @staticmethod
144 | def backward(ctx, grad_output):
145 | x, = ctx.saved_tensors
146 | g = exp(x)
147 | gen_k = genmat().to(x)
148 |
149 | # Let z = f(g) = f(exp(x))
150 | # dz = df/dgij * dgij/dxk * dxk
151 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk
152 | # = df/dgij * [gen_k*g]_ij * dxk
153 |
154 | dg = gen_k.matmul(g.view(-1, 1, 4, 4))
155 | # (k, i, j)
156 | dg = dg.to(grad_output)
157 |
158 | go = grad_output.contiguous().view(-1, 1, 4, 4)
159 | dd = go * dg
160 | grad_input = dd.sum(-1).sum(-1)
161 |
162 | return grad_input
163 |
164 |
165 | Exp = ExpMap.apply
166 |
167 | # EOF
168 |
--------------------------------------------------------------------------------
/datasets/se_math/sinc.py:
--------------------------------------------------------------------------------
1 | """ sinc(t) := sin(t) / t """
2 | import torch
3 | from torch import sin, cos
4 |
5 |
6 | def sinc1(t):
7 | """ sinc1: t -> sin(t)/t """
8 | e = 0.01
9 | r = torch.zeros_like(t)
10 | a = torch.abs(t)
11 |
12 | s = a < e
13 | c = (s == 0)
14 | t2 = t[s] ** 2
15 | r[s] = 1 - t2 / 6 * (1 - t2 / 20 * (1 - t2 / 42)) # Taylor series O(t^8)
16 | r[c] = sin(t[c]) / t[c]
17 |
18 | return r
19 |
20 |
21 | def sinc1_dt(t):
22 | """ d/dt(sinc1) """
23 | e = 0.01
24 | r = torch.zeros_like(t)
25 | a = torch.abs(t)
26 |
27 | s = a < e
28 | c = (s == 0)
29 | t2 = t ** 2
30 | r[s] = -t[s] / 3 * (1 - t2[s] / 10 * (1 - t2[s] / 28 * (1 - t2[s] / 54))) # Taylor series O(t^8)
31 | r[c] = cos(t[c]) / t[c] - sin(t[c]) / t2[c]
32 |
33 | return r
34 |
35 |
36 | def sinc1_dt_rt(t):
37 | """ d/dt(sinc1) / t """
38 | e = 0.01
39 | r = torch.zeros_like(t)
40 | a = torch.abs(t)
41 |
42 | s = a < e
43 | c = (s == 0)
44 | t2 = t ** 2
45 | r[s] = -1 / 3 * (1 - t2[s] / 10 * (1 - t2[s] / 28 * (1 - t2[s] / 54))) # Taylor series O(t^8)
46 | r[c] = (cos(t[c]) / t[c] - sin(t[c]) / t2[c]) / t[c]
47 |
48 | return r
49 |
50 |
51 | def rsinc1(t):
52 | """ rsinc1: t -> t/sinc1(t) """
53 | e = 0.01
54 | r = torch.zeros_like(t)
55 | a = torch.abs(t)
56 |
57 | s = a < e
58 | c = (s == 0)
59 | t2 = t[s] ** 2
60 | r[s] = (((31 * t2) / 42 + 7) * t2 / 60 + 1) * t2 / 6 + 1 # Taylor series O(t^8)
61 | r[c] = t[c] / sin(t[c])
62 |
63 | return r
64 |
65 |
66 | def rsinc1_dt(t):
67 | """ d/dt(rsinc1) """
68 | e = 0.01
69 | r = torch.zeros_like(t)
70 | a = torch.abs(t)
71 |
72 | s = a < e
73 | c = (s == 0)
74 | t2 = t[s] ** 2
75 | r[s] = ((((127 * t2) / 30 + 31) * t2 / 28 + 7) * t2 / 30 + 1) * t[s] / 3 # Taylor series O(t^8)
76 | r[c] = 1 / sin(t[c]) - (t[c] * cos(t[c])) / (sin(t[c]) * sin(t[c]))
77 |
78 | return r
79 |
80 |
81 | def rsinc1_dt_csc(t):
82 | """ d/dt(rsinc1) / sin(t) """
83 | e = 0.01
84 | r = torch.zeros_like(t)
85 | a = torch.abs(t)
86 |
87 | s = a < e
88 | c = (s == 0)
89 | t2 = t[s] ** 2
90 | r[s] = t2 * (t2 * ((4 * t2) / 675 + 2 / 63) + 2 / 15) + 1 / 3 # Taylor series O(t^8)
91 | r[c] = (1 / sin(t[c]) - (t[c] * cos(t[c])) / (sin(t[c]) * sin(t[c]))) / sin(t[c])
92 |
93 | return r
94 |
95 |
96 | def sinc2(t):
97 | """ sinc2: t -> (1 - cos(t)) / (t**2) """
98 | e = 0.01
99 | r = torch.zeros_like(t)
100 | a = torch.abs(t)
101 |
102 | s = a < e
103 | c = (s == 0)
104 | t2 = t ** 2
105 | r[s] = 1 / 2 * (1 - t2[s] / 12 * (1 - t2[s] / 30 * (1 - t2[s] / 56))) # Taylor series O(t^8)
106 | r[c] = (1 - cos(t[c])) / t2[c]
107 |
108 | return r
109 |
110 |
111 | def sinc2_dt(t):
112 | """ d/dt(sinc2) """
113 | e = 0.01
114 | r = torch.zeros_like(t)
115 | a = torch.abs(t)
116 |
117 | s = a < e
118 | c = (s == 0)
119 | t2 = t ** 2
120 | r[s] = -t[s] / 12 * (1 - t2[s] / 5 * (1.0 / 3 - t2[s] / 56 * (1.0 / 2 - t2[s] / 135))) # Taylor series O(t^8)
121 | r[c] = sin(t[c]) / t2[c] - 2 * (1 - cos(t[c])) / (t2[c] * t[c])
122 |
123 | return r
124 |
125 |
126 | def sinc3(t):
127 | """ sinc3: t -> (t - sin(t)) / (t**3) """
128 | e = 0.01
129 | r = torch.zeros_like(t)
130 | a = torch.abs(t)
131 |
132 | s = a < e
133 | c = (s == 0)
134 | t2 = t[s] ** 2
135 | r[s] = 1 / 6 * (1 - t2 / 20 * (1 - t2 / 42 * (1 - t2 / 72))) # Taylor series O(t^8)
136 | r[c] = (t[c] - sin(t[c])) / (t[c] ** 3)
137 |
138 | return r
139 |
140 |
141 | def sinc3_dt(t):
142 | """ d/dt(sinc3) """
143 | e = 0.01
144 | r = torch.zeros_like(t)
145 | a = torch.abs(t)
146 |
147 | s = a < e
148 | c = (s == 0)
149 | t2 = t[s] ** 2
150 | r[s] = -t[s] / 60 * (1 - t2 / 21 * (1 - t2 / 24 * (1.0 / 2 - t2 / 165))) # Taylor series O(t^8)
151 | r[c] = (3 * sin(t[c]) - t[c] * (cos(t[c]) + 2)) / (t[c] ** 4)
152 |
153 | return r
154 |
155 |
156 | def sinc4(t):
157 | """ sinc4: t -> 1/t^2 * (1/2 - sinc2(t))
158 | = 1/t^2 * (1/2 - (1 - cos(t))/t^2)
159 | """
160 | e = 0.01
161 | r = torch.zeros_like(t)
162 | a = torch.abs(t)
163 |
164 | s = a < e
165 | c = (s == 0)
166 | t2 = t ** 2
167 | r[s] = 1 / 24 * (1 - t2 / 30 * (1 - t2 / 56 * (1 - t2 / 90))) # Taylor series O(t^8)
168 | r[c] = (0.5 - (1 - cos(t)) / t2) / t2
169 |
170 |
171 | class Sinc1_autograd(torch.autograd.Function):
172 | @staticmethod
173 | def forward(ctx, theta):
174 | ctx.save_for_backward(theta)
175 | return sinc1(theta)
176 |
177 | @staticmethod
178 | def backward(ctx, grad_output):
179 | theta, = ctx.saved_tensors
180 | grad_theta = None
181 | if ctx.needs_input_grad[0]:
182 | grad_theta = grad_output * sinc1_dt(theta).to(grad_output)
183 | return grad_theta
184 |
185 |
186 | Sinc1 = Sinc1_autograd.apply
187 |
188 |
189 | class RSinc1_autograd(torch.autograd.Function):
190 | @staticmethod
191 | def forward(ctx, theta):
192 | ctx.save_for_backward(theta)
193 | return rsinc1(theta)
194 |
195 | @staticmethod
196 | def backward(ctx, grad_output):
197 | theta, = ctx.saved_tensors
198 | grad_theta = None
199 | if ctx.needs_input_grad[0]:
200 | grad_theta = grad_output * rsinc1_dt(theta).to(grad_output)
201 | return grad_theta
202 |
203 |
204 | RSinc1 = RSinc1_autograd.apply
205 |
206 |
207 | class Sinc2_autograd(torch.autograd.Function):
208 | @staticmethod
209 | def forward(ctx, theta):
210 | ctx.save_for_backward(theta)
211 | return sinc2(theta)
212 |
213 | @staticmethod
214 | def backward(ctx, grad_output):
215 | theta, = ctx.saved_tensors
216 | grad_theta = None
217 | if ctx.needs_input_grad[0]:
218 | grad_theta = grad_output * sinc2_dt(theta).to(grad_output)
219 | return grad_theta
220 |
221 |
222 | Sinc2 = Sinc2_autograd.apply
223 |
224 |
225 | class Sinc3_autograd(torch.autograd.Function):
226 | @staticmethod
227 | def forward(ctx, theta):
228 | ctx.save_for_backward(theta)
229 | return sinc3(theta)
230 |
231 | @staticmethod
232 | def backward(ctx, grad_output):
233 | theta, = ctx.saved_tensors
234 | grad_theta = None
235 | if ctx.needs_input_grad[0]:
236 | grad_theta = grad_output * sinc3_dt(theta).to(grad_output)
237 | return grad_theta
238 |
239 |
240 | Sinc3 = Sinc3_autograd.apply
241 |
242 | # EOF
243 |
--------------------------------------------------------------------------------
/datasets/se_math/so3.py:
--------------------------------------------------------------------------------
1 | """ 3-d rotation group and corresponding Lie algebra """
2 | import torch
3 | from . import sinc
4 | from .sinc import sinc1, sinc2, sinc3
5 |
6 |
7 | def cross_prod(x, y):
8 | z = torch.cross(x.view(-1, 3), y.view(-1, 3), dim=1).view_as(x)
9 | return z
10 |
11 |
12 | def liebracket(x, y):
13 | return cross_prod(x, y)
14 |
15 |
16 | def mat(x):
17 | # size: [*, 3] -> [*, 3, 3]
18 | x_ = x.view(-1, 3)
19 | x1, x2, x3 = x_[:, 0], x_[:, 1], x_[:, 2]
20 | O = torch.zeros_like(x1)
21 |
22 | X = torch.stack((
23 | torch.stack((O, -x3, x2), dim=1),
24 | torch.stack((x3, O, -x1), dim=1),
25 | torch.stack((-x2, x1, O), dim=1)), dim=1)
26 | return X.view(*(x.size()[0:-1]), 3, 3)
27 |
28 |
29 | def vec(X):
30 | X_ = X.view(-1, 3, 3)
31 | x1, x2, x3 = X_[:, 2, 1], X_[:, 0, 2], X_[:, 1, 0]
32 | x = torch.stack((x1, x2, x3), dim=1)
33 | return x.view(*X.size()[0:-2], 3)
34 |
35 |
36 | def genvec():
37 | return torch.eye(3)
38 |
39 |
40 | def genmat():
41 | return mat(genvec())
42 |
43 |
44 | def RodriguesRotation(x):
45 | # for autograd
46 | w = x.view(-1, 3)
47 | t = w.norm(p=2, dim=1).view(-1, 1, 1)
48 | W = mat(w)
49 | S = W.bmm(W)
50 | I = torch.eye(3).to(w)
51 |
52 | # Rodrigues' rotation formula.
53 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w');
54 | # R = eye(3) + sinc1(t)*W + sinc2(t)*S
55 |
56 | R = I + sinc.Sinc1(t) * W + sinc.Sinc2(t) * S
57 |
58 | return R.view(*(x.size()[0:-1]), 3, 3)
59 |
60 |
61 | def exp(x):
62 | w = x.view(-1, 3)
63 | t = w.norm(p=2, dim=1).view(-1, 1, 1)
64 | W = mat(w)
65 | S = W.bmm(W)
66 | I = torch.eye(3).to(w)
67 |
68 | # Rodrigues' rotation formula.
69 | # R = cos(t)*eye(3) + sinc1(t)*W + sinc2(t)*(w*w');
70 | # R = eye(3) + sinc1(t)*W + sinc2(t)*S
71 |
72 | R = I + sinc1(t) * W + sinc2(t) * S
73 |
74 | return R.view(*(x.size()[0:-1]), 3, 3)
75 |
76 |
77 | def inverse(g):
78 | R = g.view(-1, 3, 3)
79 | Rt = R.transpose(1, 2)
80 | return Rt.view_as(g)
81 |
82 |
83 | def btrace(X):
84 | # batch-trace: [B, N, N] -> [B]
85 | n = X.size(-1)
86 | X_ = X.view(-1, n, n)
87 | tr = torch.zeros(X_.size(0)).to(X)
88 | for i in range(tr.size(0)):
89 | m = X_[i, :, :]
90 | tr[i] = torch.trace(m)
91 | return tr.view(*(X.size()[0:-2]))
92 |
93 |
94 | def log(g):
95 | eps = 1.0e-7
96 | R = g.view(-1, 3, 3)
97 | tr = btrace(R)
98 | c = (tr - 1) / 2
99 | t = torch.acos(c)
100 | sc = sinc1(t)
101 | idx0 = (torch.abs(sc) <= eps)
102 | idx1 = (torch.abs(sc) > eps)
103 | sc = sc.view(-1, 1, 1)
104 |
105 | X = torch.zeros_like(R)
106 | if idx1.any():
107 | X[idx1] = (R[idx1] - R[idx1].transpose(1, 2)) / (2 * sc[idx1])
108 |
109 | if idx0.any():
110 | # t[idx0] == math.pi
111 | t2 = t[idx0] ** 2
112 | A = (R[idx0] + torch.eye(3).type_as(R).unsqueeze(0)) * t2.view(-1, 1, 1) / 2
113 | aw1 = torch.sqrt(A[:, 0, 0])
114 | aw2 = torch.sqrt(A[:, 1, 1])
115 | aw3 = torch.sqrt(A[:, 2, 2])
116 | sgn_3 = torch.sign(A[:, 0, 2])
117 | sgn_3[sgn_3 == 0] = 1
118 | sgn_23 = torch.sign(A[:, 1, 2])
119 | sgn_23[sgn_23 == 0] = 1
120 | sgn_2 = sgn_23 * sgn_3
121 | w1 = aw1
122 | w2 = aw2 * sgn_2
123 | w3 = aw3 * sgn_3
124 | w = torch.stack((w1, w2, w3), dim=-1)
125 | W = mat(w)
126 | X[idx0] = W
127 |
128 | x = vec(X.view_as(g))
129 | return x
130 |
131 |
132 | def transform(g, a):
133 | # g in SO(3): * x 3 x 3
134 | # a in R^3: * x 3[x N]
135 | if len(g.size()) == len(a.size()):
136 | b = g.matmul(a)
137 | else:
138 | b = g.matmul(a.unsqueeze(-1)).squeeze(-1)
139 | return b
140 |
141 |
142 | def group_prod(g, h):
143 | # g, h : SO(3)
144 | g1 = g.matmul(h)
145 | return g1
146 |
147 |
148 | def vecs_Xg_ig(x):
149 | """ Vi = vec(dg/dxi * inv(g)), where g = exp(x)
150 | (== [Ad(exp(x))] * vecs_ig_Xg(x))
151 | """
152 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1)
153 | X = mat(x)
154 | S = X.bmm(X)
155 | # B = x.view(-1,3,1).bmm(x.view(-1,1,3)) # B = x*x'
156 | I = torch.eye(3).to(X)
157 |
158 | # V = sinc1(t)*eye(3) + sinc2(t)*X + sinc3(t)*B
159 | # V = eye(3) + sinc2(t)*X + sinc3(t)*S
160 |
161 | V = I + sinc2(t) * X + sinc3(t) * S
162 |
163 | return V.view(*(x.size()[0:-1]), 3, 3)
164 |
165 |
166 | def inv_vecs_Xg_ig(x):
167 | """ H = inv(vecs_Xg_ig(x)) """
168 | t = x.view(-1, 3).norm(p=2, dim=1).view(-1, 1, 1)
169 | X = mat(x)
170 | S = X.bmm(X)
171 | I = torch.eye(3).to(x)
172 |
173 | e = 0.01
174 | eta = torch.zeros_like(t)
175 | s = (t < e)
176 | c = (s == 0)
177 | t2 = t[s] ** 2
178 | eta[s] = ((t2 / 40 + 1) * t2 / 42 + 1) * t2 / 720 + 1 / 12 # O(t**8)
179 | eta[c] = (1 - (t[c] / 2) / torch.tan(t[c] / 2)) / (t[c] ** 2)
180 |
181 | H = I - 1 / 2 * X + eta * S
182 | return H.view(*(x.size()[0:-1]), 3, 3)
183 |
184 |
185 | class ExpMap(torch.autograd.Function):
186 | """ Exp: so(3) -> SO(3)
187 | """
188 |
189 | @staticmethod
190 | def forward(ctx, x):
191 | """ Exp: R^3 -> M(3),
192 | size: [B, 3] -> [B, 3, 3],
193 | or [B, 1, 3] -> [B, 1, 3, 3]
194 | """
195 | ctx.save_for_backward(x)
196 | g = exp(x)
197 | return g
198 |
199 | @staticmethod
200 | def backward(ctx, grad_output):
201 | x, = ctx.saved_tensors
202 | g = exp(x)
203 | gen_k = genmat().to(x)
204 | # gen_1 = gen_k[0, :, :]
205 | # gen_2 = gen_k[1, :, :]
206 | # gen_3 = gen_k[2, :, :]
207 |
208 | # Let z = f(g) = f(exp(x))
209 | # dz = df/dgij * dgij/dxk * dxk
210 | # = df/dgij * (d/dxk)[exp(x)]_ij * dxk
211 | # = df/dgij * [gen_k*g]_ij * dxk
212 |
213 | dg = gen_k.matmul(g.view(-1, 1, 3, 3))
214 | # (k, i, j)
215 | dg = dg.to(grad_output)
216 |
217 | go = grad_output.contiguous().view(-1, 1, 3, 3)
218 | dd = go * dg
219 | grad_input = dd.sum(-1).sum(-1)
220 |
221 | return grad_input
222 |
223 |
224 | Exp = ExpMap.apply
225 |
226 | # EOF
227 |
--------------------------------------------------------------------------------
/datasets/se_math/transforms.py:
--------------------------------------------------------------------------------
1 | """ gives some transform methods for 3d points """
2 | import math
3 |
4 | import torch
5 | import torch.utils.data
6 |
7 | from . import so3
8 | from . import se3
9 |
10 |
11 | class Mesh2Points:
12 | def __init__(self):
13 | pass
14 |
15 | def __call__(self, mesh):
16 | mesh = mesh.clone()
17 | v = mesh.vertex_array
18 | return torch.from_numpy(v).type(dtype=torch.float)
19 |
20 |
21 | class OnUnitSphere:
22 | def __init__(self, zero_mean=False):
23 | self.zero_mean = zero_mean
24 |
25 | def __call__(self, tensor):
26 | if self.zero_mean:
27 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D]
28 | v = tensor - m
29 | else:
30 | v = tensor
31 | nn = v.norm(p=2, dim=1) # [N, D] -> [N]
32 | nmax = torch.max(nn)
33 | return v / nmax
34 |
35 |
36 | class OnUnitCube:
37 | def __init__(self):
38 | pass
39 |
40 | def method1(self, tensor):
41 | m = tensor.mean(dim=0, keepdim=True) # [N, D] -> [1, D]
42 | v = tensor - m
43 | s = torch.max(v.abs())
44 | v = v / s * 0.5
45 | return v
46 |
47 | def method2(self, tensor):
48 | c = torch.max(tensor, dim=0)[0] - torch.min(tensor, dim=0)[0] # [N, D] -> [D]
49 | s = torch.max(c) # -> scalar
50 | v = tensor / s
51 | return v - v.mean(dim=0, keepdim=True)
52 |
53 | def __call__(self, tensor):
54 | # return self.method1(tensor)
55 | return self.method2(tensor)
56 |
57 |
58 | class Resampler:
59 | """ [N, D] -> [M, D] """
60 |
61 | def __init__(self, num):
62 | self.num = num
63 |
64 | def __call__(self, tensor):
65 | num_points, dim_p = tensor.size()
66 | out = torch.zeros(self.num, dim_p).to(tensor)
67 |
68 | selected = 0
69 | while selected < self.num:
70 | remainder = self.num - selected
71 | idx = torch.randperm(num_points)
72 | sel = min(remainder, num_points)
73 | val = tensor[idx[:sel]]
74 | out[selected:(selected + sel)] = val
75 | selected += sel
76 | return out
77 |
78 |
79 | class RandomTranslate:
80 | def __init__(self, mag=None, randomly=True):
81 | self.mag = 1.0 if mag is None else mag
82 | self.randomly = randomly
83 | self.igt = None
84 |
85 | def __call__(self, tensor):
86 | # tensor: [N, 3]
87 | amp = torch.rand(1) if self.randomly else 1.0
88 | t = torch.randn(1, 3).to(tensor)
89 | t = t / t.norm(p=2, dim=1, keepdim=True) * amp * self.mag
90 |
91 | g = torch.eye(4).to(tensor)
92 | g[0:3, 3] = t[0, :]
93 | self.igt = g # [4, 4]
94 |
95 | p1 = tensor + t
96 | return p1
97 |
98 |
99 | class RandomRotator:
100 | def __init__(self, mag=None, randomly=True):
101 | self.mag = math.pi if mag is None else mag
102 | self.randomly = randomly
103 | self.igt = None
104 |
105 | def __call__(self, tensor):
106 | # tensor: [N, 3]
107 | amp = torch.rand(1) if self.randomly else 1.0
108 | w = torch.randn(1, 3)
109 | w = w / w.norm(p=2, dim=1, keepdim=True) * amp * self.mag
110 |
111 | g = so3.exp(w).to(tensor) # [1, 3, 3]
112 | self.igt = g.squeeze(0) # [3, 3]
113 |
114 | p1 = so3.transform(g, tensor) # [1, 3, 3] x [N, 3] -> [N, 3]
115 | return p1
116 |
117 |
118 | class RandomRotatorZ:
119 | def __init__(self):
120 | self.mag = 2 * math.pi
121 |
122 | def __call__(self, tensor):
123 | # tensor: [N, 3]
124 | w = torch.Tensor([0, 0, 1]).view(1, 3) * torch.rand(1) * self.mag
125 |
126 | g = so3.exp(w).to(tensor) # [1, 3, 3]
127 |
128 | p1 = so3.transform(g, tensor)
129 | return p1
130 |
131 |
132 | class RandomJitter:
133 | """ generate perturbations """
134 |
135 | def __init__(self, scale=0.01, clip=0.05):
136 | self.scale = scale
137 | self.clip = clip
138 | self.e = None
139 |
140 | def jitter(self, tensor):
141 | noise = torch.zeros_like(tensor).to(tensor) # [N, 3]
142 | noise.normal_(0, self.scale)
143 | noise.clamp_(-self.clip, self.clip)
144 | self.e = noise
145 | return tensor.add(noise)
146 |
147 | def __call__(self, tensor):
148 | return self.jitter(tensor)
149 |
150 |
151 | class RandomTransformSE3:
152 | """ rigid motion """
153 |
154 | def __init__(self, mag=1, mag_randomly=False):
155 | self.mag = mag
156 | self.randomly = mag_randomly
157 |
158 | self.gt = None
159 | self.igt = None
160 |
161 | def generate_transform(self):
162 | # return: a twist-vector
163 | amp = self.mag
164 | if self.randomly:
165 | amp = torch.rand(1, 1) * self.mag
166 | x = torch.randn(1, 6)
167 | x = x / x.norm(p=2, dim=1, keepdim=True) * amp
168 |
169 | '''a = torch.rand(3)
170 | a = a * math.pi
171 | b = torch.zeros(1, 6)
172 | b[:, 0:3] = a
173 | x = x+b
174 | '''
175 | return x # [1, 6]
176 |
177 | def apply_transform(self, p0, x):
178 | # p0: [N, 3]
179 | # x: [1, 6]
180 | g = se3.exp(x).to(p0) # [1, 4, 4]
181 | gt = se3.exp(-x).to(p0) # [1, 4, 4]
182 |
183 | p1 = se3.transform(g, p0)
184 | self.gt = gt.squeeze(0) # gt: p1 -> p0
185 | self.igt = g.squeeze(0) # igt: p0 -> p1
186 | return p1
187 |
188 | def transform(self, tensor):
189 | x = self.generate_transform()
190 | return self.apply_transform(tensor, x)
191 |
192 | def __call__(self, tensor):
193 | return self.transform(tensor)
194 |
195 | # EOF
196 |
--------------------------------------------------------------------------------
/datasets/utils/commons.py:
--------------------------------------------------------------------------------
1 | import pickle, matplotlib.pyplot as plt, numpy as np
2 |
3 | def load_data(path):
4 | file = open(path, "rb")
5 | data = pickle.load(file)
6 | file.close()
7 | return data
8 |
9 | def save_data(path, data):
10 | file = open(path, "wb")
11 | pickle.dump(data, file)
12 | file.close()
13 |
14 | def line_plot(xs, ys, title, path):
15 | plt.figure()
16 | plt.plot(xs, ys)
17 | plt.title(title)
18 | plt.savefig(path)
19 | plt.close()
20 |
21 | def stack_action_seqs(action_seqs):
22 | """
23 | :param action_seqs: (H, B, D)
24 | :return: actions (B, D)
25 | """
26 | return action_seqs.sum(0)
27 |
28 | def cal_errors_np(rs_pred, ts_pred, rs_lb, ts_lb, is_degree=False):
29 | """
30 | :param rs_pred: (B, 3)
31 | :param ts_pred: (B, 3)
32 | :param rs_lb: (B, 3)
33 | :param ts_lb: (B, 3)
34 | :param is_degree: bool
35 | :return: dict key: val (B, )
36 | """
37 | if not is_degree:
38 | rs_pred = np.degrees(rs_pred)
39 | rs_lb = np.degrees(rs_lb)
40 | rs_mse = np.mean((rs_pred - rs_lb) ** 2, 1) # (B, )
41 | ts_mse = np.mean((ts_pred - ts_lb) ** 2, 1)
42 | rs_rmse = np.sqrt(rs_mse)
43 | ts_rmse = np.sqrt(ts_mse)
44 | rs_mae = np.mean(np.abs(rs_pred - rs_lb), 1)
45 | ts_mae = np.mean(np.abs(ts_pred - ts_lb), 1)
46 |
47 | return {"rs_mse": rs_mse, "ts_mse": ts_mse,
48 | "rs_rmse": rs_rmse, "ts_rmse": ts_rmse,
49 | "rs_mae": rs_mae, "ts_mae": ts_mae}
50 |
--------------------------------------------------------------------------------
/datasets/utils/db_icl_nuim.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copyright (c) 2020 NVIDIA
3 | Author: Wentao Yuan
4 | '''
5 |
6 | import h5py
7 | import numpy as np
8 | import os
9 | import torch
10 | from scipy.spatial import cKDTree
11 | from torch.utils.data import Dataset
12 | from sklearn.neighbors import NearestNeighbors
13 | from scipy.spatial.distance import minkowski
14 | import open3d
15 | def jitter_pcd(pcd, sigma=0.01, clip=0.05):
16 | pcd += np.clip(sigma * np.random.randn(*pcd.shape), -1 * clip, clip)
17 | return pcd
18 |
19 | def random_pose(max_angle, max_trans):
20 | R = random_rotation(max_angle)
21 | t = random_translation(max_trans)
22 | return np.concatenate([np.concatenate([R, t], 1), [[0, 0, 0, 1]]], 0)
23 |
24 |
25 | def random_rotation(max_angle):
26 | axis = np.random.randn(3)
27 | axis /= np.linalg.norm(axis)
28 | angle = np.random.rand() * max_angle
29 | A = np.array([[0, -axis[2], axis[1]],
30 | [axis[2], 0, -axis[0]],
31 | [-axis[1], axis[0], 0]])
32 | R = np.eye(3) + np.sin(angle) * A + (1 - np.cos(angle)) * np.dot(A, A)
33 | return R
34 |
35 |
36 | def random_translation(max_dist):
37 | t = np.random.randn(3)
38 | t /= np.linalg.norm(t)
39 | t *= np.random.rand() * max_dist
40 | return np.expand_dims(t, 1)
41 |
42 | def farthest_subsample_points(pointcloud1, pointcloud2, num_subsampled_points=768):
43 | num_points = pointcloud1.shape[0]
44 | nbrs1 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto',
45 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud1)
46 | random_p1 = np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 1, -1])
47 | idx1 = nbrs1.kneighbors(random_p1, return_distance=False).reshape((num_subsampled_points,))
48 | nbrs2 = NearestNeighbors(n_neighbors=num_subsampled_points, algorithm='auto',
49 | metric=lambda x, y: minkowski(x, y)).fit(pointcloud2)
50 | random_p2 = random_p1 #np.random.random(size=(1, 3)) + np.array([[500, 500, 500]]) * np.random.choice([1, -1, 2, -2])
51 | idx2 = nbrs2.kneighbors(random_p2, return_distance=False).reshape((num_subsampled_points,))
52 | return pointcloud1[idx1, :], pointcloud2[idx2, :]
53 |
54 | class TestData(Dataset):
55 | def __init__(self, path):
56 | super(TestData, self).__init__()
57 | with h5py.File(path, 'r') as f:
58 | self.source = f['source'][...]
59 | self.target = f['target'][...]
60 | self.transform = f['transform'][...]
61 | self.n_points = 1024
62 | self.max_angle = 45 / 180 * np.pi
63 | self.max_trans = 1.0
64 | self.noisy = False
65 | self.subsampled = True
66 | self.num_subsampled_points = 768
67 |
68 | def __getitem__(self, index):
69 | np.random.seed(index)
70 | pcd1 = self.source[index][:self.n_points]
71 | pcd2 = self.target[index][:self.n_points]
72 | transform = self.transform[index]
73 | pcd1 = pcd1 @ transform[:3, :3].T + transform[:3, 3]
74 | transform = random_pose(self.max_angle, self.max_trans)
75 | pose1 = random_pose(np.pi, self.max_trans)
76 | pose2 = transform @ pose1
77 | pcd1 = pcd1 @ pose1[:3, :3].T + pose1[:3, 3]
78 | pcd2 = pcd2 @ pose2[:3, :3].T + pose2[:3, 3]
79 | R_ab = transform[:3, :3]
80 | translation_ab = transform[:3, 3]
81 |
82 | if self.subsampled:
83 | pcd1, pcd2 = farthest_subsample_points(pcd1, pcd2,num_subsampled_points = self.num_subsampled_points)
84 |
85 | # return pcd1.T.astype('float32'), pcd2.T.astype('float32'), R_ab.astype('float32'), \
86 | # translation_ab.astype('float32')
87 | pcd_src = open3d.geometry.PointCloud()
88 | pcd_src.points = open3d.utility.Vector3dVector(pcd1)
89 | open3d.geometry.estimate_normals(pcd_src)
90 | # print(np.asarray(pcd_src.points).shape)
91 | # print(np.asarray(pcd_src.normals).shape)
92 | # exit()
93 | pcd_tgt = open3d.geometry.PointCloud()
94 | pcd_tgt.points = open3d.utility.Vector3dVector(pcd2)
95 | open3d.geometry.estimate_normals(pcd_tgt)
96 | sample = {'idx': np.array(index, dtype=np.int32),'transform_gt':np.eye(4).astype('float32')}
97 | sample['points_src'] = np.concatenate([np.asarray(pcd_src.points),np.asarray(pcd_src.normals)],axis=1).astype('float32')
98 | sample['points_ref'] = np.concatenate([np.asarray(pcd_tgt.points),np.asarray(pcd_tgt.normals)],axis=1).astype('float32')
99 | sample['transform_gt'][:3, :3] = R_ab.astype('float32')
100 | sample['transform_gt'][:3, 3] = translation_ab.astype('float32')
101 | return sample
102 |
103 | def __len__(self):
104 | return self.transform.shape[0]
105 |
106 |
107 | class TrainData(Dataset):
108 | def __init__(self, path):
109 | super(TrainData, self).__init__()
110 | with h5py.File(path, 'r') as f:
111 | self.points = f['points'][...]
112 | self.n_points = 1024
113 | self.max_angle = 45 / 180 * np.pi
114 | self.max_trans = 1.0
115 | self.noisy = False
116 | self.subsampled = True
117 | self.num_subsampled_points = 768
118 |
119 | def __getitem__(self, index):
120 |
121 | pcd1 = self.points[index][:self.n_points]
122 | pcd2 = self.points[index][:self.n_points]
123 | transform = random_pose(self.max_angle, self.max_trans)
124 | pose1 = random_pose(np.pi, self.max_trans)
125 | pose2 = transform @ pose1
126 | pcd1 = pcd1 @ pose1[:3, :3].T + pose1[:3, 3]
127 | pcd2 = pcd2 @ pose2[:3, :3].T + pose2[:3, 3]
128 | R_ab = transform[:3, :3]
129 | translation_ab = transform[:3, 3]
130 | if self.subsampled:
131 | pcd1, pcd2 = farthest_subsample_points(pcd1, pcd2,num_subsampled_points = self.num_subsampled_points)
132 |
133 | # return pcd1.T.astype('float32'), pcd2.T.astype('float32'), R_ab.astype('float32'), \
134 | # translation_ab.astype('float32')
135 | pcd_src = open3d.geometry.PointCloud()
136 | pcd_src.points = open3d.utility.Vector3dVector(pcd1)
137 | open3d.geometry.estimate_normals(pcd_src)
138 | # print(np.asarray(pcd_src.points).shape)
139 | # print(np.asarray(pcd_src.normals).shape)
140 | # exit()
141 | pcd_tgt = open3d.geometry.PointCloud()
142 | pcd_tgt.points = open3d.utility.Vector3dVector(pcd2)
143 | open3d.geometry.estimate_normals(pcd_tgt)
144 | sample = {'idx': np.array(index, dtype=np.int32),'transform_gt':np.eye(4).astype('float32')}
145 | sample['points_src'] = np.concatenate([np.asarray(pcd_src.points),np.asarray(pcd_src.normals)],axis=1).astype('float32')
146 | sample['points_ref'] = np.concatenate([np.asarray(pcd_tgt.points),np.asarray(pcd_tgt.normals)],axis=1).astype('float32')
147 | sample['transform_gt'][:3, :3] = R_ab.astype('float32')
148 | sample['transform_gt'][:3, 3] = translation_ab.astype('float32')
149 | return sample
150 |
151 | def __len__(self):
152 | return self.points.shape[0]
153 |
154 |
--------------------------------------------------------------------------------
/datasets/utils/gen_normal.py:
--------------------------------------------------------------------------------
1 | import open3d as o3d, numpy as np, pdb
2 |
3 | def gen_normal(pcs):
4 | """
5 | :param pcs: shape (B, 3, N), np.array
6 | :return: shape (B, 6, N)
7 | """
8 | normal_pcs = np.zeros([len(pcs), 6, pcs.shape[2]])
9 | for idx, pc in enumerate(pcs):
10 | _pc = o3d.geometry.PointCloud()
11 | _pc.points = o3d.utility.Vector3dVector(pc.transpose([1, 0]))
12 | o3d.geometry.estimate_normals(_pc, search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
13 | normal_pc = np.concatenate([np.asarray(_pc.points), np.asarray(_pc.normals)], 1).transpose([1, 0])
14 | normal_pcs[idx] = normal_pc
15 | return normal_pcs
16 |
17 | if __name__ == '__main__':
18 | pcs = np.zeros([5, 3, 1024])
19 | gen_normal(pcs)
--------------------------------------------------------------------------------
/datasets/utils/npmat2euler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.transform import Rotation
3 |
4 | def npmat2euler(mats, seq='zyx', is_degrees=True):
5 | eulers = []
6 | for i in range(mats.shape[0]):
7 | r = Rotation.from_dcm(mats[i])
8 | eulers.append(r.as_euler(seq, degrees=is_degrees))
9 | return np.asarray(eulers, dtype='float32')
--------------------------------------------------------------------------------
/modules/__pycache__/commons.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/commons.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/dcp_net.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/dcp_net.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/dgcnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/dgcnn.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/sparsemax.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/modules/__pycache__/sparsemax.cpython-36.pyc
--------------------------------------------------------------------------------
/modules/commons.py:
--------------------------------------------------------------------------------
1 | import copy, torch.nn as nn, torch, math, torch.nn.functional as F
2 |
3 | def knn(x, k):
4 | inner = -2 * torch.matmul(x.transpose(2, 1).contiguous(), x)
5 | xx = torch.sum(x ** 2, dim=1, keepdim=True)
6 | pairwise_distance = -xx - inner - xx.transpose(2, 1).contiguous()
7 |
8 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
9 | return idx
10 |
11 | def get_graph_feature(x, k=20):
12 | # x = x.squeeze()
13 | idx = knn(x, k=k) # (batch_size, num_points, k)
14 | batch_size, num_points, _ = idx.size()
15 | device = torch.device('cuda')
16 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
17 | idx = idx + idx_base
18 | idx = idx.view(-1)
19 | _, num_dims, _ = x.size()
20 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
21 | feature = x.view(batch_size * num_points, -1)[idx, :]
22 | feature = feature.view(batch_size, num_points, k, num_dims)#[b,n,k,3],knn
23 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)#[b,n,k,3],central points
24 | feature = torch.cat((feature, x), dim=3).permute(0, 3, 1, 2)
25 | return feature
26 |
27 | def get_graph_feature2(x, x_features, k=20):
28 | idx = knn(x, k=k) # (batch_size, num_points, k)
29 | batch_size, num_points, _ = idx.size()
30 |
31 | device = torch.device('cuda')
32 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
33 | idx = idx + idx_base
34 | idx = idx.view(-1)
35 | _, num_dims, _ = x.size()
36 | _,num_dims_features,_ = x_features.size()
37 |
38 | x_features = x_features.transpose(2,1).contiguous()#[b,n,c]
39 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
40 | feature = x.view(batch_size * num_points, -1)[idx, :]
41 | x_feature_k = x_features.view(batch_size * num_points, -1)[idx, :]
42 | feature = feature.view(batch_size, num_points, k, num_dims)#[b,n,k,3],knn
43 | x_feature_k = x_feature_k.view(batch_size,num_points,k,num_dims_features).permute(0,3,1,2)#[b,n,k,c],knn
44 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)#[b,n,k,3],central points
45 | feature = (feature - x).permute(0,3,1,2) #[b,n,k,3]
46 |
47 | return feature,x_feature_k
48 |
49 | def pairwise_distance_batch(x,y):
50 |
51 | xx = torch.sum(torch.mul(x,x), 1, keepdim = True)#[b,1,n]
52 | yy = torch.sum(torch.mul(y,y),1, keepdim = True) #[b,1,n]
53 | inner = -2 * torch.matmul(x.transpose(2,1),y) #[b,n,n]
54 |
55 | pair_distance = xx.transpose(2,1) + inner + yy #[b,n,n]
56 | device = torch.device('cuda')
57 | # print("3")
58 | zeros_matrix = torch.zeros_like(pair_distance,device = device)
59 | pair_distance_square = torch.where(pair_distance > 0.0,pair_distance,zeros_matrix)
60 | error_mask = torch.le(pair_distance_square,0.0)
61 | # print("4")
62 | pair_distances = torch.sqrt(pair_distance_square + error_mask.float()*1e-16)
63 | pair_distances = torch.mul(pair_distances,(1.0-error_mask.float()))
64 | # print("5")
65 | return pair_distances
66 |
67 | def clones(module, N):
68 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
69 |
70 | class LayerNorm(nn.Module):
71 | def __init__(self, features, eps=1e-6):
72 | super(LayerNorm, self).__init__()
73 | self.a_2 = nn.Parameter(torch.ones(features))
74 | self.b_2 = nn.Parameter(torch.zeros(features))
75 | self.eps = eps
76 |
77 | def forward(self, x):
78 | mean = x.mean(-1, keepdim=True)
79 | std = x.std(-1, keepdim=True)
80 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
81 |
82 | class Identity(nn.Module):
83 | def __init__(self):
84 | super(Identity, self).__init__()
85 |
86 | class SublayerConnection(nn.Module):
87 | def __init__(self, size, dropout=None):
88 | super(SublayerConnection, self).__init__()
89 | self.norm = LayerNorm(size)
90 |
91 | def forward(self, x, sublayer):
92 | return x + sublayer(self.norm(x))
93 |
94 | def attention(query, key, value, mask=None, dropout=None):
95 | d_k = query.size(-1)
96 | scores = torch.matmul(query, key.transpose(-2, -1).contiguous()) / math.sqrt(d_k)
97 | if mask is not None:
98 | scores = scores.masked_fill(mask == 0, -1e9)
99 | p_attn = F.softmax(scores, dim=-1)
100 | return torch.matmul(p_attn, value), p_attn
101 |
102 | class MultiHeadedAttention(nn.Module):
103 | def __init__(self, h, d_model, dropout=0.1):
104 | "Take in model size and number of heads."
105 | super(MultiHeadedAttention, self).__init__()
106 | assert d_model % h == 0
107 | # We assume d_v always equals d_k
108 | self.d_k = d_model // h
109 | self.h = h
110 | self.linears = clones(nn.Linear(d_model, d_model), 4)
111 | self.attn = None
112 | self.dropout = None
113 |
114 | def forward(self, query, key, value, mask=None):
115 | "Implements Figure 2"
116 | if mask is not None:
117 | # Same mask applied to all h heads.
118 | mask = mask.unsqueeze(1)
119 | nbatches = query.size(0)
120 |
121 | # 1) Do all the linear projections in batch from d_model => h x d_k
122 | query, key, value = \
123 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2).contiguous()
124 | for l, x in zip(self.linears, (query, key, value))]
125 |
126 | # 2) Apply attention on all the projected vectors in batch.
127 | x, self.attn = attention(query, key, value, mask=mask,
128 | dropout=self.dropout)
129 |
130 | # 3) "Concat" using a view and apply a final linear.
131 | x = x.transpose(1, 2).contiguous() \
132 | .view(nbatches, -1, self.h * self.d_k)
133 | return self.linears[-1](x)
134 |
135 | class PositionwiseFeedForward(nn.Module):
136 | "Implements FFN equation."
137 |
138 | def __init__(self, d_model, d_ff, dropout=0.1):
139 | super(PositionwiseFeedForward, self).__init__()
140 | self.w_1 = nn.Linear(d_model, d_ff)
141 | self.norm = nn.Sequential() # nn.BatchNorm1d(d_ff)
142 | self.w_2 = nn.Linear(d_ff, d_model)
143 | self.dropout = None
144 |
145 | def forward(self, x):
146 | return self.w_2(self.norm(F.relu(self.w_1(x)).transpose(2, 1).contiguous()).transpose(2, 1).contiguous())
--------------------------------------------------------------------------------
/modules/dcp_net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn, torch
2 | from modules.dgcnn import DGCNN
3 | from utils.mat2euler import mat2euler
4 |
5 | def pairwise_distance_batch(x,y):
6 | xx = torch.sum(torch.mul(x,x), 1, keepdim = True)#[b,1,n]
7 | yy = torch.sum(torch.mul(y,y),1, keepdim = True) #[b,1,n]
8 | inner = -2*torch.matmul(x.transpose(2,1),y) #[b,n,n]
9 | pair_distance = xx.transpose(2,1) + inner + yy #[b,n,n]
10 | device = torch.device('cuda')
11 | zeros_matrix = torch.zeros_like(pair_distance,device = device)
12 | pair_distance_square = torch.where(pair_distance > 0.0,pair_distance,zeros_matrix)
13 | error_mask = torch.le(pair_distance_square,0.0)
14 | pair_distances = torch.sqrt(pair_distance_square + error_mask.float()*1e-16)
15 | pair_distances = torch.mul(pair_distances,(1.0-error_mask.float()))
16 | return pair_distances
17 |
18 | class DCPNet(nn.Module):
19 | def __init__(self, opts):
20 | super(DCPNet, self).__init__()
21 | self.emb_nn = DGCNN(emb_dims = opts.emb_dims)
22 | self.emb_dims = opts.emb_dims
23 | self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
24 | self.reflect[2, 2] = -1
25 |
26 | self.opts = opts
27 | self.planning_horizon = self.opts.cem.planning_horizon
28 | self.nn = nn.Sequential(nn.Linear(self.emb_dims * 2, self.emb_dims // 4),
29 | nn.BatchNorm1d(self.emb_dims // 4),
30 | nn.LeakyReLU(),
31 | nn.Linear(self.emb_dims // 4, self.emb_dims // 8),
32 | nn.BatchNorm1d(self.emb_dims // 8),
33 | nn.LeakyReLU(),
34 | # nn.Linear(self.emb_dims // 4, self.emb_dims // 8),
35 | # nn.BatchNorm1d(self.emb_dims // 8),
36 | # nn.LeakyReLU(),
37 | nn.Linear(self.emb_dims // 8, 6))
38 |
39 | def forward(self, srcs, tgts, is_sigma=False):
40 | batch_size = len(srcs)
41 | srcs_emb = self.emb_nn(srcs) # 3, 512, 1024
42 | tgts_emb = self.emb_nn(tgts)
43 | scores = -pairwise_distance_batch(srcs_emb, tgts_emb)
44 | scores = torch.softmax(scores, dim=2)
45 | src_corr = torch.matmul(tgts, scores.transpose(2, 1).contiguous())
46 | src_centered = srcs - srcs.mean(dim=2, keepdim=True)
47 | src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True)
48 | H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous())
49 | U, S, V = [], [], []
50 | R = []
51 |
52 | for i in range(srcs.size(0)):
53 | u, s, v = torch.svd(H[i])
54 | r = torch.matmul(v, u.transpose(1, 0).contiguous())
55 | r_det = torch.det(r)
56 | if r_det < 0:
57 | u, s, v = torch.svd(H[i])
58 | v = torch.matmul(v, self.reflect)
59 | r = torch.matmul(v, u.transpose(1, 0).contiguous())
60 | R.append(r)
61 | U.append(u)
62 | S.append(s)
63 | V.append(v)
64 | R = torch.stack(R, dim=0)
65 | t = (torch.matmul(-R, srcs.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)).reshape(batch_size, -1)
66 | r = mat2euler(R)
67 | if is_sigma:
68 | sg = torch.cat([srcs_emb, tgts_emb], 1)
69 | sigmas = self.nn(sg.max(dim=-1)[0])
70 | sigmas[:, :3] = torch.nn.Sigmoid()(sigmas[:, :3]) * 1.0
71 | sigmas[:, 3:] = torch.nn.Sigmoid()(sigmas[:, 3:]) * 1.0
72 | return torch.cat([r, t], 1).unsqueeze(0), sigmas.unsqueeze(0)
73 | else:
74 | return {"r": r, "t": t}
--------------------------------------------------------------------------------
/modules/dgcnn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn, torch.nn.functional as F, torch
2 | from modules.commons import get_graph_feature
3 |
4 | class DGCNN(nn.Module):
5 | def __init__(self, emb_dims=512):
6 | super(DGCNN, self).__init__()
7 | self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False)
8 | self.conv2 = nn.Conv2d(64, 64, kernel_size=1, bias=False)
9 | self.conv3 = nn.Conv2d(64, 128, kernel_size=1, bias=False)
10 | self.conv4 = nn.Conv2d(128, 256, kernel_size=1, bias=False)
11 | self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
12 | self.bn1 = nn.BatchNorm2d(64)
13 | self.bn2 = nn.BatchNorm2d(64)
14 | self.bn3 = nn.BatchNorm2d(128)
15 | self.bn4 = nn.BatchNorm2d(256)
16 | self.bn5 = nn.BatchNorm2d(emb_dims)
17 |
18 | def forward(self, x):
19 | batch_size, num_dims, num_points = x.size()
20 | x = get_graph_feature(x)
21 | x = F.relu(self.bn1(self.conv1(x)))
22 | x1 = x.max(dim=-1, keepdim=True)[0]
23 | x = F.relu(self.bn2(self.conv2(x)))
24 | x2 = x.max(dim=-1, keepdim=True)[0]
25 | x = F.relu(self.bn3(self.conv3(x)))
26 | x3 = x.max(dim=-1, keepdim=True)[0]
27 | x = F.relu(self.bn4(self.conv4(x)))
28 | x4 = x.max(dim=-1, keepdim=True)[0]
29 | x = torch.cat((x1, x2, x3, x4), dim=1)
30 | x = F.relu(self.bn5(self.conv5(x))).view(batch_size, -1, num_points)
31 | return x
--------------------------------------------------------------------------------
/modules/sparsemax.py:
--------------------------------------------------------------------------------
1 | import torch, torch.nn as nn
2 |
3 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4 |
5 | class Sparsemax(nn.Module):
6 | """Sparsemax function."""
7 |
8 | def __init__(self, dim=None):
9 | """Initialize sparsemax activation
10 |
11 | Args:
12 | dim (int, optional): The dimension over which to apply the sparsemax function.
13 | """
14 | super(Sparsemax, self).__init__()
15 |
16 | self.dim = -1 if dim is None else dim
17 |
18 | def forward(self, input):
19 | """Forward function.
20 |
21 | Args:
22 | input (torch.Tensor): Input tensor. First dimension should be the batch size
23 |
24 | Returns:
25 | torch.Tensor: [batch_size x number_of_logits] Output tensor
26 |
27 | """
28 | # Sparsemax currently only handles 2-dim tensors,
29 | # so we reshape to a convenient shape and reshape back after sparsemax
30 | input = input.transpose(0, self.dim)
31 | original_size = input.size()
32 | input = input.reshape(input.size(0), -1)
33 | input = input.transpose(0, 1)
34 | dim = 1
35 |
36 | number_of_logits = input.size(dim)
37 |
38 | # Translate input by max for numerical stability
39 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
40 |
41 | # Sort input in descending order.
42 | # (NOTE: Can be replaced with linear time selection method described here:
43 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
44 | zs = torch.sort(input=input, dim=dim, descending=True)[0]
45 | range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
46 | range = range.expand_as(zs)
47 |
48 | # Determine sparsity of projection
49 | bound = 1 + range * zs
50 | cumulative_sum_zs = torch.cumsum(zs, dim)
51 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
52 | k = torch.max(is_gt * range, dim, keepdim=True)[0]
53 |
54 | # Compute threshold function
55 | zs_sparse = is_gt * zs
56 |
57 | # Compute taus
58 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
59 | taus = taus.expand_as(input)
60 |
61 | # Sparsemax
62 | self.output = torch.max(torch.zeros_like(input), input - taus)
63 |
64 | # Reshape back to original shape
65 | output = self.output
66 | output = output.transpose(0, 1)
67 | output = output.reshape(original_size)
68 | output = output.transpose(0, self.dim)
69 |
70 | return output
71 |
72 | def backward(self, grad_output):
73 | """Backward function."""
74 | dim = 1
75 | nonzeros = torch.ne(self.output, 0)
76 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
77 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
78 |
79 | return self.grad_input
80 |
81 | if __name__ == '__main__':
82 | import torch
83 |
84 | sparsemax = Sparsemax(dim=1)
85 | softmax = torch.nn.Softmax(dim=1)
86 |
87 | logits = torch.rand(10, 5).cuda()
88 | print("\nLogits")
89 | print(logits)
90 |
91 | # softmax_probs = softmax(logits)
92 | # print("\nSoftmax probabilities")
93 | # print(softmax_probs)
94 |
95 | sparsemax_probs = sparsemax(logits)
96 | print("\nSparsemax probabilities")
97 | print(sparsemax_probs)
--------------------------------------------------------------------------------
/results/icl_nuim_n768_unseen0_noise0_seed123/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/icl_nuim_n768_unseen0_noise0_seed123/model.pth
--------------------------------------------------------------------------------
/results/modelnet40_n768_unseen0_noise0_seed123/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/modelnet40_n768_unseen0_noise0_seed123/model.pth
--------------------------------------------------------------------------------
/results/modelnet40_n768_unseen0_noise1_seed123/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/modelnet40_n768_unseen0_noise1_seed123/model.pth
--------------------------------------------------------------------------------
/results/modelnet40_n768_unseen1_noise0_seed123/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/modelnet40_n768_unseen1_noise0_seed123/model.pth
--------------------------------------------------------------------------------
/results/scene7_n768_unseen0_noise0_seed123/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/results/scene7_n768_unseen0_noise0_seed123/model.pth
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | cd cemnet_lib
2 | python3 setup.py install
3 | cd ..
4 |
5 | cd batch_svd
6 | python3 setup.py install
7 | cd ..
--------------------------------------------------------------------------------
/test_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np, torch, time
2 | from datasets.get_dataset import get_dataset
3 | from utils.options import opts
4 | from utils.recorder import Recorder
5 | from utils.attr_dict import AttrDict
6 | from cems.guided_cem import GuidedCEM
7 | from tqdm import tqdm
8 |
9 | np.random.seed(opts.seed)
10 | torch.manual_seed(opts.seed)
11 | torch.cuda.manual_seed_all(opts.seed)
12 | torch.backends.cudnn.enabled = True
13 | torch.backends.cudnn.benchmark = True
14 |
15 | opts.db_nm = "scene7"
16 | opts.db = AttrDict(
17 | modelnet40 = AttrDict(
18 | path = "/test/datasets/registration/modelnet40/modelnet40_normal_n2048.pth",
19 | cls_idx = -1, # None_class: -1, airplane: 0, car: 7, chair: 8, table: 33, lamp: 19
20 | is_neg_angle = False,
21 | unseen = False,
22 | gaussian_noise = False,
23 | n_points = 1024,
24 | n_sub_points = 768,
25 | factor = 4
26 | ),
27 | scene7 = AttrDict(
28 | path = "/test/datasets/registration/7scene/7scene_normal_n2048.pth",
29 | cls_idx = -1,
30 | is_neg_angle = False,
31 | unseen = False,
32 | gaussian_noise = False,
33 | n_points = 1024,
34 | n_sub_points = 768,
35 | factor = 4
36 | ),
37 | icl_nuim = AttrDict(
38 | path = "/test/datasets/registration/icl_nuim/icl_nuim_normal_n2048.pth",
39 | cls_idx = -1,
40 | is_neg_angle = False,
41 | unseen = False,
42 | gaussian_noise = False,
43 | n_points = 1024,
44 | n_sub_points = 768,
45 | factor = 4
46 | )
47 | )
48 | opts.db = opts.db[opts.db_nm]
49 |
50 | def init_opts(opts):
51 | opts.is_train = False
52 | opts.cem.metric_type = [["MC", 0.1], ["CD"], ["GM", 0.01]][0]
53 | opts.cem.n_candidates = [1000, {"minibatch": 1000}]
54 | opts.cem.is_fused_reward = [True, 0.5, 3]
55 | opts.exploration_weight = 0.5
56 | opts.cem.n_iters = 10
57 | opts.is_debug = True
58 | opts.cem.init_sigma = AttrDict(
59 | modelnet40 = [1.0, 0.5],
60 | scene7 = [1.0, 0.5],
61 | icl_nuim = [1.0, 0.5]
62 | )
63 |
64 | # 1. ModelNet40 - Unseen Object
65 | # opts.model_path = "./results/modelnet40_n768_unseen0_noise0_seed123/model.pth"
66 |
67 | # 2. ModelNet40 - Unseen Catergory
68 | # opts.model_path = "./results/modelnet40_n768_unseen1_noise0_seed123/model.pth"
69 |
70 | # 3. ModelNet40 - Noise
71 | # opts.model_path = "./results/modelnet40_n768_unseen0_noise1_seed123/model.pth"
72 |
73 | # 4. 7Scene
74 | opts.model_path = "./results/scene7_n768_unseen0_noise0_seed123/model.pth"
75 |
76 | # 5. ICL-NUIM
77 | # opts.model_path = "./results/icl_nuim_n768_unseen0_noise0_seed123/model.pth"
78 |
79 | return opts
80 |
81 | def test(opts, model, test_loader):
82 | rcd_times, n_cnt = 0, 0
83 | with torch.no_grad():
84 | r_mses, t_mses, r_maes, t_maes = [], [], [], []
85 | for srcs, tgts, rs_lb, ts_lb in tqdm(test_loader):
86 | srcs, tgts = srcs.cuda(), tgts.cuda()
87 | t1 = time.time()
88 | results = model(srcs, tgts)
89 | rcd_times += time.time() - t1
90 | n_cnt += len(srcs)
91 | rs_pred, ts_pred = results["r"], results["t"]
92 | r_mses.append(np.mean((np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())) ** 2, 1))
93 | r_maes.append(np.mean(np.abs(np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())), 1))
94 | t_mses.append(np.mean((ts_pred.cpu().numpy() - ts_lb.numpy()) ** 2, 1))
95 | t_maes.append(np.mean(np.abs(ts_pred.cpu().numpy() - ts_lb.numpy()), 1))
96 |
97 | r_mse = np.mean(np.concatenate(r_mses, 0)).item()
98 | t_mse = np.mean(np.concatenate(t_mses, 0)).item()
99 | r_mae = np.mean(np.concatenate(r_maes, 0)).item()
100 | t_mae = np.mean(np.concatenate(t_maes, 0)).item()
101 |
102 | print("--- Test: r_mse: %.8f, t_mse: %.8f, r_rmse: %.8f, t_rmse: %.8f, r_mae: %.8f, t_mae: %.8f, time: %.8f ---" % (
103 | r_mse, t_mse, np.sqrt(r_mse), np.sqrt(t_mse), r_mae, t_mae, rcd_times / n_cnt))
104 |
105 | def model_test(opts):
106 | np.random.seed(opts.seed)
107 | torch.manual_seed(opts.seed)
108 | torch.cuda.manual_seed_all(opts.seed)
109 | torch.backends.cudnn.enabled = True
110 | torch.backends.cudnn.benchmark = True
111 |
112 | ## initial setting
113 | opts.recorder = Recorder(opts)
114 | model = GuidedCEM(opts).to(opts.device).load_model(opts.model_path)
115 |
116 | test_loader, db1 = get_dataset(opts, db_nm=opts.db_nm, partition="test", is_normal=False, batch_size=opts.batch_size, shuffle=False, drop_last=False, cls_idx=opts.db.cls_idx)
117 | ## testing
118 | test(opts, model, test_loader)
119 |
120 | if __name__ == '__main__':
121 | opts = init_opts(opts)
122 | model_test(opts)
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np, torch, os, pdb
2 | from utils.options import opts
3 | from utils.transform_pc import transform_pc_torch
4 | from utils.euler2mat import euler2mat_torch
5 | from utils.recorder import Recorder
6 | from utils.test import test
7 | from utils.losses import CDLoss, GMLoss
8 | from datasets.get_dataset import get_dataset
9 | from cems.guided_cem import GuidedCEM
10 | from torch.optim.lr_scheduler import MultiStepLR
11 | from tqdm import tqdm
12 |
13 | np.random.seed(opts.seed)
14 | torch.manual_seed(opts.seed)
15 | torch.cuda.manual_seed_all(opts.seed)
16 | torch.backends.cudnn.enabled = True
17 | torch.backends.cudnn.benchmark = True
18 |
19 | def init_opts(opts):
20 | opts.is_debug = True
21 | opts.is_train = True
22 | opts.loss_type = [["GM", 0.01], ["CD", -1.0]][0]
23 | opts.cem.metric_type = [["MC", 0.1], ["CD"], ["GM", 0.01]][0]
24 | opts.cem.n_candidates = [1000, {"minibatch": 1000}]
25 | opts.cem.n_iters = 10
26 | opts.cem.is_fused_reward = [False, -1.0, 0]
27 | opts.results_dir = "./results/%s_n%d_unseen%d_noise%d_seed%s_v0" % (
28 | opts.db_nm, opts.db.n_sub_points, opts.db.unseen, opts.db.gaussian_noise, opts.seed)
29 | if not opts.is_debug:
30 | os.makedirs(opts.results_dir, exist_ok=True)
31 | return opts
32 |
33 | def main(opts):
34 | ## initial setting
35 | opts = init_opts(opts)
36 | opts.recorder = Recorder(opts)
37 | model = GuidedCEM(opts).to(opts.device)
38 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
39 | scheduler = MultiStepLR(optimizer, milestones=[35, 100, 150], gamma=0.7)
40 | if opts.loss_type[0] == "CD":
41 | loss_func = CDLoss(opts)
42 | elif opts.loss_type[0] == "GM":
43 | loss_func = GMLoss(opts)
44 | train_loader, _ = get_dataset(opts, db_nm=opts.db_nm, partition="train", is_normal=False, batch_size=opts.batch_size, shuffle=True, drop_last=False)
45 | test_loader, _ = get_dataset(opts, db_nm=opts.db_nm, partition="test", is_normal=False, batch_size=opts.batch_size, shuffle=False, drop_last=False)
46 |
47 | ## training
48 | print(opts.results_dir)
49 | for epoch in range(opts.n_epochs):
50 | scheduler.step()
51 | ## train
52 | losses = []
53 | for srcs, tgts, rs_lb, ts_lb in tqdm(train_loader):
54 | srcs, tgts, rs_lb, ts_lb = [x.to(opts.device) for x in [srcs, tgts, rs_lb, ts_lb]]
55 | results = model(srcs, tgts)
56 | rs_pred, ts_pred, rs_prior, ts_prior = results["r"], results["t"], results["r_init"], results["t_init"]
57 | transform_srcs_pred = transform_pc_torch(srcs, euler2mat_torch(rs_pred), ts_pred)
58 | transform_srcs_prior = transform_pc_torch(srcs, euler2mat_torch(rs_prior), ts_prior)
59 | loss = loss_func(transform_srcs_pred, tgts) + loss_func(transform_srcs_prior, tgts)
60 | if torch.isnan(loss):
61 | print("None, skip")
62 | continue
63 | optimizer.zero_grad()
64 | loss.backward()
65 | optimizer.step()
66 | losses.append(loss.item())
67 | print("Epoch[%d], losses: %.9f, %s." % (epoch, np.mean(losses), opts.results_dir))
68 |
69 | ## test
70 | test(opts, model, test_loader, epoch)
71 |
72 | if __name__ == '__main__':
73 | main(opts)
74 |
75 |
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/utils/__pycache__/attr_dict.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/attr_dict.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/batch_icp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/batch_icp.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/commons.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/commons.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/euler2mat.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/euler2mat.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/losses.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/losses.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/mat2euler.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/mat2euler.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/options.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/recorder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/recorder.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/test.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/transform_pc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiang-HB/CEMNet/2ec7959733e13e007fa8c4daf3974312cf1bf825/utils/__pycache__/transform_pc.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/attr_dict.py:
--------------------------------------------------------------------------------
1 | class AttrDict(dict):
2 |
3 | def __init__(self, *args, **kwargs):
4 | super(AttrDict, self).__init__(*args, **kwargs)
5 |
6 | def __getattr__(self, key):
7 | if key.startswith('__'):
8 | raise AttributeError
9 | return self.get(key, None)
10 |
11 | def __setattr__(self, key, value):
12 | if key.startswith('__'):
13 | raise AttributeError("Cannot set magic attribute '{}'".format(key))
14 | self[key] = value
--------------------------------------------------------------------------------
/utils/batch_icp.py:
--------------------------------------------------------------------------------
1 | import torch, pdb
2 | from utils.transform_pc import transform_pc_torch
3 | from torch_batch_svd import svd
4 | from cemnet_lib.functions import closest_point
5 |
6 | def one_step(srcs, tgt, Rs, ts):
7 |
8 | xs_mean = srcs.mean(2, keepdim=True) # [B, 3, 1]
9 | xs_centered = srcs - xs_mean # [B, 3, N]
10 | ys = closest_point(transform_pc_torch(srcs, Rs, ts), tgt)
11 |
12 | ys_mean = ys.mean(2, keepdim=True)
13 | ys_centered = ys - ys_mean
14 |
15 |
16 | H = torch.matmul(xs_centered, ys_centered.transpose(2, 1).contiguous())
17 | u, _, v = svd(H)
18 | Rs = torch.matmul(v, u.transpose(2, 1)).contiguous()
19 | r_det = torch.det(Rs)
20 | # Rs[:, 2, 2] = r_det
21 | diag = torch.eye(3, 3).unsqueeze(0).repeat(len(Rs), 1, 1).to(srcs.device)
22 | diag[:, 2, 2] = r_det
23 | Rs = torch.matmul(torch.matmul(v, diag), u.transpose(2, 1)).contiguous()
24 |
25 | # idxs = torch.where(r_det < 0)[0]
26 | # Rs[idxs, :, 2] *= -1
27 | # Rs[idxs, 2, :] *= -1
28 | ts = torch.matmul(- Rs, xs_mean) + ys_mean
29 |
30 | return Rs, ts.squeeze(-1)
31 |
32 | def batch_icp(opts, srcs, tgt, Rs=None, ts=None, is_path=False):
33 |
34 | # srcs(b, c, n) tgt(c, n)
35 | if Rs is None:
36 | Rs = torch.eye(3).unsqueeze(0).repeat(len(srcs), 1, 1).to(srcs.device)
37 | if ts is None:
38 | ts = torch.zeros(len(srcs), 3).to(srcs.device)
39 | paths = [[Rs, ts]]
40 | for i in range(3):
41 | Rs, ts = one_step(srcs, tgt, Rs, ts)
42 | if is_path:
43 | paths.append([Rs, ts])
44 | if is_path:
45 | return Rs, ts, paths
46 | else:
47 | return Rs, ts
--------------------------------------------------------------------------------
/utils/commons.py:
--------------------------------------------------------------------------------
1 | import pickle, numpy as np, matplotlib, pdb, torch, os
2 | matplotlib.use("Agg")
3 | import matplotlib.pyplot as plt
4 | from utils.mat2euler import mat2euler_torch
5 | from utils.euler2mat import euler2mat_torch
6 |
7 | def load_data(path):
8 | file = open(path, "rb")
9 | data = pickle.load(file)
10 | file.close()
11 | return data
12 |
13 | def save_data(path, data):
14 | file = open(path, "wb")
15 | pickle.dump(data, file)
16 | file.close()
17 |
18 | def chunker_list(seq, size):
19 | return [seq[pos: pos + size] for pos in range(0, len(seq), size)]
20 |
21 | def chunker_num(num, size):
22 | return [list(range(num))[pos: pos + size] for pos in range(0, num, size)]
23 |
24 | def stack_action_seqs(action_seqs):
25 | """
26 | :param action_seqs: (H, B, D)
27 | :return: actions (B, D)
28 | """
29 | return action_seqs.sum(0)
30 |
31 | def line_plot(xs, ys, title, path):
32 | plt.figure()
33 | plt.plot(xs, ys)
34 | plt.title(title)
35 | plt.savefig(path)
36 | plt.close()
37 |
38 | def shuffle_along_axis(a, axis):
39 | idx = np.random.rand(*a.shape).argsort(axis=axis)
40 | return np.take_along_axis(a,idx,axis=axis)
41 |
42 | def cal_errors_np(rs_pred, ts_pred, rs_lb, ts_lb, is_degree=False):
43 | """
44 | :param rs_pred: (B, 3)
45 | :param ts_pred: (B, 3)
46 | :param rs_lb: (B, 3)
47 | :param ts_lb: (B, 3)
48 | :param is_degree: bool
49 | :return: dict key: val (B, )
50 | """
51 | if not is_degree:
52 | rs_pred = np.degrees(rs_pred)
53 | rs_lb = np.degrees(rs_lb)
54 | rs_mse = np.mean((rs_pred - rs_lb) ** 2, 1) # (B, )
55 | ts_mse = np.mean((ts_pred - ts_lb) ** 2, 1)
56 | rs_rmse = np.sqrt(rs_mse)
57 | ts_rmse = np.sqrt(ts_mse)
58 | rs_mae = np.mean(np.abs(rs_pred - rs_lb), 1)
59 | ts_mae = np.mean(np.abs(ts_pred - ts_lb), 1)
60 |
61 | return {"rs_mse": rs_mse, "ts_mse": ts_mse,
62 | "rs_rmse": rs_rmse, "ts_rmse": ts_rmse,
63 | "rs_mae": rs_mae, "ts_mae": ts_mae}
64 |
65 | def plot_pc(pcs, save_path):
66 | # pcs = [[[nm, pc], [nm, pc]], [[nm, pc]]] pc (3, N)
67 | N = len(pcs)
68 | n_col = 3
69 | n_row = np.ceil(N / n_col)
70 | plt.figure(figsize=(n_col * 4, n_row * 4))
71 | colors = [(244/255, 17/255, 10/255), (44/255, 175/255, 53/255), (18/255, 72/255, 148/255), (246/255, 130/255, 11/255)]
72 | for i, _pcs in enumerate(pcs):
73 | ax = plt.subplot(n_row, n_col, i + 1, projection='3d')
74 | for j, (lb, pc) in enumerate(_pcs):
75 | ax.scatter(pc[0], pc[1], pc[2], color=colors[j], marker='.', label=lb)
76 | ax.legend(fontsize=12, frameon=True)
77 | plt.savefig(save_path)
78 | # plt.close()
79 |
80 | def stack_transforms(transforms1, transforms2):
81 | """
82 | :param transforms1: (B, 6), tensor
83 | :return: transforms2 (B, 6), tensor
84 | """
85 | rs1, ts1 = transforms1[:, :3], transforms1[:, 3:] # (B, 3)
86 | rs2, ts2 = transforms2[:, :3], transforms2[:, 3:] # (B, 3)
87 | Rs1 = euler2mat_torch(rs1) # (B, 3, 3)
88 | Rs2 = euler2mat_torch(rs2) # (B, 3, 3)
89 | Rs = torch.matmul(Rs2, Rs1) # (B, 3, 3)
90 | ts = (torch.matmul(Rs2, ts1.unsqueeze(2)) + ts2.unsqueeze(2)).squeeze(2) # (B, 3)
91 | rs = mat2euler_torch(Rs, is_degrees=False) # (B, 3)
92 | return torch.cat([rs, ts], 1) # (B, 6)
93 |
94 | def stack_transforms_seq(transforms):
95 | """
96 | :param transforms: (L, B, 6), tensor
97 | """
98 | L = len(transforms)
99 | if L == 1:
100 | actions = transforms[0, :, :]
101 | else:
102 | for l in range(L - 1):
103 | if l == 0:
104 | actions = stack_transforms(transforms[l, :, :], transforms[l + 1, :, :])
105 | else:
106 | actions = stack_transforms(actions, transforms[l + 1, :, :])
107 |
108 | return actions # (B, 6)
109 |
--------------------------------------------------------------------------------
/utils/euler2mat.py:
--------------------------------------------------------------------------------
1 | import numpy as np, torch
2 |
3 | def euler2mat_np(rs, seq="zyx"):
4 | assert seq == "zyx", "Invalid euler seq."
5 | rs_z, rs_y, rs_x = rs[:, [0]], rs[:, [1]], rs[:, [2]]
6 | sinx, siny, sinz = np.sin(rs_x), np.sin(rs_y), np.sin(rs_z)
7 | cosx, cosy, cosz = np.cos(rs_x), np.cos(rs_y), np.cos(rs_z)
8 | R = np.concatenate([cosy * cosz, - cosy * sinz, siny,
9 | sinx * siny * cosz + cosx * sinz, - sinx * siny * sinz + cosx * cosz, - sinx * cosy,
10 | - cosx * siny * cosz + sinx * sinz, cosx * siny * sinz + sinx * cosz, cosx * cosy],
11 | 1).reshape([-1, 3, 3])
12 | return R
13 |
14 | def euler2mat_torch(rs, seq="zyx"):
15 | assert seq == "zyx", "Invalid euler seq."
16 | rs_z, rs_y, rs_x = rs[:, [0]], rs[:, [1]], rs[:, [2]]
17 | sinx, siny, sinz = torch.sin(rs_x), torch.sin(rs_y), torch.sin(rs_z)
18 | cosx, cosy, cosz = torch.cos(rs_x), torch.cos(rs_y), torch.cos(rs_z)
19 | R = torch.cat([cosy * cosz, - cosy * sinz, siny,
20 | sinx * siny * cosz + cosx * sinz, - sinx * siny * sinz + cosx * cosz, - sinx * cosy,
21 | - cosx * siny * cosz + sinx * sinz, cosx * siny * sinz + sinx * cosz, cosx * cosy],
22 | 1).view([-1, 3, 3])
23 | return R
24 |
25 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch, torch.nn as nn, pdb
2 |
3 | class CDLoss(nn.Module):
4 | def __init__(self, opts):
5 | super(CDLoss, self).__init__()
6 | self.device = opts.device
7 |
8 | def forward(self, srcs, tgts):
9 | P = self.pairwise_distance(srcs, tgts)
10 | return torch.min(P, 1)[0].mean() + torch.min(P, 2)[0].mean()
11 |
12 | def pairwise_distance(self, srcs, tgts):
13 | srcs, tgts = srcs.transpose(2, 1), tgts.transpose(2, 1)
14 | batch_size, n_points_src, _ = srcs.size()
15 | _, n_points_tgt, _ = tgts.size()
16 | srcs_dist = torch.bmm(srcs, srcs.transpose(2, 1)) # (B, n_points_src, n_points_src)
17 | tgts_dist = torch.bmm(tgts, tgts.transpose(2, 1)) # (B, n_points_tgt, n_points_tgt)
18 | srcs_tgts_dist = torch.bmm(srcs, tgts.transpose(2, 1)) # (B, n_points_src, n_points_tgt)
19 | diag_ind_srcs = torch.arange(0, n_points_src).long().to(self.device)
20 | diag_ind_tgts = torch.arange(0, n_points_tgt).long().to(self.device)
21 | rx = srcs_dist[:, diag_ind_srcs, diag_ind_srcs].unsqueeze(1).expand_as(srcs_tgts_dist.transpose(2, 1)) # (B, n_points_tgt, n_points_src)
22 | ry = tgts_dist[:, diag_ind_tgts, diag_ind_tgts].unsqueeze(1).expand_as(srcs_tgts_dist) # (B, n_points_src, n_points_tgt)
23 | P = (rx.transpose(2, 1) + ry - 2 * srcs_tgts_dist)
24 | return P
25 |
26 | class GMLoss(nn.Module):
27 | def __init__(self, opts):
28 | super(GMLoss, self).__init__()
29 | self.device = opts.device
30 | self.opts = opts
31 |
32 | def forward(self, srcs, tgts):
33 | mu = self.opts.loss_type[1]
34 | srcs, tgts = srcs.transpose(2, 1), tgts.transpose(2, 1)
35 | P = torch.norm(srcs[:, :, None, :] - tgts[:, None, :, :], dim=-1, p=2).pow(2.0)
36 | distances = torch.cat([torch.min(P, 1)[0].unsqueeze(-1), torch.min(P, 2)[0].unsqueeze(-1)], -1)
37 | losses = ((mu * distances) / (distances + mu)).sum(2).mean(1).mean()
38 | return losses
39 |
40 | def pairwise_distance(self, srcs, tgts):
41 | batch_size, n_points_src, _ = srcs.size()
42 | _, n_points_tgt, _ = tgts.size()
43 | srcs_dist = torch.bmm(srcs, srcs.transpose(2, 1)) # (B, n_points_src, n_points_src)
44 | tgts_dist = torch.bmm(tgts, tgts.transpose(2, 1)) # (B, n_points_tgt, n_points_tgt)
45 | srcs_tgts_dist = torch.bmm(srcs, tgts.transpose(2, 1)) # (B, n_points_src, n_points_tgt)
46 | diag_ind_srcs = torch.arange(0, n_points_src).long().to(self.device)
47 | diag_ind_tgts = torch.arange(0, n_points_tgt).long().to(self.device)
48 | rx = srcs_dist[:, diag_ind_srcs, diag_ind_srcs].unsqueeze(1).expand_as(srcs_tgts_dist.transpose(2, 1)) # (B, n_points_tgt, n_points_src)
49 | ry = tgts_dist[:, diag_ind_tgts, diag_ind_tgts].unsqueeze(1).expand_as(srcs_tgts_dist) # (B, n_points_src, n_points_tgt)
50 | P = (rx.transpose(2, 1) + ry - 2 * srcs_tgts_dist)
51 | return P
--------------------------------------------------------------------------------
/utils/mat2euler.py:
--------------------------------------------------------------------------------
1 | import numpy as np, torch
2 | from scipy.spatial.transform import Rotation
3 |
4 | def mat2euler_np(mats, seq='zyx', is_degrees=True):
5 | eulers = []
6 | for i in range(mats.shape[0]):
7 | r = Rotation.from_dcm(mats[i])
8 | eulers.append(r.as_euler(seq, degrees=is_degrees))
9 | return np.asarray(eulers, dtype='float32')
10 |
11 | def mat2euler_torch(mats, seq='zyx', is_degrees=True):
12 | mats_np = mats.detach().cpu().numpy()
13 | eulers = []
14 | for i in range(mats_np.shape[0]):
15 | r = Rotation.from_dcm(mats_np[i])
16 | eulers.append(r.as_euler(seq, degrees=is_degrees))
17 | return torch.FloatTensor(np.asarray(eulers, dtype='float32')).to(mats.device)
18 |
19 | def mat2euler(rot_mat, seq='xyz'):
20 | """
21 | convert rotation matrix to euler angle
22 | :param rot_mat: rotation matrix rx*ry*rz [B, 3, 3]
23 | :param seq: seq is xyz(rotate along z first) or zyx
24 | :return: three angles, x, y, z
25 | """
26 | r11 = rot_mat[:, 0, 0]
27 | r12 = rot_mat[:, 0, 1]
28 | r13 = rot_mat[:, 0, 2]
29 | r21 = rot_mat[:, 1, 0]
30 | r22 = rot_mat[:, 1, 1]
31 | r23 = rot_mat[:, 1, 2]
32 | r31 = rot_mat[:, 2, 0]
33 | r32 = rot_mat[:, 2, 1]
34 | r33 = rot_mat[:, 2, 2]
35 | if seq == 'xyz':
36 | z = torch.atan2(-r12, r11)
37 | y = torch.asin(r13)
38 | x = torch.atan2(-r23, r33)
39 | else:
40 | y = torch.asin(-r31)
41 | x = torch.atan2(r32, r33)
42 | z = torch.atan2(r21, r11)
43 | return torch.stack((z, y, x), dim=1)
--------------------------------------------------------------------------------
/utils/options.py:
--------------------------------------------------------------------------------
1 | from utils.attr_dict import AttrDict
2 | import numpy as np, torch
3 |
4 | opts = AttrDict()
5 | ## general setting
6 | opts.db_nm = "scene7" # "modelnet40", "scene7", "icl_nuim
7 | opts.is_debug = False
8 | opts.device=torch.device("cuda")
9 | opts.seed = 123
10 | opts.batch_size = 35
11 | opts.minibatch_size = 35
12 | opts.n_epochs = 30
13 |
14 | ## dataset
15 | opts.db = AttrDict(
16 | modelnet40 = AttrDict(
17 | path = "/test/datasets/registration/modelnet40/modelnet40_normal_n2048.pth",
18 | cls_idx = -1, # None_class: -1, airplane: 0, car: 7, chair: 8, table: 33, lamp: 19
19 | is_neg_angle = False,
20 | unseen = False,
21 | gaussian_noise = True,
22 | n_points = 1024,
23 | n_sub_points = 768,
24 | factor = 4
25 | ),
26 | scene7 = AttrDict(
27 | path = "/test/datasets/registration/7scene/7scene_normal_n2048.pth",
28 | cls_idx = -1,
29 | is_neg_angle = False,
30 | unseen = False,
31 | gaussian_noise = False,
32 | n_points = 1024,
33 | n_sub_points = 768,
34 | factor = 4
35 | ),
36 | icl_nuim = AttrDict(
37 | path = "/test/datasets/registration/icl_nuim/icl_nuim_normal_n2048.pth",
38 | cls_idx = -1,
39 | is_neg_angle = False,
40 | unseen = False,
41 | gaussian_noise = False,
42 | n_points = 1024,
43 | n_sub_points = 768,
44 | factor = 4
45 | )
46 | )
47 | opts.db = opts.db[opts.db_nm]
48 |
49 | ## cem module
50 | opts.cem = AttrDict(
51 | metric_type = ["iou", {"epsilon": 0.01}], # "iou", "cd"
52 | n_candidates = [1000, {"minibatch": 1000}],
53 | n_elites = 25,
54 | n_iters = 10,
55 | r_range = [-np.pi, np.pi],
56 | t_range = AttrDict(
57 | modelnet40 = [-1.0, 1.0],
58 | scene7 = [-1.0, 1.0],
59 | icl_nuim = [-1.0, 1.0]
60 | ),
61 | init_sigma = AttrDict(
62 | modelnet40 = [1.0, 0.5],
63 | scene7 = [1.0, 0.5],
64 | icl_nuim = [1.0, 0.5]
65 | ),
66 | planning_horizon = 1,
67 | is_icp_modification = [True, 0.5, 3]
68 | )
69 |
70 | # network setting
71 | opts.pointer = "identity" # or "transformer", "identity"
72 | opts.head = "svd" # "mlp", "svd
73 | opts.eval = False
74 | opts.emb_nn = "dgcnn"
75 | opts.emb_dims = 512
76 | opts.ff_dim = 1024
77 | opts.n_blocks = 1
78 | opts.n_heads = 4
79 | opts.dropout = 0.
80 |
--------------------------------------------------------------------------------
/utils/recorder.py:
--------------------------------------------------------------------------------
1 | import numpy as np, os
2 | from utils.commons import save_data, line_plot
3 | from collections import defaultdict
4 |
5 | class Recorder:
6 | def __init__(self, opts):
7 | self.results_dir = opts.results_dir
8 | self.results = defaultdict(list)
9 | self.opts = opts
10 |
11 | def add_res(self, res):
12 | for key, res in res.items():
13 | if isinstance(res, list):
14 | self.results[key].extend(res)
15 | else:
16 | self.results[key].append(res)
17 |
18 | def add_reslist(self, reslist):
19 | for key, res in reslist.items():
20 | self.results[key].append(res)
21 |
22 | def line_plot(self, nms):
23 | for nm in nms:
24 | res = self.results[nm]
25 | title = "%s_seed%d" % (nm, self.opts.seed)
26 | line_plot(np.arange(len(res)), res, title, os.path.join(self.opts.results_dir, "%s.pdf" % title))
27 |
28 | def save(self, file_nm=None):
29 | if file_nm is not None:
30 | save_data(os.path.join(self.opts.results_dir, "%s_%d.pth" % (file_nm, self.opts.seed)), self.results)
31 | else:
32 | save_data(os.path.join(self.opts.results_dir, "res_seed%d.pth" % (self.opts.seed)), self.results)
33 |
--------------------------------------------------------------------------------
/utils/test.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import torch, numpy as np, os
3 |
4 | def test(opts, model, test_loader, epoch):
5 | cal_score = lambda x: np.mean(np.concatenate(x, 0)).item()
6 | with torch.no_grad():
7 | rs_mse, ts_mse, rs_mae, ts_mae = [], [], [], []
8 | rs_prior_mse, ts_prior_mse, rs_prior_mae, ts_prior_mae = [], [], [], []
9 | for srcs, tgts, rs_lb, ts_lb in tqdm(test_loader):
10 | srcs, tgts = srcs.cuda(), tgts.cuda()
11 | results = model(srcs, tgts)
12 | # rs_pred, ts_pred, rs_prior, ts_prior = results["rs"], results["ts"], results["rs_init"], results["ts_init"]
13 | rs_pred, ts_pred, rs_prior, ts_prior = results["r"], results["t"], results["r_init"], results["t_init"]
14 | rs_mse.append(np.mean((np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())) ** 2, 1))
15 | rs_mae.append(np.mean(np.abs(np.degrees(rs_pred.cpu().numpy()) - np.degrees(rs_lb.numpy())), 1))
16 | ts_mse.append(np.mean((ts_pred.cpu().numpy() - ts_lb.numpy()) ** 2, 1))
17 | ts_mae.append(np.mean(np.abs(ts_pred.cpu().numpy() - ts_lb.numpy()), 1))
18 |
19 | rs_prior_mse.append(np.mean((np.degrees(rs_prior.cpu().numpy()) - np.degrees(rs_lb.numpy())) ** 2, 1))
20 | rs_prior_mae.append(np.mean(np.abs(np.degrees(rs_prior.cpu().numpy()) - np.degrees(rs_lb.numpy())), 1))
21 | ts_prior_mse.append(np.mean((ts_prior.cpu().numpy() - ts_lb.numpy()) ** 2, 1))
22 | ts_prior_mae.append(np.mean(np.abs(ts_prior.cpu().numpy() - ts_lb.numpy()), 1))
23 |
24 | r_mse = cal_score(rs_mse)
25 | t_mse = cal_score(ts_mse)
26 | r_mae = cal_score(rs_mae)
27 | t_mae = cal_score(ts_mae)
28 | r_prior_mse = cal_score(rs_prior_mse)
29 | t_prior_mse = cal_score(ts_prior_mse)
30 | r_prior_mae = cal_score(rs_prior_mae)
31 | t_prior_mae = cal_score(ts_prior_mae)
32 |
33 | if not opts.is_debug:
34 | torch.save(model.state_dict(), os.path.join(opts.results_dir, 'model_epoch%d.pth' % (epoch)))
35 |
36 | print("[%d] Test: r_mse: %.8f, t_mse: %.8f, r_mae: %.8f, t_mae: %.8f" % (
37 | epoch, r_mse, t_mse, r_mae, t_mae))
38 | print("[%d] Prior test: r_mse: %.8f, t_mse: %.8f, r_mae: %.8f, t_mae: %.8f" % (
39 | epoch, r_prior_mse, t_prior_mse, r_prior_mae, t_prior_mae))
--------------------------------------------------------------------------------
/utils/transform_pc.py:
--------------------------------------------------------------------------------
1 | import torch, numpy as np
2 | from utils.euler2mat import euler2mat_torch, euler2mat_np
3 |
4 | def transform_pc_torch(pcs, Rs, ts):
5 | """
6 | :param pcs: point clouds, (B, 3, N),
7 | :param Rs: rotation matrix, (B, 3, 3)
8 | :param ts: translation vector, (B, 3)
9 | :return: transformed pcs, (B, 3, N)
10 | """
11 | return torch.matmul(Rs, pcs) + ts.unsqueeze(2)
12 |
13 | def transform_pc_np(pcs, Rs, ts):
14 | """
15 | :param pcs: point clouds, (B, 3, N),
16 | :param Rs: rotation matrix, (B, 3, 3)
17 | :param ts: translation vector, (B, 3)
18 | :return: transformed pcs, (B, 3, N)
19 | """
20 | return np.matmul(Rs, pcs) + np.expand_dims(ts, axis=2)
21 |
22 | def transform_pc_action_pytorch(pcs, actions):
23 | """
24 | :param pcs: point clouds, (B, 3, N),
25 | :param actions: rotation matrix, (B, 6)
26 | """
27 | Rs = euler2mat_torch(actions[:, :3], seq="zyx")
28 | ts = actions[:, 3:]
29 | return torch.matmul(Rs, pcs) + ts.unsqueeze(2)
30 |
31 | def transform_pc_action_np(pcs, actions):
32 | """
33 | :param pcs: point clouds, (B, 3, N),
34 | :param actions: rotation matrix, (B, 6)
35 | """
36 | Rs = euler2mat_np(actions[:, :3], seq="zyx")
37 | ts = actions[:, 3:]
38 | return np.matmul(Rs, pcs) + np.expand_dims(ts, axis=2)
39 |
--------------------------------------------------------------------------------