├── pytorch_ops ├── __init__.py ├── grouping │ ├── __init__.py │ ├── _ext │ │ ├── __init__.py │ │ └── grouping │ │ │ └── __init__.py │ ├── src │ │ ├── __init.py │ │ ├── group_cuda_kernel.cu.o │ │ ├── group_cuda.h │ │ ├── group_cuda_kernel.h │ │ ├── group_cuda.c │ │ └── group_cuda_kernel.cu │ ├── build.py │ └── group.py ├── losses │ ├── __init__.py │ ├── cd │ │ ├── __init__.py │ │ ├── cd_cuda_kernel.o │ │ ├── cd_cuda.h │ │ ├── cd_cuda_kernel.h │ │ ├── cd.py │ │ ├── cd_cuda.c │ │ └── cd_cuda_kernel.cu │ ├── _ext │ │ ├── __init__.py │ │ ├── cd │ │ │ └── __init__.py │ │ └── emd │ │ │ └── __init__.py │ ├── emd │ │ ├── __init__.py │ │ ├── emd_cuda_kernel.o │ │ ├── emd_cuda.h │ │ ├── emd_cuda_kernel.h │ │ ├── emd.py │ │ ├── emd_cuda.c │ │ └── emd_cuda_kernel.cu │ └── build.py ├── sampling │ ├── __init__.py │ ├── _ext │ │ ├── __init__.py │ │ └── farthestpointsampling │ │ │ └── __init__.py │ ├── src │ │ ├── __init__.py │ │ ├── sample_cuda_kernel.cu.o │ │ ├── sample_cuda.h │ │ ├── sample_cuda_kernel.h │ │ ├── sample_cuda.c │ │ └── sample_cuda_kernel.cu │ ├── build.py │ └── sample.py └── interpolation │ ├── __init__.py │ ├── _ext │ ├── __init__.py │ └── interpolate │ │ └── __init__.py │ ├── src │ ├── __init__.py │ ├── interpolate_cuda_kernel.o │ ├── interpolation_cuda.h │ ├── interpolation_cuda_kernel.h │ ├── interpolation_cuda.c │ ├── cuda_utils.h │ └── interpolation_cuda_kernel.cu │ ├── build.py │ └── interpolate.py ├── models └── place_pre_trained_weights_here.txt ├── picture └── airplane.png ├── data └── airplane │ └── demo.mat ├── part_color_mapping.json ├── LICENSE ├── .gitignore ├── util.py ├── torchfoldext.py ├── pointnet2.py ├── train.py ├── test_demo.py ├── README.md ├── ap_evaluate.py ├── dataloader_grass.py ├── dataloader_symh.py ├── dataloader.py └── partnet.py /pytorch_ops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/losses/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/src/__init.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/losses/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/place_pre_trained_weights_here.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /picture/airplane.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/picture/airplane.png -------------------------------------------------------------------------------- /data/airplane/demo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/data/airplane/demo.mat -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/cd_cuda_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/pytorch_ops/losses/cd/cd_cuda_kernel.o -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/emd_cuda_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/pytorch_ops/losses/emd/emd_cuda_kernel.o -------------------------------------------------------------------------------- /pytorch_ops/grouping/src/group_cuda_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/pytorch_ops/grouping/src/group_cuda_kernel.cu.o -------------------------------------------------------------------------------- /pytorch_ops/sampling/src/sample_cuda_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/pytorch_ops/sampling/src/sample_cuda_kernel.cu.o -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/interpolate_cuda_kernel.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FENGGENYU/CVPR2019_PartNet/HEAD/pytorch_ops/interpolation/src/interpolate_cuda_kernel.o -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/emd_cuda.h: -------------------------------------------------------------------------------- 1 | int approxmatch_cuda_forward(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *match, THCudaTensor *temp); 2 | int matchcost_cuda_forward(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *match, THCudaTensor *out); 3 | int matchcost_cuda_backward(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *match, THCudaTensor *grad1, THCudaTensor *grad2); -------------------------------------------------------------------------------- /pytorch_ops/sampling/src/sample_cuda.h: -------------------------------------------------------------------------------- 1 | int farthestpointsampling_forward_cuda(int b, int n, int m, THCudaTensor *inp, THCudaTensor *temp, THCudaIntTensor *out); 2 | int gatherpoint_forward_cuda(int b, int n, int m, THCudaTensor *inp, THCudaIntTensor *idx, THCudaTensor *out); 3 | int gatherpoint_backward_cuda(int b, int n, int m, THCudaTensor *out_g, THCudaIntTensor *idx, THCudaTensor *inp_g); -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/cd_cuda.h: -------------------------------------------------------------------------------- 1 | int cd_forward_cuda(int b, int n, THCudaTensor *xyz, int m, THCudaTensor *xyz2, THCudaTensor *result, THCudaIntTensor *result_i, THCudaTensor *result2, THCudaIntTensor *result2_i); 2 | int cd_backward_cuda(int b, int n, THCudaTensor *xyz1, int m, THCudaTensor *xyz2, THCudaTensor *grad_dist1, THCudaIntTensor *idx1, THCudaTensor *grad_dist2, THCudaIntTensor *idx2, THCudaTensor *grad_xyz1, THCudaTensor *grad_xyz2); 3 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/src/sample_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int farthestpointsamplingLauncher(int b, int n, int m, const float *inp, float *temp, int *out); 6 | int gatherpoint_forward_Launcher(int b, int n, int m, const float *inp, const int *idx, float *out); 7 | int gatherpoint_backward_Launcher(int b, int n, int m, const float *out_g, const int *idx, float *inp_g); 8 | 9 | #ifdef __cplusplus 10 | } 11 | #endif 12 | -------------------------------------------------------------------------------- /pytorch_ops/losses/_ext/cd/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._cd import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /pytorch_ops/losses/_ext/emd/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._emd import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/_ext/grouping/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._grouping import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/_ext/interpolate/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._interpolate import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/_ext/farthestpointsampling/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._farthestpointsampling import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/emd_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int approxmatch_forward_Launcher(int b, int n, int m, const float *xyz1, const float *xyz2, float *match, float *temp); 6 | int matchcost_forward_Launcher(int b, int n, int m, const float *xyz1, const float *xyz2, const float *match, float *out); 7 | int matchcost_backward_Launcher(int b, int n, int m, const float *xyz1, const float *xyz2, const float *match, float *grad1, float *grad2); 8 | 9 | #ifdef __cplusplus 10 | } 11 | #endif 12 | 13 | -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/cd_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _CD_KERNEL 2 | #define _CD_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | int cd_forward_Launcher(int b, int n, const float *xyz, int m, const float *xyz2, float *result, int *result_i, float *result2, int *result2_i); 9 | int cd_backward_Launcher(int b, int n, const float *xyz1, int m, const float *xyz2, const float *grad_dist1, const int *idx1, const float *grad_dist2, const int *idx2, float *grad_xyz1, float *grad_xyz2); 10 | 11 | #ifdef __cplusplus 12 | } 13 | #endif 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/src/group_cuda.h: -------------------------------------------------------------------------------- 1 | int queryBallPoint_cuda(int b, int n, int m, float radius, int nsample, THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaIntTensor *idx, THCudaIntTensor *pts_cnt); 2 | int selectionSort_cuda(int b, int n, int m, int k, THCudaTensor *dist, THCudaIntTensor *outi, THCudaTensor *out); 3 | int groupPoint_forward_cuda(int b, int n, int c, int m, int nsample, THCudaTensor *points, THCudaIntTensor *idx, THCudaTensor *out); 4 | int groupPoint_backward_cuda(int b, int n, int c, int m, int nsample, THCudaTensor *grad_out, THCudaIntTensor *idx, THCudaTensor *grad_points); 5 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/interpolation_cuda.h: -------------------------------------------------------------------------------- 1 | int three_interpolate_wrapper(int b, int c, int n, int m, 2 | THCudaTensor *points_tensor, 3 | THCudaIntTensor *idx_tensor, 4 | THCudaTensor *weight_tensor, 5 | THCudaTensor *out_tensor); 6 | 7 | int three_interpolate_grad_wrapper(int b, int c, int n, int m, 8 | THCudaTensor *grad_out_tensor, 9 | THCudaIntTensor *idx_tensor, 10 | THCudaTensor *weight_tensor, 11 | THCudaTensor *grad_points_tensor); -------------------------------------------------------------------------------- /pytorch_ops/grouping/src/group_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | int queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt); 6 | int selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out); 7 | int groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out); 8 | int groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points); 9 | 10 | #ifdef __cplusplus 11 | } 12 | #endif 13 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/interpolation_cuda_kernel.h: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | // input: features(b, c, m), idxs(b, n, 3), weights(b, n, 3) 6 | // output: out(b, c, n) 7 | int three_interpolate_kernel_wrapper(int b, int c, int n, int m, const float *points, const int *idx, const float *weight, float *out); 8 | // input: grad_out(b, c, n), idxs(b, n, k), weights(b, n, k) 9 | // output: grad_points(b, c, m) 10 | int three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points); 11 | 12 | #ifdef __cplusplus 13 | } 14 | #endif -------------------------------------------------------------------------------- /pytorch_ops/grouping/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.ffi import create_extension 4 | this_file = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | extra_objects = ['src/group_cuda_kernel.cu.o'] 7 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 8 | ffi = create_extension( 9 | name='_ext.grouping', 10 | headers=['src/group_cuda.h'], 11 | sources=['src/group_cuda.c'], 12 | define_macros=[('WITH_CUDA', None)], 13 | with_cuda=True, 14 | relative_to=__file__, 15 | extra_objects=extra_objects, 16 | extra_compile_args=["-I/usr/local/cuda-8.0/include"]) 17 | ffi.build() 18 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.ffi import create_extension 4 | this_file = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | extra_objects = ['src/sample_cuda_kernel.cu.o'] 7 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 8 | ffi = create_extension( 9 | name='_ext.farthestpointsampling', 10 | headers=['src/sample_cuda.h'], 11 | sources=['src/sample_cuda.c'], 12 | define_macros=[('WITH_CUDA', None)], 13 | with_cuda=True, 14 | relative_to=__file__, 15 | extra_objects=extra_objects, 16 | extra_compile_args=["-I/usr/local/cuda-8.0/include"]) 17 | ffi.build() 18 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.ffi import create_extension 4 | this_file = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | extra_objects = ['src/interpolate_cuda_kernel.o'] 7 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 8 | ffi = create_extension( 9 | name='_ext.interpolate', 10 | headers=['src/interpolation_cuda.h'], 11 | sources=['src/interpolation_cuda.c'], 12 | define_macros=[('WITH_CUDA', None)], 13 | with_cuda=True, 14 | relative_to=__file__, 15 | extra_objects=extra_objects, 16 | extra_compile_args=["-I/usr/local/cuda-8.0/include"]) 17 | ffi.build() -------------------------------------------------------------------------------- /part_color_mapping.json: -------------------------------------------------------------------------------- 1 | [[0.65, 0.95, 0.05], [0.35, 0.05, 0.35], [0.65, 0.35, 0.65], [0.95, 0.95, 0.65], [0.95, 0.65, 0.05], [0.35, 0.05, 0.05], [0.65, 0.05, 0.05], [0.65, 0.35, 0.95], [0.05, 0.05, 0.65], [0.65, 0.05, 0.35], [0.05, 0.35, 0.35], [0.65, 0.65, 0.35], [0.35, 0.95, 0.05], [0.05, 0.35, 0.65], [0.95, 0.95, 0.35], [0.65, 0.65, 0.65], [0.95, 0.95, 0.05], [0.65, 0.35, 0.05], [0.35, 0.65, 0.05], [0.95, 0.65, 0.95], [0.95, 0.35, 0.65], [0.05, 0.65, 0.95], [0.65, 0.95, 0.65], [0.95, 0.35, 0.95], [0.05, 0.05, 0.95], [0.65, 0.05, 0.95], [0.65, 0.05, 0.65], [0.35, 0.35, 0.95], [0.95, 0.95, 0.95], [0.05, 0.05, 0.05], [0.05, 0.35, 0.95], [0.65, 0.95, 0.95], [0.95, 0.05, 0.05], [0.35, 0.95, 0.35], [0.05, 0.35, 0.05], [0.05, 0.65, 0.35], [0.05, 0.95, 0.05], [0.95, 0.65, 0.65], [0.35, 0.95, 0.95], [0.05, 0.95, 0.35], [0.95, 0.35, 0.05], [0.65, 0.35, 0.35], [0.35, 0.95, 0.65], [0.35, 0.35, 0.65], [0.65, 0.95, 0.35], [0.05, 0.95, 0.65], [0.65, 0.65, 0.95], [0.35, 0.05, 0.95], [0.35, 0.65, 0.95], [0.35, 0.05, 0.65]] -------------------------------------------------------------------------------- /pytorch_ops/interpolation/interpolate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.module import Module 4 | from ._ext import interpolate 5 | 6 | 7 | class InterpolateFunction(Function): 8 | def forward(ctx, points, idx, weight): 9 | # points: (b, c, m) 10 | # idx: (b, n, 3) 11 | # weight: (b, n, 3) 12 | b = points.size(0) 13 | c = points.size(1) 14 | m = points.size(2) 15 | n = idx.size(1) 16 | out = torch.zeros(b, c, n).cuda() 17 | interpolate.three_interpolate_wrapper(b, c, n, m, points, idx, weight, out) 18 | 19 | ctx.b = b 20 | ctx.c = c 21 | ctx.n = n 22 | ctx.m = m 23 | ctx.save_for_backward(idx, weight) 24 | return out 25 | 26 | def backward(ctx, out_grad): 27 | points_grad = torch.zeros(ctx.b, ctx.c, ctx.m).cuda() 28 | idx, weight = ctx.saved_tensors 29 | interpolate.three_interpolate_grad_wrapper(ctx.b, ctx.c, ctx.n, ctx.m, out_grad, idx, weight, points_grad) 30 | return points_grad, None, None 31 | -------------------------------------------------------------------------------- /pytorch_ops/losses/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torch.utils.ffi import create_extension 4 | this_file = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | extra_objects = ['cd/cd_cuda_kernel.o'] 7 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 8 | ffi = create_extension( 9 | name='_ext.cd', 10 | headers=['cd/cd_cuda.h'], 11 | sources=['cd/cd_cuda.c'], 12 | define_macros=[('WITH_CUDA', None)], 13 | with_cuda=True, 14 | relative_to=__file__, 15 | extra_objects=extra_objects, 16 | extra_compile_args=["-I/usr/local/cuda-8.0/include"]) 17 | ffi.build() 18 | 19 | extra_objects = ['emd/emd_cuda_kernel.o'] 20 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 21 | ffi1 = create_extension( 22 | name='_ext.emd', 23 | headers=['emd/emd_cuda.h'], 24 | sources=['emd/emd_cuda.c'], 25 | define_macros=[('WITH_CUDA', None)], 26 | with_cuda=True, 27 | relative_to=__file__, 28 | extra_objects=extra_objects, 29 | extra_compile_args=["-I/usr/local/cuda-8.0/include"]) 30 | ffi1.build() 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 FoggYu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/src/sample_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "sample_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int farthestpointsampling_forward_cuda(int b, int n, int m, THCudaTensor *inp, THCudaTensor *temp, THCudaIntTensor *out) 8 | { 9 | float *inp_data = THCudaTensor_data(state, inp); 10 | float *temp_data = THCudaTensor_data(state, temp); 11 | int *out_data = THCudaIntTensor_data(state, out); 12 | farthestpointsamplingLauncher(b, n, m, inp_data, temp_data, out_data); 13 | return 1; 14 | } 15 | 16 | int gatherpoint_forward_cuda(int b, int n, int m, THCudaTensor *inp, THCudaIntTensor *idx, THCudaTensor *out) 17 | { 18 | float *inp_data = THCudaTensor_data(state, inp); 19 | int *idx_data = THCudaIntTensor_data(state, idx); 20 | float *out_data = THCudaTensor_data(state, out); 21 | gatherpoint_forward_Launcher(b, n, m, inp_data, idx_data, out_data); 22 | return 1; 23 | } 24 | 25 | int gatherpoint_backward_cuda(int b, int n, int m, THCudaTensor *out_g, THCudaIntTensor *idx, THCudaTensor *inp_g) 26 | { 27 | float *outg_data = THCudaTensor_data(state, out_g); 28 | int *idx_data = THCudaIntTensor_data(state, idx); 29 | float *inpg_data = THCudaTensor_data(state, inp_g); 30 | gatherpoint_backward_Launcher(b, n, m, outg_data, idx_data, inpg_data); 31 | return 1; 32 | } -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/cd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.module import Module 4 | from .._ext import cd 5 | 6 | 7 | class CDFunction(Function): 8 | def forward(ctx, xyz1, xyz2): 9 | b = xyz1.size()[0] 10 | n = xyz1.size()[1] 11 | m = xyz2.size()[1] 12 | ctx.b=b 13 | ctx.n=n 14 | ctx.m=m 15 | dist1 = torch.zeros(b, n).cuda() 16 | dist2 = torch.zeros(b, m).cuda() 17 | idx1 = torch.IntTensor(b, n).zero_().cuda() 18 | idx2 = torch.IntTensor(b, m).zero_().cuda() 19 | cd.cd_forward_cuda(b, n, xyz1, m, xyz2, dist1, idx1, dist2, idx2) 20 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 21 | return dist1, idx1, dist2, idx2 22 | 23 | def backward(ctx, grad_dist1, grad_idx1, grad_dist2, grad_idx2): 24 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 25 | b = ctx.b 26 | n = ctx.n 27 | m = ctx.m 28 | grad_xyz1 = torch.zeros(b, n, 3).cuda() 29 | grad_xyz2 = torch.zeros(b, m, 3).cuda() 30 | cd.cd_backward_cuda(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_dist2, 31 | idx2, grad_xyz1, grad_xyz2) 32 | return grad_xyz1, grad_xyz2 33 | 34 | 35 | class CDModule(Module): 36 | def forward(self, input1, input2): 37 | return CDFunction()(input1, input2) 38 | -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/emd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function,Variable 3 | from torch.nn.modules.module import Module 4 | from .._ext import emd 5 | 6 | 7 | class EMDFunction(Function): 8 | def forward(ctx, xyz1, xyz2): 9 | b = xyz1.size()[0] 10 | n = xyz1.size()[1] 11 | m = xyz2.size()[1] 12 | ctx.b=b 13 | ctx.n=n 14 | ctx.m=m 15 | match = torch.zeros(b, m, n).cuda() 16 | temp = torch.zeros(b, (n + m) * 2).cuda() 17 | cost = torch.zeros(b).cuda() 18 | emd.approxmatch_cuda_forward(xyz1, xyz2, match, temp) 19 | ctx.save_for_backward(xyz1, xyz2) 20 | ctx.match = match 21 | emd.matchcost_cuda_forward(xyz1, xyz2, match, cost) 22 | return cost 23 | 24 | def backward(ctx,grad_cost): 25 | b = ctx.b 26 | n = ctx.n 27 | m = ctx.m 28 | xyz1, xyz2 = ctx.saved_tensors 29 | grad1 = torch.zeros(b, n, 3).cuda() 30 | grad2 = torch.zeros(b, m, 3).cuda() 31 | emd.matchcost_cuda_backward(xyz1, xyz2, ctx.match, grad1, grad2) 32 | grad_cost=torch.unsqueeze(torch.unsqueeze(grad_cost,1),2) 33 | grad1 = grad1*grad_cost 34 | grad2 = grad2*grad_cost 35 | return grad1, grad2 36 | 37 | 38 | class EMDModule(Module): 39 | def forward(self, input1, input2): 40 | return EMDFunction()(input1, input2) 41 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.module import Module 4 | from ._ext import farthestpointsampling as fps 5 | 6 | 7 | class SampleFunction(Function): 8 | def __init__(self, npoints): 9 | self.npoints = npoints 10 | def forward(ctx, inp): 11 | b = inp.size()[0] 12 | n = inp.size()[1] 13 | temp = torch.zeros(32, n).cuda() 14 | idx = torch.IntTensor(b, ctx.npoints).zero_().cuda() 15 | out = torch.zeros(b, ctx.npoints, 3).cuda() 16 | fps.farthestpointsampling_forward_cuda(b, n, ctx.npoints, inp, temp, idx) 17 | fps.gatherpoint_forward_cuda(b, n, ctx.npoints, inp, idx, out) 18 | ctx.save_for_backward(idx) 19 | ctx.b = b 20 | ctx.n = n 21 | return out, idx 22 | 23 | def backward(ctx, out_grad, idx_grad): 24 | idx = ctx.saved_tensors[0] 25 | b = ctx.b 26 | n = ctx.n 27 | m = ctx.npoints 28 | inp_grad = torch.zeros(b, n, 3).cuda() 29 | fps.gatherpoint_backward_cuda(b, n, m, out_grad, idx, inp_grad) 30 | return inp_grad, None 31 | 32 | 33 | class FarthestSample(Module): 34 | def __init__(self, npoints): 35 | super(FarthestSample, self).__init__() 36 | self.npoints = npoints 37 | def forward(self, inp): 38 | return SampleFunction(self.npoints)(inp) 39 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/interpolation_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "interpolation_cuda_kernel.h" 3 | 4 | extern THCState *state; 5 | 6 | int three_interpolate_wrapper(int b, int c, int n, int m, 7 | THCudaTensor *points_tensor, 8 | THCudaIntTensor *idx_tensor, 9 | THCudaTensor *weight_tensor, 10 | THCudaTensor *out_tensor) { 11 | 12 | const float *points = THCudaTensor_data(state, points_tensor); 13 | const float *weight = THCudaTensor_data(state, weight_tensor); 14 | float *out = THCudaTensor_data(state, out_tensor); 15 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 16 | 17 | three_interpolate_kernel_wrapper(b, c, n, m, points, idx, weight, out); 18 | return 1; 19 | } 20 | 21 | int three_interpolate_grad_wrapper(int b, int c, int n, int m, 22 | THCudaTensor *grad_out_tensor, 23 | THCudaIntTensor *idx_tensor, 24 | THCudaTensor *weight_tensor, 25 | THCudaTensor *grad_points_tensor) { 26 | 27 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 28 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 29 | const float *weight = THCudaTensor_data(state, weight_tensor); 30 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 31 | 32 | three_interpolate_grad_kernel_wrapper(b, c, n, m, grad_out, idx, weight, grad_points); 33 | return 1; 34 | } -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_UTILS_H 2 | #define CUDA_UTILS_H 3 | 4 | #include 5 | #define TOTAL_THREADS 512 6 | 7 | inline int opt_n_threads(int work_size) { 8 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 9 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 10 | } 11 | 12 | inline dim3 opt_block_config(int x, int y) { 13 | const int x_threads = opt_n_threads(x); 14 | const int y_threads = max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 15 | dim3 block_config(x_threads, y_threads, 1); 16 | 17 | return block_config; 18 | } 19 | 20 | // #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") 21 | // #define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous") 22 | // #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 23 | // #define CHECK_INPUT_TYPE(x, y) AT_ASSERT(x.type().scalarType() == y, #x " must be " #y) 24 | 25 | #endif 26 | /* 27 | 28 | #ifndef CUDA_UTILS_H 29 | #define CUDA_UTILS_H 30 | 31 | #include 32 | #include 33 | #define TOTAL_THREADS 512 34 | 35 | inline int opt_n_threads(int work_size) { 36 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 37 | return std::max(std::min(1 << pow_2, TOTAL_THREADS), 1); 38 | } 39 | 40 | inline dim3 opt_block_config(int x, int y) { 41 | const int x_threads = opt_n_threads(x); 42 | const int y_threads = std::max(std::min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 43 | dim3 block_config(x_threads, y_threads, 1); 44 | 45 | return block_config; 46 | } 47 | 48 | #endif 49 | */ 50 | -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/cd_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "cd_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int cd_forward_cuda(int b, int n, THCudaTensor *xyz, int m, THCudaTensor *xyz2, THCudaTensor *result, THCudaIntTensor *result_i, THCudaTensor *result2, THCudaIntTensor *result2_i) 8 | { 9 | float *xyz_data = THCudaTensor_data(state, xyz); 10 | float *xyz2_data = THCudaTensor_data(state, xyz2); 11 | float *result_data = THCudaTensor_data(state, result); 12 | int *result_i_data = THCudaIntTensor_data(state, result_i); 13 | float *result2_data = THCudaTensor_data(state, result2); 14 | int *result2_i_data = THCudaIntTensor_data(state, result2_i); 15 | cd_forward_Launcher(b, n, xyz_data, m, xyz2_data, result_data, result_i_data, result2_data, result2_i_data); 16 | return 1; 17 | } 18 | 19 | int cd_backward_cuda(int b, int n, THCudaTensor *xyz1, int m, THCudaTensor *xyz2, THCudaTensor *grad_dist1, THCudaIntTensor *idx1, THCudaTensor *grad_dist2, THCudaIntTensor *idx2, THCudaTensor *grad_xyz1, THCudaTensor *grad_xyz2) 20 | { 21 | float *xyz1_data = THCudaTensor_data(state, xyz1); 22 | float *xyz2_data = THCudaTensor_data(state, xyz2); 23 | float *grad_dist1_data = THCudaTensor_data(state, grad_dist1); 24 | float *grad_dist2_data = THCudaTensor_data(state, grad_dist2); 25 | int *idx1_data = THCudaIntTensor_data(state, idx1); 26 | int *idx2_data = THCudaIntTensor_data(state, idx2); 27 | float *grad_xyz1_data = THCudaTensor_data(state, grad_xyz1); 28 | float *grad_xyz2_data = THCudaTensor_data(state, grad_xyz2); 29 | cd_backward_Launcher(b, n, xyz1_data, m, xyz2_data, grad_dist1_data, idx1_data, grad_dist2_data, idx2_data, grad_xyz1_data, grad_xyz2_data); 30 | return 1; 31 | } 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/src/group_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "group_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int queryBallPoint_cuda(int b, int n, int m, float radius, int nsample, THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaIntTensor *idx, THCudaIntTensor *pts_cnt) 8 | { 9 | float *xyz1_data = THCudaTensor_data(state, xyz1); 10 | float *xyz2_data = THCudaTensor_data(state, xyz2); 11 | int *idx_data = THCudaIntTensor_data(state, idx); 12 | int *pts_data = THCudaIntTensor_data(state, pts_cnt); 13 | queryBallPointLauncher(b, n, m, radius, nsample, xyz1_data, xyz2_data, idx_data, pts_data); 14 | return 1; 15 | } 16 | 17 | int selectionSort_cuda(int b, int n, int m, int k, THCudaTensor *dist, THCudaIntTensor *outi, THCudaTensor *out) 18 | { 19 | float *dist_data = THCudaTensor_data(state, dist); 20 | int *outi_data = THCudaIntTensor_data(state, outi); 21 | float *out_data = THCudaTensor_data(state, out); 22 | selectionSortLauncher(b, n, m, k, dist_data, outi_data, out_data); 23 | return 1; 24 | } 25 | 26 | int groupPoint_forward_cuda(int b, int n, int c, int m, int nsample, THCudaTensor *points, THCudaIntTensor *idx, THCudaTensor *out) 27 | { 28 | float *points_data = THCudaTensor_data(state, points); 29 | int *idx_data = THCudaIntTensor_data(state, idx); 30 | float *out_data = THCudaTensor_data(state, out); 31 | groupPointLauncher(b, n, c, m, nsample, points_data, idx_data, out_data); 32 | return 1; 33 | } 34 | 35 | int groupPoint_backward_cuda(int b, int n, int c, int m, int nsample, THCudaTensor *grad_out, THCudaIntTensor *idx, THCudaTensor *grad_points) 36 | { 37 | float *grad_out_data = THCudaTensor_data(state, grad_out); 38 | int *idx_data = THCudaIntTensor_data(state, idx); 39 | float *grad_points_data = THCudaTensor_data(state, grad_points); 40 | groupPointGradLauncher(b, n, c, m, nsample, grad_out_data, idx_data, grad_points_data); 41 | return 1; 42 | } 43 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | def get_args(): 5 | parser = ArgumentParser(description='grass_pytorch') 6 | parser.add_argument('--part_code_size', type=int, default=128) 7 | parser.add_argument('--feature_size', type=int, default=128) 8 | parser.add_argument('--hidden_size', type=int, default=256) 9 | 10 | parser.add_argument('--epochs', type=int, default=1000) 11 | parser.add_argument('--batch_size', type=int, default=10) 12 | parser.add_argument('--show_log_every', type=int, default=1) 13 | parser.add_argument('--save_log', action='store_true', default=False) 14 | parser.add_argument('--save_log_every', type=int, default=3) 15 | parser.add_argument('--save_snapshot', action='store_true', default=False) 16 | parser.add_argument('--save_snapshot_every', type=int, default=5) 17 | parser.add_argument('--no_plot', action='store_true', default=True) 18 | parser.add_argument('--lr', type=float, default=.001) 19 | parser.add_argument('--lr_decay_by', type=float, default=1) 20 | parser.add_argument('--lr_decay_every', type=float, default=1) 21 | parser.add_argument('--no_cuda', action='store_true', default=False) 22 | parser.add_argument('--gpu', type=int, default=0) 23 | parser.add_argument('--data_path', type=str, default='./data/airplane/') 24 | parser.add_argument('--save_path', type=str, default='./models/airplane/') 25 | parser.add_argument('--output_path', type=str, default='./results/airplane/') 26 | #training_split_num airplane:510 sofa:510 bike:125 helico:80 table:450 chair:800 27 | #total_num airplane:630 sofa:630 bike:155 helico:100 table:583 chair:999 28 | parser.add_argument('--split_num', type=int, default=510) 29 | parser.add_argument('--total_num', type=int, default=630) 30 | parser.add_argument('--training', type=bool, default=False) 31 | #label_category airplane:4 sofa:4 bike:8 helico:5 table:4 32 | parser.add_argument('--label_category', type=int, default=4) 33 | parser.add_argument('--resume_snapshot', type=str, default='') 34 | args = parser.parse_args() 35 | return args 36 | -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/emd_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "emd_cuda_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int approxmatch_cuda_forward(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *match, THCudaTensor *temp) 8 | { 9 | int b = THCudaTensor_size(state, xyz1, 0); 10 | int b1 = THCudaTensor_size(state, xyz2, 0); 11 | if (b != b1) 12 | { 13 | return 0; 14 | } 15 | int n = THCudaTensor_size(state, xyz1, 1); 16 | int m = THCudaTensor_size(state, xyz2, 1); 17 | float *xyz1_data = THCudaTensor_data(state, xyz1); 18 | float *xyz2_data = THCudaTensor_data(state, xyz2); 19 | float *match_data = THCudaTensor_data(state, match); 20 | float *temp_data = THCudaTensor_data(state, temp); 21 | approxmatch_forward_Launcher(b, n, m, xyz1_data, xyz2_data, match_data, temp_data); 22 | return 1; 23 | } 24 | 25 | int matchcost_cuda_forward(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *match, THCudaTensor *out) 26 | { 27 | int b = THCudaTensor_size(state, xyz1, 0); 28 | int n = THCudaTensor_size(state, xyz1, 1); 29 | int m = THCudaTensor_size(state, xyz2, 1); 30 | float *xyz1_data = THCudaTensor_data(state, xyz1); 31 | float *xyz2_data = THCudaTensor_data(state, xyz2); 32 | float *match_data = THCudaTensor_data(state, match); 33 | float *out_data = THCudaTensor_data(state, out); 34 | matchcost_forward_Launcher(b, n, m, xyz1_data, xyz2_data, match_data, out_data); 35 | return 1; 36 | } 37 | 38 | int matchcost_cuda_backward(THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaTensor *match, THCudaTensor *grad1, THCudaTensor *grad2) 39 | { 40 | int b = THCudaTensor_size(state, xyz1, 0); 41 | int n = THCudaTensor_size(state, xyz1, 1); 42 | int m = THCudaTensor_size(state, xyz2, 1); 43 | float *xyz1_data = THCudaTensor_data(state, xyz1); 44 | float *xyz2_data = THCudaTensor_data(state, xyz2); 45 | float *match_data = THCudaTensor_data(state, match); 46 | float *grad1_data = THCudaTensor_data(state, grad1); 47 | float *grad2_data = THCudaTensor_data(state, grad2); 48 | matchcost_backward_Launcher(b, n, m, xyz1_data, xyz2_data, match_data, grad1_data, grad2_data); 49 | return 1; 50 | } 51 | -------------------------------------------------------------------------------- /torchfoldext.py: -------------------------------------------------------------------------------- 1 | import torchfold 2 | from torchfold import Fold 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class FoldExt(Fold): 8 | 9 | def __init__(self, volatile=False, cuda=False): 10 | Fold.__init__(self, volatile, cuda) 11 | 12 | 13 | def add(self, op, *args): 14 | """Add op to the fold.""" 15 | self.total_nodes += 1 16 | if not all([isinstance(arg, ( 17 | Fold.Node, int, torch.Tensor, torch.FloatTensor, torch.LongTensor, Variable)) for arg in args]): 18 | raise ValueError( 19 | "All args should be Tensor, Variable, int or Node, got: %s" % str(args)) 20 | if args not in self.cached_nodes[op]: 21 | step = max([0] + [arg.step + 1 for arg in args 22 | if isinstance(arg, Fold.Node)]) 23 | node = Fold.Node(op, step, len(self.steps[step][op]), *args) 24 | self.steps[step][op].append(args) 25 | self.cached_nodes[op][args] = node 26 | return self.cached_nodes[op][args] 27 | 28 | 29 | def _batch_args(self, arg_lists, values): 30 | res = [] 31 | for arg in arg_lists: 32 | r = [] 33 | if isinstance(arg[0], Fold.Node): 34 | if arg[0].batch: 35 | for x in arg: 36 | r.append(x.get(values)) 37 | res.append(torch.cat(r, 0)) 38 | else: 39 | for i in range(2, len(arg)): 40 | if arg[i] != arg[0]: 41 | raise ValueError("Can not use more then one of nobatch argument, got: %s." % str(arg)) 42 | x = arg[0] 43 | res.append(x.get(values)) 44 | else: 45 | # Below is what this extension changes against the original version: 46 | # We make Fold handle float tensor 47 | try: 48 | if (isinstance(arg[0], Variable)): 49 | var = torch.cat(arg, 0) 50 | else: 51 | var = Variable(torch.cat(arg, 0), volatile=self.volatile) 52 | if self._cuda: 53 | var = var.cuda() 54 | res.append(var) 55 | except: 56 | print("Constructing float tensor from %s" % str(arg)) 57 | raise 58 | return res -------------------------------------------------------------------------------- /pytorch_ops/grouping/group.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.module import Module 4 | from ._ext import grouping 5 | 6 | 7 | class GroupPoints(Function): 8 | def forward(ctx, points, idx): 9 | b = points.size()[0] 10 | n = points.size()[1] 11 | c = points.size()[2] 12 | m = idx.size()[1] 13 | nsamples = idx.size()[2] 14 | out = torch.zeros(b, m, nsamples, c).cuda() 15 | ctx.save_for_backward(idx) 16 | ctx.b = b 17 | ctx.n = n 18 | ctx.c = c 19 | grouping.groupPoint_forward_cuda(b, n, c, m, nsamples, points, idx, out) 20 | return out 21 | 22 | def backward(ctx, grad_out): 23 | idx = ctx.saved_tensors[0] 24 | b = ctx.b 25 | n = ctx.n 26 | c = ctx.c 27 | m = idx.size()[1] 28 | nsamples = idx.size()[2] 29 | grad_points = torch.zeros(b, n, c).cuda() 30 | grouping.groupPoint_backward_cuda(b, n, c, m, nsamples, grad_out, idx, grad_points) 31 | return grad_points, None 32 | 33 | 34 | class QueryBallPoint(Function): 35 | def __init__(self, radius, nsample): 36 | super(QueryBallPoint, self).__init__() 37 | self.radius = radius 38 | self.nsample = nsample 39 | #self.requires_grad = False 40 | 41 | def forward(ctx, xyz1, xyz2): 42 | b = xyz1.size()[0] 43 | n = xyz1.size()[1] 44 | m = xyz2.size()[1] 45 | idx = torch.IntTensor(b, m, ctx.nsample).cuda() 46 | pts_cnt = torch.IntTensor(b, m).cuda() 47 | grouping.queryBallPoint_cuda(b, n, m, ctx.radius, ctx.nsample, xyz1, xyz2, idx, pts_cnt) 48 | return idx, pts_cnt 49 | 50 | def backward(ctx, idx_grad, pts_grad): 51 | print('QueryBallPoint backward') 52 | return None, None 53 | 54 | 55 | class GroupPointsModule(Module): 56 | def forward(self, points, idx): 57 | return GroupPoints(points, idx) 58 | 59 | """ 60 | def knn_point(k, xyz1, xyz2): 61 | ''' 62 | Input: 63 | k: int32, number of k in k-nn search 64 | xyz1: (batch_size, ndataset, c) float32 array, input points 65 | xyz2: (batch_size, npoint, c) float32 array, query points 66 | Output: 67 | val: (batch_size, npoint, k) float32 array, L2 distances 68 | idx: (batch_size, npoint, k) int32 array, indices to input points 69 | ''' 70 | b = xyz1.get_shape()[0].value 71 | n = xyz1.get_shape()[1].value 72 | c = xyz1.get_shape()[2].value 73 | m = xyz2.get_shape()[1].value 74 | print b, n, c, m 75 | print xyz1, (b, 1, n, c) 76 | xyz1 = tf.tile(tf.reshape(xyz1, (b, 1, n, c)), [1, m, 1, 1]) 77 | xyz2 = tf.tile(tf.reshape(xyz2, (b, m, 1, c)), [1, 1, n, 1]) 78 | dist = tf.reduce_sum((xyz1 - xyz2)**2, -1) 79 | print dist, k 80 | outi, out = select_top_k(k, dist) 81 | idx = tf.slice(outi, [0, 0, 0], [-1, -1, k]) 82 | val = tf.slice(out, [0, 0, 0], [-1, -1, k]) 83 | print idx, val 84 | #val, idx = tf.nn.top_k(-dist, k=k) # ONLY SUPPORT CPU 85 | return val, idx 86 | """ 87 | -------------------------------------------------------------------------------- /pytorch_ops/interpolation/src/interpolation_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolation_cuda_kernel.h" 7 | 8 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 9 | // output: out(b, c, n) 10 | __global__ void three_interpolate_kernel(int b, int c, int n, int m, 11 | const float *__restrict__ points, 12 | const int *__restrict__ idx, 13 | const float *__restrict__ weight, 14 | float *__restrict__ out) { 15 | 16 | int batch_index = blockIdx.x; 17 | points += batch_index * m * c; 18 | idx += batch_index * n * 3; 19 | weight += batch_index * n * 3; 20 | out += batch_index * n * c; 21 | 22 | for (int i = threadIdx.y; i < c; i += blockDim.y) { 23 | for (int j = threadIdx.x; j < n; j += blockDim.x) { 24 | float w1 = weight[j * 3 + 0]; 25 | float w2 = weight[j * 3 + 1]; 26 | float w3 = weight[j * 3 + 2]; 27 | 28 | int i1 = idx[j * 3 + 0]; 29 | int i2 = idx[j * 3 + 1]; 30 | int i3 = idx[j * 3 + 2]; 31 | 32 | out[i * blockDim.x + j] = points[i * m + i1] * w1 + points[i * m + i2] * w2 + points[i * m + i3] * w3; 33 | } 34 | } 35 | } 36 | 37 | // input: grad_out(b, c, n), idxs(b, n, 3), weights(b, n, 3) 38 | // output: grad_points(b, c, m) 39 | __global__ void three_interpolate_grad_kernel(int b, int c, int n, int m, 40 | const float *__restrict__ grad_out, 41 | const int *__restrict__ idx, 42 | const float *__restrict__ weight, 43 | float *__restrict__ grad_points) { 44 | const int batch_index = blockIdx.x; 45 | grad_out += batch_index * n * c; 46 | idx += batch_index * n * 3; 47 | weight += batch_index * n * 3; 48 | grad_points += batch_index * m * c; 49 | 50 | for (int i = threadIdx.y; i < c; i += blockDim.y) { 51 | for (int j = threadIdx.x; j < n; j += blockDim.x) { 52 | float w1 = weight[j * 3 + 0]; 53 | float w2 = weight[j * 3 + 1]; 54 | float w3 = weight[j * 3 + 2]; 55 | 56 | int i1 = idx[j * 3 + 0]; 57 | int i2 = idx[j * 3 + 1]; 58 | int i3 = idx[j * 3 + 2]; 59 | 60 | atomicAdd(grad_points + i * m + i1, grad_out[i * n + j] * w1); 61 | atomicAdd(grad_points + i * m + i2, grad_out[i * n + j] * w2); 62 | atomicAdd(grad_points + i * m + i3, grad_out[i * n + j] * w3); 63 | } 64 | } 65 | } 66 | 67 | // input: features(b, c, m), idxs(b, n, 3), weights(b, n, 3) 68 | // output: out(b, c, n) 69 | int three_interpolate_kernel_wrapper(int b, int c, int n, int m, const float *points, const int *idx, const float *weight, float *out) { 70 | three_interpolate_kernel<<>>(b, c, n, m, points, idx, weight, out); 71 | return 1; 72 | } 73 | 74 | // input: grad_out(b, c, n), idxs(b, n, k), weights(b, n, k) 75 | // output: grad_points(b, c, m) 76 | int three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, const float *grad_out, const int *idx, const float *weight, float *grad_points) { 77 | three_interpolate_grad_kernel<<>>(b, c, n, m, grad_out, idx, weight, grad_points); 78 | return 1; 79 | } 80 | 81 | #ifdef __cplusplus 82 | } 83 | #endif 84 | -------------------------------------------------------------------------------- /pytorch_ops/sampling/src/sample_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include "sample_cuda_kernel.h" 6 | 7 | 8 | __global__ void farthestpointsamplingKernel(int b,int n,int m,const float * __restrict__ dataset,float * __restrict__ temp,int * __restrict__ idxs){ 9 | if (m<=0) 10 | return; 11 | const int BlockSize=512; 12 | __shared__ float dists[BlockSize]; 13 | __shared__ int dists_i[BlockSize]; 14 | const int BufferSize=3072; 15 | __shared__ float buf[BufferSize*3]; 16 | for (int i=blockIdx.x;ibest){ 50 | best=d2; 51 | besti=k; 52 | } 53 | } 54 | dists[threadIdx.x]=best; 55 | dists_i[threadIdx.x]=besti; 56 | for (int u=0;(1<>(u+1))){ 59 | int i1=(threadIdx.x*2)<>>(b,n,m,inp,temp,out); 100 | return 1; 101 | } 102 | 103 | int gatherpoint_forward_Launcher(int b,int n,int m,const float * inp,const int * idx,float * out){ 104 | gatherpointKernel<<>>(b,n,m,inp,idx,out); 105 | return 1; 106 | } 107 | 108 | int gatherpoint_backward_Launcher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g){ 109 | scatteraddpointKernel<<>>(b,n,m,out_g,idx,inp_g); 110 | return 1; 111 | } 112 | 113 | #ifdef __cplusplus 114 | } 115 | #endif 116 | -------------------------------------------------------------------------------- /pointnet2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch_ops.grouping.group import QueryBallPoint, GroupPoints 5 | from pytorch_ops.sampling.sample import SampleFunction 6 | from pytorch_ops.interpolation.interpolate import InterpolateFunction 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self): 11 | super(Encoder, self).__init__() 12 | self.sa1 = pointnet_sa_module(npoint=512, radius=0.2, nsample=32, in_channel=6, mlp=[64, 64, 128], group_all=False) 13 | self.sa2 = pointnet_sa_module(npoint=128, radius=0.4, nsample=64, in_channel=131, mlp=[128, 256], group_all=False) 14 | self.sa3 = pointnet_sa_module(npoint=None, radius=None, nsample=None, in_channel=259, mlp=[256, 128], group_all=True) 15 | 16 | self.fp1 = pointnet_fp_module(in_channel=256+128, mlp=[256, 256]) 17 | self.fp2 = pointnet_fp_module(in_channel=128+256, mlp=[256, 128]) 18 | self.fp3 = pointnet_fp_module(in_channel=6+128, mlp=[128, 128, 128]) 19 | 20 | def forward(self, input_data): # input:(b, n, 3) 21 | l0_xyz = input_data[:, :, :3] 22 | l0_points = input_data[:, :, 3:] 23 | l1_xyz, l1_points = self.sa1(l0_xyz, l0_points) # 128 24 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) # 256 25 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) # 128 26 | 27 | l2_points = self.fp1(l2_xyz, l3_xyz, l2_points.transpose(1, 2), l3_points.transpose(1, 2)) 28 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points.transpose(1, 2), l2_points) 29 | l0_points = self.fp3(l0_xyz, l1_xyz, torch.cat([l0_xyz, l0_points], 2).transpose(1, 2), l1_points) 30 | 31 | return l0_points 32 | 33 | 34 | class pointnet_sa_module(nn.Module): 35 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all=False, bn=True): 36 | super(pointnet_sa_module, self).__init__() 37 | self.sample_points = SampleFunction(npoint) 38 | self.radius = radius 39 | self.nsample = nsample 40 | self.group_all = group_all 41 | channels = in_channel 42 | models = [] 43 | for x in mlp: 44 | models.append(nn.Conv2d(channels, x, 1, 1)) 45 | models.append(nn.BatchNorm2d(x)) 46 | models.append(nn.ReLU()) 47 | channels = x 48 | self.Model = nn.Sequential(*models) 49 | print(self.Model) 50 | 51 | def forward(self, xyz, points): 52 | if self.group_all: 53 | new_xyz = torch.zeros(xyz.size()[0], 1, 3) 54 | if xyz.is_cuda: 55 | new_xyz = new_xyz.cuda() 56 | new_points = torch.cat([xyz, points], 2) 57 | new_points = new_points.unsqueeze(1) 58 | else: 59 | new_xyz, _ = self.sample_points(xyz) 60 | idx, pts_cnt = QueryBallPoint(self.radius, self.nsample)(xyz, new_xyz) 61 | grouped_xyz = GroupPoints()(xyz, idx) 62 | grouped_xyz -= torch.unsqueeze(new_xyz, 2).repeat(1, 1, self.nsample, 1) 63 | if points is not None: 64 | grouped_points = GroupPoints()(points, idx) 65 | new_points = torch.cat([grouped_xyz, grouped_points], -1) 66 | else: 67 | new_points = grouped_xyz 68 | new_points = new_points.permute(0, 3, 1, 2) 69 | 70 | new_points = self.Model(new_points) 71 | new_points = new_points.permute(0, 2, 3, 1) 72 | new_points, _ = torch.max(new_points, 2) 73 | return new_xyz, new_points 74 | 75 | 76 | class pointnet_fp_module(nn.Module): 77 | def __init__(self, in_channel, mlp): 78 | super().__init__() 79 | self.convs = [] 80 | channels = in_channel 81 | models = [] 82 | for x in mlp: 83 | models.append(nn.Conv2d(channels, x, 1, 1)) 84 | models.append(nn.BatchNorm2d(x)) 85 | models.append(nn.ReLU()) 86 | channels = x 87 | self.Model = nn.Sequential(*models) 88 | 89 | def forward(self, xyz1, xyz2, points1, points2): 90 | # xyz1:(b,n,3) 91 | # xyz2:(b,m,3) m < n 92 | # points1:(b,c1,n) 93 | # points2:(b,c2,m) 94 | # out:(b,mlp[-1],n) 95 | if xyz2.size(1) == 1: 96 | interpolate = points2.repeat(1, 1, xyz1.size(1)) 97 | else: 98 | D = xyz1.unsqueeze(2) - xyz2.unsqueeze(1) 99 | D = torch.sum(torch.pow(D, 2), -1) 100 | dist, idx = torch.topk(D, 3, 2, False) 101 | idx = idx.int() 102 | dist = torch.clamp(dist, min=1e-10) 103 | norm = torch.sum(1.0/dist, 2, True).repeat(1, 1, 3) 104 | weight = (1.0/dist) / norm 105 | interpolate = InterpolateFunction()(points2, idx, weight) 106 | if points1 is not None: 107 | new_points1 = torch.cat([interpolate, points1], 1) # B,nchannel1+nchannel2,ndataset1 108 | else: 109 | new_points1 = interpolate 110 | new_points1 = new_points1.unsqueeze(-1) 111 | new_points1 = self.Model(new_points1) 112 | return new_points1.squeeze(-1) 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from time import gmtime, strftime 4 | from datetime import datetime 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | import torch.utils.data 9 | from torchfoldext import FoldExt 10 | import util 11 | from tensorboard_logger import configure, log_value 12 | 13 | from dataloader import Data_Loader 14 | import partnet as partnet_model 15 | 16 | config = util.get_args() 17 | 18 | config.cuda = not config.no_cuda 19 | if config.gpu < 0 and config.cuda: 20 | config.gpu = 0 21 | torch.cuda.set_device(config.gpu) 22 | if config.cuda and torch.cuda.is_available(): 23 | print("Using CUDA on GPU ", config.gpu) 24 | else: 25 | print("Not using CUDA.") 26 | 27 | SEED = 8701 28 | torch.manual_seed(SEED) 29 | torch.cuda.manual_seed(SEED) 30 | torch.backends.cudnn.deterministic=True 31 | 32 | net = partnet_model.PARTNET(config) 33 | 34 | if config.cuda: 35 | net = net.cuda() 36 | 37 | print("Loading data ...... ", end='\n', flush=True) 38 | data_loader_batch = Data_Loader(config.data_path, config.training, config.split_num, config.total_num) 39 | 40 | def my_collate(batch): 41 | a = torch.cat([x.shape for x in batch], 0) 42 | return batch, a 43 | 44 | train_iter = torch.utils.data.DataLoader( 45 | data_loader_batch, 46 | batch_size=config.batch_size, 47 | shuffle=True, 48 | collate_fn=my_collate) 49 | print("DONE") 50 | 51 | opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5) 52 | scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma=0.8) 53 | 54 | print("Start training ...... ") 55 | 56 | start = time.time() 57 | 58 | net.train() 59 | 60 | header = 'Time Epoch Iteration Progress(%) LabelLoss SegLoss Seg_acc(%)' 61 | log_template = ' '.join('{:>9s},{:>5.0f}/{:<5.0f},{:>5.0f}/{:<5.0f},{:>9.1f}%,{:>11.2f},{:>10.2f},{:>10.2f}'.split(',')) 62 | 63 | if not os.path.exists(config.save_path): 64 | os.makedirs(config.save_path) 65 | 66 | configure(config.save_path + "training_log/") 67 | total_iter = config.epochs * len(train_iter) 68 | step = 0 69 | for epoch in range(config.epochs): 70 | scheduler.step() 71 | print(header) 72 | for batch_idx, batch in enumerate(train_iter): 73 | # compute points feature 74 | input_data = batch[1].cuda() 75 | jitter_input = torch.randn(input_data.size()).cuda() 76 | jitter_input = torch.clamp(0.01*jitter_input, min=-0.05, max=0.05) 77 | jitter_input += input_data 78 | points_feature = net.pointnet(jitter_input) 79 | # Split into a list of fold nodes per example 80 | enc_points_feature = torch.split(points_feature, 1, 0) 81 | # Initialize torchfold for *decoding* 82 | dec_fold = FoldExt(cuda=config.cuda) 83 | # Collect computation nodes recursively from decoding process 84 | dec_fold_nodes_label = [] 85 | dec_fold_nodes_box = [] 86 | dec_fold_nodes_acc = [] 87 | for example, points_f in zip(batch[0], enc_points_feature): 88 | labelloss, boxloss, acc = partnet_model.decode_structure_fold(dec_fold, example, points_f) 89 | dec_fold_nodes_label.append(labelloss) 90 | dec_fold_nodes_box.append(boxloss) 91 | dec_fold_nodes_acc.append(acc) 92 | # Apply the computations on the decoder model 93 | dec_loss = dec_fold.apply(net, [dec_fold_nodes_label, dec_fold_nodes_box, dec_fold_nodes_acc]) 94 | num_nodes = torch.cat([x.n_nodes for x in batch[0]], 0) 95 | label_loss = torch.mean(dec_loss[0]/num_nodes) 96 | seg_loss = torch.mean(dec_loss[1]/num_nodes) 97 | acc_mean = torch.mean(dec_loss[2]/num_nodes) 98 | acc = acc_mean.item()/2048. 99 | log_value('label_loss', label_loss.item(), step) 100 | log_value('seg_loss', seg_loss.item(), step) 101 | log_value('acc', acc, step) 102 | total_loss = label_loss + seg_loss 103 | # Do parameter optimization 104 | opt.zero_grad() 105 | total_loss.backward() 106 | opt.step() 107 | # Report statistics 108 | if batch_idx % config.show_log_every == 0: 109 | print( 110 | log_template.format( 111 | strftime("%H:%M:%S", time.gmtime(time.time() - start)), 112 | epoch, config.epochs, 1 + batch_idx, len(train_iter), 113 | 100. * (1 + batch_idx + len(train_iter) * epoch) / 114 | (len(train_iter) * config.epochs), label_loss.item(), seg_loss.item(), acc*100)) 115 | step += 1 116 | if (epoch+1) % 100 ==0: 117 | #print("Saving temp models ...... ", flush=True) 118 | print("Saving models ...... ", end='', flush=True) 119 | torch.save(net.state_dict(), config.save_path + '/partnet_temp_%d.pkl'%epoch) 120 | 121 | # Save the final models 122 | print("Saving final models ...... ", end='', flush=True) 123 | torch.save(net.state_dict(), config.save_path + '/partnet_final.pkl') 124 | print("DONE") 125 | -------------------------------------------------------------------------------- /pytorch_ops/grouping/src/group_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include "group_cuda_kernel.h" 6 | 7 | // input: radius (1), nsample (1), xyz1 (b,n,3), xyz2 (b,m,3) 8 | // output: idx (b,m,nsample), pts_cnt (b,m) 9 | __global__ void query_ball_point_gpu(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt) { 10 | int batch_index = blockIdx.x; 11 | xyz1 += n*3*batch_index; 12 | xyz2 += m*3*batch_index; 13 | idx += m*nsample*batch_index; 14 | pts_cnt += m*batch_index; // counting how many unique points selected in local region 15 | 16 | int index = threadIdx.x; 17 | int stride = blockDim.x; 18 | 19 | for (int j=index;j>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 133 | return 1; 134 | //cudaDeviceSynchronize(); 135 | } 136 | int selectionSortLauncher(int b, int n, int m, int k, const float *dist, int *outi, float *out) { 137 | selection_sort_gpu<<>>(b,n,m,k,dist,outi,out); 138 | return 1; 139 | //cudaDeviceSynchronize(); 140 | } 141 | int groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out){ 142 | group_point_gpu<<>>(b,n,c,m,nsample,points,idx,out); 143 | return 1; 144 | //cudaDeviceSynchronize(); 145 | } 146 | int groupPointGradLauncher(int b, int n, int c, int m, int nsample, const float *grad_out, const int *idx, float *grad_points){ 147 | group_point_grad_gpu<<>>(b,n,c,m,nsample,grad_out,idx,grad_points); 148 | return 1; 149 | //group_point_grad_gpu<<<1,1>>>(b,n,c,m,nsample,grad_out,idx,grad_points); 150 | //cudaDeviceSynchronize(); 151 | } 152 | 153 | #ifdef __cplusplus 154 | } 155 | #endif 156 | -------------------------------------------------------------------------------- /test_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | import scipy.io as sio 4 | import os 5 | import torch 6 | from torch import nn 7 | from dataloader import Data_Loader 8 | import util 9 | import torch.utils.data 10 | from torchfoldext import FoldExt 11 | import sys 12 | import json 13 | import partnet as partnet_model 14 | 15 | with open('./part_color_mapping.json', 'r') as f: 16 | color = json.load(f) 17 | 18 | for c in color: 19 | c[0] = int(c[0]*255) 20 | c[1] = int(c[1]*255) 21 | c[2] = int(c[2]*255) 22 | 23 | def writeply(savedir, data, label): 24 | path = os.path.dirname(savedir) 25 | if not os.path.exists(path): 26 | os.makedirs(path) 27 | if data.size(0) == 0: 28 | n_vertex = 0 29 | else: 30 | n_vertex = data.size(1) 31 | with open(savedir, 'w') as f: 32 | f.write('ply\n') 33 | f.write('format ascii 1.0\n') 34 | f.write('comment 111231\n') 35 | f.write('element vertex %d\n' % n_vertex) 36 | f.write('property float x\n') 37 | f.write('property float y\n') 38 | f.write('property float z\n') 39 | f.write('property float nx\n') 40 | f.write('property float ny\n') 41 | f.write('property float nz\n') 42 | f.write('property uchar red\n') 43 | f.write('property uchar green\n') 44 | f.write('property uchar blue\n') 45 | f.write('property uchar label\n') 46 | f.write('end_header\n') 47 | for j in range(n_vertex): 48 | f.write('%g %g %g %g %g %g %d %d %d %d\n' % (*data[0, j], *color[label[j]], label[j])) 49 | 50 | def normalize_shape(shape): 51 | shape_min, _ = torch.min(shape[:,:,:3], 1) 52 | shape_max, _ = torch.max(shape[:,:,:3], 1) 53 | x_length = shape_max[0, 0] - shape_min[0, 0] 54 | y_length = shape_max[0, 1] - shape_min[0, 1] 55 | z_length = shape_max[0, 2] - shape_min[0, 2] 56 | length = shape_max - shape_min 57 | max_length = torch.max(length) 58 | scale = 1/max_length 59 | center = shape_min + length/2 60 | print(center[0, 0],center[0, 1],center[0, 2], scale) 61 | for i in range(2048): 62 | shape[0, i, 0] = (shape[0, i, 0] - center[0, 0])* scale 63 | shape[0, i, 1] = (shape[0, i, 1] - center[0, 1])* scale 64 | shape[0, i, 2] = (shape[0, i, 2] - center[0, 2])* scale 65 | return shape 66 | 67 | def decode_structure(model, feature, points_f, shape): 68 | """ 69 | segment shape 70 | """ 71 | global m 72 | n_points = shape.size(1) 73 | loc_f = model.pcEncoder(shape) 74 | if feature is None: 75 | feature = loc_f 76 | f_c1 = torch.cat([feature, loc_f], 1) 77 | label_prob = model.nodeClassifier(f_c1) 78 | _, label = torch.max(label_prob, 1) 79 | label = label.item() 80 | if label == 1 or label == 3: # ADJ 81 | left, right = model.adjDecoder(f_c1) 82 | f_max, _ = torch.max(points_f, 2) 83 | f_c2 = torch.cat([f_max, f_c1], 1) 84 | point_label_prob, _ = model.decoder.loc_points_predictor(points_f, f_c2) 85 | point_label_prob = point_label_prob.cpu() 86 | _, point_label=torch.max(point_label_prob, 1) 87 | left_list=[] 88 | right_list=[] 89 | for i in range(n_points): 90 | if point_label[0, i].item() == 0: 91 | left_list.append(i) 92 | else: 93 | right_list.append(i) 94 | left_idx=torch.LongTensor(left_list).cuda() 95 | right_idx=torch.LongTensor(right_list).cuda() 96 | if left_idx.size(0) > 20 and right_idx.size(0) > 20: 97 | l_pc=torch.index_select(shape, 1, left_idx) 98 | l_feature=torch.index_select(points_f, 2, left_idx) 99 | r_pc=torch.index_select(shape, 1, right_idx) 100 | r_feature=torch.index_select(points_f, 2, right_idx) 101 | l_labeling = decode_structure(model, left, l_feature, l_pc) 102 | r_labeling = decode_structure(model, right, r_feature, r_pc) 103 | prediction = torch.LongTensor(n_points).zero_() 104 | for i, j in enumerate(left_idx): 105 | prediction[j.item()]=l_labeling[i] 106 | for i, j in enumerate(right_idx): 107 | prediction[j.item()]=r_labeling[i] 108 | return prediction 109 | else: 110 | prediction = torch.LongTensor(n_points).fill_(m) 111 | m += 1 112 | return prediction 113 | elif label == 2: 114 | f_max, _ = torch.max(points_f, 2) 115 | f_c2 = torch.cat([f_max, f_c1], 1) 116 | point_label_prob, _ = model.decoder.loc_points_predictor_multi(points_f, f_c2) 117 | point_label_prob = point_label_prob.cpu() 118 | _, point_label = torch.max(point_label_prob, 1) 119 | prediction = torch.LongTensor(n_points) 120 | for i in range(point_label.size(1)): 121 | prediction[i] = m+point_label[0, i].item() 122 | mx = torch.max(point_label).item() 123 | m += mx + 1 124 | return prediction 125 | else: 126 | prediction=torch.LongTensor(n_points).fill_(m) 127 | m += 1 128 | return prediction 129 | 130 | m = 0 131 | def main(): 132 | config = util.get_args() 133 | config.cuda = not config.no_cuda 134 | torch.cuda.set_device(config.gpu) 135 | if config.cuda and torch.cuda.is_available(): 136 | print("Using CUDA on GPU ", config.gpu) 137 | else: 138 | print("Not using CUDA.") 139 | net = partnet_model.PARTNET(config) 140 | net.load_state_dict(torch.load(config.save_path + '/partnet_final.pkl', map_location=lambda storage, loc: storage.cuda(config.gpu))) 141 | if config.cuda: 142 | net.cuda() 143 | net.eval() 144 | 145 | if not os.path.exists(config.output_path + 'segmented'): 146 | os.makedirs(config.output_path + 'segmented') 147 | print("Loading data ...... ", end='\n', flush=True) 148 | 149 | shape = torch.from_numpy(sio.loadmat(config.data_path + 'demo.mat')['pc']).float() 150 | ##for your own new shape 151 | ##shape = normalize_shape(shape) 152 | with torch.no_grad(): 153 | shape = shape.cuda() 154 | points_feature = net.pointnet(shape) 155 | root_feature = net.pcEncoder(shape) 156 | global m 157 | m = 0 158 | label = decode_structure(net, root_feature, points_feature, shape) 159 | 160 | #segmented results 161 | writeply(config.output_path + 'segmented/demo.ply', shape, label) 162 | print('Successfully output result!') 163 | 164 | if __name__ == '__main__': 165 | main() 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVPR 2019-PartNet 2 | Code for PartNet: A Recursive Part Decomposition Network for Fine-grained and Hierarchical Shape Segmentation 3 | 4 | ### Intruduction 5 | 6 | Deep learning approaches to 3D shape segmentation are 7 | typically formulated as a multi-class labeling problem. Existing 8 | models are trained for a fixed set of labels, which 9 | greatly limits their flexibility and adaptivity. We opt for topdown 10 | recursive decomposition and develop the first deep 11 | learning model for hierarchical segmentation of 3D shapes, 12 | based on recursive neural networks. Starting from a full 13 | shape represented as a point cloud, our model performs 14 | recursive binary decomposition, where the decomposition 15 | network at all nodes in the hierarchy share weights. At each 16 | node, a node classifier is trained to determine the type (adjacency 17 | or symmetry) and stopping criteria of its decomposition. 18 | The features extracted in higher level nodes are 19 | recursively propagated to lower level ones. Thus, the meaningful 20 | decompositions in higher levels provide strong contextual 21 | cues constraining the segmentations in lower levels. 22 | Meanwhile, to increase the segmentation accuracy at each 23 | node, we enhance the recursive contextual feature with the 24 | shape feature extracted for the corresponding part. Our 25 | method segments a 3D shape in point cloud into an unfixed 26 | number of parts, depending on the shape complexity, showing 27 | strong generality and flexibility. It achieves the stateof- 28 | the-art performance, both for fine-grained and semantic 29 | segmentation, on the public benchmark and a new benchmark 30 | of fine-grained segmentation proposed in this work. 31 | We also demonstrate its application for fine-grained part 32 | refinements in image-to-shape reconstruction. 33 | 34 | ### Dependencies 35 | 36 | Requirements: 37 | - Python 3.5 with numpy, scipy, torchfold, tensorboard and etc. 38 | - [PyTorch](https://pytorch.org/resources) 39 | 40 | Our code has been tested with Python 3.5, PyTorch 0.4.0, CUDA 8.0 on Ubuntu 16.04. 41 | ## News 42 | ***July 3, 2022.*** If you use Pytorch later than 1.6, please use this [new version](https://github.com/FENGGENYU/PartNet/tree/torch1.6-and-later). 43 | 44 | ***Dec 3, 2019.*** Our structure hierarchies are same as [GRASS](https://github.com/kevin-kaixu/grass_pytorch), but symmetric parameters have a little difference. Specifically, our symmetric parameter labels are:0--reflected,-1--rotational,1--translational, while GRASS uses 1--reflected,-1--rotational,0--translational. In addition, our translational parameters are set different from GRASS. 45 | If you want to use symmetric parameters from [GRASS](https://github.com/kevin-kaixu/grass_pytorch), please use dataloader_grass.py. 46 | If you want to use symmetric parameters from [PartNet_symh](https://github.com/FoggYu/PartNet_symh), please use dataloader_symh.py. 47 | 48 | ***Oct 30, 2019.*** Our extended datasets have been released at [here](https://github.com/FoggYu/PartNet_symh). 49 | 50 | ## Datasets and Pre-trained weights 51 | The input pointcloud and training hierarchical trees are on [Here](https://www.dropbox.com/sh/7nuqb9wphsjkzko/AAAgy8zzmeRFsNuGuYCxUUWTa?dl=0). 52 | Each category contains following folds: 53 | - models_2048_points_normals: normalized input shape pointcloud with normal 54 | - training_data_models_segment_2048_normals: GT part pointcloud with normal. (Note that we only store half parts of a shape, the other parts needs to be recovered by symmetric parameters.) 55 | - training_trees: hierarchical structure(ops.mat) with symmetric parameters(syms.mat). The labels(labels.mat) are for testing. 56 | 57 | Our dataset has been extended and updated these days: 58 | #### Datasets information 59 | | category_name | chair | airplane | table | sofa | helicopter | bike | 60 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 61 | | number of shapes | 999 | 630 | 583 | 630 | 100 | 155 | 62 | | number of parts | 9697 | 5234 | 3207 | 4747 | 1415 | 1238 | 63 | | maximum parts per shape | 25 | 14 | 17 | 27 | 21 | 9 | 64 | | minimum parts per shape | 3 | 4 | 2 | 2 | 6 | 6 | 65 | 66 | The Pre-trained model are on [Here](https://www.dropbox.com/sh/um1li37bnbkpuck/AAAaCAuXWaY050E7W5b42XT1a?dl=0). 67 | 68 | ### Usage: Demo 69 | Require 3GB RAM on the GPU and 5sec to run. 70 | 71 | This script takes as input a normalized 2048 pointcloud with normal (Sampled from ShapeNet model using [virtualscanner](https://github.com/Microsoft/O-CNN)). 72 | 73 | Please download Pre-trained weights of airplane first and put it at ./models/airplane. 74 | 75 | Build extention for each op in ./pytorch_ops/*** (eg ./pytorch_ops/sampling/) using build.py 76 | 77 | ps: torch 0.4 is required, and it won't work with later torch version. 78 | ``` 79 | python build.py 80 | ``` 81 | Then run 82 | ``` 83 | python test_demo.py 84 | ``` 85 | ![input](./picture/airplane.png) 86 | 87 | ### Usage: Training 88 | 89 | Put data of each category in ./data/category_name(eg ./data/airplane) 90 | 91 | Build extention for each op in ./pytorch_ops/*** (eg ./pytorch_ops/sampling/) using build.py 92 | 93 | ps: torch 0.4 is required, and it won't work with later torch version. 94 | ``` 95 | python build.py 96 | ``` 97 | Then run training process 98 | ``` 99 | python train.py 100 | ``` 101 | 102 | More training and testing arguments are set in util.py 103 | ``` 104 | '--epochs' (number of epochs; default=1000) 105 | '--batch_size' (batch size; default=10) 106 | '--show_log_every' (show training log for every X frames; default=3) 107 | '--no_cuda' (don't use cuda) 108 | '--gpu' (device id of GPU to run cuda) 109 | '--data_path' (dataset path, default='data') 110 | '--save_path' (trained model path, default='models') 111 | '--output_path' (segmented result path, default='results') 112 | '--training' (training or testing, default=False) 113 | '--split_num' (training data size for each category) 114 | '--total_num' (full data size for each category, only for testing) 115 | '--label_category' (semantic labels for each category, only for testing) 116 | ``` 117 | ### Usage: Testing 118 | To evaluate AP correctly, you need to set label_category for each category in util.py 119 | ``` 120 | '--label_category' (semantic labels for each category, only for testing) 121 | ``` 122 | We measure AP(%) with IoU threshold being 0.25 and 0.5, respectively. 123 | ``` 124 | python ap_evaluate.py 125 | ``` 126 | Segmentation results and its corresponding GT can also be found in ./results/category_name(eg ./data/airplane). 127 | 128 | #### AP information of our method on updated datasets and training/testing split. 129 | 130 | | category name | chair | airplane | table | sofa | helicopter | bike | 131 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 132 | | IoU > 0.25 | 93.83 | 96.33 | 78.49 | 76.07 | 83.0 | 98.22 | 133 | | IoU > 0.5 | 84.23 | 88.41 | 63.2 | 55.76 | 69.4 | 97.60 | 134 | 135 | PS: If you want to try more new shapes, please make sure that them are oriented and normalized as our shapes. 136 | 137 | ## Citation 138 | If you use this code, please cite the following paper. 139 | ``` 140 | @inproceedings{yu2019partnet, 141 | title = {PartNet: A Recursive Part Decomposition Network for Fine-grained and Hierarchical Shape Segmentation}, 142 | author = {Fenggen Yu and Kun Liu and Yan Zhang and Chenyang Zhu and Kai Xu}, 143 | booktitle = {CVPR}, 144 | pages = {to appear}, 145 | year = {2019} 146 | } 147 | ``` 148 | -------------------------------------------------------------------------------- /pytorch_ops/losses/cd/cd_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include "cd_cuda_kernel.h" 6 | 7 | 8 | __global__ void cd_cuda_forward_kernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 9 | const int batch=512; 10 | __shared__ float buf[batch*3]; 11 | for (int i=blockIdx.x;ibest){ 123 | result[(i*n+j)]=best; 124 | result_i[(i*n+j)]=best_i; 125 | } 126 | } 127 | __syncthreads(); 128 | } 129 | } 130 | } 131 | int cd_forward_Launcher(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i){ 132 | cd_cuda_forward_kernel<<>>(b,n,xyz,m,xyz2,result,result_i); 133 | cd_cuda_forward_kernel<<>>(b,m,xyz2,n,xyz,result2,result2_i); 134 | return 1; 135 | } 136 | 137 | __global__ void cd_cuda_backward_kernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 138 | for (int i=blockIdx.x;i>>(b,n,xyz1,m,xyz2,grad_dist1,idx1,grad_xyz1,grad_xyz2); 160 | cd_cuda_backward_kernel<<>>(b,m,xyz2,n,xyz1,grad_dist2,idx2,grad_xyz2,grad_xyz1); 161 | return 1; 162 | } 163 | 164 | 165 | #ifdef __cplusplus 166 | } 167 | #endif 168 | -------------------------------------------------------------------------------- /ap_evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | import scipy.io as sio 4 | import os 5 | import torch 6 | from torch import nn 7 | from dataloader import Data_Loader 8 | import util 9 | import torch.utils.data 10 | from torchfoldext import FoldExt 11 | import sys 12 | import json 13 | import partnet as partnet_model 14 | 15 | with open('./part_color_mapping.json', 'r') as f: 16 | color = json.load(f) 17 | 18 | for c in color: 19 | c[0] = int(c[0]*255) 20 | c[1] = int(c[1]*255) 21 | c[2] = int(c[2]*255) 22 | 23 | def writeply(savedir, data, label): 24 | path = os.path.dirname(savedir) 25 | if not os.path.exists(path): 26 | os.makedirs(path) 27 | if data.size(0) == 0: 28 | n_vertex = 0 29 | else: 30 | n_vertex = data.size(1) 31 | with open(savedir, 'w') as f: 32 | f.write('ply\n') 33 | f.write('format ascii 1.0\n') 34 | f.write('comment 111231\n') 35 | f.write('element vertex %d\n' % n_vertex) 36 | f.write('property float x\n') 37 | f.write('property float y\n') 38 | f.write('property float z\n') 39 | f.write('property float nx\n') 40 | f.write('property float ny\n') 41 | f.write('property float nz\n') 42 | f.write('property uchar red\n') 43 | f.write('property uchar green\n') 44 | f.write('property uchar blue\n') 45 | f.write('property uchar label\n') 46 | f.write('end_header\n') 47 | for j in range(n_vertex): 48 | f.write('%g %g %g %g %g %g %d %d %d %d\n' % (*data[0, j], *color[label[j]], label[j])) 49 | 50 | 51 | def evaluate(pred_grp, gt_grp, seg, n_category=4, at=0.5): 52 | total_sgpn = np.zeros(n_category) 53 | tpsins = [[] for itmp in range(n_category)] 54 | fpsins = [[] for itmp in range(n_category)] 55 | pts_in_pred = [[] for itmp in range(n_category)] 56 | pts_in_gt = [[] for itmp in range(n_category)] 57 | un = np.unique(pred_grp) 58 | for i, g in enumerate(un): 59 | tmp = (pred_grp == g) 60 | sem_seg_g = int(stats.mode(seg[tmp])[0]) 61 | pts_in_pred[sem_seg_g] += [tmp] 62 | un = np.unique(gt_grp) 63 | for ig, g in enumerate(un): 64 | tmp = (gt_grp == g) 65 | sem_seg_g = int(stats.mode(seg[tmp])[0]) 66 | pts_in_gt[sem_seg_g] += [tmp] 67 | total_sgpn[sem_seg_g] += 1 68 | for i_sem in range(n_category): 69 | tp = [0.] * len(pts_in_pred[i_sem]) 70 | fp = [0.] * len(pts_in_pred[i_sem]) 71 | gtflag = np.zeros(len(pts_in_gt[i_sem])) 72 | 73 | for ip, ins_pred in enumerate(pts_in_pred[i_sem]): 74 | ovmax = -1. 75 | 76 | for ig, ins_gt in enumerate(pts_in_gt[i_sem]): 77 | union = (ins_pred | ins_gt) 78 | intersect = (ins_pred & ins_gt) 79 | iou = float(np.sum(intersect)) / np.sum(union) 80 | 81 | if iou > ovmax: 82 | ovmax = iou 83 | igmax = ig 84 | 85 | if ovmax >= at: 86 | if gtflag[igmax] == 0: 87 | tp[ip] = 1 # true 88 | gtflag[igmax] = 1 89 | else: 90 | fp[ip] = 1 # multiple det 91 | else: 92 | fp[ip] = 1 # false positive 93 | tpsins[i_sem] += tp 94 | fpsins[i_sem] += fp 95 | return tpsins, fpsins, total_sgpn 96 | 97 | 98 | def eval_3d_perclass(tp, fp, npos): 99 | tp = np.asarray(tp).astype(np.float) 100 | fp = np.asarray(fp).astype(np.float) 101 | tp = np.cumsum(tp) 102 | fp = np.cumsum(fp) 103 | rec = tp / npos 104 | prec = tp / (fp+tp) 105 | 106 | ap = 0. 107 | for t in np.arange(0, 1, 0.1): 108 | prec1 = prec[rec>=t] 109 | prec1 = prec1[~np.isnan(prec1)] 110 | if len(prec1) == 0: 111 | p = 0. 112 | else: 113 | p = max(prec1) 114 | if not p: 115 | p = 0. 116 | ap = ap + p / 10 117 | return ap, rec, prec 118 | 119 | 120 | def decode_structure(model, feature, points_f, shape): 121 | """ 122 | Decode a root code into a tree structure of boxes 123 | """ 124 | global m 125 | n_points = shape.size(1) 126 | loc_f = model.pcEncoder(shape) 127 | if feature is None: 128 | feature = loc_f 129 | f_c1 = torch.cat([feature, loc_f], 1) 130 | label_prob = model.nodeClassifier(f_c1) 131 | _, label = torch.max(label_prob, 1) 132 | label = label.item() 133 | if label == 1 or label == 3: # ADJ 134 | left, right = model.adjDecoder(f_c1) 135 | f_max, _ = torch.max(points_f, 2) 136 | f_c2 = torch.cat([f_max, f_c1], 1) 137 | bbox, _ = model.decoder.loc_points_predictor(points_f, f_c2) 138 | bbox = bbox.cpu() 139 | _, point_label=torch.max(bbox, 1) 140 | left_list=[] 141 | right_list=[] 142 | for i in range(n_points): 143 | if point_label[0, i].item() == 0: 144 | left_list.append(i) 145 | else: 146 | right_list.append(i) 147 | left_idx=torch.LongTensor(left_list).cuda() 148 | right_idx=torch.LongTensor(right_list).cuda() 149 | if left_idx.size(0) > 20 and right_idx.size(0) > 20: 150 | l_pc=torch.index_select(shape, 1, left_idx) 151 | l_feature=torch.index_select(points_f, 2, left_idx) 152 | r_pc=torch.index_select(shape, 1, right_idx) 153 | r_feature=torch.index_select(points_f, 2, right_idx) 154 | l_labeling = decode_structure(model, left, l_feature, l_pc) 155 | r_labeling = decode_structure(model, right, r_feature, r_pc) 156 | prediction = torch.LongTensor(n_points).zero_() 157 | for i, j in enumerate(left_idx): 158 | prediction[j.item()]=l_labeling[i] 159 | for i, j in enumerate(right_idx): 160 | prediction[j.item()]=r_labeling[i] 161 | return prediction 162 | else: 163 | prediction = torch.LongTensor(n_points).fill_(m) 164 | m += 1 165 | return prediction 166 | elif label == 2: 167 | f_max, _ = torch.max(points_f, 2) 168 | f_c2 = torch.cat([f_max, f_c1], 1) 169 | bbox, _ = model.decoder.loc_points_predictor_multi(points_f, f_c2) 170 | bbox = bbox.cpu() 171 | _, point_label = torch.max(bbox, 1) 172 | prediction = torch.LongTensor(n_points) 173 | for i in range(point_label.size(1)): 174 | prediction[i] = m+point_label[0, i].item() 175 | mx = torch.max(point_label).item() 176 | m += mx + 1 177 | return prediction 178 | else: 179 | prediction=torch.LongTensor(n_points).fill_(m) 180 | m += 1 181 | return prediction 182 | 183 | m = 0 184 | def main(): 185 | config = util.get_args() 186 | config.cuda = not config.no_cuda 187 | torch.cuda.set_device(config.gpu) 188 | if config.cuda and torch.cuda.is_available(): 189 | print("Using CUDA on GPU ", config.gpu) 190 | else: 191 | print("Not using CUDA.") 192 | net = partnet_model.PARTNET(config) 193 | net.load_state_dict(torch.load(config.save_path + '/partnet_final.pkl', map_location=lambda storage, loc: storage.cuda(config.gpu))) 194 | if config.cuda: 195 | net.cuda() 196 | net.eval() 197 | 198 | if not os.path.exists(config.output_path + 'gt_grp'): 199 | os.makedirs(config.output_path + 'gt_grp') 200 | if not os.path.exists(config.output_path + 'gt'): 201 | os.makedirs(config.output_path + 'gt') 202 | if not os.path.exists(config.output_path + 'segmented'): 203 | os.makedirs(config.output_path + 'segmented') 204 | print("Loading data ...... ", end='\n', flush=True) 205 | data_loader_batch = Data_Loader(config.data_path, config.training, config.split_num, config.total_num) 206 | NUM_CATEGORY = config.label_category 207 | recall_all = 0 208 | with torch.no_grad(): 209 | tpsins = [[] for itmp in range(NUM_CATEGORY)] 210 | fpsins = [[] for itmp in range(NUM_CATEGORY)] 211 | total_sgpn = np.zeros(NUM_CATEGORY) 212 | tpsins_2 = [[] for itmp in range(NUM_CATEGORY)] 213 | fpsins_2 = [[] for itmp in range(NUM_CATEGORY)] 214 | total_sgpn_2 = np.zeros(NUM_CATEGORY) 215 | bad_shape = torch.zeros(len(data_loader_batch), 1) 216 | bad_num = 0 217 | for n in range(len(data_loader_batch)): 218 | shape = data_loader_batch[n].shape.cuda() 219 | points_feature = net.pointnet(shape) 220 | root_feature = net.pcEncoder(shape) 221 | print('index : ', n) 222 | global m 223 | m = 0 224 | label = decode_structure(net, root_feature, points_feature, shape) 225 | seg_gt = data_loader_batch[n].shape_label 226 | grp_gt = data_loader_batch[n].grp 227 | 228 | #ground truth fine-grained groups 229 | writeply(config.output_path + 'gt_grp/%d.ply' % (n), data_loader_batch[n].shape, grp_gt) 230 | #ground truth semantic group 231 | writeply(config.output_path + 'gt/%d.ply' % (n), data_loader_batch[n].shape, seg_gt) 232 | #segmented results 233 | writeply(config.output_path + 'segmented/%d.ply' % (n), data_loader_batch[n].shape, label) 234 | tp, fp, groups = evaluate(label.numpy(), grp_gt.numpy(), seg_gt.numpy(), NUM_CATEGORY, at=0.25) 235 | tp2, fp2, groups2 = evaluate(label.numpy(), grp_gt.numpy(), seg_gt.numpy(), NUM_CATEGORY, at=0.5) 236 | 237 | for i in range(NUM_CATEGORY): 238 | tpsins[i] += tp[i] 239 | fpsins[i] += fp[i] 240 | total_sgpn[i] += groups[i] 241 | tpsins_2[i] += tp2[i] 242 | fpsins_2[i] += fp2[i] 243 | total_sgpn_2[i] += groups2[i] 244 | 245 | ap = np.zeros(NUM_CATEGORY) 246 | ap2 = np.zeros(NUM_CATEGORY) 247 | for i_sem in range(NUM_CATEGORY): 248 | ap[i_sem], _, _ = eval_3d_perclass(tpsins[i_sem], fpsins[i_sem], total_sgpn[i_sem]) 249 | ap2[i_sem], _, _ = eval_3d_perclass(tpsins_2[i_sem], fpsins_2[i_sem], total_sgpn_2[i_sem]) 250 | 251 | print('Instance Segmentation AP(IoU 0.25):', ap) 252 | print('Instance Segmentation mAP(IoU 0.25:', np.mean(ap)) 253 | print('Instance Segmentation AP(IoU 0.5):', ap2) 254 | print('Instance Segmentation mAP(IoU 0.5):', np.mean(ap2)) 255 | 256 | if __name__ == '__main__': 257 | main() 258 | -------------------------------------------------------------------------------- /pytorch_ops/losses/emd/emd_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include "emd_cuda_kernel.h" 6 | 7 | __global__ void approxmatch(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,float * __restrict__ match,float * temp){ 8 | float * remainL=temp+blockIdx.x*(n+m)*2, * remainR=temp+blockIdx.x*(n+m)*2+n,*ratioL=temp+blockIdx.x*(n+m)*2+n+m,*ratioR=temp+blockIdx.x*(n+m)*2+n+m+n; 9 | float multiL,multiR; 10 | if (n>=m){ 11 | multiL=1; 12 | multiR=n/m; 13 | }else{ 14 | multiL=m/n; 15 | multiR=1; 16 | } 17 | const int Block=1024; 18 | __shared__ float buf[Block*4]; 19 | for (int i=blockIdx.x;i=-2;j--){ 28 | float level=-powf(4.0f,j); 29 | if (j==-2){ 30 | level=0; 31 | } 32 | for (int k0=0;k0>>(b,n,m,xyz1,xyz2,match,temp); 188 | return 1; 189 | } 190 | 191 | __global__ void matchcost(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ out){ 192 | __shared__ float allsum[512]; 193 | const int Block=1024; 194 | __shared__ float buf[Block*3]; 195 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,out); 236 | return 1; 237 | } 238 | __global__ void matchcostgrad2(int b,int n,int m,const float * __restrict__ xyz1,const float * __restrict__ xyz2,const float * __restrict__ match,float * __restrict__ grad2){ 239 | __shared__ float sum_grad[256*3]; 240 | for (int i=blockIdx.x;i>>(b,n,m,xyz1,xyz2,match,grad1); 303 | matchcostgrad2<<>>(b,n,m,xyz1,xyz2,match,grad2); 304 | return 1; 305 | } 306 | 307 | #ifdef __cplusplus 308 | } 309 | #endif 310 | -------------------------------------------------------------------------------- /dataloader_grass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from scipy.io import loadmat 4 | from enum import Enum 5 | from torch.autograd import Variable 6 | import math 7 | from pytorch_ops.sampling.sample import FarthestSample 8 | from pytorch_ops.losses.cd.cd import CDModule 9 | 10 | m_grp = 0 11 | 12 | def vrrotvec2mat(rotvector, angle): 13 | s = math.sin(angle) 14 | c = math.cos(angle) 15 | t = 1 - c 16 | x = rotvector[0] 17 | y = rotvector[1] 18 | z = rotvector[2] 19 | m = torch.FloatTensor( 20 | [[t * x * x + c, t * x * y - s * z, t * x * z + s * y], 21 | [t * x * y + s * z, t * y * y + c, t * y * z - s * x], 22 | [t * x * z - s * y, t * y * z + s * x, t * z * z + c]]) 23 | return m 24 | 25 | #segmentation for symmetric node 26 | def multilabel(points, shape, cdloss): 27 | c = torch.LongTensor(1, 2048).zero_() 28 | c = c - 1 29 | for i in range(points.size(0)): 30 | a = points[i].unsqueeze(0).cuda() 31 | _, index, _, _ = cdloss(a, shape) 32 | b = torch.unique(index.cpu()) 33 | for k in range(b.size(0)): 34 | c[0, b[k].item()] = i 35 | return c 36 | 37 | class Tree(object): 38 | class NodeType(Enum): 39 | LEAF = 0 # leaf node 40 | ADJ = 1 # adjacency (adjacent part assembly) node 41 | SYM = 2 # symmetry (symmetric part grouping) node 42 | SYM_ADJ = 3 #reflect 43 | 44 | class Node(object): 45 | def __init__(self, 46 | leaf_points=None, 47 | left=None, 48 | right=None, 49 | node_type=None, 50 | sym_p=None, 51 | sym_a=None, 52 | sym_t=None, 53 | semantic_label=None): 54 | self.leaf_points = leaf_points # node points 55 | if isinstance(sym_t, int): 56 | self.sym_t = torch.LongTensor([sym_t]) 57 | else: 58 | self.sym_t = None 59 | if isinstance(sym_a, int): 60 | self.sym_a = torch.LongTensor([sym_a]) 61 | else: 62 | self.sym_a = None 63 | self.sym_p = sym_p 64 | self.sym_type = self.sym_a 65 | self.left = left # left child for ADJ or SYM (a symmeter generator) 66 | self.right = right # right child 67 | self.node_type = node_type 68 | self.label = torch.LongTensor([self.node_type.value]) 69 | self.is_root = False 70 | self.semantic_label = semantic_label 71 | 72 | def is_leaf(self): 73 | return self.node_type == Tree.NodeType.LEAF 74 | 75 | def is_adj(self): 76 | return self.node_type == Tree.NodeType.ADJ 77 | 78 | def is_sym(self): 79 | return self.node_type == Tree.NodeType.SYM 80 | 81 | def is_sym_adj(self): 82 | return self.node_type == Tree.NodeType.SYM_ADJ 83 | 84 | def __init__(self, parts, ops, syms, labels, shape): 85 | parts_list = [p for p in torch.split(parts, 1, 0)] 86 | sym_param = [s for s in torch.split(syms, 1, 0)] 87 | part_labels = [s for s in torch.split(labels, 1, 0)] 88 | parts_list.reverse() 89 | sym_param.reverse() 90 | part_labels.reverse() 91 | queue = [] 92 | sym_node_num = 0 93 | for id in range(ops.size()[1]): 94 | if ops[0, id] == Tree.NodeType.LEAF.value: 95 | queue.append( 96 | Tree.Node(leaf_points=parts_list.pop(), node_type=Tree.NodeType.LEAF, semantic_label=part_labels.pop())) 97 | elif ops[0, id] == Tree.NodeType.ADJ.value: 98 | left_node = queue.pop() 99 | right_node = queue.pop() 100 | queue.append( 101 | Tree.Node( 102 | left=left_node, 103 | right=right_node, 104 | node_type=Tree.NodeType.ADJ)) 105 | elif ops[0, id] == Tree.NodeType.SYM.value: 106 | node = queue.pop() 107 | s = sym_param.pop() 108 | b = s[0, 0] + 1 109 | t = s[0, 7].item() 110 | p = s[0, 1:7] 111 | if t > 0: 112 | t = round(1.0/t) 113 | queue.append( 114 | Tree.Node( 115 | left=node, 116 | sym_p=p.unsqueeze(0), 117 | sym_a=int(b), 118 | sym_t=int(t), 119 | node_type=Tree.NodeType.SYM)) 120 | if b != 1: 121 | sym_node_num += 1 122 | assert len(queue) == 1 123 | self.root = queue[0] 124 | self.root.is_root = True 125 | assert self.root.is_adj() 126 | self.shape = shape 127 | if sym_node_num == 0: 128 | self.n_syms = torch.Tensor([sym_node_num]).cuda() 129 | else: 130 | self.n_syms = torch.Tensor([1/sym_node_num]).cuda() 131 | 132 | #find GT label's index in input 133 | def Attention(feature2048, shape): 134 | index = [] 135 | for i in range(shape.size(1)): 136 | if feature2048[0, i] > -1: 137 | index.append(i) 138 | pad_index = [] 139 | while len(pad_index) < 2048: 140 | pad_index.extend(index) 141 | pad_index = torch.LongTensor(pad_index[:2048]) 142 | return pad_index.unsqueeze(0).cpu() 143 | 144 | #construct groundtruth for pointcloud segmentation 145 | 146 | def dfs_fix(node, shape, cdloss, shape_normal, seg, grp, reflect=None): 147 | global m_grp 148 | if node.is_leaf(): 149 | # find node's corresponding points on input 150 | _, index, _ , _ = cdloss(node.leaf_points[:, :, :3].cuda(), shape) 151 | b = torch.unique(index.cpu()) 152 | c = torch.LongTensor(1, 2048).zero_() 153 | c = c - 1 154 | for i in range(b.size(0)): 155 | c[0, b[i].item()] = 0 156 | node.index = c #segmentation GT binary label 157 | idx = Attention(c, shape) #node's corresponding idx 158 | #node's corresponding points 159 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 160 | #node's corresponding idx 161 | node.pad_index = idx 162 | for i in range(node.pad_index.size(1)): 163 | seg[node.pad_index[0, i].item()] = node.semantic_label 164 | grp[node.pad_index[0, i].item()] = m_grp 165 | m_grp += 1 166 | if reflect is not None: 167 | #recover reflect's children 168 | re_leaf_points = torch.cat([node.leaf_points[:, :, :3], node.leaf_points[:, :, :3]+node.leaf_points[:, :, 3:]], 1) 169 | re_leaf_points = re_leaf_points.squeeze(0).cpu() 170 | sList = torch.split(reflect, 1, 0) 171 | ref_normal = torch.cat([sList[0], sList[1], sList[2]]) 172 | ref_normal = ref_normal / torch.norm(ref_normal) 173 | ref_point = torch.cat([sList[3], sList[4], sList[5]]) 174 | new_points = 2 * ref_point.add(-re_leaf_points).matmul(ref_normal) 175 | new_points = new_points.unsqueeze(-1) 176 | new_points = new_points.repeat(1, 3) 177 | new_points = ref_normal.mul(new_points).add(re_leaf_points) 178 | new_points = torch.cat([new_points[:2048, :], new_points[2048:, :] - new_points[:2048, :]], 1) 179 | New_node = Tree.Node(leaf_points=new_points.unsqueeze(0), node_type=Tree.NodeType.LEAF) 180 | #build node for reflect node's children 181 | _, index, _ , _ = cdloss(New_node.leaf_points[:, :, :3].cuda(), shape) 182 | b = torch.unique(index.cpu()) 183 | reflect_c = torch.LongTensor(1, 2048).zero_() 184 | reflect_c = reflect_c - 1 185 | for i in range(b.size(0)): 186 | reflect_c[0, b[i].item()] = 0 187 | New_node.index = reflect_c 188 | idx = Attention(reflect_c, shape) 189 | New_node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 190 | New_node.pad_index = idx 191 | New_node.semantic_label = node.semantic_label 192 | for i in range(New_node.pad_index.size(1)): 193 | seg[New_node.pad_index[0, i].item()] = New_node.semantic_label 194 | grp[New_node.pad_index[0, i].item()] = m_grp 195 | m_grp += 1 196 | return torch.Tensor([0]).cuda(), New_node 197 | else: 198 | return torch.Tensor([0]).cuda(), node 199 | if node.is_adj(): 200 | l_num, new_node_l = dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, reflect) 201 | r_num, new_node_r = dfs_fix(node.right, shape, cdloss, shape_normal, seg, grp, reflect) 202 | #build adj node 203 | c = torch.LongTensor(1, 2048).zero_() 204 | c = c - 1 205 | for i in range(2048): 206 | if node.left.index[0, i].item() > -1: 207 | c[0, i] = 0 208 | for i in range(2048): 209 | if node.right.index[0, i].item() > -1: 210 | c[0, i] = 1 211 | node.index = c 212 | idx = Attention(c, shape) 213 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 214 | node.pad_index = idx 215 | 216 | if reflect is not None: 217 | New_node = Tree.Node(left=new_node_l, right=new_node_r, node_type=Tree.NodeType.ADJ) 218 | reflect_c = torch.LongTensor(1, 2048).zero_() 219 | reflect_c = reflect_c - 1 220 | for i in range(2048): 221 | if new_node_l.index[0, i].item() > -1: 222 | reflect_c[0, i] = 0 223 | for i in range(2048): 224 | if new_node_r.index[0, i].item() > -1: 225 | reflect_c[0, i] = 1 226 | New_node.index = reflect_c 227 | idx = Attention(reflect_c, shape) 228 | New_node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 229 | New_node.pad_index = idx 230 | 231 | return l_num + r_num + torch.Tensor([2]).cuda(), New_node 232 | else: 233 | return l_num + r_num + torch.Tensor([1]).cuda(), node 234 | if node.is_sym(): 235 | #build symmetric node 236 | t = node.sym_t.item() 237 | p = node.sym_p.squeeze(0) 238 | 239 | if node.sym_type.item() == 2: #reflect node 240 | child_num, new_node = dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, p) 241 | 242 | c = torch.LongTensor(1, 2048).zero_() 243 | c = c - 1 244 | for i in range(2048): 245 | if node.left.index[0, i].item() > -1: 246 | c[0, i] = 0 247 | for i in range(2048): 248 | if new_node.index[0, i].item() > -1: 249 | c[0, i] = 1 250 | node.index = c 251 | node.right = new_node 252 | idx = Attention(c, shape) 253 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 254 | node.node_type = Tree.NodeType.SYM_ADJ 255 | node.label = torch.LongTensor([node.node_type.value]) 256 | node.pad_index = idx 257 | 258 | return child_num + torch.Tensor([1]).cuda(), node 259 | else: 260 | child_num, _= dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, None) 261 | new_leaf_points = node.left.leaf_points.squeeze(0) 262 | leaf_points_list = [new_leaf_points.unsqueeze(0)] 263 | 264 | new_leaf_points = torch.cat([new_leaf_points[:, :3] , new_leaf_points[:, :3] + new_leaf_points[:, 3:]], 0) 265 | 266 | if node.sym_type.item() == 0:#rotate symmetry 267 | sList = torch.split(p, 1, 0) 268 | f1 = torch.cat([sList[0], sList[1], sList[2]]) 269 | if f1[1] < 0: 270 | f1 = - f1 271 | f1 = f1 / torch.norm(f1) 272 | f2 = torch.cat([sList[3], sList[4], sList[5]]) 273 | folds = int(t) 274 | a = 1.0 / float(folds) 275 | for i in range(folds - 1): 276 | angle = a * 2 * 3.1415 * (i + 1) 277 | rotm = vrrotvec2mat(f1, angle) 278 | sym_leaf_points = rotm.matmul(new_leaf_points.add(-f2).t()).t().add(f2) 279 | sym_leaf_points = torch.cat([sym_leaf_points[:2048, :] , sym_leaf_points[2048:, :] - sym_leaf_points[:2048, :]], 1) 280 | leaf_points_list.append(sym_leaf_points.unsqueeze(0)) 281 | elif node.sym_type.item() == 1: #translate symmetry 282 | sList = torch.split(p, 1, 0) 283 | trans = torch.cat([sList[0], sList[1], sList[2]]) 284 | folds = t - 1 285 | for i in range(folds): 286 | sym_leaf_points = new_leaf_points.add(trans.mul(i + 1)) 287 | sym_leaf_points = torch.cat([sym_leaf_points[:2048, :] , sym_leaf_points[2048:, :] - sym_leaf_points[:2048, :]], 1) 288 | leaf_points_list.append(sym_leaf_points.unsqueeze(0)) 289 | 290 | a = torch.cat(leaf_points_list, 0) 291 | node.index = multilabel(a[:, :, :3], shape, cdloss) 292 | idx = Attention(node.index, shape) 293 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 294 | node.pad_index = Attention(node.index, shape) 295 | for i in range(node.pad_index.size(1)): 296 | seg[node.pad_index[0, i].item()] = node.left.semantic_label 297 | for i in range(2048): 298 | if node.index[0, i].item() > -1: 299 | grp[i] = m_grp + node.index[0, i] 300 | m_grp = m_grp + torch.max(node.index) + 1 301 | return torch.Tensor([1]).cuda(), node 302 | 303 | class Data_Loader(data.Dataset): 304 | def __init__(self, dir, is_train, split_num, total_num): 305 | self.dir = dir 306 | op_data = torch.from_numpy(loadmat(self.dir + 'training_trees/ops.mat')['ops']).int() 307 | label_data = torch.from_numpy(loadmat(self.dir + 'training_trees/labels.mat')['labels']).int() 308 | sym_data = torch.from_numpy(loadmat(self.dir + 'training_trees/syms.mat')['syms']).float() 309 | num_examples = op_data.size()[1] 310 | op_data = torch.chunk(op_data, num_examples, 1) 311 | label_data = torch.chunk(label_data, num_examples, 1) 312 | sym_data = torch.chunk(sym_data, num_examples, 1) 313 | self.trees = [] 314 | self.training = is_train 315 | if is_train: 316 | begin = 0 317 | end = split_num 318 | else: 319 | begin = split_num 320 | end = total_num 321 | for i in range(begin, end): 322 | parts = torch.from_numpy(loadmat(self.dir + 'training_data_models_segment_2048_normals/%d.mat' % i)['pc']).float() 323 | shape = torch.from_numpy(loadmat(self.dir + 'models_2048_points_normals/%d.mat' % i)['pc']).float() 324 | ops = torch.t(op_data[i]) 325 | syms = torch.t(sym_data[i]) 326 | labels = torch.t(label_data[i]) 327 | tree = Tree(parts, ops, syms, labels, shape) 328 | cdloss = CDModule() 329 | seg = torch.LongTensor(2048).zero_() # for ap calculation 330 | grp = torch.LongTensor(2048).zero_() 331 | global m_grp 332 | m_grp = 0 333 | num_node, _ = dfs_fix(tree.root, shape[0, :, :3].unsqueeze(0).cuda(), cdloss, shape, seg, grp) 334 | tree.n_nodes = num_node 335 | tree.shape_label = seg 336 | tree.grp = grp 337 | self.trees.append(tree) 338 | print('load data', i) 339 | print(len(self.trees)) 340 | 341 | def __getitem__(self, index): 342 | tree = self.trees[index] 343 | return tree 344 | 345 | def __len__(self): 346 | return len(self.trees) 347 | -------------------------------------------------------------------------------- /dataloader_symh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from scipy.io import loadmat 4 | from enum import Enum 5 | from torch.autograd import Variable 6 | import math 7 | from pytorch_ops.sampling.sample import FarthestSample 8 | from pytorch_ops.losses.cd.cd import CDModule 9 | 10 | m_grp = 0 11 | 12 | def vrrotvec2mat(rotvector, angle): 13 | s = math.sin(angle) 14 | c = math.cos(angle) 15 | t = 1 - c 16 | x = rotvector[0] 17 | y = rotvector[1] 18 | z = rotvector[2] 19 | m = torch.FloatTensor( 20 | [[t * x * x + c, t * x * y - s * z, t * x * z + s * y], 21 | [t * x * y + s * z, t * y * y + c, t * y * z - s * x], 22 | [t * x * z - s * y, t * y * z + s * x, t * z * z + c]]) 23 | return m 24 | 25 | #segmentation for symmetric node 26 | def multilabel(points, shape, cdloss): 27 | c = torch.LongTensor(1, 2048).zero_() 28 | c = c - 1 29 | for i in range(points.size(0)): 30 | a = points[i].unsqueeze(0).cuda() 31 | _, index, _, _ = cdloss(a, shape) 32 | b = torch.unique(index.cpu()) 33 | for k in range(b.size(0)): 34 | c[0, b[k].item()] = i 35 | return c 36 | 37 | class Tree(object): 38 | class NodeType(Enum): 39 | LEAF = 0 # leaf node 40 | ADJ = 1 # adjacency (adjacent part assembly) node 41 | SYM = 2 # symmetry (symmetric part grouping) node 42 | SYM_ADJ = 3 #reflect 43 | 44 | class Node(object): 45 | def __init__(self, 46 | leaf_points=None, 47 | left=None, 48 | right=None, 49 | node_type=None, 50 | sym_p=None, 51 | sym_a=None, 52 | sym_t=None, 53 | semantic_label=None): 54 | self.leaf_points = leaf_points # node points 55 | if isinstance(sym_t, int): 56 | self.sym_t = torch.LongTensor([sym_t]) 57 | else: 58 | self.sym_t = None 59 | if isinstance(sym_a, int): 60 | self.sym_a = torch.LongTensor([sym_a]) 61 | else: 62 | self.sym_a = None 63 | self.sym_p = sym_p 64 | self.sym_type = self.sym_a 65 | self.left = left # left child for ADJ or SYM (a symmeter generator) 66 | self.right = right # right child 67 | self.node_type = node_type 68 | self.label = torch.LongTensor([self.node_type.value]) 69 | self.is_root = False 70 | self.semantic_label = semantic_label 71 | 72 | def is_leaf(self): 73 | return self.node_type == Tree.NodeType.LEAF 74 | 75 | def is_adj(self): 76 | return self.node_type == Tree.NodeType.ADJ 77 | 78 | def is_sym(self): 79 | return self.node_type == Tree.NodeType.SYM 80 | 81 | def is_sym_adj(self): 82 | return self.node_type == Tree.NodeType.SYM_ADJ 83 | 84 | def __init__(self, parts, ops, syms, labels, shape): 85 | parts_list = [p for p in torch.split(parts, 1, 0)] 86 | sym_param = [s for s in torch.split(syms, 1, 0)] 87 | part_labels = [s for s in torch.split(labels, 1, 0)] 88 | parts_list.reverse() 89 | sym_param.reverse() 90 | part_labels.reverse() 91 | queue = [] 92 | sym_node_num = 0 93 | for id in range(ops.size()[1]): 94 | if ops[0, id] == Tree.NodeType.LEAF.value: 95 | queue.append( 96 | Tree.Node(leaf_points=parts_list.pop(), node_type=Tree.NodeType.LEAF, semantic_label=part_labels.pop())) 97 | elif ops[0, id] == Tree.NodeType.ADJ.value: 98 | left_node = queue.pop() 99 | right_node = queue.pop() 100 | queue.append( 101 | Tree.Node( 102 | left=left_node, 103 | right=right_node, 104 | node_type=Tree.NodeType.ADJ)) 105 | elif ops[0, id] == Tree.NodeType.SYM.value: 106 | node = queue.pop() 107 | s = sym_param.pop() 108 | b = s[0, 0] + 1 109 | t = s[0, 7].item() 110 | p = s[0, 1:7] 111 | if t > 0: 112 | t = round(1.0/t) 113 | queue.append( 114 | Tree.Node( 115 | left=node, 116 | sym_p=p.unsqueeze(0), 117 | sym_a=int(b), 118 | sym_t=int(t), 119 | node_type=Tree.NodeType.SYM)) 120 | if b != 1: 121 | sym_node_num += 1 122 | assert len(queue) == 1 123 | self.root = queue[0] 124 | self.root.is_root = True 125 | assert self.root.is_adj() 126 | self.shape = shape 127 | if sym_node_num == 0: 128 | self.n_syms = torch.Tensor([sym_node_num]).cuda() 129 | else: 130 | self.n_syms = torch.Tensor([1/sym_node_num]).cuda() 131 | 132 | #find GT label's index in input 133 | def Attention(feature2048, shape): 134 | index = [] 135 | for i in range(shape.size(1)): 136 | if feature2048[0, i] > -1: 137 | index.append(i) 138 | pad_index = [] 139 | while len(pad_index) < 2048: 140 | pad_index.extend(index) 141 | pad_index = torch.LongTensor(pad_index[:2048]) 142 | return pad_index.unsqueeze(0).cpu() 143 | 144 | #construct groundtruth for pointcloud segmentation 145 | 146 | def dfs_fix(node, shape, cdloss, shape_normal, seg, grp, reflect=None): 147 | global m_grp 148 | if node.is_leaf(): 149 | # find node's corresponding points on input 150 | _, index, _ , _ = cdloss(node.leaf_points[:, :, :3].cuda(), shape) 151 | b = torch.unique(index.cpu()) 152 | c = torch.LongTensor(1, 2048).zero_() 153 | c = c - 1 154 | for i in range(b.size(0)): 155 | c[0, b[i].item()] = 0 156 | node.index = c #segmentation GT binary label 157 | idx = Attention(c, shape) #node's corresponding idx 158 | #node's corresponding points 159 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 160 | #node's corresponding idx 161 | node.pad_index = idx 162 | for i in range(node.pad_index.size(1)): 163 | seg[node.pad_index[0, i].item()] = node.semantic_label 164 | grp[node.pad_index[0, i].item()] = m_grp 165 | m_grp += 1 166 | if reflect is not None: 167 | #recover reflect's children 168 | re_leaf_points = torch.cat([node.leaf_points[:, :, :3], node.leaf_points[:, :, :3]+node.leaf_points[:, :, 3:]], 1) 169 | re_leaf_points = re_leaf_points.squeeze(0).cpu() 170 | sList = torch.split(reflect, 1, 0) 171 | ref_normal = torch.cat([sList[0], sList[1], sList[2]]) 172 | ref_normal = ref_normal / torch.norm(ref_normal) 173 | ref_point = torch.cat([sList[3], sList[4], sList[5]]) 174 | new_points = 2 * ref_point.add(-re_leaf_points).matmul(ref_normal) 175 | new_points = new_points.unsqueeze(-1) 176 | new_points = new_points.repeat(1, 3) 177 | new_points = ref_normal.mul(new_points).add(re_leaf_points) 178 | new_points = torch.cat([new_points[:2048, :], new_points[2048:, :] - new_points[:2048, :]], 1) 179 | New_node = Tree.Node(leaf_points=new_points.unsqueeze(0), node_type=Tree.NodeType.LEAF) 180 | #build node for reflect node's children 181 | _, index, _ , _ = cdloss(New_node.leaf_points[:, :, :3].cuda(), shape) 182 | b = torch.unique(index.cpu()) 183 | reflect_c = torch.LongTensor(1, 2048).zero_() 184 | reflect_c = reflect_c - 1 185 | for i in range(b.size(0)): 186 | reflect_c[0, b[i].item()] = 0 187 | New_node.index = reflect_c 188 | idx = Attention(reflect_c, shape) 189 | New_node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 190 | New_node.pad_index = idx 191 | New_node.semantic_label = node.semantic_label 192 | for i in range(New_node.pad_index.size(1)): 193 | seg[New_node.pad_index[0, i].item()] = New_node.semantic_label 194 | grp[New_node.pad_index[0, i].item()] = m_grp 195 | m_grp += 1 196 | return torch.Tensor([0]).cuda(), New_node 197 | else: 198 | return torch.Tensor([0]).cuda(), node 199 | if node.is_adj(): 200 | l_num, new_node_l = dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, reflect) 201 | r_num, new_node_r = dfs_fix(node.right, shape, cdloss, shape_normal, seg, grp, reflect) 202 | #build adj node 203 | c = torch.LongTensor(1, 2048).zero_() 204 | c = c - 1 205 | for i in range(2048): 206 | if node.left.index[0, i].item() > -1: 207 | c[0, i] = 0 208 | for i in range(2048): 209 | if node.right.index[0, i].item() > -1: 210 | c[0, i] = 1 211 | node.index = c 212 | idx = Attention(c, shape) 213 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 214 | node.pad_index = idx 215 | 216 | if reflect is not None: 217 | New_node = Tree.Node(left=new_node_l, right=new_node_r, node_type=Tree.NodeType.ADJ) 218 | reflect_c = torch.LongTensor(1, 2048).zero_() 219 | reflect_c = reflect_c - 1 220 | for i in range(2048): 221 | if new_node_l.index[0, i].item() > -1: 222 | reflect_c[0, i] = 0 223 | for i in range(2048): 224 | if new_node_r.index[0, i].item() > -1: 225 | reflect_c[0, i] = 1 226 | New_node.index = reflect_c 227 | idx = Attention(reflect_c, shape) 228 | New_node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 229 | New_node.pad_index = idx 230 | 231 | return l_num + r_num + torch.Tensor([2]).cuda(), New_node 232 | else: 233 | return l_num + r_num + torch.Tensor([1]).cuda(), node 234 | if node.is_sym(): 235 | #build symmetric node 236 | t = node.sym_t.item() 237 | p = node.sym_p.squeeze(0) 238 | 239 | if node.sym_type.item() == 1: #reflect node 240 | child_num, new_node = dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, p) 241 | 242 | c = torch.LongTensor(1, 2048).zero_() 243 | c = c - 1 244 | for i in range(2048): 245 | if node.left.index[0, i].item() > -1: 246 | c[0, i] = 0 247 | for i in range(2048): 248 | if new_node.index[0, i].item() > -1: 249 | c[0, i] = 1 250 | node.index = c 251 | node.right = new_node 252 | idx = Attention(c, shape) 253 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 254 | node.node_type = Tree.NodeType.SYM_ADJ 255 | node.label = torch.LongTensor([node.node_type.value]) 256 | node.pad_index = idx 257 | 258 | return child_num + torch.Tensor([1]).cuda(), node 259 | else: 260 | child_num, _= dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, None) 261 | new_leaf_points = node.left.leaf_points.squeeze(0) 262 | leaf_points_list = [new_leaf_points.unsqueeze(0)] 263 | 264 | new_leaf_points = torch.cat([new_leaf_points[:, :3] , new_leaf_points[:, :3] + new_leaf_points[:, 3:]], 0) 265 | 266 | if node.sym_type.item() == 0:#rotate symmetry 267 | sList = torch.split(p, 1, 0) 268 | f1 = torch.cat([sList[0], sList[1], sList[2]]) 269 | if f1[1] < 0: 270 | f1 = - f1 271 | f1 = f1 / torch.norm(f1) 272 | f2 = torch.cat([sList[3], sList[4], sList[5]]) 273 | folds = int(t) 274 | a = 1.0 / float(folds) 275 | for i in range(folds - 1): 276 | angle = a * 2 * 3.1415 * (i + 1) 277 | rotm = vrrotvec2mat(f1, angle) 278 | sym_leaf_points = rotm.matmul(new_leaf_points.add(-f2).t()).t().add(f2) 279 | sym_leaf_points = torch.cat([sym_leaf_points[:2048, :] , sym_leaf_points[2048:, :] - sym_leaf_points[:2048, :]], 1) 280 | leaf_points_list.append(sym_leaf_points.unsqueeze(0)) 281 | elif node.sym_type.item() == 2: #translate symmetry 282 | sList = torch.split(p, 1, 0) 283 | trans = torch.cat([sList[0], sList[1], sList[2]]) 284 | folds = t - 1 285 | for i in range(folds): 286 | sym_leaf_points = new_leaf_points.add(trans.mul(i + 1)) 287 | sym_leaf_points = torch.cat([sym_leaf_points[:2048, :] , sym_leaf_points[2048:, :] - sym_leaf_points[:2048, :]], 1) 288 | leaf_points_list.append(sym_leaf_points.unsqueeze(0)) 289 | 290 | a = torch.cat(leaf_points_list, 0) 291 | node.index = multilabel(a[:, :, :3], shape, cdloss) 292 | idx = Attention(node.index, shape) 293 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 294 | node.pad_index = Attention(node.index, shape) 295 | for i in range(node.pad_index.size(1)): 296 | seg[node.pad_index[0, i].item()] = node.left.semantic_label 297 | for i in range(2048): 298 | if node.index[0, i].item() > -1: 299 | grp[i] = m_grp + node.index[0, i] 300 | m_grp = m_grp + torch.max(node.index) + 1 301 | return torch.Tensor([1]).cuda(), node 302 | 303 | class Data_Loader(data.Dataset): 304 | def __init__(self, dir, is_train, split_num, total_num): 305 | self.dir = dir 306 | op_data = torch.from_numpy(loadmat(self.dir + 'training_trees/ops.mat')['ops']).int() 307 | label_data = torch.from_numpy(loadmat(self.dir + 'training_trees/labels.mat')['labels']).int() 308 | sym_data = torch.from_numpy(loadmat(self.dir + 'training_trees/syms.mat')['syms']).float() 309 | num_examples = op_data.size()[1] 310 | op_data = torch.chunk(op_data, num_examples, 1) 311 | label_data = torch.chunk(label_data, num_examples, 1) 312 | sym_data = torch.chunk(sym_data, num_examples, 1) 313 | self.trees = [] 314 | self.training = is_train 315 | if is_train: 316 | begin = 0 317 | end = split_num 318 | else: 319 | begin = split_num 320 | end = total_num 321 | for i in range(begin, end): 322 | parts = torch.from_numpy(loadmat(self.dir + 'training_data_models_segment_2048_normals/%d.mat' % i)['pc']).float() 323 | shape = torch.from_numpy(loadmat(self.dir + 'models_2048_points_normals/%d.mat' % i)['pc']).float() 324 | ops = torch.t(op_data[i]) 325 | syms = torch.t(sym_data[i]) 326 | labels = torch.t(label_data[i]) 327 | tree = Tree(parts, ops, syms, labels, shape) 328 | cdloss = CDModule() 329 | seg = torch.LongTensor(2048).zero_() # for ap calculation 330 | grp = torch.LongTensor(2048).zero_() 331 | global m_grp 332 | m_grp = 0 333 | num_node, _ = dfs_fix(tree.root, shape[0, :, :3].unsqueeze(0).cuda(), cdloss, shape, seg, grp) 334 | tree.n_nodes = num_node 335 | tree.shape_label = seg 336 | tree.grp = grp 337 | self.trees.append(tree) 338 | print('load data', i) 339 | print(len(self.trees)) 340 | 341 | def __getitem__(self, index): 342 | tree = self.trees[index] 343 | return tree 344 | 345 | def __len__(self): 346 | return len(self.trees) 347 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | from scipy.io import loadmat 4 | from enum import Enum 5 | from torch.autograd import Variable 6 | import math 7 | from pytorch_ops.sampling.sample import FarthestSample 8 | from pytorch_ops.losses.cd.cd import CDModule 9 | 10 | m_grp = 0 11 | 12 | def vrrotvec2mat(rotvector, angle): 13 | s = math.sin(angle) 14 | c = math.cos(angle) 15 | t = 1 - c 16 | x = rotvector[0] 17 | y = rotvector[1] 18 | z = rotvector[2] 19 | m = torch.FloatTensor( 20 | [[t * x * x + c, t * x * y - s * z, t * x * z + s * y], 21 | [t * x * y + s * z, t * y * y + c, t * y * z - s * x], 22 | [t * x * z - s * y, t * y * z + s * x, t * z * z + c]]) 23 | return m 24 | 25 | #segmentation for symmetric node 26 | def multilabel(points, shape, cdloss): 27 | c = torch.LongTensor(1, 2048).zero_() 28 | c = c - 1 29 | for i in range(points.size(0)): 30 | a = points[i].unsqueeze(0).cuda() 31 | _, index, _, _ = cdloss(a, shape) 32 | b = torch.unique(index.cpu()) 33 | for k in range(b.size(0)): 34 | c[0, b[k].item()] = i 35 | return c 36 | 37 | class Tree(object): 38 | class NodeType(Enum): 39 | LEAF = 0 # leaf node 40 | ADJ = 1 # adjacency (adjacent part assembly) node 41 | SYM = 2 # symmetry (symmetric part grouping) node 42 | SYM_ADJ = 3 #reflect 43 | 44 | class Node(object): 45 | def __init__(self, 46 | leaf_points=None, 47 | left=None, 48 | right=None, 49 | node_type=None, 50 | sym_p=None, 51 | sym_a=None, 52 | sym_t=None, 53 | semantic_label=None): 54 | self.leaf_points = leaf_points # node points 55 | if isinstance(sym_t, int): 56 | self.sym_t = torch.LongTensor([sym_t]) 57 | else: 58 | self.sym_t = None 59 | if isinstance(sym_a, int): 60 | self.sym_a = torch.LongTensor([sym_a]) 61 | else: 62 | self.sym_a = None 63 | self.sym_p = sym_p 64 | self.sym_type = self.sym_a 65 | self.left = left # left child for ADJ or SYM (a symmeter generator) 66 | self.right = right # right child 67 | self.node_type = node_type 68 | self.label = torch.LongTensor([self.node_type.value]) 69 | self.is_root = False 70 | self.semantic_label = semantic_label 71 | 72 | def is_leaf(self): 73 | return self.node_type == Tree.NodeType.LEAF 74 | 75 | def is_adj(self): 76 | return self.node_type == Tree.NodeType.ADJ 77 | 78 | def is_sym(self): 79 | return self.node_type == Tree.NodeType.SYM 80 | 81 | def is_sym_adj(self): 82 | return self.node_type == Tree.NodeType.SYM_ADJ 83 | 84 | def __init__(self, parts, ops, syms, labels, shape): 85 | parts_list = [p for p in torch.split(parts, 1, 0)] 86 | sym_param = [s for s in torch.split(syms, 1, 0)] 87 | part_labels = [s for s in torch.split(labels, 1, 0)] 88 | parts_list.reverse() 89 | sym_param.reverse() 90 | part_labels.reverse() 91 | queue = [] 92 | sym_node_num = 0 93 | for id in range(ops.size()[1]): 94 | if ops[0, id] == Tree.NodeType.LEAF.value: 95 | queue.append( 96 | Tree.Node(leaf_points=parts_list.pop(), node_type=Tree.NodeType.LEAF, semantic_label=part_labels.pop())) 97 | elif ops[0, id] == Tree.NodeType.ADJ.value: 98 | left_node = queue.pop() 99 | right_node = queue.pop() 100 | queue.append( 101 | Tree.Node( 102 | left=left_node, 103 | right=right_node, 104 | node_type=Tree.NodeType.ADJ)) 105 | elif ops[0, id] == Tree.NodeType.SYM.value: 106 | node = queue.pop() 107 | s = sym_param.pop() 108 | b = s[0, 0] + 1 109 | t = s[0, 7].item() 110 | p = s[0, 1:7] 111 | if t > 0: 112 | t = round(1.0/t) 113 | queue.append( 114 | Tree.Node( 115 | left=node, 116 | sym_p=p.unsqueeze(0), 117 | sym_a=int(b), 118 | sym_t=int(t), 119 | node_type=Tree.NodeType.SYM)) 120 | if b != 1: 121 | sym_node_num += 1 122 | assert len(queue) == 1 123 | self.root = queue[0] 124 | self.root.is_root = True 125 | assert self.root.is_adj() 126 | self.shape = shape 127 | if sym_node_num == 0: 128 | self.n_syms = torch.Tensor([sym_node_num]).cuda() 129 | else: 130 | self.n_syms = torch.Tensor([1/sym_node_num]).cuda() 131 | 132 | #find GT label's index in input 133 | def Attention(feature2048, shape): 134 | index = [] 135 | for i in range(shape.size(1)): 136 | if feature2048[0, i] > -1: 137 | index.append(i) 138 | pad_index = [] 139 | while len(pad_index) < 2048: 140 | pad_index.extend(index) 141 | pad_index = torch.LongTensor(pad_index[:2048]) 142 | return pad_index.unsqueeze(0).cpu() 143 | 144 | #construct groundtruth for pointcloud segmentation 145 | 146 | def dfs_fix(node, shape, cdloss, shape_normal, seg, grp, reflect=None): 147 | global m_grp 148 | if node.is_leaf(): 149 | # find node's corresponding points on input 150 | _, index, _ , _ = cdloss(node.leaf_points[:, :, :3].cuda(), shape) 151 | b = torch.unique(index.cpu()) 152 | c = torch.LongTensor(1, 2048).zero_() 153 | c = c - 1 154 | for i in range(b.size(0)): 155 | c[0, b[i].item()] = 0 156 | node.index = c #segmentation GT binary label 157 | idx = Attention(c, shape) #node's corresponding idx 158 | #node's corresponding points 159 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 160 | #node's corresponding idx 161 | node.pad_index = idx 162 | for i in range(node.pad_index.size(1)): 163 | seg[node.pad_index[0, i].item()] = node.semantic_label 164 | grp[node.pad_index[0, i].item()] = m_grp 165 | m_grp += 1 166 | if reflect is not None: 167 | #recover reflect's children 168 | re_leaf_points = torch.cat([node.leaf_points[:, :, :3], node.leaf_points[:, :, :3]+node.leaf_points[:, :, 3:]], 1) 169 | re_leaf_points = re_leaf_points.squeeze(0).cpu() 170 | sList = torch.split(reflect, 1, 0) 171 | ref_normal = torch.cat([sList[0], sList[1], sList[2]]) 172 | ref_normal = ref_normal / torch.norm(ref_normal) 173 | ref_point = torch.cat([sList[3], sList[4], sList[5]]) 174 | new_points = 2 * ref_point.add(-re_leaf_points).matmul(ref_normal) 175 | new_points = new_points.unsqueeze(-1) 176 | new_points = new_points.repeat(1, 3) 177 | new_points = ref_normal.mul(new_points).add(re_leaf_points) 178 | new_points = torch.cat([new_points[:2048, :], new_points[2048:, :] - new_points[:2048, :]], 1) 179 | New_node = Tree.Node(leaf_points=new_points.unsqueeze(0), node_type=Tree.NodeType.LEAF) 180 | #build node for reflect node's children 181 | _, index, _ , _ = cdloss(New_node.leaf_points[:, :, :3].cuda(), shape) 182 | b = torch.unique(index.cpu()) 183 | reflect_c = torch.LongTensor(1, 2048).zero_() 184 | reflect_c = reflect_c - 1 185 | for i in range(b.size(0)): 186 | reflect_c[0, b[i].item()] = 0 187 | New_node.index = reflect_c 188 | idx = Attention(reflect_c, shape) 189 | New_node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 190 | New_node.pad_index = idx 191 | New_node.semantic_label = node.semantic_label 192 | for i in range(New_node.pad_index.size(1)): 193 | seg[New_node.pad_index[0, i].item()] = New_node.semantic_label 194 | grp[New_node.pad_index[0, i].item()] = m_grp 195 | m_grp += 1 196 | return torch.Tensor([0]).cuda(), New_node 197 | else: 198 | return torch.Tensor([0]).cuda(), node 199 | if node.is_adj(): 200 | l_num, new_node_l = dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, reflect) 201 | r_num, new_node_r = dfs_fix(node.right, shape, cdloss, shape_normal, seg, grp, reflect) 202 | #build adj node 203 | c = torch.LongTensor(1, 2048).zero_() 204 | c = c - 1 205 | for i in range(2048): 206 | if node.left.index[0, i].item() > -1: 207 | c[0, i] = 0 208 | for i in range(2048): 209 | if node.right.index[0, i].item() > -1: 210 | c[0, i] = 1 211 | node.index = c 212 | idx = Attention(c, shape) 213 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 214 | node.pad_index = idx 215 | 216 | if reflect is not None: 217 | New_node = Tree.Node(left=new_node_l, right=new_node_r, node_type=Tree.NodeType.ADJ) 218 | reflect_c = torch.LongTensor(1, 2048).zero_() 219 | reflect_c = reflect_c - 1 220 | for i in range(2048): 221 | if new_node_l.index[0, i].item() > -1: 222 | reflect_c[0, i] = 0 223 | for i in range(2048): 224 | if new_node_r.index[0, i].item() > -1: 225 | reflect_c[0, i] = 1 226 | New_node.index = reflect_c 227 | idx = Attention(reflect_c, shape) 228 | New_node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 229 | New_node.pad_index = idx 230 | 231 | return l_num + r_num + torch.Tensor([2]).cuda(), New_node 232 | else: 233 | return l_num + r_num + torch.Tensor([1]).cuda(), node 234 | if node.is_sym(): 235 | #build symmetric node 236 | t = node.sym_t.item() 237 | p = node.sym_p.squeeze(0) 238 | 239 | if node.sym_type.item() == 1: #reflect node 240 | child_num, new_node = dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, p) 241 | 242 | c = torch.LongTensor(1, 2048).zero_() 243 | c = c - 1 244 | for i in range(2048): 245 | if node.left.index[0, i].item() > -1: 246 | c[0, i] = 0 247 | for i in range(2048): 248 | if new_node.index[0, i].item() > -1: 249 | c[0, i] = 1 250 | node.index = c 251 | node.right = new_node 252 | idx = Attention(c, shape) 253 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 254 | node.node_type = Tree.NodeType.SYM_ADJ 255 | node.label = torch.LongTensor([node.node_type.value]) 256 | node.pad_index = idx 257 | 258 | return child_num + torch.Tensor([1]).cuda(), node 259 | else: 260 | child_num, _= dfs_fix(node.left, shape, cdloss, shape_normal, seg, grp, None) 261 | new_leaf_points = node.left.leaf_points.squeeze(0) 262 | leaf_points_list = [new_leaf_points.unsqueeze(0)] 263 | 264 | new_leaf_points = torch.cat([new_leaf_points[:, :3] , new_leaf_points[:, :3] + new_leaf_points[:, 3:]], 0) 265 | 266 | if node.sym_type.item() == 0:#rotate symmetry 267 | sList = torch.split(p, 1, 0) 268 | f1 = torch.cat([sList[0], sList[1], sList[2]]) 269 | if f1[1] < 0: 270 | f1 = - f1 271 | f1 = f1 / torch.norm(f1) 272 | f2 = torch.cat([sList[3], sList[4], sList[5]]) 273 | folds = int(t) 274 | a = 1.0 / float(folds) 275 | for i in range(folds - 1): 276 | angle = a * 2 * 3.1415 * (i + 1) 277 | rotm = vrrotvec2mat(f1, angle) 278 | sym_leaf_points = rotm.matmul(new_leaf_points.add(-f2).t()).t().add(f2) 279 | sym_leaf_points = torch.cat([sym_leaf_points[:2048, :] , sym_leaf_points[2048:, :] - sym_leaf_points[:2048, :]], 1) 280 | leaf_points_list.append(sym_leaf_points.unsqueeze(0)) 281 | elif node.sym_type.item() == 2: #translate symmetry 282 | sList = torch.split(p, 1, 0) 283 | trans = torch.cat([sList[0], sList[1], sList[2]]) 284 | folds = t - 1 285 | trans = trans / float(folds) 286 | for i in range(folds): 287 | sym_leaf_points = new_leaf_points.add(trans.mul(i + 1)) 288 | sym_leaf_points = torch.cat([sym_leaf_points[:2048, :] , sym_leaf_points[2048:, :] - sym_leaf_points[:2048, :]], 1) 289 | leaf_points_list.append(sym_leaf_points.unsqueeze(0)) 290 | 291 | a = torch.cat(leaf_points_list, 0) 292 | node.index = multilabel(a[:, :, :3], shape, cdloss) 293 | idx = Attention(node.index, shape) 294 | node.points = torch.index_select(shape_normal, 1, idx.squeeze(0).long().cpu()) 295 | node.pad_index = Attention(node.index, shape) 296 | for i in range(node.pad_index.size(1)): 297 | seg[node.pad_index[0, i].item()] = node.left.semantic_label 298 | for i in range(2048): 299 | if node.index[0, i].item() > -1: 300 | grp[i] = m_grp + node.index[0, i] 301 | m_grp = m_grp + torch.max(node.index) + 1 302 | return torch.Tensor([1]).cuda(), node 303 | 304 | class Data_Loader(data.Dataset): 305 | def __init__(self, dir, is_train, split_num, total_num): 306 | self.dir = dir 307 | op_data = torch.from_numpy(loadmat(self.dir + 'training_trees/ops.mat')['ops']).int() 308 | label_data = torch.from_numpy(loadmat(self.dir + 'training_trees/labels.mat')['labels']).int() 309 | sym_data = torch.from_numpy(loadmat(self.dir + 'training_trees/syms.mat')['syms']).float() 310 | num_examples = op_data.size()[1] 311 | op_data = torch.chunk(op_data, num_examples, 1) 312 | label_data = torch.chunk(label_data, num_examples, 1) 313 | sym_data = torch.chunk(sym_data, num_examples, 1) 314 | self.trees = [] 315 | self.training = is_train 316 | if is_train: 317 | begin = 0 318 | end = split_num 319 | else: 320 | begin = split_num 321 | end = total_num 322 | for i in range(begin, end): 323 | parts = torch.from_numpy(loadmat(self.dir + 'training_data_models_segment_2048_normals/%d.mat' % i)['pc']).float() 324 | shape = torch.from_numpy(loadmat(self.dir + 'models_2048_points_normals/%d.mat' % i)['pc']).float() 325 | ops = torch.t(op_data[i]) 326 | syms = torch.t(sym_data[i]) 327 | labels = torch.t(label_data[i]) 328 | tree = Tree(parts, ops, syms, labels, shape) 329 | cdloss = CDModule() 330 | seg = torch.LongTensor(2048).zero_() # for ap calculation 331 | grp = torch.LongTensor(2048).zero_() 332 | global m_grp 333 | m_grp = 0 334 | num_node, _ = dfs_fix(tree.root, shape[0, :, :3].unsqueeze(0).cuda(), cdloss, shape, seg, grp) 335 | tree.n_nodes = num_node 336 | tree.shape_label = seg 337 | tree.grp = grp 338 | self.trees.append(tree) 339 | print('load data', i) 340 | print(len(self.trees)) 341 | 342 | def __getitem__(self, index): 343 | tree = self.trees[index] 344 | return tree 345 | 346 | def __len__(self): 347 | return len(self.trees) 348 | -------------------------------------------------------------------------------- /partnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.autograd import Variable 5 | from time import time 6 | import torch.nn.functional as F 7 | from pytorch_ops.losses.cd.cd import CDModule 8 | from pytorch_ops.losses.emd.emd import EMDModule 9 | from pytorch_ops.sampling.sample import FarthestSample 10 | import pointnet2 as Pointnet 11 | ######################################################################################### 12 | # Encoder 13 | ######################################################################################### 14 | 15 | 16 | class PCEncoder(nn.Module): 17 | def __init__(self, bottleneck=128): 18 | super(PCEncoder, self).__init__() 19 | self.conv1 = nn.Conv2d(1, 64, (1, 6), 1) 20 | self.bn1 = nn.BatchNorm2d(64) 21 | self.conv2 = nn.Conv2d(64, 128, 1, 1) 22 | self.bn2 = nn.BatchNorm2d(128) 23 | self.conv3 = nn.Conv2d(128, 128, 1, 1) 24 | self.bn3 = nn.BatchNorm2d(128) 25 | self.conv4 = nn.Conv2d(128, 256, 1, 1) 26 | self.bn4 = nn.BatchNorm2d(256) 27 | self.conv5 = nn.Conv2d(256, bottleneck, 1, 1) 28 | self.bn5 = nn.BatchNorm2d(bottleneck) 29 | self.dropout2d = nn.Dropout2d(p=0.2) 30 | 31 | def forward(self, input): 32 | input = input.unsqueeze(1) 33 | input = self.conv1(input) 34 | input = self.bn1(input) 35 | input = F.relu(input) 36 | input = self.conv2(input) 37 | input = self.bn2(input) 38 | input = F.relu(input) 39 | input = self.conv3(input) 40 | input = self.bn3(input) 41 | input = F.relu(input) 42 | input = self.conv4(input) 43 | input = self.dropout2d(input) 44 | input = self.bn4(input) 45 | input = F.relu(input) 46 | input = self.conv5(input) 47 | input = self.dropout2d(input) 48 | input = self.bn5(input) 49 | input = F.relu(input) 50 | input, _ = torch.max(input, 2) 51 | input = input.squeeze(-1) 52 | return input 53 | 54 | 55 | ######################################################################################### 56 | # Decoder 57 | ######################################################################################### 58 | 59 | class NodeClassifier(nn.Module): 60 | def __init__(self, feature_size, hidden_size): 61 | super(NodeClassifier, self).__init__() 62 | self.mlp1 = nn.Linear(feature_size*2, hidden_size) 63 | self.tanh = nn.Tanh() 64 | self.mlp2 = nn.Linear(hidden_size, 4) 65 | 66 | def forward(self, input_feature): 67 | output = self.mlp1(input_feature) 68 | output = self.tanh(output) 69 | output = self.mlp2(output) 70 | return output 71 | 72 | 73 | class AdjDecoder(nn.Module): 74 | """ Decode an input (parent) feature into a left-child and a right-child feature """ 75 | 76 | def __init__(self, feature_size, hidden_size): 77 | super(AdjDecoder, self).__init__() 78 | self.mlp = nn.Linear(feature_size*2, hidden_size) 79 | self.mlp_left = nn.Linear(hidden_size, feature_size) 80 | self.mlp_right = nn.Linear(hidden_size, feature_size) 81 | self.tanh = nn.Tanh() 82 | 83 | def forward(self, parent_feature): 84 | vector = self.mlp(parent_feature) 85 | vector = self.tanh(vector) 86 | left_feature = self.mlp_left(vector) 87 | left_feature = self.tanh(left_feature) 88 | right_feature = self.mlp_right(vector) 89 | right_feature = self.tanh(right_feature) 90 | return left_feature, right_feature 91 | 92 | 93 | class SymDecoder(nn.Module): 94 | def __init__(self, feature_size, hidden_size): 95 | super(SymDecoder, self).__init__() 96 | self.mlp = nn.Linear(feature_size*2, hidden_size) # layer for decoding a feature vector 97 | self.tanh = nn.Tanh() 98 | self.mlp_sg = nn.Linear(hidden_size, feature_size) # layer for outputing the feature of symmetry generator 99 | self.mlp_sp = nn.Linear(hidden_size, hidden_size) # layer for outputing the vector of symmetry parameter 100 | self.mlp_s1 = nn.Linear(hidden_size, 3)# symmetric label 101 | self.mlp_s2 = nn.Linear(hidden_size, 6)# symmetric parameter 102 | self.mlp_s3 = nn.Linear(hidden_size, 9)# max symmetric number 103 | 104 | def forward(self, parent_feature): 105 | vector = self.mlp(parent_feature) 106 | vector = self.tanh(vector) 107 | sym_gen_vector = self.mlp_sg(vector) 108 | sym_gen_vector = self.tanh(sym_gen_vector) 109 | sym_param_vector = self.mlp_sp(vector) 110 | sym_param_vector = self.tanh(sym_param_vector) 111 | sym_label_vector = self.mlp_s1(sym_param_vector) 112 | sym_vector = self.mlp_s2(sym_param_vector) 113 | sym_time_vector = self.mlp_s3(sym_param_vector) 114 | return sym_gen_vector, sym_label_vector, sym_vector, sym_time_vector 115 | 116 | class PointPrediction(nn.Module): 117 | def __init__(self, feature_size=128): 118 | super().__init__() 119 | self.conv1 = nn.Conv2d(128+feature_size*3, 256, 1, 1) 120 | #self.bn1 = nn.BatchNorm2d(256) 121 | self.conv2 = nn.Conv2d(256, 128, 1, 1) 122 | #self.bn2 = nn.BatchNorm2d(128) 123 | self.conv3 = nn.Conv2d(128, 64, 1, 1) 124 | #self.bn3 = nn.BatchNorm2d(64) 125 | self.conv4 = nn.Conv2d(64, 64, 1, 1) 126 | #self.bn4 = nn.BatchNorm2d(64) 127 | self.conv5 = nn.Conv2d(64, 2, 1, 1) 128 | #self.bn5 = nn.BatchNorm2d(2) 129 | self.relu = nn.ReLU() 130 | self.logsoftmax = nn.LogSoftmax(1) 131 | 132 | def forward(self, points_feature, inp_feature): 133 | #points_feature = points.transpose(1, 2) 134 | output = torch.cat([points_feature, inp_feature.unsqueeze(-1).repeat(1, 1, points_feature.size(2))], 1).unsqueeze(-1) 135 | output = self.conv1(output) 136 | #output = self.bn1(output) 137 | output = self.relu(output) 138 | output = self.conv2(output) 139 | # output = self.bn2(output) 140 | output = self.relu(output) 141 | output = self.conv3(output) 142 | #output = self.bn3(output) 143 | output = self.relu(output) 144 | output = self.conv4(output) 145 | #output = self.bn4(output) 146 | output4 = self.relu(output) 147 | output = self.conv5(output4) 148 | #output = self.bn5(output) 149 | output = self.logsoftmax(output) 150 | maxpool, _ = torch.max(output4, 2) 151 | return output.squeeze(-1), maxpool.squeeze(-1) 152 | 153 | class PointPredictionMulti(nn.Module): 154 | def __init__(self, feature_size=128): 155 | super().__init__() 156 | self.conv1 = nn.Conv2d(128+feature_size*3, 256, 1, 1) 157 | #self.bn1 = nn.BatchNorm2d(256) 158 | self.conv2 = nn.Conv2d(256, 128, 1, 1) 159 | #self.bn2 = nn.BatchNorm2d(128) 160 | self.conv3 = nn.Conv2d(128, 64, 1, 1) 161 | #self.bn3 = nn.BatchNorm2d(64) 162 | self.conv4 = nn.Conv2d(64, 64, 1, 1) 163 | #self.bn4 = nn.BatchNorm2d(64) 164 | self.conv5 = nn.Conv2d(64, 10, 1, 1) 165 | #self.bn5 = nn.BatchNorm2d(2) 166 | self.relu = nn.ReLU() 167 | self.logsoftmax = nn.LogSoftmax(1) 168 | 169 | def forward(self, points_feature, inp_feature): 170 | #points_feature = points.transpose(1, 2) 171 | output = torch.cat([points_feature, inp_feature.unsqueeze(-1).repeat(1, 1, points_feature.size(2))], 1).unsqueeze(-1) 172 | output = self.conv1(output) 173 | #output = self.bn1(output) 174 | output = self.relu(output) 175 | output = self.conv2(output) 176 | # output = self.bn2(output) 177 | output = self.relu(output) 178 | output = self.conv3(output) 179 | #output = self.bn3(output) 180 | output = self.relu(output) 181 | output = self.conv4(output) 182 | #output = self.bn4(output) 183 | output4 = self.relu(output) 184 | output = self.conv5(output4) 185 | #output = self.bn5(output) 186 | output = self.logsoftmax(output) 187 | maxpool, _ = torch.max(output4, 2) 188 | return output.squeeze(-1), maxpool.squeeze(-1) 189 | 190 | class PARTNETDecoder(nn.Module): 191 | def __init__(self, config, decoder_param_path=None): 192 | super(PARTNETDecoder, self).__init__() 193 | self.loc_points_predictor = PointPrediction(feature_size=config.feature_size) 194 | self.loc_points_predictor_multi = PointPredictionMulti(feature_size=config.feature_size) 195 | self.pc_encoder = PCEncoder(bottleneck=128) 196 | self.adj_decoder = AdjDecoder(feature_size=config.feature_size, hidden_size=config.hidden_size) 197 | self.sym_adj_decoder = AdjDecoder(feature_size=config.feature_size, hidden_size=config.hidden_size) 198 | self.sym_decoder = SymDecoder(feature_size=config.feature_size, hidden_size=config.hidden_size) 199 | self.node_classifier = NodeClassifier(feature_size=config.feature_size, hidden_size=config.hidden_size) 200 | 201 | class PARTNET(nn.Module): 202 | def __init__(self, config, encoder_param_path=None, decoder_param_path=None): 203 | super(PARTNET, self).__init__() 204 | self.pointnet = Pointnet.Encoder() 205 | self.decoder = PARTNETDecoder(config, decoder_param_path) 206 | # pytorch's mean squared error loss 207 | self.mseLoss = nn.MSELoss(reduce=False) 208 | self.nllloss = nn.NLLLoss(reduce=False) 209 | self.creLoss = nn.CrossEntropyLoss(reduce=False) 210 | self.cdloss = CDModule() 211 | self.emdloss = EMDModule() 212 | self.sample = FarthestSample(256) 213 | 214 | def pcEncoder(self, points): 215 | return self.decoder.pc_encoder(points) 216 | 217 | def adjDecoder(self, feature): 218 | return self.decoder.adj_decoder(feature) 219 | 220 | def symadjDecoder(self, feature): 221 | return self.decoder.sym_adj_decoder(feature) 222 | 223 | def symDecoder(self, feature): 224 | return self.decoder.sym_decoder(feature) 225 | 226 | def nodeClassifier(self, feature): 227 | return self.decoder.node_classifier(feature) 228 | 229 | def symTimeLossEstimator(self, sym_time, gt_sym_time): 230 | return self.creLoss(sym_time, gt_sym_time).mul_(30) 231 | 232 | def symLabelLossEstimator(self, sym_label, gt_sym_label): 233 | return self.creLoss(sym_label, gt_sym_label).mul_(30) 234 | 235 | def symLossEstimator(self, sym_param, gt_sym_param): 236 | return torch.mean(self.mseLoss(sym_param, gt_sym_param).mul_(30), 1) 237 | 238 | def classifyLossEstimator(self, label_vector, gt_label_vector): 239 | a = self.creLoss(label_vector, gt_label_vector).mul(30) # 20 240 | return a 241 | 242 | def vectorAdder(self, v1, v2): 243 | return v1.add_(v2) 244 | 245 | def vectorAdder3(self, v1, v2, v3): 246 | return v1.add_(v2).add_(v3) 247 | 248 | def vectorAdder4(self, v1, v2, v3, v4): 249 | return v1.add_(v2).add_(v3).add_(v4) 250 | 251 | def vectorzero(self): 252 | temp = torch.zeros(1) 253 | return Variable(temp.cuda()) 254 | 255 | def feature_concat2(self, f1, f2): 256 | return torch.cat([f1, f2], 1) 257 | 258 | def feature_concat3(self, f1, f2, f3): 259 | return torch.cat([f1, f2, f3], 1) 260 | 261 | def locPointsPredic(self, shape, feature, pad_index, gt): 262 | newf = [] 263 | newl = [] 264 | for i in range(pad_index.size(0)): 265 | new_feature2= torch.index_select(shape[i].unsqueeze(0), 2, pad_index[i]) 266 | new_node_label2 = torch.index_select(gt[i].unsqueeze(0), 1, pad_index[i]) 267 | newf.append(new_feature2) 268 | newl.append(new_node_label2) 269 | 270 | newf = torch.cat(newf, 0) 271 | newf_max, _ = torch.max(newf, 2) 272 | feature = torch.cat([newf_max, feature], 1) 273 | gene_label, last2feature = self.decoder.loc_points_predictor(newf, feature) 274 | newl = torch.cat(newl, 0) 275 | 276 | loss = torch.mean(self.nllloss(gene_label, newl), 1).mul_(30) 277 | _, index = torch.max(gene_label, 1) 278 | acc = torch.sum(torch.eq(index, newl), 1).float() 279 | return loss, acc, last2feature 280 | 281 | def locPointsPredic_multi(self, shape, feature, pad_index, gt): 282 | newf = [] 283 | newl = [] 284 | for i in range(pad_index.size(0)): 285 | new_feature2= torch.index_select(shape[i].unsqueeze(0), 2, pad_index[i]) 286 | new_node_label2 = torch.index_select(gt[i].unsqueeze(0), 1, pad_index[i]) 287 | newf.append(new_feature2) 288 | newl.append(new_node_label2) 289 | 290 | newf = torch.cat(newf, 0) 291 | newf_max, _ = torch.max(newf, 2) 292 | feature = torch.cat([newf_max, feature], 1) 293 | gene_label, last2feature = self.decoder.loc_points_predictor_multi(newf, feature) 294 | newl = torch.cat(newl, 0) 295 | 296 | loss = torch.mean(self.nllloss(gene_label, newl), 1).mul_(30) 297 | _, index = torch.max(gene_label, 1) 298 | acc = torch.sum(torch.eq(index, newl), 1).float() 299 | 300 | return loss, acc, last2feature 301 | 302 | def jitter(shape): 303 | input_data = shape 304 | jitter_points = torch.randn(input_data.size()) 305 | jitter_points = torch.clamp(0.01*jitter_points, min=-0.05, max=0.05) 306 | jitter_points += input_data 307 | return jitter_points 308 | 309 | def decode_structure_fold(fold, tree, points_f): 310 | def decode_node(node, feature): 311 | if node.is_leaf(): 312 | input_data = jitter(node.points) 313 | local_f = fold.add('pcEncoder', input_data) 314 | feature_c = fold.add('feature_concat2', feature, local_f) 315 | label = fold.add('nodeClassifier', feature_c) 316 | label_loss = fold.add('classifyLossEstimator', label, node.label) 317 | return label_loss, fold.add('vectorzero'), fold.add('vectorzero') 318 | 319 | elif node.is_adj(): 320 | input_data = jitter(node.points) 321 | local_f = fold.add('pcEncoder', input_data) 322 | feature_c = fold.add('feature_concat2', feature, local_f) 323 | node_segloss, acc, last2feature = fold.add('locPointsPredic', points_f, feature_c, node.pad_index, node.index).split(3) 324 | left, right = fold.add('adjDecoder', feature_c).split(2) 325 | left_label_loss, left_segloss, left_acc = decode_node(node.left, left) 326 | right_label_loss, right_segloss, right_acc = decode_node(node.right, right) 327 | label = fold.add('nodeClassifier', feature_c) 328 | label_loss = fold.add('classifyLossEstimator', label, node.label) 329 | child_label_loss = fold.add('vectorAdder', left_label_loss, right_label_loss) 330 | node_label_loss = fold.add('vectorAdder', child_label_loss, label_loss) 331 | child_segloss = fold.add('vectorAdder', left_segloss, right_segloss) 332 | child_acc = fold.add('vectorAdder', left_acc, right_acc) 333 | node_segloss = fold.add('vectorAdder', child_segloss, node_segloss) 334 | acc = fold.add('vectorAdder', acc, child_acc) 335 | return node_label_loss, node_segloss, acc 336 | 337 | elif node.is_sym(): 338 | input_data = jitter(node.points) 339 | local_f = fold.add('pcEncoder', input_data) 340 | feature_c = fold.add('feature_concat2', feature, local_f) 341 | node_segloss, acc, last2feature = fold.add('locPointsPredic_multi', points_f, feature_c, node.pad_index, node.index).split(3) 342 | label = fold.add('nodeClassifier', feature_c) 343 | node_label_loss = fold.add('classifyLossEstimator', label, node.label) 344 | return node_label_loss, node_segloss, acc 345 | 346 | elif node.is_sym_adj(): 347 | input_data = jitter(node.points) 348 | local_f = fold.add('pcEncoder', input_data) 349 | feature_c = fold.add('feature_concat2', feature, local_f) 350 | node_segloss, acc, last2feature = fold.add('locPointsPredic', points_f, feature_c, node.pad_index, node.index).split(3) 351 | left, right = fold.add('adjDecoder', feature_c).split(2) 352 | left_label_loss, left_segloss, left_acc = decode_node(node.left, left) 353 | right_label_loss, right_segloss, right_acc = decode_node(node.right, right) 354 | label = fold.add('nodeClassifier', feature_c) 355 | label_loss = fold.add('classifyLossEstimator', label, node.label) 356 | child_label_loss = fold.add('vectorAdder', left_label_loss, right_label_loss) 357 | node_label_loss = fold.add('vectorAdder', child_label_loss, label_loss) 358 | child_segloss = fold.add('vectorAdder', left_segloss, right_segloss) 359 | child_acc = fold.add('vectorAdder', left_acc, right_acc) 360 | node_segloss = fold.add('vectorAdder', child_segloss, node_segloss) 361 | acc = fold.add('vectorAdder', acc, child_acc) 362 | return node_label_loss, node_segloss, acc 363 | 364 | input_data = jitter(tree.root.points) 365 | local_f = fold.add('pcEncoder', input_data) 366 | node_label_loss, node_segloss, acc = decode_node(tree.root, local_f) 367 | return node_label_loss, node_segloss, acc --------------------------------------------------------------------------------