├── Chamfer3D ├── __pycache__ │ └── dist_chamfer_3D.cpython-37.pyc ├── chamfer3D.cu ├── chamfer_cuda.cpp ├── dist_chamfer_3D.py └── setup.py ├── LICENSE ├── PMPPlus-Jittor ├── README.md ├── config_c3d.py ├── config_pcn.py ├── core │ ├── __init__.py │ ├── chamfer.py │ ├── inference_c3d.py │ ├── inference_pcn.py │ ├── test_c3d.py │ ├── test_pcn.py │ ├── train_c3d.py │ └── train_pcn.py ├── datasets │ ├── Completion3D.json │ ├── KITTI.json │ └── ShapeNet.json ├── main_c3d.py ├── main_pcn.py ├── models │ ├── __init__.py │ ├── misc │ │ ├── __pycache__ │ │ │ └── ops.cpython-37.pyc │ │ ├── layers.py │ │ ├── ops.py │ │ ├── pointconv_utils.py │ │ └── utils.py │ ├── model.py │ ├── pointnet2_partseg.py │ └── transformers.py ├── requirements.txt └── utils │ ├── __init__.py │ ├── average_meter.py │ ├── data_loaders.py │ ├── data_transforms.py │ ├── helpers.py │ ├── io.py │ └── metrics.py ├── README.md ├── config_c3d.py ├── config_pcn.py ├── core ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── inference_c3d.cpython-37.pyc │ ├── inference_pcn.cpython-37.pyc │ ├── test_c3d.cpython-37.pyc │ ├── test_pcn.cpython-37.pyc │ ├── train_c3d.cpython-37.pyc │ └── train_pcn.cpython-37.pyc ├── inference_c3d.py ├── inference_pcn.py ├── test_c3d.py ├── test_pcn.py ├── train_c3d.py └── train_pcn.py ├── datasets ├── Completion3D.json ├── KITTI.json └── ShapeNet.json ├── main_c3d.py ├── main_pcn.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── model.cpython-37.pyc │ ├── transformer.cpython-37.pyc │ └── utils.cpython-37.pyc ├── model.py ├── transformer.py └── utils.py ├── pics └── network.png ├── pointnet2_ops_lib ├── pointnet2_ops │ ├── __init__.py │ ├── _ext-src │ │ ├── include │ │ │ ├── ball_query.h │ │ │ ├── cuda_utils.h │ │ │ ├── group_points.h │ │ │ ├── interpolate.h │ │ │ ├── sampling.h │ │ │ └── utils.h │ │ └── src │ │ │ ├── ball_query.cpp │ │ │ ├── ball_query_gpu.cu │ │ │ ├── bindings.cpp │ │ │ ├── group_points.cpp │ │ │ ├── group_points_gpu.cu │ │ │ ├── interpolate.cpp │ │ │ ├── interpolate_gpu.cu │ │ │ ├── sampling.cpp │ │ │ └── sampling_gpu.cu │ ├── _version.py │ ├── pointnet2_modules.py │ └── pointnet2_utils.py └── setup.py ├── requirements.txt └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── average_meter.cpython-37.pyc ├── data_loaders.cpython-37.pyc ├── data_transforms.cpython-37.pyc ├── helpers.cpython-37.pyc ├── io.cpython-37.pyc └── metrics.cpython-37.pyc ├── average_meter.py ├── data_loaders.py ├── data_transforms.py ├── helpers.py ├── io.py └── metrics.py /Chamfer3D/__pycache__/dist_chamfer_3D.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/Chamfer3D/__pycache__/dist_chamfer_3D.cpython-37.pyc -------------------------------------------------------------------------------- /Chamfer3D/chamfer3D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /Chamfer3D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /Chamfer3D/dist_chamfer_3D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_3D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 3D") 10 | 11 | from torch.utils.cpp_extension import load 12 | chamfer_3D = load(name="chamfer_3D", 13 | sources=[ 14 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 15 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), 16 | ]) 17 | print("Loaded JIT 3D CUDA chamfer distance") 18 | 19 | else: 20 | import chamfer_3D 21 | print("Loaded compiled 3D CUDA chamfer distance") 22 | 23 | 24 | # Chamfer's distance module @thibaultgroueix 25 | # GPU tensors only 26 | class chamfer_3DFunction(Function): 27 | @staticmethod 28 | def forward(ctx, xyz1, xyz2): 29 | batchsize, n, _ = xyz1.size() 30 | _, m, _ = xyz2.size() 31 | device = xyz1.device 32 | 33 | dist1 = torch.zeros(batchsize, n) 34 | dist2 = torch.zeros(batchsize, m) 35 | 36 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 37 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 38 | 39 | dist1 = dist1.to(device) 40 | dist2 = dist2.to(device) 41 | idx1 = idx1.to(device) 42 | idx2 = idx2.to(device) 43 | torch.cuda.set_device(device) 44 | 45 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 46 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 47 | return dist1, dist2, idx1, idx2 48 | 49 | @staticmethod 50 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 51 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 52 | graddist1 = graddist1.contiguous() 53 | graddist2 = graddist2.contiguous() 54 | device = graddist1.device 55 | 56 | gradxyz1 = torch.zeros(xyz1.size()) 57 | gradxyz2 = torch.zeros(xyz2.size()) 58 | 59 | gradxyz1 = gradxyz1.to(device) 60 | gradxyz2 = gradxyz2.to(device) 61 | chamfer_3D.backward( 62 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 63 | ) 64 | return gradxyz1, gradxyz2 65 | 66 | 67 | class chamfer_3DDist(nn.Module): 68 | def __init__(self): 69 | super(chamfer_3DDist, self).__init__() 70 | 71 | def forward(self, input1, input2): 72 | input1 = input1.contiguous() 73 | input2 = input2.contiguous() 74 | return chamfer_3DFunction.apply(input1, input2) 75 | 76 | -------------------------------------------------------------------------------- /Chamfer3D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_3D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_3D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 X.Wen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/README.md: -------------------------------------------------------------------------------- 1 | # PMP-Net++: Point Cloud Completion by Transformer-Enhanced Multi-step Point Moving Paths (TPAMI 2022) (Jittor Implementation) 2 | 3 | [Intro pic](../pics/network.png) 4 | 5 | 6 | ## [PMP-Net++] 7 | This repository contains the [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/) implementation of the papers: 8 | 9 | **1. PMP-Net++: Point Cloud Completion by Transformer-Enhanced Multi-step Point Moving Paths, TPAMI 2022** 10 | 11 | **2. PMP-Net: Point Cloud Completion by Learning Multi-step Point Moving Paths, CVPR 2021** 12 | 13 | The **Jittor** implementation on different datasets: 14 | | Model | Completion3d | PCN dataset | 15 | | ----------- | -------------- | ------------ | 16 | | PMP-Net | √ | √ | 17 | | PMP-Net++ | √ | √ | 18 | 19 | [ [PMP-Net](https://arxiv.org/abs/2012.03408) | [PMP-Net++](https://arxiv.org/abs/2012.03408) | [IEEE Xplore](https://ieeexplore.ieee.org/document/9735342) | [Webpage]() | [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/) ] 20 | 21 | > Point cloud completion concerns to predict missing part for incomplete 3D shapes. A common strategy is to generate 22 | complete shape according to incomplete input. However, unordered nature of point clouds will degrade generation of high-quality 3D 23 | shapes, as detailed topology and structure of unordered points are hard to be captured during the generative process using an 24 | extracted latent code. We address this problem by formulating completion as point cloud deformation process. Specifically, we design a 25 | novel neural network, named PMP-Net++, to mimic behavior of an earth mover. It moves each point of incomplete input to obtain a 26 | complete point cloud, where total distance of point moving paths (PMPs) should be the shortest. Therefore, PMP-Net++ predicts 27 | unique PMP for each point according to constraint of point moving distances. The network learns a strict and unique correspondence 28 | on point-level, and thus improves quality of predicted complete shape. Moreover, since moving points heavily relies on per-point 29 | features learned by network, we further introduce a transformer-enhanced representation learning network, which significantly 30 | improves completion performance of PMP-Net++. We conduct comprehensive experiments in shape completion, and further explore 31 | application on point cloud up-sampling, which demonstrate non-trivial improvement of PMP-Net++ over state-of-the-art point cloud 32 | completion/up-sampling methods 33 | 34 | ## [Cite this work] 35 | 36 | ``` 37 | @ARTICLE{pmpnet++, 38 | author={Wen, Xin and Xiang, Peng and Han, Zhizhong and Cao, Yan-Pei and Wan, Pengfei and Zheng, Wen and Liu, Yu-Shen}, 39 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 40 | title={PMP-Net++: Point Cloud Completion by Transformer-Enhanced Multi-step Point Moving Paths}, 41 | year={2022}, 42 | volume={}, 43 | number={}, 44 | pages={1-1}, 45 | doi={10.1109/TPAMI.2022.3159003} 46 | } 47 | 48 | @inproceedings{wen2021pmp, 49 | title={PMP-Net: Point cloud completion by learning multi-step point moving paths}, 50 | author={Wen, Xin and Xiang, Peng and Han, Zhizhong and Cao, Yan-Pei and Wan, Pengfei and Zheng, Wen and Liu, Yu-Shen}, 51 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 52 | year={2021} 53 | } 54 | ``` 55 | 56 | ## [Getting Started] 57 | #### Datasets and Pretrained Models 58 | 59 | We use the [PCN](https://www.shapenet.org/) and [Compeletion3D](http://completion3d.stanford.edu/) datasets in our experiments, which are available below: 60 | 61 | - [PCN](https://drive.google.com/drive/folders/1P_W1tz5Q4ZLapUifuOE4rFAZp6L1XTJz) 62 | - [Completion3D](http://download.cs.stanford.edu/downloads/completion3d/dataset2019.zip) 63 | 64 | 65 | #### Install Python Denpendencies 66 | 67 | ``` 68 | conda create -n pmp python=3.7 69 | conda activate pmp 70 | pip3 install -r requirements.txt 71 | 72 | python3.7 -m pip install jittor 73 | python3.7 -m jittor.test.test_example 74 | python3.7 -m jittor.test.test_cudnn_op 75 | # more information about Jittor can be found at https://cg.cs.tsinghua.edu.cn/jittor/ 76 | ``` 77 | 78 | You need to update the file path of the datasets: 79 | 80 | ``` 81 | __C.DATASETS.COMPLETION3D.PARTIAL_POINTS_PATH = '/path/to/datasets/Completion3D/%s/partial/%s/%s.h5' 82 | __C.DATASETS.COMPLETION3D.COMPLETE_POINTS_PATH = '/path/to/datasets/Completion3D/%s/gt/%s/%s.h5' 83 | __C.DATASETS.SHAPENET.PARTIAL_POINTS_PATH = '/path/to/datasets/ShapeNet/ShapeNetCompletion/%s/partial/%s/%s/%02d.pcd' 84 | __C.DATASETS.SHAPENET.COMPLETE_POINTS_PATH = '/path/to/datasets/ShapeNet/ShapeNetCompletion/%s/complete/%s/%s.pcd' 85 | 86 | # Dataset Options: Completion3D, Completion3DPCCT, ShapeNet, ShapeNetCars 87 | __C.DATASET.TRAIN_DATASET = 'ShapeNet' 88 | __C.DATASET.TEST_DATASET = 'ShapeNet' 89 | ``` 90 | 91 | #### Training, Testing and Inference 92 | 93 | To train PMP-Net++ or PMP-Net, you can simply use the following command: 94 | 95 | ``` 96 | python main_*.py # remember to change '*' to 'c3d' or 'pcn', and change between 'import PMPNetPlus' and 'import PMPNet' 97 | ``` 98 | 99 | To test or inference, you should specify the path of checkpoint if the config_*.py file 100 | ``` 101 | __C.CONST.WEIGHTS = "path to your checkpoint" 102 | ``` 103 | 104 | then use the following command: 105 | 106 | ``` 107 | python main_*.py --test 108 | python main_*.py --inference 109 | ``` 110 | 111 | ## [Acknowledgements] 112 | 113 | Some of the code of this repo is borrowed from [GRNet](https://github.com/hzxie/GRNet) and [PointCloudLib](https://github.com/Jittor/PointCloudLib). We thank the authors for their wonderful job! 114 | 115 | ## [License] 116 | 117 | This project is open sourced under MIT license. 118 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/config_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | from easydict import EasyDict as edict 5 | 6 | __C = edict() 7 | cfg = __C 8 | 9 | # 10 | # Dataset Config 11 | # 12 | __C.DATASETS = edict() 13 | __C.DATASETS.COMPLETION3D = edict() 14 | __C.DATASETS.COMPLETION3D.CATEGORY_FILE_PATH = './datasets/Completion3D.json' 15 | __C.DATASETS.COMPLETION3D.PARTIAL_POINTS_PATH = '/data/shapenet/%s/partial/%s/%s.h5' 16 | __C.DATASETS.COMPLETION3D.COMPLETE_POINTS_PATH = '/data/shapenet/%s/gt/%s/%s.h5' 17 | __C.DATASETS.SHAPENET = edict() 18 | __C.DATASETS.SHAPENET.CATEGORY_FILE_PATH = './datasets/ShapeNet.json' 19 | __C.DATASETS.SHAPENET.N_RENDERINGS = 8 20 | __C.DATASETS.SHAPENET.N_POINTS = 2048 21 | __C.DATASETS.SHAPENET.PARTIAL_POINTS_PATH = '/PCN/%s/partial/%s/%s/%02d.pcd' 22 | __C.DATASETS.SHAPENET.COMPLETE_POINTS_PATH = '/PCN/%s/complete/%s/%s.pcd' 23 | 24 | # 25 | # Dataset 26 | # 27 | __C.DATASET = edict() 28 | # Dataset Options: Completion3D, ShapeNet, ShapeNetCars, Completion3DPCCT 29 | __C.DATASET.TRAIN_DATASET = 'Completion3D' 30 | __C.DATASET.TEST_DATASET = 'Completion3D' 31 | 32 | # 33 | # Constants 34 | # 35 | __C.CONST = edict() 36 | 37 | __C.CONST.NUM_WORKERS = 4 38 | __C.CONST.N_INPUT_POINTS = 2048 39 | 40 | # 41 | # Directories 42 | # 43 | 44 | __C.DIR = edict() 45 | __C.DIR.OUT_PATH = './exp/c3d' 46 | __C.CONST.DEVICE = '0' 47 | __C.CONST.WEIGHTS = '' #'/data1/xp/pmp_jittor/checkpoints/2021-07-07T18:01:01.811422/ckpt-best.pkl' # 'ckpt-best.pth' # specify a path to run test and inference 48 | 49 | # 50 | # Memcached 51 | # 52 | __C.MEMCACHED = edict() 53 | __C.MEMCACHED.ENABLED = False 54 | __C.MEMCACHED.LIBRARY_PATH = '/mnt/lustre/share/pymc/py3' 55 | __C.MEMCACHED.SERVER_CONFIG = '/mnt/lustre/share/memcached_client/server_list.conf' 56 | __C.MEMCACHED.CLIENT_CONFIG = '/mnt/lustre/share/memcached_client/client.conf' 57 | 58 | # 59 | # Network 60 | # 61 | __C.NETWORK = edict() 62 | __C.NETWORK.N_SAMPLING_POINTS = 2048 63 | 64 | # 65 | # Train 66 | # 67 | __C.TRAIN = edict() 68 | __C.TRAIN.LAMBDA_CD = 1000 69 | __C.TRAIN.LAMBDA_PMD = 1e-3 70 | __C.TRAIN.BATCH_SIZE = 48 71 | __C.TRAIN.N_EPOCHS = 300 72 | __C.TRAIN.SAVE_FREQ = 25 73 | __C.TRAIN.LEARNING_RATE = 0.001 74 | __C.TRAIN.LR_MILESTONES = [50, 100, 150, 200, 250] 75 | __C.TRAIN.GAMMA = .5 76 | __C.TRAIN.BETAS = (.9, .999) 77 | __C.TRAIN.WEIGHT_DECAY = 0 78 | 79 | 80 | # 81 | # Test 82 | # 83 | __C.TEST = edict() 84 | __C.TEST.METRIC_NAME = 'ChamferDistance' 85 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/config_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | from easydict import EasyDict as edict 5 | 6 | __C = edict() 7 | cfg = __C 8 | 9 | # 10 | # Dataset Config 11 | # 12 | __C.DATASETS = edict() 13 | __C.DATASETS.COMPLETION3D = edict() 14 | __C.DATASETS.COMPLETION3D.CATEGORY_FILE_PATH = './datasets/Completion3D.json' 15 | __C.DATASETS.COMPLETION3D.PARTIAL_POINTS_PATH = '/data/shapenet/%s/partial/%s/%s.h5' 16 | __C.DATASETS.COMPLETION3D.COMPLETE_POINTS_PATH = '/data/shapenet/%s/gt/%s/%s.h5' 17 | __C.DATASETS.SHAPENET = edict() 18 | __C.DATASETS.SHAPENET.CATEGORY_FILE_PATH = './datasets/ShapeNet.json' 19 | __C.DATASETS.SHAPENET.N_RENDERINGS = 8 20 | __C.DATASETS.SHAPENET.N_POINTS = 16384 21 | __C.DATASETS.SHAPENET.PARTIAL_POINTS_PATH = '/data/PCN/%s/partial/%s/%s/%02d.pcd' 22 | __C.DATASETS.SHAPENET.COMPLETE_POINTS_PATH = '/data/PCN/%s/complete/%s/%s.pcd' 23 | 24 | # 25 | # Dataset 26 | # 27 | __C.DATASET = edict() 28 | # Dataset Options: Completion3D, ShapeNet, ShapeNetCars, Completion3DPCCT 29 | __C.DATASET.TRAIN_DATASET = 'ShapeNet' 30 | __C.DATASET.TEST_DATASET = 'ShapeNet' 31 | 32 | # 33 | # Constants 34 | # 35 | __C.CONST = edict() 36 | 37 | __C.CONST.NUM_WORKERS = 8 38 | __C.CONST.N_INPUT_POINTS = 2048 39 | 40 | # 41 | # Directories 42 | # 43 | 44 | __C.DIR = edict() 45 | __C.DIR.OUT_PATH = './exp/pcn' 46 | __C.CONST.DEVICE = '0' 47 | __C.CONST.WEIGHTS = '' # specify a path to run test and inference 48 | 49 | # 50 | # Memcached 51 | # 52 | __C.MEMCACHED = edict() 53 | __C.MEMCACHED.ENABLED = False 54 | __C.MEMCACHED.LIBRARY_PATH = '/mnt/lustre/share/pymc/py3' 55 | __C.MEMCACHED.SERVER_CONFIG = '/mnt/lustre/share/memcached_client/server_list.conf' 56 | __C.MEMCACHED.CLIENT_CONFIG = '/mnt/lustre/share/memcached_client/client.conf' 57 | 58 | # 59 | # Network 60 | # 61 | __C.NETWORK = edict() 62 | __C.NETWORK.N_SAMPLING_POINTS = 2048 63 | 64 | 65 | # 66 | # Train 67 | # 68 | __C.TRAIN = edict() 69 | __C.TRAIN.LAMBDA_CD = 1000 70 | __C.TRAIN.LAMBDA_PMD = 1e-3 71 | __C.TRAIN.BATCH_SIZE = 32 72 | __C.TRAIN.N_EPOCHS = 150 73 | __C.TRAIN.SAVE_FREQ = 25 74 | __C.TRAIN.LEARNING_RATE = 0.001 75 | __C.TRAIN.LR_MILESTONES = [50, 100, 150, 200, 250] 76 | __C.TRAIN.GAMMA = .5 77 | __C.TRAIN.BETAS = (.9, .999) 78 | __C.TRAIN.WEIGHT_DECAY = 0 79 | 80 | # 81 | # Test 82 | # 83 | __C.TEST = edict() 84 | __C.TEST.METRIC_NAME = 'ChamferDistance' 85 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/PMPPlus-Jittor/core/__init__.py -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/chamfer.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | def select_vertices(verts, idxs): 4 | batch_size = verts.shape[0] 5 | assert idxs.shape[0] == batch_size 6 | 7 | verts = verts.reindex([batch_size, idxs.shape[1], 3], [ 8 | 'i0', 9 | '@e0(i0, i1)', 10 | 'i2' 11 | ], extras=[idxs]) 12 | return verts 13 | 14 | 15 | cpu_src = ''' 16 | for (int bs = 0; bs < in0_shape0; ++bs) 17 | for (int i = 0; i < in0_shape1; ++i) { 18 | float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) + 19 | (@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) + 20 | (@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2)); 21 | @out(bs, i) = 0; 22 | for (int j = 1; j < in1_shape1; ++j) { 23 | float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) + 24 | (@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) + 25 | (@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2)); 26 | if (dis < min_dis) { 27 | min_dis = dis; 28 | @out(bs, i) = j; 29 | } 30 | } 31 | } 32 | ''' 33 | 34 | cuda_src = ''' 35 | __global__ void chamfer_loss_min_idx_kernel(@ARGS_DEF) { 36 | @PRECALC 37 | int bs = blockIdx.x; 38 | int n = in0_shape1; 39 | int m = in1_shape1; 40 | 41 | for (int i = threadIdx.x; i < n; i += blockDim.x) { 42 | float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) + 43 | (@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) + 44 | (@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2)); 45 | @out(bs, i) = 0; 46 | for (int j = 1; j < m; ++j) { 47 | float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) + 48 | (@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) + 49 | (@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2)); 50 | if (dis < min_dis) { 51 | min_dis = dis; 52 | @out(bs, i) = j; 53 | } 54 | } 55 | } 56 | } 57 | 58 | chamfer_loss_min_idx_kernel<<>>(@ARGS); 59 | ''' 60 | 61 | 62 | def chamfer_loss(pc1, pc2, reduction='mean', sqrt=True): 63 | ''' 64 | return the chamfer loss from pc1 to pc2. 65 | 66 | Parameters: 67 | =========== 68 | pc1: [B, N, xyz] 69 | pc2: [B, N, xyz] 70 | reduction: 'mean', 'sum', or None 71 | ''' 72 | batch_size_1, n_samples_pc1, _ = pc1.shape 73 | batch_size_2, n_samples_pc2, _ = pc2.shape 74 | 75 | assert batch_size_1 == batch_size_2 76 | batch_size = batch_size_1 77 | 78 | idx = jt.code([batch_size, n_samples_pc1], 'int32', [pc1, pc2], 79 | cpu_src=cpu_src, 80 | cuda_src=cuda_src) 81 | 82 | nearest_pts = select_vertices(pc2, idx) 83 | if sqrt: 84 | chamfer_distance = (((pc1 - nearest_pts) ** 2).sum(dim=-1)).sqrt() 85 | else: 86 | chamfer_distance = (((pc1 - nearest_pts) ** 2).sum(dim=-1)) 87 | 88 | if reduction is None: 89 | return chamfer_distance 90 | elif reduction == 'sum': 91 | return jt.sum(chamfer_distance) 92 | elif reduction == 'mean': 93 | return jt.mean(chamfer_distance) 94 | 95 | 96 | def chamfer_loss_bidirectional_sqrt(pc1, pc2): 97 | ''' 98 | return the chamfer loss between two point clouds. 99 | ''' 100 | return (chamfer_loss(pc1, pc2, sqrt=True) + chamfer_loss(pc2, pc1, sqrt=True)) / 2 101 | 102 | 103 | def chamfer_loss_bidirectional(pc1, pc2): 104 | ''' 105 | return the chamfer loss between two point clouds. 106 | ''' 107 | l = chamfer_loss(pc1, pc2, sqrt=False) 108 | r = chamfer_loss(pc2, pc1, sqrt=False) 109 | return l + r -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/inference_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import jittor 7 | import utils.helpers 8 | import utils.io 9 | import utils.data_loaders as dataloader_jt 10 | from tqdm import tqdm 11 | from models.model import PMPNetPlus as Model 12 | 13 | 14 | def inference_net(cfg): 15 | dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 16 | test_data_loader = dataset_loader.get_dataset(dataloader_jt.DatasetSubset.TEST, 17 | batch_size=1, 18 | shuffle=False) 19 | 20 | model = Model(dataset=cfg.DATASET.TEST_DATASET) 21 | 22 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 23 | print('loading: ', cfg.CONST.WEIGHTS) 24 | model.load(cfg.CONST.WEIGHTS) 25 | 26 | # Switch models to evaluation mode 27 | model.eval() 28 | 29 | # The inference loop 30 | n_samples = len(test_data_loader) 31 | t_obj = tqdm(test_data_loader) 32 | 33 | 34 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t_obj): 35 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 36 | model_id = model_id[0] 37 | 38 | 39 | partial = jittor.array(data['partial_cloud']) 40 | 41 | pcds = model(partial)[0] 42 | pcd1, pcd2, pcd3 = pcds 43 | 44 | 45 | output_folder = os.path.join(cfg.DIR.OUT_PATH, 'benchmark', taxonomy_id) 46 | if not os.path.exists(output_folder): 47 | os.makedirs(output_folder) 48 | output_folder_pcd1 = os.path.join(output_folder, 'pcd1') 49 | output_folder_pcd2 = os.path.join(output_folder, 'pcd2') 50 | output_folder_pcd3 = os.path.join(output_folder, 'pcd3') 51 | if not os.path.exists(output_folder_pcd1): 52 | os.makedirs(output_folder_pcd1) 53 | os.makedirs(output_folder_pcd2) 54 | os.makedirs(output_folder_pcd3) 55 | 56 | # print(pcd1) 57 | output_file_path = os.path.join(output_folder, 'pcd1', '%s.h5' % model_id) 58 | utils.io.IO.put(output_file_path, pcd3.squeeze(0).detach().numpy()) 59 | 60 | output_file_path = os.path.join(output_folder, 'pcd2', '%s.h5' % model_id) 61 | utils.io.IO.put(output_file_path, pcd2.squeeze(0).detach().numpy()) 62 | 63 | output_file_path = os.path.join(output_folder, 'pcd3', '%s.h5' % model_id) 64 | utils.io.IO.put(output_file_path, pcd3.squeeze(0).detach().numpy()) 65 | 66 | t_obj.set_description('Test[%d/%d] Taxonomy = %s Sample = %s File = %s' % 67 | (model_idx + 1, n_samples, taxonomy_id, model_id, output_file_path)) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/inference_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import jittor 7 | import utils.data_loaders as dataloader_jt 8 | import utils.helpers 9 | import utils.io 10 | from tqdm import tqdm 11 | from models.model import PMPNetPlus as Model 12 | 13 | 14 | def random_subsample(pcd, n_points=2048): 15 | """ 16 | Args: 17 | pcd: (B, N, 3) 18 | 19 | returns: 20 | new_pcd: (B, n_points, 3) 21 | """ 22 | b, n, _ = pcd.shape 23 | batch_idx = jittor.arange(b,).reshape((-1, 1)).repeat(1, n_points) 24 | idx = jittor.concat([jittor.randperm(n,)[:n_points].reshape((1, -1)) for i in range(b)], 0) 25 | return pcd[batch_idx, idx, :] 26 | 27 | def inference_net(cfg): 28 | dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 29 | test_data_loader = dataset_loader.get_dataset(dataloader_jt.DatasetSubset.TEST, 30 | batch_size=1, 31 | shuffle=False) 32 | 33 | model = Model(dataset=cfg.DATASET.TEST_DATASET) 34 | 35 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 36 | print('loading: ', cfg.CONST.WEIGHTS) 37 | model.load(cfg.CONST.WEIGHTS) 38 | 39 | # Switch models to evaluation mode 40 | model.eval() 41 | 42 | # The inference loop 43 | n_samples = len(test_data_loader) 44 | t_obj = tqdm(test_data_loader) 45 | 46 | 47 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t_obj): 48 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 49 | model_id = model_id[0] 50 | 51 | partial = jittor.array(data['partial_cloud']) 52 | partial = random_subsample(partial.repeat((1, 8, 1)).reshape(-1, 16384, 3)) # b*8, 2048, 3 53 | pcds = model(partial)[0] 54 | 55 | pcd1 = pcds[0].reshape(-1, 16384, 3) 56 | pcd2 = pcds[1].reshape(-1, 16384, 3) 57 | pcd3 = pcds[2].reshape(-1, 16384, 3) 58 | 59 | output_folder = os.path.join(cfg.DIR.OUT_PATH, 'benchmark', taxonomy_id) 60 | if not os.path.exists(output_folder): 61 | os.makedirs(output_folder) 62 | output_folder_pcd1 = os.path.join(output_folder, 'pcd1') 63 | output_folder_pcd2 = os.path.join(output_folder, 'pcd2') 64 | output_folder_pcd3 = os.path.join(output_folder, 'pcd3') 65 | if not os.path.exists(output_folder_pcd1): 66 | os.makedirs(output_folder_pcd1) 67 | os.makedirs(output_folder_pcd2) 68 | os.makedirs(output_folder_pcd3) 69 | 70 | output_file_path = os.path.join(output_folder, 'pcd1', '%s.h5' % model_id) 71 | utils.io.IO.put(output_file_path, pcd1.squeeze(0).detach().numpy()) 72 | 73 | output_file_path = os.path.join(output_folder, 'pcd2', '%s.h5' % model_id) 74 | utils.io.IO.put(output_file_path, pcd2.squeeze(0).detach().numpy()) 75 | 76 | output_file_path = os.path.join(output_folder, 'pcd3', '%s.h5' % model_id) 77 | utils.io.IO.put(output_file_path, pcd3.squeeze(0).detach().numpy()) 78 | 79 | t_obj.set_description('Test[%d/%d] Taxonomy = %s Sample = %s File = %s' % 80 | (model_idx + 1, n_samples, taxonomy_id, model_id, output_file_path)) 81 | 82 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/test_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import jittor 5 | import utils.data_loaders as dataloader_jt 6 | from tqdm import tqdm 7 | from utils.average_meter import AverageMeter 8 | from utils.metrics import Metrics 9 | from models.model import PMPNetPlus as Model 10 | from core.chamfer import chamfer_loss_bidirectional, chamfer_loss_bidirectional_sqrt 11 | chamfer = chamfer_loss_bidirectional 12 | chamfer_sqrt = chamfer_loss_bidirectional_sqrt 13 | 14 | def test_net(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, model=None): 15 | 16 | if test_data_loader is None: 17 | # Set up data loader 18 | dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 19 | test_data_loader = dataset_loader.get_dataset(dataloader_jt.DatasetSubset.VAL, 20 | batch_size=1, 21 | shuffle=False) 22 | 23 | 24 | # Setup networks and initialize networks 25 | if model is None: 26 | model = Model(dataset=cfg.DATASET.TEST_DATASET) 27 | 28 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 29 | print('loading: ', cfg.CONST.WEIGHTS) 30 | model.load(cfg.CONST.WEIGHTS) 31 | 32 | # Switch models to evaluation mode 33 | model.eval() 34 | 35 | n_samples = len(test_data_loader) 36 | test_losses = AverageMeter(['cd1', 'cd2', 'cd3', 'pmd']) 37 | test_metrics = AverageMeter(Metrics.names()) 38 | category_metrics = dict() 39 | 40 | # Testing loop 41 | with tqdm(test_data_loader) as t: 42 | # print('repeating') 43 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t): 44 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 45 | model_id = model_id[0] 46 | 47 | 48 | # for k, v in data.items(): 49 | # data[k] = utils.helpers.var_or_cuda(v) 50 | 51 | partial = jittor.array(data['partial_cloud']) 52 | gt = jittor.array(data['gtcloud']) 53 | 54 | b, n, _ = partial.shape 55 | 56 | pcds, deltas = model(partial) 57 | 58 | cd1 = chamfer(pcds[0], gt).item() * 1e3 59 | cd2 = chamfer(pcds[1], gt).item() * 1e3 60 | cd3 = chamfer(pcds[2], gt).item() * 1e3 61 | 62 | # pmd loss 63 | pmd_losses = [] 64 | for delta in deltas: 65 | pmd_losses.append(jittor.sum(delta ** 2)) 66 | 67 | 68 | pmd = jittor.sum(jittor.stack(pmd_losses)) / 3 69 | 70 | pmd_item = pmd.item() 71 | 72 | _metrics = [pmd_item, cd3] 73 | test_losses.update([cd1, cd2, cd3, pmd_item]) 74 | 75 | test_metrics.update(_metrics) 76 | if taxonomy_id not in category_metrics: 77 | category_metrics[taxonomy_id] = AverageMeter(Metrics.names()) 78 | category_metrics[taxonomy_id].update(_metrics) 79 | 80 | t.set_description('Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s' % 81 | (model_idx + 1, n_samples, taxonomy_id, model_id, ['%.4f' % l for l in test_losses.val() 82 | ], ['%.4f' % m for m in _metrics])) 83 | 84 | # Print testing results 85 | print('============================ TEST RESULTS ============================') 86 | print('Taxonomy', end='\t') 87 | print('#Sample', end='\t') 88 | for metric in test_metrics.items: 89 | print(metric, end='\t') 90 | print() 91 | 92 | for taxonomy_id in category_metrics: 93 | print(taxonomy_id, end='\t') 94 | print(category_metrics[taxonomy_id].count(0), end='\t') 95 | for value in category_metrics[taxonomy_id].avg(): 96 | print('%.4f' % value, end='\t') 97 | print() 98 | 99 | print('Overall', end='\t\t\t') 100 | for value in test_metrics.avg(): 101 | print('%.4f' % value, end='\t') 102 | print('\n') 103 | 104 | # Add testing results to TensorBoard 105 | if test_writer is not None: 106 | test_writer.add_scalar('Loss/Epoch/cd1', test_losses.avg(0), epoch_idx) 107 | test_writer.add_scalar('Loss/Epoch/cd2', test_losses.avg(1), epoch_idx) 108 | test_writer.add_scalar('Loss/Epoch/cd3', test_losses.avg(2), epoch_idx) 109 | test_writer.add_scalar('Loss/Epoch/delta', test_losses.avg(3), epoch_idx) 110 | for i, metric in enumerate(test_metrics.items): 111 | test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i), epoch_idx) 112 | model.train() 113 | return test_losses.avg(2) 114 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/test_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import jittor 6 | import utils.helpers 7 | from tqdm import tqdm 8 | import utils.data_loaders as dataloader_jt 9 | from utils.average_meter import AverageMeter 10 | from utils.metrics import Metrics 11 | from models.model import PMPNetPlus as Model 12 | from core.chamfer import chamfer_loss_bidirectional_sqrt as chamfer 13 | 14 | def random_subsample(pcd, n_points=2048): 15 | """ 16 | Args: 17 | pcd: (B, N, 3) 18 | 19 | returns: 20 | new_pcd: (B, n_points, 3) 21 | """ 22 | b, n, _ = pcd.shape 23 | batch_idx = jittor.arange(b,).reshape((-1, 1)).repeat(1, n_points) 24 | idx = jittor.concat([jittor.randperm(n,)[:n_points].reshape((1, -1)) for i in range(b)], 0) 25 | return pcd[batch_idx, idx, :] 26 | 27 | def test_net(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, model=None): 28 | 29 | if test_data_loader is None: 30 | # Set up data loader 31 | dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 32 | test_data_loader = dataset_loader.get_dataset(dataloader_jt.DatasetSubset.TEST, 33 | batch_size=4, 34 | shuffle=False) 35 | 36 | # Setup networks and initialize networks 37 | if model is None: 38 | model = Model(dataset=cfg.DATASET.TEST_DATASET) 39 | 40 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 41 | print('loading: ', cfg.CONST.WEIGHTS) 42 | model.load(cfg.CONST.WEIGHTS) 43 | # Switch models to evaluation mode 44 | model.eval() 45 | 46 | n_samples = len(test_data_loader) 47 | test_losses = AverageMeter(['cd1', 'cd2', 'cd3', 'pmd']) 48 | test_metrics = AverageMeter(Metrics.names()) 49 | category_metrics = dict() 50 | 51 | # Testing loop 52 | with tqdm(test_data_loader) as t: 53 | # print('repeating') 54 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t): 55 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 56 | model_id = model_id[0] 57 | 58 | partial = jittor.array(data['partial_cloud']) 59 | gt = jittor.array(data['gtcloud']) 60 | partial = random_subsample(partial.repeat((1, 8, 1)).reshape(-1, 16384, 3)) # b*8, 2048, 3 61 | 62 | b, n, _ = partial.shape 63 | 64 | pcds, deltas = model(partial) 65 | 66 | cd1 = chamfer(pcds[0].reshape(-1, 16384, 3), gt).item() * 1e3 67 | cd2 = chamfer(pcds[1].reshape(-1, 16384, 3), gt).item() * 1e3 68 | cd3 = chamfer(pcds[2].reshape(-1, 16384, 3), gt).item() * 1e3 69 | 70 | # pmd loss 71 | pmd_losses = [] 72 | for delta in deltas: 73 | pmd_losses.append(jittor.sum(delta ** 2)) 74 | 75 | pmd = jittor.sum(jittor.stack(pmd_losses)) / 3 76 | 77 | pmd_item = pmd.item() 78 | 79 | _metrics = [pmd_item, cd3] 80 | test_losses.update([cd1, cd2, cd3, pmd_item]) 81 | 82 | test_metrics.update(_metrics) 83 | if taxonomy_id not in category_metrics: 84 | category_metrics[taxonomy_id] = AverageMeter(Metrics.names()) 85 | category_metrics[taxonomy_id].update(_metrics) 86 | 87 | t.set_description('Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s' % 88 | (model_idx + 1, n_samples, taxonomy_id, model_id, ['%.4f' % l for l in test_losses.val() 89 | ], ['%.4f' % m for m in _metrics])) 90 | 91 | # Print testing results 92 | print('============================ TEST RESULTS ============================') 93 | print('Taxonomy', end='\t') 94 | print('#Sample', end='\t') 95 | for metric in test_metrics.items: 96 | print(metric, end='\t') 97 | print() 98 | 99 | for taxonomy_id in category_metrics: 100 | print(taxonomy_id, end='\t') 101 | print(category_metrics[taxonomy_id].count(0), end='\t') 102 | for value in category_metrics[taxonomy_id].avg(): 103 | print('%.4f' % value, end='\t') 104 | print() 105 | 106 | print('Overall', end='\t\t\t') 107 | for value in test_metrics.avg(): 108 | print('%.4f' % value, end='\t') 109 | print('\n') 110 | 111 | # Add testing results to TensorBoard 112 | if test_writer is not None: 113 | test_writer.add_scalar('Loss/Epoch/cd1', test_losses.avg(0), epoch_idx) 114 | test_writer.add_scalar('Loss/Epoch/cd2', test_losses.avg(1), epoch_idx) 115 | test_writer.add_scalar('Loss/Epoch/cd3', test_losses.avg(2), epoch_idx) 116 | test_writer.add_scalar('Loss/Epoch/delta', test_losses.avg(3), epoch_idx) 117 | for i, metric in enumerate(test_metrics.items): 118 | test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i), epoch_idx) 119 | 120 | return test_losses.avg(2) 121 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/train_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import jittor 7 | import utils.data_loaders as dataloader_jt 8 | from jittor import nn 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from time import time 12 | from tensorboardX import SummaryWriter 13 | from core.test_c3d import test_net 14 | from utils.average_meter import AverageMeter 15 | from models.model import PMPNetPlus as Model 16 | from core.chamfer import chamfer_loss_bidirectional as chamfer 17 | from jittor.utils.nvtx import nvtx_scope 18 | 19 | def lr_lambda(epoch): 20 | if 0 <= epoch <= 100: 21 | return 1 22 | elif 100 < epoch <= 150: 23 | return 0.5 24 | elif 150 < epoch <= 250: 25 | return 0.1 26 | else: 27 | return 0.5 28 | 29 | 30 | def train_net(cfg): 31 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 32 | 33 | # train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) 34 | # test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 35 | 36 | train_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) 37 | test_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 38 | 39 | train_data_loader = train_dataset_loader.get_dataset(dataloader_jt.DatasetSubset.TRAIN, 40 | batch_size=cfg.TRAIN.BATCH_SIZE, 41 | num_workers=cfg.CONST.NUM_WORKERS, 42 | shuffle=True) 43 | val_data_loader = test_dataset_loader.get_dataset(dataloader_jt.DatasetSubset.VAL, 44 | batch_size=cfg.TRAIN.BATCH_SIZE, 45 | num_workers=cfg.CONST.NUM_WORKERS, 46 | shuffle=False) 47 | 48 | # Set up folders for logs and checkpoints 49 | output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) 50 | cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' 51 | cfg.DIR.LOGS = output_dir % 'logs' 52 | if not os.path.exists(cfg.DIR.CHECKPOINTS): 53 | os.makedirs(cfg.DIR.CHECKPOINTS) 54 | 55 | # Create tensorboard writers 56 | train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) 57 | val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) 58 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 59 | init_epoch = 0 60 | best_metrics = float('inf') 61 | 62 | optimizer = nn.Adam(model.parameters(), 63 | lr=cfg.TRAIN.LEARNING_RATE, 64 | weight_decay=cfg.TRAIN.WEIGHT_DECAY, 65 | betas=cfg.TRAIN.BETAS) 66 | lr_scheduler = jittor.lr_scheduler.MultiStepLR(optimizer, 67 | milestones=cfg.TRAIN.LR_MILESTONES, 68 | gamma=cfg.TRAIN.GAMMA, 69 | last_epoch=init_epoch) 70 | 71 | 72 | 73 | # Training/Testing the network 74 | for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): 75 | epoch_start_time = time() 76 | 77 | model.train() 78 | 79 | loss_metric = AverageMeter() 80 | n_batches = len(train_data_loader) 81 | print('epoch: ', epoch_idx, 'optimizer: ', lr_scheduler.get_lr()) 82 | with tqdm(train_data_loader) as t: 83 | for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t): 84 | partial = jittor.array(data['partial_cloud']) 85 | gt = jittor.array(data['gtcloud']) 86 | pcds, deltas = model(partial) 87 | 88 | cd1 = chamfer(pcds[0], gt) 89 | cd2 = chamfer(pcds[1], gt) 90 | cd3 = chamfer(pcds[2], gt) 91 | loss_cd = cd1 + cd2 + cd3 92 | 93 | delta_losses = [] 94 | for delta in deltas: 95 | delta_losses.append(jittor.sum(delta ** 2)) 96 | 97 | loss_pmd = jittor.sum(jittor.stack(delta_losses)) / 3 98 | 99 | loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD 100 | optimizer.step(loss) 101 | 102 | loss_item = loss.item() 103 | loss_metric.update(loss_item) 104 | 105 | jittor.sync_all() 106 | 107 | t.set_description( 108 | '[Epoch %d/%d][Batch %d/%d]' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches)) 109 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [loss_item]]) 110 | 111 | 112 | lr_scheduler.step() 113 | epoch_end_time = time() 114 | train_writer.add_scalar('Loss/Epoch/loss', loss_metric.avg(), epoch_idx) 115 | logging.info( 116 | '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % 117 | (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, 118 | ['%.4f' % l for l in [loss_metric.avg()]])) 119 | 120 | # Validate the current model 121 | cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) 122 | 123 | # Save checkpoints 124 | if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics: 125 | file_name = 'ckpt-best.pkl' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pkl' % epoch_idx 126 | output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) 127 | 128 | model.save(output_path) 129 | 130 | logging.info('Saved checkpoint to %s ...' % output_path) 131 | if cd_eval < best_metrics: 132 | best_metrics = cd_eval 133 | 134 | train_writer.close() 135 | val_writer.close() 136 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/core/train_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import jittor 7 | import utils.data_loaders as dataloader_jt 8 | from jittor import nn 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from time import time 12 | from tensorboardX import SummaryWriter 13 | from core.test_pcn import test_net 14 | from utils.average_meter import AverageMeter 15 | from models.model import PMPNetPlus as Model 16 | from core.chamfer import chamfer_loss_bidirectional as chamfer 17 | from jittor.utils.nvtx import nvtx_scope 18 | 19 | def random_subsample(pcd, n_points=2048): 20 | """ 21 | Args: 22 | pcd: (B, N, 3) 23 | 24 | returns: 25 | new_pcd: (B, n_points, 3) 26 | """ 27 | b, n, _ = pcd.shape 28 | batch_idx = jittor.arange(b).reshape((-1, 1)).repeat(1, n_points) 29 | idx = jittor.concat([jittor.randperm(n)[:n_points].reshape((1, -1)) for i in range(b)], 0) 30 | return pcd[batch_idx, idx, :] 31 | 32 | 33 | 34 | def lr_lambda(epoch): 35 | if 0 <= epoch <= 100: 36 | return 1 37 | elif 100 < epoch <= 150: 38 | return 0.5 39 | elif 150 < epoch <= 250: 40 | return 0.1 41 | else: 42 | return 0.5 43 | 44 | 45 | def train_net(cfg): 46 | 47 | train_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) 48 | test_dataset_loader = dataloader_jt.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 49 | 50 | train_data_loader = train_dataset_loader.get_dataset(dataloader_jt.DatasetSubset.TRAIN, 51 | batch_size=cfg.TRAIN.BATCH_SIZE, 52 | num_workers=cfg.CONST.NUM_WORKERS, 53 | shuffle=True) 54 | val_data_loader = test_dataset_loader.get_dataset(dataloader_jt.DatasetSubset.TEST, 55 | num_workers=cfg.CONST.NUM_WORKERS, 56 | batch_size=4, 57 | shuffle=False) 58 | 59 | # Set up folders for logs and checkpoints 60 | output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) 61 | cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' 62 | cfg.DIR.LOGS = output_dir % 'logs' 63 | if not os.path.exists(cfg.DIR.CHECKPOINTS): 64 | os.makedirs(cfg.DIR.CHECKPOINTS) 65 | 66 | # Create tensorboard writers 67 | train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) 68 | val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) 69 | 70 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 71 | init_epoch = 0 72 | best_metrics = float('inf') 73 | 74 | # Create the optimizers 75 | optimizer = nn.Adam(model.parameters(), 76 | lr=cfg.TRAIN.LEARNING_RATE, 77 | weight_decay=cfg.TRAIN.WEIGHT_DECAY, 78 | betas=cfg.TRAIN.BETAS) 79 | lr_scheduler = jittor.lr_scheduler.MultiStepLR(optimizer, 80 | milestones=cfg.TRAIN.LR_MILESTONES, 81 | gamma=cfg.TRAIN.GAMMA, 82 | last_epoch=init_epoch) 83 | 84 | # Training/Testing the network 85 | for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): 86 | epoch_start_time = time() 87 | 88 | batch_time = AverageMeter() 89 | data_time = AverageMeter() 90 | 91 | model.train() 92 | 93 | loss_metric = AverageMeter() 94 | 95 | batch_end_time = time() 96 | n_batches = len(train_data_loader) 97 | # cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) 98 | with tqdm(train_data_loader) as t: 99 | for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t): 100 | data_time.update(time() - batch_end_time) 101 | partial = random_subsample(jittor.array(data['partial_cloud'])) 102 | gt = random_subsample(jittor.array(data['gtcloud'])) 103 | 104 | pcds, deltas = model(partial) 105 | 106 | cd1 = chamfer(pcds[0], gt) 107 | cd2 = chamfer(pcds[1], gt) 108 | cd3 = chamfer(pcds[2], gt) 109 | loss_cd = cd1 + cd2 + cd3 110 | 111 | delta_losses = [] 112 | for delta in deltas: 113 | delta_losses.append(jittor.sum(delta ** 2)) 114 | 115 | loss_pmd = jittor.sum(jittor.stack(delta_losses)) / 3 116 | 117 | loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD 118 | 119 | optimizer.step(loss) 120 | jittor.sync_all() 121 | 122 | loss_item = loss.item() 123 | loss_metric.update(loss_item) 124 | batch_time.update(time() - batch_end_time) 125 | batch_end_time = time() 126 | t.set_description( 127 | '[Epoch %d/%d][Batch %d/%d]' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches)) 128 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [loss_item]]) 129 | 130 | 131 | lr_scheduler.step() 132 | epoch_end_time = time() 133 | train_writer.add_scalar('Loss/Epoch/loss', loss_metric.avg(), epoch_idx) 134 | logging.info( 135 | '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % 136 | (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, ['%.4f' % l for l in [loss_metric.avg()]])) 137 | 138 | # Validate the current model 139 | cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) 140 | 141 | # Save checkpoints 142 | if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics: 143 | file_name = 'ckpt-best.pkl' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pkl' % epoch_idx 144 | output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) 145 | model.save(output_path) 146 | 147 | logging.info('Saved checkpoint to %s ...' % output_path) 148 | if cd_eval < best_metrics: 149 | best_metrics = cd_eval 150 | 151 | train_writer.close() 152 | val_writer.close() 153 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/main_c3d.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Peng Xiang 4 | 5 | import argparse 6 | import logging 7 | import os 8 | import numpy as np 9 | from pprint import pprint 10 | from config_c3d import cfg 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.DEVICE 13 | 14 | import jittor 15 | jittor.flags.use_cuda = 1 16 | from core.train_c3d import train_net 17 | from core.test_c3d import test_net 18 | from core.inference_c3d import inference_net 19 | 20 | def set_seed(seed): 21 | np.random.seed(seed) 22 | jittor.set_global_seed(seed) 23 | 24 | 25 | def get_args_from_command_line(): 26 | parser = argparse.ArgumentParser(description='The argument parser of PMP-Net') 27 | parser.add_argument('--test', dest='test', help='Test neural networks', action='store_true') 28 | parser.add_argument('--inference', dest='inference', help='Inference for benchmark', action='store_true') 29 | args = parser.parse_args() 30 | 31 | return args 32 | 33 | 34 | def main(): 35 | # Get args from command line 36 | args = get_args_from_command_line() 37 | 38 | # Print config 39 | print('Use config:') 40 | pprint(cfg) 41 | 42 | if not args.test and not args.inference: 43 | train_net(cfg) 44 | else: 45 | if cfg.CONST.WEIGHTS is None: 46 | raise Exception('Please specify the path to checkpoint in the configuration file!') 47 | 48 | if args.test: 49 | test_net(cfg) 50 | else: 51 | inference_net(cfg) 52 | 53 | if __name__ == '__main__': 54 | # Check python version 55 | 56 | seed = 1 57 | set_seed(seed) 58 | logging.basicConfig(format='[%(levelname)s] %(asctime)s %(message)s', level=logging.DEBUG) 59 | main() 60 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/main_pcn.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Peng Xiang 4 | import os 5 | 6 | import argparse 7 | import logging 8 | import numpy as np 9 | from pprint import pprint 10 | from config_pcn import cfg 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.DEVICE 13 | 14 | import jittor 15 | jittor.flags.use_cuda = 1 16 | from core.train_pcn import train_net 17 | from core.test_pcn import test_net 18 | from core.inference_pcn import inference_net 19 | 20 | def set_seed(seed): 21 | np.random.seed(seed) 22 | jittor.set_global_seed(seed) 23 | 24 | def get_args_from_command_line(): 25 | parser = argparse.ArgumentParser(description='The argument parser of PMP-Net') 26 | parser.add_argument('--test', dest='test', help='Test neural networks', action='store_true') 27 | parser.add_argument('--inference', dest='inference', help='Inference for benchmark', action='store_true') 28 | args = parser.parse_args() 29 | 30 | return args 31 | 32 | 33 | def main(): 34 | # Get args from command line 35 | args = get_args_from_command_line() 36 | 37 | # Print config 38 | print('Use config:') 39 | pprint(cfg) 40 | 41 | if not args.test and not args.inference: 42 | train_net(cfg) 43 | else: 44 | if cfg.CONST.WEIGHTS is None: 45 | raise Exception('Please specify the path to checkpoint in the configuration file!') 46 | 47 | if args.test: 48 | test_net(cfg) 49 | else: 50 | inference_net(cfg) 51 | 52 | if __name__ == '__main__': 53 | # Check python version 54 | seed = 1 55 | set_seed(seed) 56 | logging.basicConfig(format='[%(levelname)s] %(asctime)s %(message)s', level=logging.DEBUG) 57 | main() 58 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../pointnet2_ops_lib') 3 | sys.path.append('..') -------------------------------------------------------------------------------- /PMPPlus-Jittor/models/misc/__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/PMPPlus-Jittor/models/misc/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /PMPPlus-Jittor/models/misc/utils.py: -------------------------------------------------------------------------------- 1 | from jittor import nn 2 | import jittor as jt 3 | from sklearn.neighbors import NearestNeighbors 4 | import math 5 | import numpy as np 6 | 7 | 8 | class LRScheduler: 9 | def __init__(self, optimizer, base_lr): 10 | self.optimizer = optimizer 11 | 12 | self.basic_lr = base_lr 13 | self.lr_decay = 0.6 14 | self.decay_step = 15000 15 | 16 | def step(self, step): 17 | lr_decay = self.lr_decay ** int(step / self.decay_step) 18 | lr_decay = max(lr_decay, 2e-5) 19 | self.optimizer.lr = lr_decay * self.basic_lr 20 | 21 | 22 | def knn_indices_func_cpu(rep_pts, # (N, pts, dim) 23 | pts, # (N, x, dim) 24 | K : int, 25 | D : int): 26 | """ 27 | CPU-based Indexing function based on K-Nearest Neighbors search. 28 | :param rep_pts: Representative points. 29 | :param pts: Point cloud to get indices from. 30 | :param K: Number of nearest neighbors to collect. 31 | :param D: "Spread" of neighboring points. 32 | :return: Array of indices, P_idx, into pts such that pts[n][P_idx[n],:] 33 | is the set k-nearest neighbors for the representative points in pts[n]. 34 | """ 35 | rep_pts = rep_pts.data 36 | pts = pts.data 37 | region_idx = [] 38 | 39 | for n, p in enumerate(rep_pts): 40 | P_particular = pts[n] 41 | nbrs = NearestNeighbors(D*K + 1, algorithm = "ball_tree").fit(P_particular) 42 | indices = nbrs.kneighbors(p)[1] 43 | region_idx.append(indices[:,1::D]) 44 | 45 | region_idx = jt.array(np.stack(region_idx, axis = 0)) 46 | return region_idx 47 | 48 | def knn_indices_func_gpu(rep_pts, # (N, pts, dim) 49 | pts, # (N, x, dim) 50 | k : int, d : int ): # (N, pts, K) 51 | """ 52 | GPU-based Indexing function based on K-Nearest Neighbors search. 53 | Very memory intensive, and thus unoptimal for large numbers of points. 54 | :param rep_pts: Representative points. 55 | :param pts: Point cloud to get indices from. 56 | :param K: Number of nearest neighbors to collect. 57 | :param D: "Spread" of neighboring points. 58 | :return: Array of indices, P_idx, into pts such that pts[n][P_idx[n],:] 59 | is the set k-nearest neighbors for the representative points in pts[n]. 60 | """ 61 | region_idx = [] 62 | batch_size = rep_pts.shape[0] 63 | for idx in range (batch_size): 64 | qry = rep_pts[idx] 65 | ref = pts[idx] 66 | n, d = ref.shape 67 | m, d = qry.shape 68 | mref = ref.view(1, n, d).repeat(m, 1, 1) 69 | mqry = qry.view(m, 1, d).repeat(1, n, 1) 70 | 71 | dist2 = jt.sum((mqry - mref)**2, 2) # pytorch has squeeze 72 | _, inds = topk(dist2, k*d + 1, dim = 1, largest = False) 73 | 74 | region_idx.append(inds[:,1::d]) 75 | 76 | region_idx = jt.stack(region_idx, dim = 0) 77 | 78 | return region_idx 79 | 80 | 81 | 82 | def expand(x,shape): 83 | r''' 84 | Returns a new view of the self tensor with singleton dimensions expanded to a larger size. 85 | Tensor can be also expanded to a larger number of dimensions, and the new ones will be appended at the front. 86 | Args: 87 | x-the input tensor. 88 | shape-the shape of expanded tensor. 89 | ''' 90 | x_shape = x.shape 91 | x_l = len(x_shape) 92 | rest_shape=shape[:-x_l] 93 | expand_shape = shape[-x_l:] 94 | indexs=[] 95 | ii = len(rest_shape) 96 | for i,j in zip(expand_shape,x_shape): 97 | if i!=j: 98 | assert j==1 99 | indexs.append(f'i{ii}' if j>1 else f'0') 100 | ii+=1 101 | return x.reindex(shape,indexs) 102 | 103 | 104 | def topk(input, k, dim=None, largest=True, sorted=True): 105 | if dim is None: 106 | dim = -1 107 | if dim<0: 108 | dim+=input.ndim 109 | 110 | transpose_dims = [i for i in range(input.ndim)] 111 | transpose_dims[0] = dim 112 | transpose_dims[dim] = 0 113 | input = input.transpose(transpose_dims) 114 | index,values = jt.argsort(input,dim=0,descending=largest) 115 | indices = index[:k] 116 | values = values[:k] 117 | indices = indices.transpose(transpose_dims) 118 | values = values.transpose(transpose_dims) 119 | return [values,indices] 120 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/models/pointnet2_partseg.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import jittor as jt 4 | import jittor.nn as nn 5 | from jittor import init 6 | from jittor.contrib import concat 7 | 8 | from models.misc.ops import FurthestPointSampler 9 | from models.misc.ops import BallQueryGrouper 10 | from models.misc.ops import GroupAll 11 | from models.misc.ops import PointNetFeaturePropagation 12 | 13 | 14 | class PointNetModuleBase(nn.Module): 15 | def __init__(self): 16 | self.n_points = None 17 | self.sampler = None 18 | self.groupers = None 19 | self.mlps = None 20 | 21 | def build_mlps(self, mlp_spec: List[int], use_xyz: bool = True, 22 | bn: bool = True) -> nn.Sequential: 23 | layers = [] 24 | 25 | if use_xyz: 26 | mlp_spec[0] += 3 27 | 28 | for i in range(1, len(mlp_spec)): 29 | layers.append(nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1)) 30 | if bn: 31 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 32 | layers.append(nn.ReLU()) 33 | 34 | return nn.Sequential(*layers) 35 | 36 | def execute(self, xyz: jt.Var, feature: Optional[jt.Var]) -> Tuple[jt.Var, jt.Var]: 37 | ''' 38 | Parameters 39 | ---------- 40 | xyz: jt.Var, (B, N, 3) 41 | feature: jt.Var, (B, C, N) 42 | 43 | Returns 44 | ------- 45 | new_xyz: jt.Var, (B, n_points, 3) 46 | new_feature: jt.Var, (B, C', n_points) 47 | ''' 48 | B, _, _ = xyz.shape 49 | new_xyz = self.sampler(xyz) if self.n_points is not None else jt.zeros((B, 1, 3)) 50 | 51 | new_feature_list = [] 52 | # print (self.groupers) 53 | for i, grouper in self.groupers.layers.items(): 54 | new_feature = grouper(new_xyz, xyz, feature.transpose(0, 2, 1)) 55 | # [B, n_points, n_samples, C] -> [B, C, n_points, n_samples] 56 | # print('\n ---------{}--------------.'.format(i+1)) 57 | new_feature = new_feature.transpose(0, 3, 1, 2) 58 | # print('1. new_feature.shape', new_feature.shape) 59 | new_feature = self.mlps[i](new_feature) 60 | # print('2. new_feature.shape', new_feature.shape) 61 | # [B, C, n_points, n_samples] -> [B, n_points, n_samples, C] 62 | #new_feature = new_feature.transpose(0, 2, 3, 1) 63 | # print('3. new_feature.shape', new_feature.shape) 64 | new_feature = new_feature.argmax(dim=-1)[1] 65 | # print('4. new_feature.shape', new_feature.shape) 66 | 67 | new_feature_list.append(new_feature) 68 | 69 | new_feature = jt.contrib.concat(new_feature_list, dim=-1) 70 | # print('len(new_feature):', len(new_feature)) 71 | return new_xyz, new_feature 72 | 73 | 74 | class PointnetModule(PointNetModuleBase): 75 | def __init__(self, mlp: List[int], n_points=None, radius=None, 76 | n_samples=None, bn=True, use_xyz=True): 77 | super().__init__() 78 | 79 | self.n_points = n_points 80 | 81 | self.groupers = nn.ModuleList() 82 | if self.n_points is not None: 83 | self.sampler = FurthestPointSampler(n_points) 84 | self.groupers.append(BallQueryGrouper(radius, n_samples, use_xyz)) 85 | else: 86 | self.groupers.append(GroupAll(use_xyz)) 87 | 88 | self.mlps = nn.ModuleList() 89 | self.mlps.append(self.build_mlps(mlp, use_xyz)) 90 | 91 | 92 | class PointnetModuleMSG(PointNetModuleBase): 93 | def __init__(self, n_points: int, radius: List[float], n_samples: List[int], 94 | mlps: List[List[int]], bn=True, use_xyz=True): 95 | super().__init__() 96 | 97 | self.n_points = n_points 98 | self.sampler = FurthestPointSampler(n_points) 99 | 100 | self.groupers = nn.ModuleList() 101 | for r, s in zip(radius, n_samples): 102 | self.groupers.append(BallQueryGrouper(r, s, use_xyz)) 103 | 104 | self.mlps = nn.ModuleList() 105 | for mlp in mlps: 106 | self.mlps.append(self.build_mlps(mlp, use_xyz)) 107 | 108 | 109 | class PointNet2_partseg(nn.Module): 110 | def __init__(self, part_num=50, use_xyz=True): 111 | super().__init__() 112 | self.part_num = part_num 113 | self.use_xyz = use_xyz 114 | self.build_model() 115 | 116 | def build_model(self): 117 | self.pointnet_modules = nn.ModuleList() 118 | self.pointnet_modules.append( 119 | PointnetModule( 120 | n_points=512, 121 | radius=0.2, 122 | n_samples=64, 123 | mlp=[3, 64, 64, 128], 124 | use_xyz=self.use_xyz, 125 | ) 126 | ) 127 | 128 | self.pointnet_modules.append( 129 | PointnetModule( 130 | n_points=128, 131 | radius=0.4, 132 | n_samples=64, 133 | mlp=[128, 128, 128, 256], 134 | use_xyz=self.use_xyz, 135 | ) 136 | ) 137 | 138 | self.pointnet_modules.append( 139 | PointnetModule( 140 | mlp=[256, 256, 512, 1024], 141 | use_xyz=self.use_xyz, 142 | ) 143 | ) 144 | 145 | self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256]) 146 | self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128]) 147 | self.fp1 = PointNetFeaturePropagation(in_channel=128 + 16 + 6, mlp=[128, 128, 128]) 148 | 149 | self.fc_layer = nn.Sequential( 150 | nn.Conv1d(128, 128, 1), 151 | nn.BatchNorm1d(128), 152 | nn.Dropout(0.5), 153 | nn.Conv1d(128, self.part_num, 1) 154 | ) 155 | 156 | def execute(self, xyz, feature, cls_label): 157 | # for module in self.pointnet_modules: 158 | # xyz, feature = module(xyz, feature) 159 | 160 | B, N, _ = xyz.shape 161 | l1_xyz, l1_feature = self.pointnet_modules[0](xyz, feature) 162 | l2_xyz, l2_feature = self.pointnet_modules[1](l1_xyz, l1_feature) 163 | l3_xyz, l3_feature = self.pointnet_modules[2](l2_xyz, l2_feature) 164 | # print ('before interpolate shape') 165 | # print (l2_xyz.shape, l2_feature.shape, l3_xyz.shape, l3_feature.shape) 166 | l2_feature = self.fp3(l2_xyz, l3_xyz, l2_feature, l3_feature) 167 | l1_feature = self.fp2(l1_xyz, l2_xyz, l1_feature, l2_feature) 168 | cls_label_one_hot = cls_label.view(B, 16, 1).repeat(1, 1, N).permute(0, 2, 1) 169 | # print ('before concat size ') 170 | # print (cls_label_one_hot.size(),xyz.size(),feature.size()) 171 | feature = self.fp1(xyz, l1_xyz, concat([cls_label_one_hot, xyz, feature], 2), l1_feature) 172 | feature = feature.permute(0, 2, 1) 173 | # print (feature.shape) 174 | return self.fc_layer(feature) 175 | 176 | 177 | class PointNetMSG(PointNet2_partseg): 178 | def build_model(self): 179 | super().build_model() 180 | 181 | self.pointnet_modules = nn.ModuleList() 182 | self.pointnet_modules.append( 183 | PointnetModuleMSG( 184 | n_points=512, 185 | radius=[0.1, 0.2, 0.4], 186 | n_samples=[16, 32, 128], 187 | mlps=[[3, 32, 32, 64], [3, 64, 64, 128], [3, 64, 96, 128]], 188 | use_xyz=self.use_xyz, 189 | ) 190 | ) 191 | 192 | input_channels = 64 + 128 + 128 193 | self.pointnet_modules.append( 194 | PointnetModuleMSG( 195 | n_points=128, 196 | radius=[0.2, 0.4, 0.8], 197 | n_samples=[32, 64, 128], 198 | mlps=[ 199 | [input_channels, 64, 64, 128], 200 | [input_channels, 128, 128, 256], 201 | [input_channels, 128, 128, 256], 202 | ], 203 | use_xyz=self.use_xyz, 204 | ) 205 | ) 206 | 207 | self.pointnet_modules.append( 208 | PointnetModule( 209 | mlp=[128 + 256 + 256, 256, 512, 1024], 210 | use_xyz=self.use_xyz, 211 | ) 212 | ) 213 | 214 | 215 | def main(): 216 | model = PointNet2_partseg() 217 | input_point = init.gauss([2, 1024, 3], 'float', mean=0.0) 218 | input_feature = init.gauss([2, 1024, 3], 'float', mean=0.0) 219 | cls_label = init.gauss([2, 16], 'float', mean=0.0) 220 | 221 | print(input_point.shape) 222 | print(input_feature.shape) 223 | print(cls_label.shape) 224 | outputs = model(input_point, input_feature, cls_label) 225 | print(outputs.shape) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/models/transformers.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import init, nn 3 | from models.misc.ops import knn, index_points, gather_operation, grouping_operation 4 | 5 | class Transformer(nn.Module): 6 | def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4): 7 | super(Transformer, self).__init__() 8 | self.n_knn = n_knn 9 | self.conv_key = nn.Conv1d(dim, dim, 1) 10 | self.conv_query = nn.Conv1d(dim, dim, 1) 11 | self.conv_value = nn.Conv1d(dim, dim, 1) 12 | 13 | self.pos_mlp = nn.Sequential( 14 | nn.Conv2d(3, pos_hidden_dim, 1), 15 | nn.BatchNorm2d(pos_hidden_dim), 16 | nn.ReLU(), 17 | nn.Conv2d(pos_hidden_dim, dim, 1) 18 | ) 19 | 20 | self.attn_mlp = nn.Sequential( 21 | nn.Conv2d(dim, dim * attn_hidden_multiplier, 1), 22 | nn.BatchNorm2d(dim * attn_hidden_multiplier), 23 | nn.ReLU(), 24 | nn.Conv2d(dim * attn_hidden_multiplier, dim, 1) 25 | ) 26 | 27 | self.linear_start = nn.Conv1d(in_channel, dim, 1) 28 | self.linear_end = nn.Conv1d(dim, in_channel, 1) 29 | 30 | def execute(self, x, pos): 31 | """ 32 | Args: 33 | x: Tensor, (B, c, 2048) 34 | pos: Tensor, (B, 2048, 3) 35 | """ 36 | identity = x 37 | x_bcn = self.linear_start(x) 38 | b, dim, n = x_bcn.shape 39 | pos_bcn = pos.transpose(0, 2, 1) 40 | _, idx_knn = knn(pos, pos, self.n_knn) 41 | # idx_knn = knn(pos_bcn, self.n_knn) 42 | 43 | key = self.conv_key(x_bcn) 44 | value = self.conv_value(x_bcn) 45 | query = self.conv_query(x_bcn) 46 | 47 | # key = index_points(key.transpose(0, 2, 1), idx_knn).transpose(0, 3, 1, 2) # (b, c, n, n_knn) 48 | key = grouping_operation(key, idx_knn) 49 | # print('key.shape', key.shape) 50 | qk_rel = query.reshape((b, -1, n, 1)) - key 51 | 52 | 53 | pos_rel = pos_bcn.reshape((b, -1, n, 1)) - \ 54 | grouping_operation(pos_bcn, idx_knn) 55 | # index_points(pos, idx_knn).transpose(0, 3, 1, 2) 56 | pos_embedding = self.pos_mlp(pos_rel) 57 | 58 | attention = self.attn_mlp(qk_rel + pos_embedding) 59 | attention = nn.softmax(attention, dim=-1) 60 | 61 | value = value.reshape((b, -1, n, 1)) + pos_embedding 62 | 63 | agg = (value * attention).sum(dim=-1) 64 | y = self.linear_end(agg) 65 | 66 | return y+identity 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | easydict 3 | h5py 4 | matplotlib 5 | numpy 6 | open3d==0.9.0.0 7 | opencv-python 8 | pyexr 9 | scipy 10 | tensorboardX==1.2 11 | transforms3d 12 | tqdm -------------------------------------------------------------------------------- /PMPPlus-Jittor/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/PMPPlus-Jittor/utils/__init__.py -------------------------------------------------------------------------------- /PMPPlus-Jittor/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-06 22:50:12 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-03 21:50:38 6 | # @Email: cshzxie@gmail.com 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self, items=None): 12 | self.items = items 13 | self.n_items = 1 if items is None else len(items) 14 | self.reset() 15 | 16 | def reset(self): 17 | self._val = [0] * self.n_items 18 | self._sum = [0] * self.n_items 19 | self._count = [0] * self.n_items 20 | 21 | def update(self, values): 22 | if type(values).__name__ == 'list': 23 | for idx, v in enumerate(values): 24 | self._val[idx] = v 25 | self._sum[idx] += v 26 | self._count[idx] += 1 27 | else: 28 | self._val[0] = values 29 | self._sum[0] += values 30 | self._count[0] += 1 31 | 32 | def val(self, idx=None): 33 | if idx is None: 34 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)] 35 | else: 36 | return self._val[idx] 37 | 38 | def count(self, idx=None): 39 | if idx is None: 40 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)] 41 | else: 42 | return self._count[idx] 43 | 44 | def avg(self, idx=None): 45 | if idx is None: 46 | return self._sum[0] / self._count[0] if self.items is None else [ 47 | self._sum[i] / self._count[i] for i in range(self.n_items) 48 | ] 49 | else: 50 | return self._sum[idx] / self._count[idx] 51 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-07-31 16:57:15 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-02-22 18:34:19 6 | # @Email: cshzxie@gmail.com 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | 14 | def var_or_cuda(x): 15 | if torch.cuda.is_available(): 16 | x = x.cuda(non_blocking=True) 17 | 18 | return x 19 | 20 | 21 | def init_weights(m): 22 | if type(m) == torch.nn.Conv2d or type(m) == torch.nn.ConvTranspose2d or \ 23 | type(m) == torch.nn.Conv3d or type(m) == torch.nn.ConvTranspose3d: 24 | torch.nn.init.kaiming_normal_(m.weight) 25 | if m.bias is not None: 26 | torch.nn.init.constant_(m.bias, 0) 27 | elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.BatchNorm3d: 28 | torch.nn.init.constant_(m.weight, 1) 29 | torch.nn.init.constant_(m.bias, 0) 30 | elif type(m) == torch.nn.Linear: 31 | torch.nn.init.normal_(m.weight, 0, 0.01) 32 | torch.nn.init.constant_(m.bias, 0) 33 | 34 | 35 | def count_parameters(network): 36 | return sum(p.numel() for p in network.parameters()) 37 | 38 | 39 | def get_ptcloud_img(ptcloud): 40 | fig = plt.figure(figsize=(8, 8)) 41 | 42 | x, z, y = ptcloud.transpose(1, 0) 43 | ax = fig.gca(projection=Axes3D.name, adjustable='box') 44 | ax.axis('off') 45 | ax.axis('scaled') 46 | ax.view_init(30, 45) 47 | 48 | max, min = np.max(ptcloud), np.min(ptcloud) 49 | ax.set_xbound(min, max) 50 | ax.set_ybound(min, max) 51 | ax.set_zbound(min, max) 52 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet') 53 | 54 | fig.canvas.draw() 55 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 56 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, )) 57 | return img 58 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-02 10:22:03 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-02-22 19:13:01 6 | # @Email: cshzxie@gmail.com 7 | 8 | import cv2 9 | import h5py 10 | import numpy as np 11 | import open3d 12 | import os 13 | import sys 14 | 15 | from io import BytesIO 16 | 17 | # References: http://confluence.sensetime.com/pages/viewpage.action?pageId=44650315 18 | from config_c3d import cfg 19 | sys.path.append(cfg.MEMCACHED.LIBRARY_PATH) 20 | 21 | mc_client = None 22 | if cfg.MEMCACHED.ENABLED: 23 | import mc 24 | mc_client = mc.MemcachedClient.GetInstance(cfg.MEMCACHED.SERVER_CONFIG, cfg.MEMCACHED.CLIENT_CONFIG) 25 | 26 | 27 | class IO: 28 | @classmethod 29 | def get(cls, file_path): 30 | _, file_extension = os.path.splitext(file_path) 31 | 32 | if file_extension in ['.png', '.jpg']: 33 | return cls._read_img(file_path) 34 | elif file_extension in ['.npy']: 35 | return cls._read_npy(file_path) 36 | elif file_extension in ['.exr']: 37 | return cls._read_exr(file_path) 38 | elif file_extension in ['.pcd']: 39 | return cls._read_pcd(file_path) 40 | elif file_extension in ['.h5']: 41 | return cls._read_h5(file_path) 42 | elif file_extension in ['.txt']: 43 | return cls._read_txt(file_path) 44 | else: 45 | raise Exception('Unsupported file extension: %s' % file_extension) 46 | 47 | @classmethod 48 | def put(cls, file_path, file_content): 49 | _, file_extension = os.path.splitext(file_path) 50 | 51 | if file_extension in ['.pcd']: 52 | return cls._write_pcd(file_path, file_content) 53 | elif file_extension in ['.h5']: 54 | return cls._write_h5(file_path, file_content) 55 | else: 56 | raise Exception('Unsupported file extension: %s' % file_extension) 57 | 58 | @classmethod 59 | def _read_img(cls, file_path): 60 | if mc_client is None: 61 | return cv2.imread(file_path, cv2.IMREAD_UNCHANGED) / 255. 62 | else: 63 | pyvector = mc.pyvector() 64 | mc_client.Get(file_path, pyvector) 65 | buf = mc.ConvertBuffer(pyvector) 66 | img_array = np.frombuffer(buf, np.uint8) 67 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) 68 | return img / 255. 69 | 70 | # References: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py 71 | @classmethod 72 | def _read_npy(cls, file_path): 73 | if mc_client is None: 74 | return np.load(file_path) 75 | else: 76 | pyvector = mc.pyvector() 77 | mc_client.Get(file_path, pyvector) 78 | buf = mc.ConvertBuffer(pyvector) 79 | buf_bytes = buf.tobytes() 80 | if not buf_bytes[:6] == b'\x93NUMPY': 81 | raise Exception('Invalid npy file format.') 82 | 83 | header_size = int.from_bytes(buf_bytes[8:10], byteorder='little') 84 | header = eval(buf_bytes[10:header_size + 10]) 85 | dtype = np.dtype(header['descr']) 86 | nd_array = np.frombuffer(buf[header_size + 10:], dtype).reshape(header['shape']) 87 | 88 | return nd_array 89 | 90 | # @classmethod 91 | # def _read_exr(cls, file_path): 92 | # return 1.0 / pyexr.open(file_path).get("Depth.Z").astype(np.float32) 93 | 94 | # References: https://github.com/dimatura/pypcd/blob/master/pypcd/pypcd.py#L275 95 | # Support PCD files without compression ONLY! 96 | @classmethod 97 | def _read_pcd(cls, file_path): 98 | if mc_client is None: 99 | pc = open3d.io.read_point_cloud(file_path) 100 | ptcloud = np.array(pc.points) 101 | else: 102 | pyvector = mc.pyvector() 103 | mc_client.Get(file_path, pyvector) 104 | text = mc.ConvertString(pyvector).split('\n') 105 | start_line_idx = len(text) - 1 106 | for idx, line in enumerate(text): 107 | if line == 'DATA ascii': 108 | start_line_idx = idx + 1 109 | break 110 | 111 | ptcloud = text[start_line_idx:] 112 | ptcloud = np.genfromtxt(BytesIO('\n'.join(ptcloud).encode()), dtype=np.float32) 113 | 114 | # ptcloud = np.concatenate((ptcloud, np.array([[0, 0, 0]])), axis=0) 115 | return ptcloud 116 | 117 | @classmethod 118 | def _read_h5(cls, file_path): 119 | f = h5py.File(file_path, 'r') 120 | # Avoid overflow while gridding 121 | return f['data'][()] 122 | 123 | @classmethod 124 | def _read_txt(cls, file_path): 125 | return np.loadtxt(file_path) 126 | 127 | @classmethod 128 | def _write_pcd(cls, file_path, file_content): 129 | pc = open3d.geometry.PointCloud() 130 | pc.points = open3d.utility.Vector3dVector(file_content) 131 | open3d.io.write_point_cloud(file_path, pc) 132 | 133 | @classmethod 134 | def _write_h5(cls, file_path, file_content): 135 | with h5py.File(file_path, 'w') as f: 136 | f.create_dataset('data', data=file_content) 137 | -------------------------------------------------------------------------------- /PMPPlus-Jittor/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-08 14:31:30 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-05-25 09:13:32 6 | # @Email: cshzxie@gmail.com 7 | 8 | import logging 9 | import open3d 10 | import torch 11 | 12 | # from Chamfer3D.dist_chamfer_3D import chamfer_3DDist 13 | 14 | 15 | class Metrics(object): 16 | ITEMS = [{ 17 | 'name': 'pmd', 18 | 'enabled': True, 19 | 'eval_func': 'cls._get_emd_distance', 20 | 'eval_object': None, 21 | 'is_greater_better': False, 22 | 'init_value': 32767 23 | },{ 24 | 'name': 'ChamferDistance', 25 | 'enabled': True, 26 | 'eval_func': 'cls._get_chamfer_distance', 27 | 'eval_object': None, 28 | # 'eval_object': ChamferDistance(ignore_zeros=True), 29 | 'is_greater_better': False, 30 | 'init_value': 32767 31 | }] 32 | 33 | @classmethod 34 | def get(cls, pred, gt): 35 | _items = cls.items() 36 | _values = [0] * len(_items) 37 | for i, item in enumerate(_items): 38 | eval_func = eval(item['eval_func']) 39 | _values[i] = eval_func(pred, gt) 40 | 41 | return _values 42 | 43 | @classmethod 44 | def items(cls): 45 | return [i for i in cls.ITEMS if i['enabled']] 46 | 47 | @classmethod 48 | def names(cls): 49 | _items = cls.items() 50 | return [i['name'] for i in _items] 51 | 52 | @classmethod 53 | def _get_f_score(cls, pred, gt, th=0.01): 54 | """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py""" 55 | pred = cls._get_open3d_ptcloud(pred) 56 | gt = cls._get_open3d_ptcloud(gt) 57 | 58 | dist1 = pred.compute_point_cloud_distance(gt) 59 | dist2 = gt.compute_point_cloud_distance(pred) 60 | 61 | recall = float(sum(d < th for d in dist2)) / float(len(dist2)) 62 | precision = float(sum(d < th for d in dist1)) / float(len(dist1)) 63 | return 2 * recall * precision / (recall + precision) if recall + precision else 0 64 | 65 | @classmethod 66 | def _get_open3d_ptcloud(cls, tensor): 67 | tensor = tensor.squeeze().cpu().numpy() 68 | ptcloud = open3d.geometry.PointCloud() 69 | ptcloud.points = open3d.utility.Vector3dVector(tensor) 70 | 71 | return ptcloud 72 | 73 | @classmethod 74 | def _get_chamfer_distance(cls, pred, gt): 75 | # chamfer_distance = cls.ITEMS[1]['eval_object'] 76 | chamfer_distance = cls.ITEMS[1]['eval_object'] 77 | d1, d2, _, _ = chamfer_distance(pred, gt) 78 | cd = torch.mean(d1) + torch.mean(d2) 79 | return cd.item() * 1000 80 | # return chamfer_distance(pred, gt).item() * 1000 81 | 82 | @classmethod 83 | def _get_emd_distance(cls, pred, gt): 84 | emd_distance = cls.ITEMS[0]['eval_object'] 85 | return torch.mean(emd_distance(pred, gt)).item() 86 | 87 | def __init__(self, metric_name, values): 88 | self._items = Metrics.items() 89 | self._values = [item['init_value'] for item in self._items] 90 | self.metric_name = metric_name 91 | 92 | if type(values).__name__ == 'list': 93 | self._values = values 94 | elif type(values).__name__ == 'dict': 95 | metric_indexes = {} 96 | for idx, item in enumerate(self._items): 97 | item_name = item['name'] 98 | metric_indexes[item_name] = idx 99 | for k, v in values.items(): 100 | if k not in metric_indexes: 101 | logging.warn('Ignore Metric[Name=%s] due to disability.' % k) 102 | continue 103 | self._values[metric_indexes[k]] = v 104 | else: 105 | raise Exception('Unsupported value type: %s' % type(values)) 106 | 107 | def state_dict(self): 108 | _dict = dict() 109 | for i in range(len(self._items)): 110 | item = self._items[i]['name'] 111 | value = self._values[i] 112 | _dict[item] = value 113 | 114 | return _dict 115 | 116 | def __repr__(self): 117 | return str(self.state_dict()) 118 | 119 | def better_than(self, other): 120 | if other is None: 121 | return True 122 | 123 | _index = -1 124 | for i, _item in enumerate(self._items): 125 | if _item['name'] == self.metric_name: 126 | _index = i 127 | break 128 | if _index == -1: 129 | raise Exception('Invalid metric name to compare.') 130 | 131 | _metric = self._items[i] 132 | _value = self._values[_index] 133 | other_value = other._values[_index] 134 | return _value > other_value if _metric['is_greater_better'] else _value < other_value 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PMP-Net++: Point Cloud Completion by Transformer-Enhanced Multi-step Point Moving Paths (TPAMI 2023) 2 | 3 | [Intro pic](pics/network.png) 4 | 5 | ## [NEWS] 6 | 7 | - **2022-03 [NEW:tada:]** The [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/) implementations of both PMP-Net and PMP-Net++ are released in the [PMPPlus-Jittor](https://github.com/diviswen/PMP-Net/tree/main/PMPPlus-Jittor) folder. 8 | - **2022-02 [NEW:tada:]** [PMP-Net++](https://arxiv.org/abs/2012.03408), the journal extension of PMP-Net is accepted to [TPAMI](https://ieeexplore.ieee.org/document/9735342). This repository now contains the code of both PMP-Net and PMP-Net++! 9 | - **2021** [PMP-Net](https://arxiv.org/abs/2012.03408) is published at [CVPR 2021](https://openaccess.thecvf.com/content/CVPR2021/html/Wen_PMP-Net_Point_Cloud_Completion_by_Learning_Multi-Step_Point_Moving_Paths_CVPR_2021_paper.html), and the code is released! 10 | 11 | ## [PMP-Net++] 12 | This repository contains the PyTorch implementation and Jittor implementation of the papers: 13 | 14 | **1. PMP-Net++: Point Cloud Completion by Transformer-Enhanced Multi-step Point Moving Paths, TPAMI 2023** 15 | 16 | **2. PMP-Net: Point Cloud Completion by Learning Multi-step Point Moving Paths, CVPR 2021** 17 | 18 | [ [PMP-Net](https://arxiv.org/abs/2012.03408) | [PMP-Net++](https://arxiv.org/abs/2012.03408) | [IEEE Xplore](https://ieeexplore.ieee.org/document/9735342) | [Webpage]() | [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/) ] 19 | 20 | > Point cloud completion concerns to predict missing part for incomplete 3D shapes. A common strategy is to generate 21 | complete shape according to incomplete input. However, unordered nature of point clouds will degrade generation of high-quality 3D 22 | shapes, as detailed topology and structure of unordered points are hard to be captured during the generative process using an 23 | extracted latent code. We address this problem by formulating completion as point cloud deformation process. Specifically, we design a 24 | novel neural network, named PMP-Net++, to mimic behavior of an earth mover. It moves each point of incomplete input to obtain a 25 | complete point cloud, where total distance of point moving paths (PMPs) should be the shortest. Therefore, PMP-Net++ predicts 26 | unique PMP for each point according to constraint of point moving distances. The network learns a strict and unique correspondence 27 | on point-level, and thus improves quality of predicted complete shape. Moreover, since moving points heavily relies on per-point 28 | features learned by network, we further introduce a transformer-enhanced representation learning network, which significantly 29 | improves completion performance of PMP-Net++. We conduct comprehensive experiments in shape completion, and further explore 30 | application on point cloud up-sampling, which demonstrate non-trivial improvement of PMP-Net++ over state-of-the-art point cloud 31 | completion/up-sampling methods 32 | 33 | ## [Cite this work] 34 | 35 | ``` 36 | @ARTICLE{pmpnet++, 37 | author={Wen, Xin and Xiang, Peng and Han, Zhizhong and Cao, Yan-Pei and Wan, Pengfei and Zheng, Wen and Liu, Yu-Shen}, 38 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 39 | title={PMP-Net++: Point Cloud Completion by Transformer-Enhanced Multi-Step Point Moving Paths}, 40 | year={2023}, 41 | volume={45}, 42 | number={1}, 43 | pages={852-867}, 44 | doi={10.1109/TPAMI.2022.3159003}} 45 | 46 | @inproceedings{wen2021pmp, 47 | title={PMP-Net: Point cloud completion by learning multi-step point moving paths}, 48 | author={Wen, Xin and Xiang, Peng and Han, Zhizhong and Cao, Yan-Pei and Wan, Pengfei and Zheng, Wen and Liu, Yu-Shen}, 49 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 50 | year={2021} 51 | } 52 | ``` 53 | 54 | ## [Getting Started] 55 | #### Datasets and Pretrained Models 56 | 57 | We use the [PCN](https://www.shapenet.org/) and [Compeletion3D](http://completion3d.stanford.edu/) datasets in our experiments, which are available below: 58 | 59 | - [PCN](https://drive.google.com/drive/folders/1P_W1tz5Q4ZLapUifuOE4rFAZp6L1XTJz) 60 | - [Completion3D](http://download.cs.stanford.edu/downloads/completion3d/dataset2019.zip) 61 | 62 | The pretrained models on Completion3D and PCN dataset are available as follows: 63 | 64 | - [PMP-Net_pre-trained](https://drive.google.com/drive/folders/1emGsfdnIj1eUtUxZlfiWiuJ0QJag4nOn?usp=sharing) 65 | 66 | Backup Links: 67 | 68 | - [PMP-Net_pre-trained](https://pan.baidu.com/s/1oQbaVI7yN9NmI_2E9tztGQ) (pwd: n7t4) 69 | 70 | #### Install Python Denpendencies 71 | 72 | ``` 73 | cd PMP-Net 74 | conda create -n pmp python=3.7 75 | conda activate pmp 76 | pip3 install -r requirements.txt 77 | ``` 78 | 79 | #### Build PyTorch Extensions 80 | 81 | **NOTE:** PyTorch >= 1.4 of cuda version are required. 82 | 83 | ``` 84 | cd pointnet2_ops_lib 85 | python setup.py install 86 | 87 | cd .. 88 | 89 | cd Chamfer3D 90 | python setup.py install 91 | ``` 92 | 93 | You need to update the file path of the datasets: 94 | 95 | ``` 96 | __C.DATASETS.COMPLETION3D.PARTIAL_POINTS_PATH = '/path/to/datasets/Completion3D/%s/partial/%s/%s.h5' 97 | __C.DATASETS.COMPLETION3D.COMPLETE_POINTS_PATH = '/path/to/datasets/Completion3D/%s/gt/%s/%s.h5' 98 | __C.DATASETS.SHAPENET.PARTIAL_POINTS_PATH = '/path/to/datasets/ShapeNet/ShapeNetCompletion/%s/partial/%s/%s/%02d.pcd' 99 | __C.DATASETS.SHAPENET.COMPLETE_POINTS_PATH = '/path/to/datasets/ShapeNet/ShapeNetCompletion/%s/complete/%s/%s.pcd' 100 | 101 | # Dataset Options: Completion3D, Completion3DPCCT, ShapeNet, ShapeNetCars 102 | __C.DATASET.TRAIN_DATASET = 'ShapeNet' 103 | __C.DATASET.TEST_DATASET = 'ShapeNet' 104 | ``` 105 | 106 | #### Training, Testing and Inference 107 | 108 | To train PMP-Net++ or PMP-Net, you can simply use the following command: 109 | 110 | ``` 111 | python main_*.py # remember to change '*' to 'c3d' or 'pcn', and change between 'import PMPNetPlus' and 'import PMPNet' 112 | ``` 113 | 114 | To test or inference, you should specify the path of checkpoint if the config_*.py file 115 | ``` 116 | __C.CONST.WEIGHTS = "path to your checkpoint" 117 | ``` 118 | 119 | then use the following command: 120 | 121 | ``` 122 | python main_*.py --test 123 | python main_*.py --inference 124 | ``` 125 | 126 | ## [Acknowledgements] 127 | 128 | Some of the code of this repo is borrowed from [GRNet](https://github.com/hzxie/GRNet), [pytorchpointnet++](https://github.com/erikwijmans/Pointnet2_PyTorch) and [ChamferDistancePytorch](https://github.com/ThibaultGROUEIX/ChamferDistancePytorch). We thank the authors for their wonderful job! 129 | 130 | ## [License] 131 | 132 | This project is open sourced under MIT license. 133 | -------------------------------------------------------------------------------- /config_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | from easydict import EasyDict as edict 5 | 6 | __C = edict() 7 | cfg = __C 8 | 9 | # 10 | # Dataset Config 11 | # 12 | __C.DATASETS = edict() 13 | __C.DATASETS.COMPLETION3D = edict() 14 | __C.DATASETS.COMPLETION3D.CATEGORY_FILE_PATH = './datasets/Completion3D.json' 15 | __C.DATASETS.COMPLETION3D.PARTIAL_POINTS_PATH = '/data/xp/data/shapenet/shapenet/%s/partial/%s/%s.h5' 16 | __C.DATASETS.COMPLETION3D.COMPLETE_POINTS_PATH = '/data/xp/data/shapenet/shapenet/%s/gt/%s/%s.h5' 17 | __C.DATASETS.SHAPENET = edict() 18 | __C.DATASETS.SHAPENET.CATEGORY_FILE_PATH = './datasets/ShapeNet.json' 19 | __C.DATASETS.SHAPENET.N_RENDERINGS = 8 20 | __C.DATASETS.SHAPENET.N_POINTS = 16384 21 | __C.DATASETS.SHAPENET.PARTIAL_POINTS_PATH = '/data/PCN/%s/partial/%s/%s/%02d.pcd' 22 | __C.DATASETS.SHAPENET.COMPLETE_POINTS_PATH = '/data/PCN/%s/complete/%s/%s.pcd' 23 | 24 | # 25 | # Dataset 26 | # 27 | __C.DATASET = edict() 28 | # Dataset Options: Completion3D, ShapeNet, ShapeNetCars, Completion3DPCCT 29 | __C.DATASET.TRAIN_DATASET = 'Completion3D' 30 | __C.DATASET.TEST_DATASET = 'Completion3D' 31 | 32 | # 33 | # Constants 34 | # 35 | __C.CONST = edict() 36 | 37 | __C.CONST.NUM_WORKERS = 8 38 | __C.CONST.N_INPUT_POINTS = 2048 39 | 40 | # 41 | # Directories 42 | # 43 | 44 | __C.DIR = edict() 45 | __C.DIR.OUT_PATH = './exp/output' 46 | __C.CONST.DEVICE = '0' 47 | __C.CONST.WEIGHTS = '' # ./pretrained/completion3d/ckpt-best-plus.pth 48 | 49 | # 50 | # Memcached 51 | # 52 | __C.MEMCACHED = edict() 53 | __C.MEMCACHED.ENABLED = False 54 | __C.MEMCACHED.LIBRARY_PATH = '/mnt/lustre/share/pymc/py3' 55 | __C.MEMCACHED.SERVER_CONFIG = '/mnt/lustre/share/memcached_client/server_list.conf' 56 | __C.MEMCACHED.CLIENT_CONFIG = '/mnt/lustre/share/memcached_client/client.conf' 57 | 58 | # 59 | # Network 60 | # 61 | __C.NETWORK = edict() 62 | __C.NETWORK.N_SAMPLING_POINTS = 2048 63 | 64 | # 65 | # Train 66 | # 67 | __C.TRAIN = edict() 68 | __C.TRAIN.LAMBDA_CD = 1000 69 | __C.TRAIN.LAMBDA_PMD = 1e-2 70 | __C.TRAIN.BATCH_SIZE = 16 71 | __C.TRAIN.N_EPOCHS = 150 72 | __C.TRAIN.SAVE_FREQ = 25 73 | __C.TRAIN.LEARNING_RATE = 0.001 74 | __C.TRAIN.LR_MILESTONES = [50, 100, 150, 200, 250] 75 | __C.TRAIN.GAMMA = .5 76 | __C.TRAIN.BETAS = (.9, .999) 77 | __C.TRAIN.WEIGHT_DECAY = 0 78 | 79 | # 80 | # Test 81 | # 82 | __C.TEST = edict() 83 | __C.TEST.METRIC_NAME = 'ChamferDistance' 84 | -------------------------------------------------------------------------------- /config_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | from easydict import EasyDict as edict 5 | 6 | __C = edict() 7 | cfg = __C 8 | 9 | # 10 | # Dataset Config 11 | # 12 | __C.DATASETS = edict() 13 | __C.DATASETS.COMPLETION3D = edict() 14 | __C.DATASETS.COMPLETION3D.CATEGORY_FILE_PATH = './datasets/Completion3D.json' 15 | __C.DATASETS.COMPLETION3D.PARTIAL_POINTS_PATH = '/data/shapenet/%s/partial/%s/%s.h5' 16 | __C.DATASETS.COMPLETION3D.COMPLETE_POINTS_PATH = '/data/shapenet/%s/gt/%s/%s.h5' 17 | __C.DATASETS.SHAPENET = edict() 18 | __C.DATASETS.SHAPENET.CATEGORY_FILE_PATH = './datasets/ShapeNet.json' 19 | __C.DATASETS.SHAPENET.N_RENDERINGS = 8 20 | __C.DATASETS.SHAPENET.N_POINTS = 16384 21 | __C.DATASETS.SHAPENET.PARTIAL_POINTS_PATH = '/data/xp/data/PCN/%s/partial/%s/%s/%02d.pcd' 22 | __C.DATASETS.SHAPENET.COMPLETE_POINTS_PATH = '/data/xp/data/PCN/%s/complete/%s/%s.pcd' 23 | 24 | # 25 | # Dataset 26 | # 27 | __C.DATASET = edict() 28 | # Dataset Options: Completion3D, ShapeNet, ShapeNetCars, Completion3DPCCT 29 | __C.DATASET.TRAIN_DATASET = 'ShapeNet' 30 | __C.DATASET.TEST_DATASET = 'ShapeNet' 31 | 32 | # 33 | # Constants 34 | # 35 | __C.CONST = edict() 36 | 37 | __C.CONST.NUM_WORKERS = 8 38 | __C.CONST.N_INPUT_POINTS = 2048 39 | 40 | # 41 | # Directories 42 | # 43 | 44 | __C.DIR = edict() 45 | __C.DIR.OUT_PATH = './exp/output' 46 | __C.CONST.DEVICE = '0' 47 | __C.CONST.WEIGHTS = '' # './pretrained/pcn/ckpt-best-pmpplus.pth specify' a path to run test and inference 48 | 49 | # 50 | # Memcached 51 | # 52 | __C.MEMCACHED = edict() 53 | __C.MEMCACHED.ENABLED = False 54 | __C.MEMCACHED.LIBRARY_PATH = '/mnt/lustre/share/pymc/py3' 55 | __C.MEMCACHED.SERVER_CONFIG = '/mnt/lustre/share/memcached_client/server_list.conf' 56 | __C.MEMCACHED.CLIENT_CONFIG = '/mnt/lustre/share/memcached_client/client.conf' 57 | 58 | # 59 | # Network 60 | # 61 | __C.NETWORK = edict() 62 | __C.NETWORK.N_SAMPLING_POINTS = 2048 63 | 64 | 65 | # 66 | # Train 67 | # 68 | __C.TRAIN = edict() 69 | __C.TRAIN.LAMBDA_CD = 1000 70 | __C.TRAIN.LAMBDA_PMD = 1e-2 71 | __C.TRAIN.BATCH_SIZE = 16 72 | __C.TRAIN.N_EPOCHS = 400 73 | __C.TRAIN.SAVE_FREQ = 25 74 | __C.TRAIN.LEARNING_RATE = 0.001 75 | __C.TRAIN.LR_MILESTONES = [50, 100, 150, 200, 250] 76 | __C.TRAIN.GAMMA = .5 77 | __C.TRAIN.BETAS = (.9, .999) 78 | __C.TRAIN.WEIGHT_DECAY = 0 79 | 80 | # 81 | # Test 82 | # 83 | __C.TEST = edict() 84 | __C.TEST.METRIC_NAME = 'ChamferDistance' 85 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__init__.py -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/inference_c3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/inference_c3d.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/inference_pcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/inference_pcn.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/test_c3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/test_c3d.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/test_pcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/test_pcn.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_c3d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/train_c3d.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_pcn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/core/__pycache__/train_pcn.cpython-37.pyc -------------------------------------------------------------------------------- /core/inference_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import torch 7 | import utils.data_loaders 8 | import utils.helpers 9 | import utils.io 10 | from tqdm import tqdm 11 | from models.model import PMPNetPlus as Model 12 | 13 | 14 | def inference_net(cfg): 15 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 16 | torch.backends.cudnn.benchmark = True 17 | 18 | # Set up data loader 19 | dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 20 | test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset( 21 | utils.data_loaders.DatasetSubset.TEST), 22 | batch_size=1, 23 | num_workers=cfg.CONST.NUM_WORKERS, 24 | collate_fn=utils.data_loaders.collate_fn, 25 | pin_memory=True, 26 | shuffle=False) 27 | 28 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 29 | 30 | if torch.cuda.is_available(): 31 | model = torch.nn.DataParallel(model).cuda() 32 | 33 | # Load the pretrained model from a checkpoint 34 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 35 | logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) 36 | checkpoint = torch.load(cfg.CONST.WEIGHTS) 37 | model.load_state_dict(checkpoint['model']) 38 | 39 | # Switch models to evaluation mode 40 | model.eval() 41 | 42 | # The inference loop 43 | n_samples = len(test_data_loader) 44 | t_obj = tqdm(test_data_loader) 45 | 46 | 47 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t_obj): 48 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 49 | model_id = model_id[0] 50 | 51 | with torch.no_grad(): 52 | for k, v in data.items(): 53 | data[k] = utils.helpers.var_or_cuda(v) 54 | 55 | partial = data['partial_cloud'] 56 | 57 | pcds = model(partial)[0] 58 | pcd1, pcd2, pcd3 = pcds 59 | 60 | 61 | output_folder = os.path.join(cfg.DIR.OUT_PATH, 'benchmark', taxonomy_id) 62 | if not os.path.exists(output_folder): 63 | os.makedirs(output_folder) 64 | output_folder_pcd1 = os.path.join(output_folder, 'pcd1') 65 | output_folder_pcd2 = os.path.join(output_folder, 'pcd2') 66 | output_folder_pcd3 = os.path.join(output_folder, 'pcd3') 67 | if not os.path.exists(output_folder_pcd1): 68 | os.makedirs(output_folder_pcd1) 69 | os.makedirs(output_folder_pcd2) 70 | os.makedirs(output_folder_pcd3) 71 | 72 | output_file_path = os.path.join(output_folder, 'pcd1', '%s.h5' % model_id) 73 | utils.io.IO.put(output_file_path, pcd1.squeeze().cpu().numpy()) 74 | 75 | output_file_path = os.path.join(output_folder, 'pcd2', '%s.h5' % model_id) 76 | utils.io.IO.put(output_file_path, pcd2.squeeze().cpu().numpy()) 77 | 78 | output_file_path = os.path.join(output_folder, 'pcd3', '%s.h5' % model_id) 79 | utils.io.IO.put(output_file_path, pcd3.squeeze().cpu().numpy()) 80 | 81 | t_obj.set_description('Test[%d/%d] Taxonomy = %s Sample = %s File = %s' % 82 | (model_idx + 1, n_samples, taxonomy_id, model_id, output_file_path)) 83 | 84 | -------------------------------------------------------------------------------- /core/inference_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import torch 7 | import utils.data_loaders 8 | import utils.helpers 9 | import utils.io 10 | from tqdm import tqdm 11 | from models.model import PMPNetPlus as Model 12 | 13 | 14 | def random_subsample(pcd, n_points=2048): 15 | """ 16 | Args: 17 | pcd: (B, N, 3) 18 | 19 | returns: 20 | new_pcd: (B, n_points, 3) 21 | """ 22 | b, n, _ = pcd.shape 23 | device = pcd.device 24 | batch_idx = torch.arange(b, dtype=torch.long, device=device).reshape((-1, 1)).repeat(1, n_points) 25 | idx = torch.cat([torch.randperm(n, dtype=torch.long, device=device)[:n_points].reshape((1, -1)) for i in range(b)], 0) 26 | return pcd[batch_idx, idx, :] 27 | 28 | def inference_net(cfg): 29 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 30 | torch.backends.cudnn.benchmark = True 31 | 32 | # Set up data loader 33 | dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 34 | test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset( 35 | utils.data_loaders.DatasetSubset.TEST), 36 | batch_size=1, 37 | num_workers=cfg.CONST.NUM_WORKERS, 38 | collate_fn=utils.data_loaders.collate_fn, 39 | pin_memory=True, 40 | shuffle=False) 41 | 42 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 43 | 44 | if torch.cuda.is_available(): 45 | model = torch.nn.DataParallel(model).cuda() 46 | 47 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 48 | # Load the pretrained model from a checkpoint 49 | logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) 50 | checkpoint = torch.load(cfg.CONST.WEIGHTS) 51 | model.load_state_dict(checkpoint['model']) 52 | 53 | # Switch models to evaluation mode 54 | model.eval() 55 | 56 | # The inference loop 57 | n_samples = len(test_data_loader) 58 | t_obj = tqdm(test_data_loader) 59 | 60 | 61 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t_obj): 62 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 63 | model_id = model_id[0] 64 | 65 | with torch.no_grad(): 66 | for k, v in data.items(): 67 | data[k] = utils.helpers.var_or_cuda(v) 68 | 69 | partial = data['partial_cloud'] 70 | partial = random_subsample(partial.repeat((1, 8, 1)).reshape(-1, 16384, 3)) # b*8, 2048, 3 71 | pcds = model(partial)[0] 72 | 73 | pcd1 = pcds[0].reshape(-1, 16384, 3).contiguous() 74 | pcd2 = pcds[1].reshape(-1, 16384, 3).contiguous() 75 | pcd3 = pcds[2].reshape(-1, 16384, 3).contiguous() 76 | 77 | output_folder = os.path.join(cfg.DIR.OUT_PATH, 'benchmark', taxonomy_id) 78 | if not os.path.exists(output_folder): 79 | os.makedirs(output_folder) 80 | output_folder_pcd1 = os.path.join(output_folder, 'pcd1') 81 | output_folder_pcd2 = os.path.join(output_folder, 'pcd2') 82 | output_folder_pcd3 = os.path.join(output_folder, 'pcd3') 83 | if not os.path.exists(output_folder_pcd1): 84 | os.makedirs(output_folder_pcd1) 85 | os.makedirs(output_folder_pcd2) 86 | os.makedirs(output_folder_pcd3) 87 | 88 | output_file_path = os.path.join(output_folder, 'pcd1', '%s.h5' % model_id) 89 | utils.io.IO.put(output_file_path, pcd1.squeeze().cpu().numpy()) 90 | 91 | output_file_path = os.path.join(output_folder, 'pcd2', '%s.h5' % model_id) 92 | utils.io.IO.put(output_file_path, pcd2.squeeze().cpu().numpy()) 93 | 94 | output_file_path = os.path.join(output_folder, 'pcd3', '%s.h5' % model_id) 95 | utils.io.IO.put(output_file_path, pcd3.squeeze().cpu().numpy()) 96 | 97 | t_obj.set_description('Test[%d/%d] Taxonomy = %s Sample = %s File = %s' % 98 | (model_idx + 1, n_samples, taxonomy_id, model_id, output_file_path)) 99 | 100 | -------------------------------------------------------------------------------- /core/test_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import torch 6 | import utils.data_loaders 7 | import utils.helpers 8 | from tqdm import tqdm 9 | from Chamfer3D.dist_chamfer_3D import chamfer_3DDist 10 | from utils.average_meter import AverageMeter 11 | from utils.metrics import Metrics 12 | from models.model import PMPNet as Model 13 | chamfer_dist = chamfer_3DDist() 14 | 15 | 16 | def chamfer(p1, p2): 17 | d1, d2, _, _ = chamfer_dist(p1, p2) 18 | return torch.mean(d1) + torch.mean(d2) 19 | 20 | 21 | def chamfer_sqrt(p1, p2): 22 | d1, d2, _, _ = chamfer_dist(p1, p2) 23 | d1 = torch.mean(torch.sqrt(d1)) 24 | d2 = torch.mean(torch.sqrt(d2)) 25 | return (d1 + d2) / 2 26 | 27 | 28 | def test_net(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, model=None): 29 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 30 | torch.backends.cudnn.benchmark = True 31 | 32 | if test_data_loader is None: 33 | # Set up data loader 34 | dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 35 | test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset( 36 | utils.data_loaders.DatasetSubset.VAL), 37 | batch_size=1, 38 | num_workers=cfg.CONST.NUM_WORKERS, 39 | collate_fn=utils.data_loaders.collate_fn, 40 | pin_memory=True, 41 | shuffle=False) 42 | 43 | # Setup networks and initialize networks 44 | if model is None: 45 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 46 | if torch.cuda.is_available(): 47 | model = torch.nn.DataParallel(model).cuda() 48 | 49 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 50 | logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) 51 | checkpoint = torch.load(cfg.CONST.WEIGHTS) 52 | model.load_state_dict(checkpoint['model']) 53 | 54 | # Switch models to evaluation mode 55 | model.eval() 56 | 57 | n_samples = len(test_data_loader) 58 | test_losses = AverageMeter(['cd1', 'cd2', 'cd3', 'pmd']) 59 | test_metrics = AverageMeter(Metrics.names()) 60 | category_metrics = dict() 61 | 62 | # Testing loop 63 | with tqdm(test_data_loader) as t: 64 | # print('repeating') 65 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t): 66 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 67 | model_id = model_id[0] 68 | 69 | with torch.no_grad(): 70 | for k, v in data.items(): 71 | data[k] = utils.helpers.var_or_cuda(v) 72 | 73 | partial = data['partial_cloud'] 74 | gt = data['gtcloud'] 75 | 76 | b, n, _ = partial.shape 77 | 78 | pcds, deltas = model(partial.contiguous()) 79 | 80 | cd1 = chamfer(pcds[0], gt).item() * 1e3 81 | cd2 = chamfer(pcds[1], gt).item() * 1e3 82 | cd3 = chamfer(pcds[2], gt).item() * 1e3 83 | 84 | # pmd loss 85 | pmd_losses = [] 86 | for delta in deltas: 87 | pmd_losses.append(torch.sum(delta ** 2)) 88 | 89 | pmd = torch.sum(torch.stack(pmd_losses)) / 3 90 | 91 | pmd_item = pmd.item() 92 | 93 | _metrics = [pmd_item, cd3] 94 | test_losses.update([cd1, cd2, cd3, pmd_item]) 95 | 96 | test_metrics.update(_metrics) 97 | if taxonomy_id not in category_metrics: 98 | category_metrics[taxonomy_id] = AverageMeter(Metrics.names()) 99 | category_metrics[taxonomy_id].update(_metrics) 100 | 101 | t.set_description('Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s' % 102 | (model_idx + 1, n_samples, taxonomy_id, model_id, ['%.4f' % l for l in test_losses.val() 103 | ], ['%.4f' % m for m in _metrics])) 104 | 105 | # Print testing results 106 | print('============================ TEST RESULTS ============================') 107 | print('Taxonomy', end='\t') 108 | print('#Sample', end='\t') 109 | for metric in test_metrics.items: 110 | print(metric, end='\t') 111 | print() 112 | 113 | for taxonomy_id in category_metrics: 114 | print(taxonomy_id, end='\t') 115 | print(category_metrics[taxonomy_id].count(0), end='\t') 116 | for value in category_metrics[taxonomy_id].avg(): 117 | print('%.4f' % value, end='\t') 118 | print() 119 | 120 | print('Overall', end='\t\t\t') 121 | for value in test_metrics.avg(): 122 | print('%.4f' % value, end='\t') 123 | print('\n') 124 | 125 | # Add testing results to TensorBoard 126 | if test_writer is not None: 127 | test_writer.add_scalar('Loss/Epoch/cd1', test_losses.avg(0), epoch_idx) 128 | test_writer.add_scalar('Loss/Epoch/cd2', test_losses.avg(1), epoch_idx) 129 | test_writer.add_scalar('Loss/Epoch/cd3', test_losses.avg(2), epoch_idx) 130 | test_writer.add_scalar('Loss/Epoch/delta', test_losses.avg(3), epoch_idx) 131 | for i, metric in enumerate(test_metrics.items): 132 | test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i), epoch_idx) 133 | 134 | return test_losses.avg(2) 135 | -------------------------------------------------------------------------------- /core/test_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import torch 6 | import utils.data_loaders 7 | import utils.helpers 8 | from tqdm import tqdm 9 | from Chamfer3D.dist_chamfer_3D import chamfer_3DDist 10 | from utils.average_meter import AverageMeter 11 | from utils.metrics import Metrics 12 | from models.model import PMPNetPlus as Model 13 | chamfer_dist = chamfer_3DDist() 14 | 15 | 16 | def chamfer(p1, p2): 17 | d1, d2, _, _ = chamfer_dist(p1, p2) 18 | return torch.mean(d1) + torch.mean(d2) 19 | 20 | 21 | def chamfer_sqrt(p1, p2): 22 | d1, d2, _, _ = chamfer_dist(p1, p2) 23 | d1 = torch.mean(torch.sqrt(d1)) 24 | d2 = torch.mean(torch.sqrt(d2)) 25 | return (d1 + d2) / 2 26 | 27 | 28 | def random_subsample(pcd, n_points=2048): 29 | """ 30 | Args: 31 | pcd: (B, N, 3) 32 | 33 | returns: 34 | new_pcd: (B, n_points, 3) 35 | """ 36 | b, n, _ = pcd.shape 37 | device = pcd.device 38 | batch_idx = torch.arange(b, dtype=torch.long, device=device).reshape((-1, 1)).repeat(1, n_points) 39 | idx = torch.cat([torch.randperm(n, dtype=torch.long, device=device)[:n_points].reshape((1, -1)) for i in range(b)], 0) 40 | return pcd[batch_idx, idx, :] 41 | 42 | def test_net(cfg, epoch_idx=-1, test_data_loader=None, test_writer=None, model=None): 43 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 44 | torch.backends.cudnn.benchmark = True 45 | 46 | if test_data_loader is None: 47 | # Set up data loader 48 | dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 49 | test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset( 50 | utils.data_loaders.DatasetSubset.TEST), 51 | batch_size=1, 52 | num_workers=cfg.CONST.NUM_WORKERS, 53 | collate_fn=utils.data_loaders.collate_fn, 54 | pin_memory=True, 55 | shuffle=False) 56 | 57 | # Setup networks and initialize networks 58 | if model is None: 59 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 60 | if torch.cuda.is_available(): 61 | model = torch.nn.DataParallel(model).cuda() 62 | 63 | assert 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS 64 | logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) 65 | checkpoint = torch.load(cfg.CONST.WEIGHTS) 66 | model.load_state_dict(checkpoint['model']) 67 | 68 | # Switch models to evaluation mode 69 | model.eval() 70 | 71 | n_samples = len(test_data_loader) 72 | test_losses = AverageMeter(['cd1', 'cd2', 'cd3', 'pmd']) 73 | test_metrics = AverageMeter(Metrics.names()) 74 | category_metrics = dict() 75 | 76 | # Testing loop 77 | with tqdm(test_data_loader) as t: 78 | # print('repeating') 79 | for model_idx, (taxonomy_id, model_id, data) in enumerate(t): 80 | taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item() 81 | model_id = model_id[0] 82 | 83 | with torch.no_grad(): 84 | for k, v in data.items(): 85 | data[k] = utils.helpers.var_or_cuda(v) 86 | 87 | partial = data['partial_cloud'] 88 | gt = data['gtcloud'] 89 | partial = random_subsample(partial.repeat((1, 8, 1)).reshape(-1, 16384, 3)) # b*8, 2048, 3 90 | 91 | b, n, _ = partial.shape 92 | 93 | pcds, deltas = model(partial.contiguous()) 94 | 95 | cd1 = chamfer_sqrt(pcds[0].reshape(-1, 16384, 3).contiguous(), gt).item() * 1e3 96 | cd2 = chamfer_sqrt(pcds[1].reshape(-1, 16384, 3).contiguous(), gt).item() * 1e3 97 | cd3 = chamfer_sqrt(pcds[2].reshape(-1, 16384, 3).contiguous(), gt).item() * 1e3 98 | 99 | # pmd loss 100 | pmd_losses = [] 101 | for delta in deltas: 102 | pmd_losses.append(torch.sum(delta ** 2)) 103 | 104 | pmd = torch.sum(torch.stack(pmd_losses)) / 3 105 | 106 | pmd_item = pmd.item() 107 | 108 | _metrics = [pmd_item, cd3] 109 | test_losses.update([cd1, cd2, cd3, pmd_item]) 110 | 111 | test_metrics.update(_metrics) 112 | if taxonomy_id not in category_metrics: 113 | category_metrics[taxonomy_id] = AverageMeter(Metrics.names()) 114 | category_metrics[taxonomy_id].update(_metrics) 115 | 116 | t.set_description('Test[%d/%d] Taxonomy = %s Sample = %s Losses = %s Metrics = %s' % 117 | (model_idx + 1, n_samples, taxonomy_id, model_id, ['%.4f' % l for l in test_losses.val() 118 | ], ['%.4f' % m for m in _metrics])) 119 | 120 | # Print testing results 121 | print('============================ TEST RESULTS ============================') 122 | print('Taxonomy', end='\t') 123 | print('#Sample', end='\t') 124 | for metric in test_metrics.items: 125 | print(metric, end='\t') 126 | print() 127 | 128 | for taxonomy_id in category_metrics: 129 | print(taxonomy_id, end='\t') 130 | print(category_metrics[taxonomy_id].count(0), end='\t') 131 | for value in category_metrics[taxonomy_id].avg(): 132 | print('%.4f' % value, end='\t') 133 | print() 134 | 135 | print('Overall', end='\t\t\t') 136 | for value in test_metrics.avg(): 137 | print('%.4f' % value, end='\t') 138 | print('\n') 139 | 140 | # Add testing results to TensorBoard 141 | if test_writer is not None: 142 | test_writer.add_scalar('Loss/Epoch/cd1', test_losses.avg(0), epoch_idx) 143 | test_writer.add_scalar('Loss/Epoch/cd2', test_losses.avg(1), epoch_idx) 144 | test_writer.add_scalar('Loss/Epoch/cd3', test_losses.avg(2), epoch_idx) 145 | test_writer.add_scalar('Loss/Epoch/delta', test_losses.avg(3), epoch_idx) 146 | for i, metric in enumerate(test_metrics.items): 147 | test_writer.add_scalar('Metric/%s' % metric, test_metrics.avg(i), epoch_idx) 148 | 149 | return test_losses.avg(2) 150 | -------------------------------------------------------------------------------- /core/train_c3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import torch 7 | import utils.data_loaders 8 | import utils.helpers 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from time import time 12 | from tensorboardX import SummaryWriter 13 | from core.test_c3d import test_net 14 | from utils.average_meter import AverageMeter 15 | from models.model import PMPNet as Model 16 | from Chamfer3D.dist_chamfer_3D import chamfer_3DDist 17 | chamfer_dist = chamfer_3DDist() 18 | 19 | 20 | def chamfer(p1, p2): 21 | d1, d2, _, _ = chamfer_dist(p1, p2) 22 | return torch.mean(d1) + torch.mean(d2) 23 | 24 | 25 | def chamfer_sqrt(p1, p2): 26 | d1, d2, _, _ = chamfer_dist(p1, p2) 27 | d1 = torch.mean(torch.sqrt(d1)) 28 | d2 = torch.mean(torch.sqrt(d2)) 29 | return (d1 + d2) / 2 30 | 31 | 32 | def lr_lambda(epoch): 33 | if 0 <= epoch <= 100: 34 | return 1 35 | elif 100 < epoch <= 150: 36 | return 0.5 37 | elif 150 < epoch <= 250: 38 | return 0.1 39 | else: 40 | return 0.5 41 | 42 | 43 | def train_net(cfg): 44 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 45 | torch.backends.cudnn.benchmark = True 46 | 47 | train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) 48 | test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 49 | 50 | train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset_loader.get_dataset( 51 | utils.data_loaders.DatasetSubset.TRAIN), 52 | batch_size=cfg.TRAIN.BATCH_SIZE, 53 | num_workers=cfg.CONST.NUM_WORKERS, 54 | collate_fn=utils.data_loaders.collate_fn, 55 | pin_memory=True, 56 | shuffle=True, 57 | drop_last=True) 58 | val_data_loader = torch.utils.data.DataLoader(dataset=test_dataset_loader.get_dataset( 59 | utils.data_loaders.DatasetSubset.VAL), 60 | batch_size=cfg.TRAIN.BATCH_SIZE, 61 | num_workers=cfg.CONST.NUM_WORKERS//2, 62 | collate_fn=utils.data_loaders.collate_fn, 63 | pin_memory=True, 64 | shuffle=False) 65 | 66 | # Set up folders for logs and checkpoints 67 | output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) 68 | cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' 69 | cfg.DIR.LOGS = output_dir % 'logs' 70 | if not os.path.exists(cfg.DIR.CHECKPOINTS): 71 | os.makedirs(cfg.DIR.CHECKPOINTS) 72 | 73 | # Create tensorboard writers 74 | train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) 75 | val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) 76 | 77 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 78 | if torch.cuda.is_available(): 79 | model = torch.nn.DataParallel(model).cuda() 80 | 81 | # Create the optimizers 82 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 83 | lr=cfg.TRAIN.LEARNING_RATE, 84 | weight_decay=cfg.TRAIN.WEIGHT_DECAY, 85 | betas=cfg.TRAIN.BETAS) 86 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 87 | lr_lambda=lr_lambda) 88 | 89 | init_epoch = 0 90 | best_metrics = float('inf') 91 | 92 | if 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS: 93 | logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) 94 | checkpoint = torch.load(cfg.CONST.WEIGHTS) 95 | best_metrics = checkpoint['best_metrics'] 96 | model.load_state_dict(checkpoint['model']) 97 | logging.info('Recover complete. Current epoch = #%d; best metrics = %s.' % (init_epoch, best_metrics)) 98 | 99 | # Training/Testing the network 100 | for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): 101 | epoch_start_time = time() 102 | 103 | batch_time = AverageMeter() 104 | data_time = AverageMeter() 105 | 106 | model.train() 107 | 108 | total_cd1 = 0 109 | total_cd2 = 0 110 | total_cd3 = 0 111 | total_pmd = 0 112 | 113 | batch_end_time = time() 114 | n_batches = len(train_data_loader) 115 | with tqdm(train_data_loader) as t: 116 | for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t): 117 | data_time.update(time() - batch_end_time) 118 | for k, v in data.items(): 119 | data[k] = utils.helpers.var_or_cuda(v) 120 | partial = data['partial_cloud'] 121 | gt = data['gtcloud'] 122 | 123 | pcds, deltas = model(partial) 124 | 125 | cd1 = chamfer(pcds[0], gt) 126 | cd2 = chamfer(pcds[1], gt) 127 | cd3 = chamfer(pcds[2], gt) 128 | loss_cd = cd1 + cd2 + cd3 129 | 130 | delta_losses = [] 131 | for delta in deltas: 132 | delta_losses.append(torch.sum(delta ** 2)) 133 | 134 | loss_pmd = torch.sum(torch.stack(delta_losses)) / 3 135 | 136 | loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD 137 | 138 | optimizer.zero_grad() 139 | loss.backward() 140 | optimizer.step() 141 | 142 | cd1_item = cd1.item() * 1e3 143 | total_cd1 += cd1_item 144 | cd2_item = cd2.item() * 1e3 145 | total_cd2 += cd2_item 146 | cd3_item = cd3.item() * 1e3 147 | total_cd3 += cd3_item 148 | pmd_item = loss_pmd.item() 149 | total_pmd += pmd_item 150 | n_itr = (epoch_idx - 1) * n_batches + batch_idx 151 | train_writer.add_scalar('Loss/Batch/cd1', cd1_item, n_itr) 152 | train_writer.add_scalar('Loss/Batch/cd2', cd2_item, n_itr) 153 | train_writer.add_scalar('Loss/Batch/cd3', cd3_item, n_itr) 154 | train_writer.add_scalar('Loss/Batch/pmd', pmd_item, n_itr) 155 | batch_time.update(time() - batch_end_time) 156 | batch_end_time = time() 157 | t.set_description('[Epoch %d/%d][Batch %d/%d]' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches)) 158 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [cd1_item, cd2_item, cd3_item, pmd_item]]) 159 | 160 | avg_cd1 = total_cd1 / n_batches 161 | avg_cd2 = total_cd2 / n_batches 162 | avg_cd3 = total_cd3 / n_batches 163 | avg_pmd = total_pmd / n_batches 164 | 165 | lr_scheduler.step() 166 | epoch_end_time = time() 167 | train_writer.add_scalar('Loss/Epoch/cd1', avg_cd1, epoch_idx) 168 | train_writer.add_scalar('Loss/Epoch/cd2', avg_cd2, epoch_idx) 169 | train_writer.add_scalar('Loss/Epoch/cd3', avg_cd3, epoch_idx) 170 | train_writer.add_scalar('Loss/Epoch/pmd', avg_pmd, epoch_idx) 171 | logging.info( 172 | '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % 173 | (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, ['%.4f' % l for l in [avg_cd1, avg_cd2, avg_cd3, avg_pmd]])) 174 | 175 | # Validate the current model 176 | cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) 177 | 178 | # Save checkpoints 179 | if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics: 180 | file_name = 'ckpt-best.pth' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pth' % epoch_idx 181 | output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) 182 | torch.save({ 183 | 'epoch_index': epoch_idx, 184 | 'best_metrics': best_metrics, 185 | 'model': model.state_dict() 186 | }, output_path) 187 | 188 | logging.info('Saved checkpoint to %s ...' % output_path) 189 | if cd_eval < best_metrics: 190 | best_metrics = cd_eval 191 | 192 | train_writer.close() 193 | val_writer.close() 194 | -------------------------------------------------------------------------------- /core/train_pcn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: XP 3 | 4 | import logging 5 | import os 6 | import torch 7 | import utils.data_loaders 8 | import utils.helpers 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from time import time 12 | from tensorboardX import SummaryWriter 13 | from core.test_pcn import test_net 14 | from utils.average_meter import AverageMeter 15 | from models.model import PMPNetPlus as Model 16 | from Chamfer3D.dist_chamfer_3D import chamfer_3DDist 17 | chamfer_dist = chamfer_3DDist() 18 | 19 | def random_subsample(pcd, n_points=2048): 20 | """ 21 | Args: 22 | pcd: (B, N, 3) 23 | 24 | returns: 25 | new_pcd: (B, n_points, 3) 26 | """ 27 | b, n, _ = pcd.shape 28 | device = pcd.device 29 | batch_idx = torch.arange(b, dtype=torch.long, device=device).reshape((-1, 1)).repeat(1, n_points) 30 | idx = torch.cat([torch.randperm(n, dtype=torch.long, device=device)[:n_points].reshape((1, -1)) for i in range(b)], 0) 31 | return pcd[batch_idx, idx, :] 32 | 33 | 34 | def chamfer(p1, p2): 35 | d1, d2, _, _ = chamfer_dist(p1, p2) 36 | return torch.mean(d1) + torch.mean(d2) 37 | 38 | 39 | def chamfer_sqrt(p1, p2): 40 | d1, d2, _, _ = chamfer_dist(p1, p2) 41 | d1 = torch.mean(torch.sqrt(d1)) 42 | d2 = torch.mean(torch.sqrt(d2)) 43 | return (d1 + d2) / 2 44 | 45 | 46 | def lr_lambda(epoch): 47 | if 0 <= epoch <= 100: 48 | return 1 49 | elif 100 < epoch <= 150: 50 | return 0.5 51 | elif 150 < epoch <= 250: 52 | return 0.1 53 | else: 54 | return 0.5 55 | 56 | 57 | def train_net(cfg): 58 | # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use 59 | torch.backends.cudnn.benchmark = True 60 | 61 | train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg) 62 | test_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg) 63 | 64 | train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset_loader.get_dataset( 65 | utils.data_loaders.DatasetSubset.TRAIN), 66 | batch_size=cfg.TRAIN.BATCH_SIZE, 67 | num_workers=cfg.CONST.NUM_WORKERS, 68 | collate_fn=utils.data_loaders.collate_fn, 69 | pin_memory=True, 70 | shuffle=True, 71 | drop_last=True) 72 | val_data_loader = torch.utils.data.DataLoader(dataset=test_dataset_loader.get_dataset( 73 | utils.data_loaders.DatasetSubset.TEST), 74 | batch_size=cfg.TRAIN.BATCH_SIZE, 75 | num_workers=cfg.CONST.NUM_WORKERS//2, 76 | collate_fn=utils.data_loaders.collate_fn, 77 | pin_memory=True, 78 | shuffle=False) 79 | 80 | # Set up folders for logs and checkpoints 81 | output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', datetime.now().isoformat()) 82 | cfg.DIR.CHECKPOINTS = output_dir % 'checkpoints' 83 | cfg.DIR.LOGS = output_dir % 'logs' 84 | if not os.path.exists(cfg.DIR.CHECKPOINTS): 85 | os.makedirs(cfg.DIR.CHECKPOINTS) 86 | 87 | # Create tensorboard writers 88 | train_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'train')) 89 | val_writer = SummaryWriter(os.path.join(cfg.DIR.LOGS, 'test')) 90 | 91 | model = Model(dataset=cfg.DATASET.TRAIN_DATASET) 92 | if torch.cuda.is_available(): 93 | model = torch.nn.DataParallel(model).cuda() 94 | 95 | # Create the optimizers 96 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 97 | lr=cfg.TRAIN.LEARNING_RATE, 98 | weight_decay=cfg.TRAIN.WEIGHT_DECAY, 99 | betas=cfg.TRAIN.BETAS) 100 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 101 | lr_lambda=lr_lambda) 102 | 103 | init_epoch = 0 104 | best_metrics = float('inf') 105 | 106 | if 'WEIGHTS' in cfg.CONST and cfg.CONST.WEIGHTS: 107 | logging.info('Recovering from %s ...' % (cfg.CONST.WEIGHTS)) 108 | checkpoint = torch.load(cfg.CONST.WEIGHTS) 109 | best_metrics = checkpoint['best_metrics'] 110 | model.load_state_dict(checkpoint['model']) 111 | logging.info('Recover complete. Current epoch = #%d; best metrics = %s.' % (init_epoch, best_metrics)) 112 | 113 | # Training/Testing the network 114 | for epoch_idx in range(init_epoch + 1, cfg.TRAIN.N_EPOCHS + 1): 115 | epoch_start_time = time() 116 | 117 | batch_time = AverageMeter() 118 | data_time = AverageMeter() 119 | 120 | model.train() 121 | 122 | total_cd1 = 0 123 | total_cd2 = 0 124 | total_cd3 = 0 125 | total_pmd = 0 126 | 127 | batch_end_time = time() 128 | n_batches = len(train_data_loader) 129 | with tqdm(train_data_loader) as t: 130 | for batch_idx, (taxonomy_ids, model_ids, data) in enumerate(t): 131 | data_time.update(time() - batch_end_time) 132 | for k, v in data.items(): 133 | data[k] = utils.helpers.var_or_cuda(v) 134 | partial = random_subsample(data['partial_cloud']) 135 | gt = random_subsample(data['gtcloud']) 136 | 137 | pcds, deltas = model(partial) 138 | 139 | cd1 = chamfer(pcds[0], gt) 140 | cd2 = chamfer(pcds[1], gt) 141 | cd3 = chamfer(pcds[2], gt) 142 | loss_cd = cd1 + cd2 + cd3 143 | 144 | delta_losses = [] 145 | for delta in deltas: 146 | delta_losses.append(torch.sum(delta ** 2)) 147 | 148 | loss_pmd = torch.sum(torch.stack(delta_losses)) / 3 149 | 150 | loss = loss_cd * cfg.TRAIN.LAMBDA_CD + loss_pmd * cfg.TRAIN.LAMBDA_PMD 151 | 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | cd1_item = cd1.item() * 1e3 157 | total_cd1 += cd1_item 158 | cd2_item = cd2.item() * 1e3 159 | total_cd2 += cd2_item 160 | cd3_item = cd3.item() * 1e3 161 | total_cd3 += cd3_item 162 | pmd_item = loss_pmd.item() 163 | total_pmd += pmd_item 164 | n_itr = (epoch_idx - 1) * n_batches + batch_idx 165 | train_writer.add_scalar('Loss/Batch/cd1', cd1_item, n_itr) 166 | train_writer.add_scalar('Loss/Batch/cd2', cd2_item, n_itr) 167 | train_writer.add_scalar('Loss/Batch/cd3', cd3_item, n_itr) 168 | train_writer.add_scalar('Loss/Batch/pmd', pmd_item, n_itr) 169 | batch_time.update(time() - batch_end_time) 170 | batch_end_time = time() 171 | t.set_description('[Epoch %d/%d][Batch %d/%d]' % (epoch_idx, cfg.TRAIN.N_EPOCHS, batch_idx + 1, n_batches)) 172 | t.set_postfix(loss='%s' % ['%.4f' % l for l in [cd1_item, cd2_item, cd3_item, pmd_item]]) 173 | 174 | avg_cd1 = total_cd1 / n_batches 175 | avg_cd2 = total_cd2 / n_batches 176 | avg_cd3 = total_cd3 / n_batches 177 | avg_pmd = total_pmd / n_batches 178 | 179 | lr_scheduler.step() 180 | epoch_end_time = time() 181 | train_writer.add_scalar('Loss/Epoch/cd1', avg_cd1, epoch_idx) 182 | train_writer.add_scalar('Loss/Epoch/cd2', avg_cd2, epoch_idx) 183 | train_writer.add_scalar('Loss/Epoch/cd3', avg_cd3, epoch_idx) 184 | train_writer.add_scalar('Loss/Epoch/pmd', avg_pmd, epoch_idx) 185 | logging.info( 186 | '[Epoch %d/%d] EpochTime = %.3f (s) Losses = %s' % 187 | (epoch_idx, cfg.TRAIN.N_EPOCHS, epoch_end_time - epoch_start_time, ['%.4f' % l for l in [avg_cd1, avg_cd2, avg_cd3, avg_pmd]])) 188 | 189 | # Validate the current model 190 | cd_eval = test_net(cfg, epoch_idx, val_data_loader, val_writer, model) 191 | 192 | # Save checkpoints 193 | if epoch_idx % cfg.TRAIN.SAVE_FREQ == 0 or cd_eval < best_metrics: 194 | file_name = 'ckpt-best.pth' if cd_eval < best_metrics else 'ckpt-epoch-%03d.pth' % epoch_idx 195 | output_path = os.path.join(cfg.DIR.CHECKPOINTS, file_name) 196 | torch.save({ 197 | 'epoch_index': epoch_idx, 198 | 'best_metrics': best_metrics, 199 | 'model': model.state_dict() 200 | }, output_path) 201 | 202 | logging.info('Saved checkpoint to %s ...' % output_path) 203 | if cd_eval < best_metrics: 204 | best_metrics = cd_eval 205 | 206 | train_writer.close() 207 | val_writer.close() 208 | -------------------------------------------------------------------------------- /main_c3d.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Peng Xiang 4 | 5 | import argparse 6 | import logging 7 | import os 8 | import numpy as np 9 | import sys 10 | import torch 11 | from pprint import pprint 12 | from config_c3d import cfg 13 | from core.train_c3d import train_net 14 | from core.test_c3d import test_net 15 | from core.inference_c3d import inference_net 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.DEVICE 18 | 19 | def set_seed(seed): 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) 24 | 25 | 26 | def get_args_from_command_line(): 27 | parser = argparse.ArgumentParser(description='The argument parser of PMP-Net') 28 | parser.add_argument('--test', dest='test', help='Test neural networks', action='store_true') 29 | parser.add_argument('--inference', dest='inference', help='Inference for benchmark', action='store_true') 30 | args = parser.parse_args() 31 | 32 | return args 33 | 34 | 35 | def main(): 36 | # Get args from command line 37 | args = get_args_from_command_line() 38 | print('cuda available ', torch.cuda.is_available()) 39 | 40 | # Print config 41 | print('Use config:') 42 | pprint(cfg) 43 | 44 | if not args.test and not args.inference: 45 | train_net(cfg) 46 | else: 47 | if cfg.CONST.WEIGHTS is None: 48 | raise Exception('Please specify the path to checkpoint in the configuration file!') 49 | 50 | if args.test: 51 | test_net(cfg) 52 | else: 53 | inference_net(cfg) 54 | 55 | if __name__ == '__main__': 56 | # Check python version 57 | seed = 1 58 | set_seed(seed) 59 | logging.basicConfig(format='[%(levelname)s] %(asctime)s %(message)s', level=logging.DEBUG) 60 | main() 61 | -------------------------------------------------------------------------------- /main_pcn.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | # @Author: Peng Xiang 4 | 5 | import argparse 6 | import logging 7 | import os 8 | import numpy as np 9 | import sys 10 | import torch 11 | from pprint import pprint 12 | from config_pcn import cfg 13 | from core.train_pcn import train_net 14 | from core.test_pcn import test_net 15 | from core.inference_pcn import inference_net 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | os.environ["CUDA_VISIBLE_DEVICES"] = cfg.CONST.DEVICE 18 | 19 | 20 | def set_seed(seed): 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | 27 | def get_args_from_command_line(): 28 | parser = argparse.ArgumentParser(description='The argument parser of PMP-Net') 29 | parser.add_argument('--test', dest='test', help='Test neural networks', action='store_true') 30 | parser.add_argument('--inference', dest='inference', help='Inference for benchmark', action='store_true') 31 | args = parser.parse_args() 32 | 33 | return args 34 | 35 | 36 | def main(): 37 | # Get args from command line 38 | args = get_args_from_command_line() 39 | print('cuda available ', torch.cuda.is_available()) 40 | 41 | # Print config 42 | print('Use config:') 43 | pprint(cfg) 44 | 45 | if not args.test and not args.inference: 46 | train_net(cfg) 47 | else: 48 | if cfg.CONST.WEIGHTS is None: 49 | raise Exception('Please specify the path to checkpoint in the configuration file!') 50 | 51 | if args.test: 52 | test_net(cfg) 53 | else: 54 | inference_net(cfg) 55 | 56 | if __name__ == '__main__': 57 | # Check python version 58 | seed = 1 59 | set_seed(seed) 60 | logging.basicConfig(format='[%(levelname)s] %(asctime)s %(message)s', level=logging.DEBUG) 61 | main() 62 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../pointnet2_ops_lib') 3 | sys.path.append('..') -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/models/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/models/__pycache__/transformer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from models.utils import query_knn, grouping_operation 4 | 5 | class Transformer(nn.Module): 6 | def __init__(self, in_channel, dim=256, n_knn=16, pos_hidden_dim=64, attn_hidden_multiplier=4): 7 | super(Transformer, self).__init__() 8 | self.n_knn = n_knn 9 | self.conv_key = nn.Conv1d(dim, dim, 1) 10 | self.conv_query = nn.Conv1d(dim, dim, 1) 11 | self.conv_value = nn.Conv1d(dim, dim, 1) 12 | 13 | self.pos_mlp = nn.Sequential( 14 | nn.Conv2d(3, pos_hidden_dim, 1), 15 | nn.BatchNorm2d(pos_hidden_dim), 16 | nn.ReLU(), 17 | nn.Conv2d(pos_hidden_dim, dim, 1) 18 | ) 19 | 20 | self.attn_mlp = nn.Sequential( 21 | nn.Conv2d(dim, dim * attn_hidden_multiplier, 1), 22 | nn.BatchNorm2d(dim * attn_hidden_multiplier), 23 | nn.ReLU(), 24 | nn.Conv2d(dim * attn_hidden_multiplier, dim, 1) 25 | ) 26 | 27 | self.linear_start = nn.Conv1d(in_channel, dim, 1) 28 | self.linear_end = nn.Conv1d(dim, in_channel, 1) 29 | 30 | def forward(self, x, pos): 31 | """feed forward of transformer 32 | Args: 33 | x: Tensor of features, (B, in_channel, n) 34 | pos: Tensor of positions, (B, 3, n) 35 | 36 | Returns: 37 | y: Tensor of features with attention, (B, in_channel, n) 38 | """ 39 | 40 | identity = x 41 | 42 | x = self.linear_start(x) 43 | b, dim, n = x.shape 44 | 45 | pos_flipped = pos.permute(0, 2, 1).contiguous() 46 | idx_knn = query_knn(self.n_knn, pos_flipped, pos_flipped) 47 | key = self.conv_key(x) 48 | value = self.conv_value(x) 49 | query = self.conv_query(x) 50 | 51 | key = grouping_operation(key, idx_knn) # b, dim, n, n_knn 52 | qk_rel = query.reshape((b, -1, n, 1)) - key # b, dim, n, n_knn 53 | 54 | pos_rel = pos.reshape((b, -1, n, 1)) - grouping_operation(pos, idx_knn) # b, 3, n, n_knn 55 | pos_embedding = self.pos_mlp(pos_rel) # b, dim, n, n_knn 56 | 57 | attention = self.attn_mlp(qk_rel + pos_embedding) # b, n, n_knn 58 | attention = torch.softmax(attention, -1) # b, dim, n, n_knn 59 | 60 | value = value.reshape((b, -1, n, 1)) + pos_embedding 61 | 62 | agg = einsum('b c i j, b c i j -> b c i', attention, value) # b, dim, n 63 | y = self.linear_end(agg) 64 | 65 | return y+identity -------------------------------------------------------------------------------- /pics/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/pics/network.png -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/__init__.py: -------------------------------------------------------------------------------- 1 | import pointnet2_ops.pointnet2_modules 2 | import pointnet2_ops.pointnet2_utils 3 | from pointnet2_ops._version import __version__ 4 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/ball_query.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 5 | const int nsample); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #define TOTAL_THREADS 512 14 | 15 | inline int opt_n_threads(int work_size) { 16 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 17 | 18 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 19 | } 20 | 21 | inline dim3 opt_block_config(int x, int y) { 22 | const int x_threads = opt_n_threads(x); 23 | const int y_threads = 24 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 25 | dim3 block_config(x_threads, y_threads, 1); 26 | 27 | return block_config; 28 | } 29 | 30 | #define CUDA_CHECK_ERRORS() \ 31 | do { \ 32 | cudaError_t err = cudaGetLastError(); \ 33 | if (cudaSuccess != err) { \ 34 | fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \ 35 | cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \ 36 | __FILE__); \ 37 | exit(-1); \ 38 | } \ 39 | } while (0) 40 | 41 | #endif 42 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/group_points.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor group_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/interpolate.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows); 7 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 8 | at::Tensor weight); 9 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 10 | at::Tensor weight, const int m); 11 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/sampling.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | at::Tensor gather_points(at::Tensor points, at::Tensor idx); 5 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, const int n); 6 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples); 7 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #define CHECK_CUDA(x) \ 6 | do { \ 7 | AT_ASSERT(x.is_cuda(), #x " must be a CUDA tensor"); \ 8 | } while (0) 9 | 10 | #define CHECK_CONTIGUOUS(x) \ 11 | do { \ 12 | AT_ASSERT(x.is_contiguous(), #x " must be a contiguous tensor"); \ 13 | } while (0) 14 | 15 | #define CHECK_IS_INT(x) \ 16 | do { \ 17 | AT_ASSERT(x.scalar_type() == at::ScalarType::Int, \ 18 | #x " must be an int tensor"); \ 19 | } while (0) 20 | 21 | #define CHECK_IS_FLOAT(x) \ 22 | do { \ 23 | AT_ASSERT(x.scalar_type() == at::ScalarType::Float, \ 24 | #x " must be a float tensor"); \ 25 | } while (0) 26 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "utils.h" 3 | 4 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 5 | int nsample, const float *new_xyz, 6 | const float *xyz, int *idx); 7 | 8 | at::Tensor ball_query(at::Tensor new_xyz, at::Tensor xyz, const float radius, 9 | const int nsample) { 10 | CHECK_CONTIGUOUS(new_xyz); 11 | CHECK_CONTIGUOUS(xyz); 12 | CHECK_IS_FLOAT(new_xyz); 13 | CHECK_IS_FLOAT(xyz); 14 | 15 | if (new_xyz.is_cuda()) { 16 | CHECK_CUDA(xyz); 17 | } 18 | 19 | at::Tensor idx = 20 | torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample}, 21 | at::device(new_xyz.device()).dtype(at::ScalarType::Int)); 22 | 23 | if (new_xyz.is_cuda()) { 24 | query_ball_point_kernel_wrapper(xyz.size(0), xyz.size(1), new_xyz.size(1), 25 | radius, nsample, new_xyz.data_ptr(), 26 | xyz.data_ptr(), idx.data_ptr()); 27 | } else { 28 | AT_ASSERT(false, "CPU not supported"); 29 | } 30 | 31 | return idx; 32 | } 33 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 8 | // output: idx(b, m, nsample) 9 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 10 | int nsample, 11 | const float *__restrict__ new_xyz, 12 | const float *__restrict__ xyz, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | xyz += batch_index * n * 3; 16 | new_xyz += batch_index * m * 3; 17 | idx += m * nsample * batch_index; 18 | 19 | int index = threadIdx.x; 20 | int stride = blockDim.x; 21 | 22 | float radius2 = radius * radius; 23 | for (int j = index; j < m; j += stride) { 24 | float new_x = new_xyz[j * 3 + 0]; 25 | float new_y = new_xyz[j * 3 + 1]; 26 | float new_z = new_xyz[j * 3 + 2]; 27 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 28 | float x = xyz[k * 3 + 0]; 29 | float y = xyz[k * 3 + 1]; 30 | float z = xyz[k * 3 + 2]; 31 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 32 | (new_z - z) * (new_z - z); 33 | if (d2 < radius2) { 34 | if (cnt == 0) { 35 | for (int l = 0; l < nsample; ++l) { 36 | idx[j * nsample + l] = k; 37 | } 38 | } 39 | idx[j * nsample + cnt] = k; 40 | ++cnt; 41 | } 42 | } 43 | } 44 | } 45 | 46 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 47 | int nsample, const float *new_xyz, 48 | const float *xyz, int *idx) { 49 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 50 | query_ball_point_kernel<<>>( 51 | b, n, m, radius, nsample, new_xyz, xyz, idx); 52 | 53 | CUDA_CHECK_ERRORS(); 54 | } 55 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include "ball_query.h" 2 | #include "group_points.h" 3 | #include "interpolate.h" 4 | #include "sampling.h" 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("gather_points", &gather_points); 8 | m.def("gather_points_grad", &gather_points_grad); 9 | m.def("furthest_point_sampling", &furthest_point_sampling); 10 | 11 | m.def("three_nn", &three_nn); 12 | m.def("three_interpolate", &three_interpolate); 13 | m.def("three_interpolate_grad", &three_interpolate_grad); 14 | 15 | m.def("ball_query", &ball_query); 16 | 17 | m.def("group_points", &group_points); 18 | m.def("group_points_grad", &group_points_grad); 19 | } 20 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points.cpp: -------------------------------------------------------------------------------- 1 | #include "group_points.h" 2 | #include "utils.h" 3 | 4 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 5 | const float *points, const int *idx, 6 | float *out); 7 | 8 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 9 | int nsample, const float *grad_out, 10 | const int *idx, float *grad_points); 11 | 12 | at::Tensor group_points(at::Tensor points, at::Tensor idx) { 13 | CHECK_CONTIGUOUS(points); 14 | CHECK_CONTIGUOUS(idx); 15 | CHECK_IS_FLOAT(points); 16 | CHECK_IS_INT(idx); 17 | 18 | if (points.is_cuda()) { 19 | CHECK_CUDA(idx); 20 | } 21 | 22 | at::Tensor output = 23 | torch::zeros({points.size(0), points.size(1), idx.size(1), idx.size(2)}, 24 | at::device(points.device()).dtype(at::ScalarType::Float)); 25 | 26 | if (points.is_cuda()) { 27 | group_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 28 | idx.size(1), idx.size(2), 29 | points.data_ptr(), idx.data_ptr(), 30 | output.data_ptr()); 31 | } else { 32 | AT_ASSERT(false, "CPU not supported"); 33 | } 34 | 35 | return output; 36 | } 37 | 38 | at::Tensor group_points_grad(at::Tensor grad_out, at::Tensor idx, const int n) { 39 | CHECK_CONTIGUOUS(grad_out); 40 | CHECK_CONTIGUOUS(idx); 41 | CHECK_IS_FLOAT(grad_out); 42 | CHECK_IS_INT(idx); 43 | 44 | if (grad_out.is_cuda()) { 45 | CHECK_CUDA(idx); 46 | } 47 | 48 | at::Tensor output = 49 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 50 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 51 | 52 | if (grad_out.is_cuda()) { 53 | group_points_grad_kernel_wrapper( 54 | grad_out.size(0), grad_out.size(1), n, idx.size(1), idx.size(2), 55 | grad_out.data_ptr(), idx.data_ptr(), 56 | output.data_ptr()); 57 | } else { 58 | AT_ASSERT(false, "CPU not supported"); 59 | } 60 | 61 | return output; 62 | } 63 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, npoints, nsample) 7 | // output: out(b, c, npoints, nsample) 8 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 9 | int nsample, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | int batch_index = blockIdx.x; 14 | points += batch_index * n * c; 15 | idx += batch_index * npoints * nsample; 16 | out += batch_index * npoints * nsample * c; 17 | 18 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 19 | const int stride = blockDim.y * blockDim.x; 20 | for (int i = index; i < c * npoints; i += stride) { 21 | const int l = i / npoints; 22 | const int j = i % npoints; 23 | for (int k = 0; k < nsample; ++k) { 24 | int ii = idx[j * nsample + k]; 25 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 26 | } 27 | } 28 | } 29 | 30 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 31 | const float *points, const int *idx, 32 | float *out) { 33 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 34 | 35 | group_points_kernel<<>>( 36 | b, c, n, npoints, nsample, points, idx, out); 37 | 38 | CUDA_CHECK_ERRORS(); 39 | } 40 | 41 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 42 | // output: grad_points(b, c, n) 43 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 44 | int nsample, 45 | const float *__restrict__ grad_out, 46 | const int *__restrict__ idx, 47 | float *__restrict__ grad_points) { 48 | int batch_index = blockIdx.x; 49 | grad_out += batch_index * npoints * nsample * c; 50 | idx += batch_index * npoints * nsample; 51 | grad_points += batch_index * n * c; 52 | 53 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 54 | const int stride = blockDim.y * blockDim.x; 55 | for (int i = index; i < c * npoints; i += stride) { 56 | const int l = i / npoints; 57 | const int j = i % npoints; 58 | for (int k = 0; k < nsample; ++k) { 59 | int ii = idx[j * nsample + k]; 60 | atomicAdd(grad_points + l * n + ii, 61 | grad_out[(l * npoints + j) * nsample + k]); 62 | } 63 | } 64 | } 65 | 66 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 67 | int nsample, const float *grad_out, 68 | const int *idx, float *grad_points) { 69 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 70 | 71 | group_points_grad_kernel<<>>( 72 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 73 | 74 | CUDA_CHECK_ERRORS(); 75 | } 76 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate.cpp: -------------------------------------------------------------------------------- 1 | #include "interpolate.h" 2 | #include "utils.h" 3 | 4 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 5 | const float *known, float *dist2, int *idx); 6 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 7 | const float *points, const int *idx, 8 | const float *weight, float *out); 9 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 10 | const float *grad_out, 11 | const int *idx, const float *weight, 12 | float *grad_points); 13 | 14 | std::vector three_nn(at::Tensor unknowns, at::Tensor knows) { 15 | CHECK_CONTIGUOUS(unknowns); 16 | CHECK_CONTIGUOUS(knows); 17 | CHECK_IS_FLOAT(unknowns); 18 | CHECK_IS_FLOAT(knows); 19 | 20 | if (unknowns.is_cuda()) { 21 | CHECK_CUDA(knows); 22 | } 23 | 24 | at::Tensor idx = 25 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 26 | at::device(unknowns.device()).dtype(at::ScalarType::Int)); 27 | at::Tensor dist2 = 28 | torch::zeros({unknowns.size(0), unknowns.size(1), 3}, 29 | at::device(unknowns.device()).dtype(at::ScalarType::Float)); 30 | 31 | if (unknowns.is_cuda()) { 32 | three_nn_kernel_wrapper(unknowns.size(0), unknowns.size(1), knows.size(1), 33 | unknowns.data_ptr(), knows.data_ptr(), 34 | dist2.data_ptr(), idx.data_ptr()); 35 | } else { 36 | AT_ASSERT(false, "CPU not supported"); 37 | } 38 | 39 | return {dist2, idx}; 40 | } 41 | 42 | at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, 43 | at::Tensor weight) { 44 | CHECK_CONTIGUOUS(points); 45 | CHECK_CONTIGUOUS(idx); 46 | CHECK_CONTIGUOUS(weight); 47 | CHECK_IS_FLOAT(points); 48 | CHECK_IS_INT(idx); 49 | CHECK_IS_FLOAT(weight); 50 | 51 | if (points.is_cuda()) { 52 | CHECK_CUDA(idx); 53 | CHECK_CUDA(weight); 54 | } 55 | 56 | at::Tensor output = 57 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 58 | at::device(points.device()).dtype(at::ScalarType::Float)); 59 | 60 | if (points.is_cuda()) { 61 | three_interpolate_kernel_wrapper( 62 | points.size(0), points.size(1), points.size(2), idx.size(1), 63 | points.data_ptr(), idx.data_ptr(), weight.data_ptr(), 64 | output.data_ptr()); 65 | } else { 66 | AT_ASSERT(false, "CPU not supported"); 67 | } 68 | 69 | return output; 70 | } 71 | at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, 72 | at::Tensor weight, const int m) { 73 | CHECK_CONTIGUOUS(grad_out); 74 | CHECK_CONTIGUOUS(idx); 75 | CHECK_CONTIGUOUS(weight); 76 | CHECK_IS_FLOAT(grad_out); 77 | CHECK_IS_INT(idx); 78 | CHECK_IS_FLOAT(weight); 79 | 80 | if (grad_out.is_cuda()) { 81 | CHECK_CUDA(idx); 82 | CHECK_CUDA(weight); 83 | } 84 | 85 | at::Tensor output = 86 | torch::zeros({grad_out.size(0), grad_out.size(1), m}, 87 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 88 | 89 | if (grad_out.is_cuda()) { 90 | three_interpolate_grad_kernel_wrapper( 91 | grad_out.size(0), grad_out.size(1), grad_out.size(2), m, 92 | grad_out.data_ptr(), idx.data_ptr(), 93 | weight.data_ptr(), output.data_ptr()); 94 | } else { 95 | AT_ASSERT(false, "CPU not supported"); 96 | } 97 | 98 | return output; 99 | } 100 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | 7 | // input: unknown(b, n, 3) known(b, m, 3) 8 | // output: dist2(b, n, 3), idx(b, n, 3) 9 | __global__ void three_nn_kernel(int b, int n, int m, 10 | const float *__restrict__ unknown, 11 | const float *__restrict__ known, 12 | float *__restrict__ dist2, 13 | int *__restrict__ idx) { 14 | int batch_index = blockIdx.x; 15 | unknown += batch_index * n * 3; 16 | known += batch_index * m * 3; 17 | dist2 += batch_index * n * 3; 18 | idx += batch_index * n * 3; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | for (int j = index; j < n; j += stride) { 23 | float ux = unknown[j * 3 + 0]; 24 | float uy = unknown[j * 3 + 1]; 25 | float uz = unknown[j * 3 + 2]; 26 | 27 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 28 | int besti1 = 0, besti2 = 0, besti3 = 0; 29 | for (int k = 0; k < m; ++k) { 30 | float x = known[k * 3 + 0]; 31 | float y = known[k * 3 + 1]; 32 | float z = known[k * 3 + 2]; 33 | float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 34 | if (d < best1) { 35 | best3 = best2; 36 | besti3 = besti2; 37 | best2 = best1; 38 | besti2 = besti1; 39 | best1 = d; 40 | besti1 = k; 41 | } else if (d < best2) { 42 | best3 = best2; 43 | besti3 = besti2; 44 | best2 = d; 45 | besti2 = k; 46 | } else if (d < best3) { 47 | best3 = d; 48 | besti3 = k; 49 | } 50 | } 51 | dist2[j * 3 + 0] = best1; 52 | dist2[j * 3 + 1] = best2; 53 | dist2[j * 3 + 2] = best3; 54 | 55 | idx[j * 3 + 0] = besti1; 56 | idx[j * 3 + 1] = besti2; 57 | idx[j * 3 + 2] = besti3; 58 | } 59 | } 60 | 61 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 62 | const float *known, float *dist2, int *idx) { 63 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 64 | three_nn_kernel<<>>(b, n, m, unknown, known, 65 | dist2, idx); 66 | 67 | CUDA_CHECK_ERRORS(); 68 | } 69 | 70 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 71 | // output: out(b, c, n) 72 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 73 | const float *__restrict__ points, 74 | const int *__restrict__ idx, 75 | const float *__restrict__ weight, 76 | float *__restrict__ out) { 77 | int batch_index = blockIdx.x; 78 | points += batch_index * m * c; 79 | 80 | idx += batch_index * n * 3; 81 | weight += batch_index * n * 3; 82 | 83 | out += batch_index * n * c; 84 | 85 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 86 | const int stride = blockDim.y * blockDim.x; 87 | for (int i = index; i < c * n; i += stride) { 88 | const int l = i / n; 89 | const int j = i % n; 90 | float w1 = weight[j * 3 + 0]; 91 | float w2 = weight[j * 3 + 1]; 92 | float w3 = weight[j * 3 + 2]; 93 | 94 | int i1 = idx[j * 3 + 0]; 95 | int i2 = idx[j * 3 + 1]; 96 | int i3 = idx[j * 3 + 2]; 97 | 98 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 99 | points[l * m + i3] * w3; 100 | } 101 | } 102 | 103 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 104 | const float *points, const int *idx, 105 | const float *weight, float *out) { 106 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 107 | three_interpolate_kernel<<>>( 108 | b, c, m, n, points, idx, weight, out); 109 | 110 | CUDA_CHECK_ERRORS(); 111 | } 112 | 113 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 114 | // output: grad_points(b, c, m) 115 | 116 | __global__ void three_interpolate_grad_kernel( 117 | int b, int c, int n, int m, const float *__restrict__ grad_out, 118 | const int *__restrict__ idx, const float *__restrict__ weight, 119 | float *__restrict__ grad_points) { 120 | int batch_index = blockIdx.x; 121 | grad_out += batch_index * n * c; 122 | idx += batch_index * n * 3; 123 | weight += batch_index * n * 3; 124 | grad_points += batch_index * m * c; 125 | 126 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 127 | const int stride = blockDim.y * blockDim.x; 128 | for (int i = index; i < c * n; i += stride) { 129 | const int l = i / n; 130 | const int j = i % n; 131 | float w1 = weight[j * 3 + 0]; 132 | float w2 = weight[j * 3 + 1]; 133 | float w3 = weight[j * 3 + 2]; 134 | 135 | int i1 = idx[j * 3 + 0]; 136 | int i2 = idx[j * 3 + 1]; 137 | int i3 = idx[j * 3 + 2]; 138 | 139 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 140 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 141 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 142 | } 143 | } 144 | 145 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 146 | const float *grad_out, 147 | const int *idx, const float *weight, 148 | float *grad_points) { 149 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 150 | three_interpolate_grad_kernel<<>>( 151 | b, c, n, m, grad_out, idx, weight, grad_points); 152 | 153 | CUDA_CHECK_ERRORS(); 154 | } 155 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "sampling.h" 2 | #include "utils.h" 3 | 4 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 5 | const float *points, const int *idx, 6 | float *out); 7 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 8 | const float *grad_out, const int *idx, 9 | float *grad_points); 10 | 11 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 12 | const float *dataset, float *temp, 13 | int *idxs); 14 | 15 | at::Tensor gather_points(at::Tensor points, at::Tensor idx) { 16 | CHECK_CONTIGUOUS(points); 17 | CHECK_CONTIGUOUS(idx); 18 | CHECK_IS_FLOAT(points); 19 | CHECK_IS_INT(idx); 20 | 21 | if (points.is_cuda()) { 22 | CHECK_CUDA(idx); 23 | } 24 | 25 | at::Tensor output = 26 | torch::zeros({points.size(0), points.size(1), idx.size(1)}, 27 | at::device(points.device()).dtype(at::ScalarType::Float)); 28 | 29 | if (points.is_cuda()) { 30 | gather_points_kernel_wrapper(points.size(0), points.size(1), points.size(2), 31 | idx.size(1), points.data_ptr(), 32 | idx.data_ptr(), output.data_ptr()); 33 | } else { 34 | AT_ASSERT(false, "CPU not supported"); 35 | } 36 | 37 | return output; 38 | } 39 | 40 | at::Tensor gather_points_grad(at::Tensor grad_out, at::Tensor idx, 41 | const int n) { 42 | CHECK_CONTIGUOUS(grad_out); 43 | CHECK_CONTIGUOUS(idx); 44 | CHECK_IS_FLOAT(grad_out); 45 | CHECK_IS_INT(idx); 46 | 47 | if (grad_out.is_cuda()) { 48 | CHECK_CUDA(idx); 49 | } 50 | 51 | at::Tensor output = 52 | torch::zeros({grad_out.size(0), grad_out.size(1), n}, 53 | at::device(grad_out.device()).dtype(at::ScalarType::Float)); 54 | 55 | if (grad_out.is_cuda()) { 56 | gather_points_grad_kernel_wrapper(grad_out.size(0), grad_out.size(1), n, 57 | idx.size(1), grad_out.data_ptr(), 58 | idx.data_ptr(), 59 | output.data_ptr()); 60 | } else { 61 | AT_ASSERT(false, "CPU not supported"); 62 | } 63 | 64 | return output; 65 | } 66 | at::Tensor furthest_point_sampling(at::Tensor points, const int nsamples) { 67 | CHECK_CONTIGUOUS(points); 68 | CHECK_IS_FLOAT(points); 69 | 70 | at::Tensor output = 71 | torch::zeros({points.size(0), nsamples}, 72 | at::device(points.device()).dtype(at::ScalarType::Int)); 73 | 74 | at::Tensor tmp = 75 | torch::full({points.size(0), points.size(1)}, 1e10, 76 | at::device(points.device()).dtype(at::ScalarType::Float)); 77 | 78 | if (points.is_cuda()) { 79 | furthest_point_sampling_kernel_wrapper( 80 | points.size(0), points.size(1), nsamples, points.data_ptr(), 81 | tmp.data_ptr(), output.data_ptr()); 82 | } else { 83 | AT_ASSERT(false, "CPU not supported"); 84 | } 85 | 86 | return output; 87 | } 88 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_ext-src/src/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | 6 | // input: points(b, c, n) idx(b, m) 7 | // output: out(b, c, m) 8 | __global__ void gather_points_kernel(int b, int c, int n, int m, 9 | const float *__restrict__ points, 10 | const int *__restrict__ idx, 11 | float *__restrict__ out) { 12 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 13 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 14 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 15 | int a = idx[i * m + j]; 16 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 17 | } 18 | } 19 | } 20 | } 21 | 22 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 23 | const float *points, const int *idx, 24 | float *out) { 25 | gather_points_kernel<<>>(b, c, n, npoints, 27 | points, idx, out); 28 | 29 | CUDA_CHECK_ERRORS(); 30 | } 31 | 32 | // input: grad_out(b, c, m) idx(b, m) 33 | // output: grad_points(b, c, n) 34 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 35 | const float *__restrict__ grad_out, 36 | const int *__restrict__ idx, 37 | float *__restrict__ grad_points) { 38 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 39 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 40 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 41 | int a = idx[i * m + j]; 42 | atomicAdd(grad_points + (i * c + l) * n + a, 43 | grad_out[(i * c + l) * m + j]); 44 | } 45 | } 46 | } 47 | } 48 | 49 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 50 | const float *grad_out, const int *idx, 51 | float *grad_points) { 52 | gather_points_grad_kernel<<>>( 54 | b, c, n, npoints, grad_out, idx, grad_points); 55 | 56 | CUDA_CHECK_ERRORS(); 57 | } 58 | 59 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 60 | int idx1, int idx2) { 61 | const float v1 = dists[idx1], v2 = dists[idx2]; 62 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 63 | dists[idx1] = max(v1, v2); 64 | dists_i[idx1] = v2 > v1 ? i2 : i1; 65 | } 66 | 67 | // Input dataset: (b, n, 3), tmp: (b, n) 68 | // Ouput idxs (b, m) 69 | template 70 | __global__ void furthest_point_sampling_kernel( 71 | int b, int n, int m, const float *__restrict__ dataset, 72 | float *__restrict__ temp, int *__restrict__ idxs) { 73 | if (m <= 0) return; 74 | __shared__ float dists[block_size]; 75 | __shared__ int dists_i[block_size]; 76 | 77 | int batch_index = blockIdx.x; 78 | dataset += batch_index * n * 3; 79 | temp += batch_index * n; 80 | idxs += batch_index * m; 81 | 82 | int tid = threadIdx.x; 83 | const int stride = block_size; 84 | 85 | int old = 0; 86 | if (threadIdx.x == 0) idxs[0] = old; 87 | 88 | __syncthreads(); 89 | for (int j = 1; j < m; j++) { 90 | int besti = 0; 91 | float best = -1; 92 | float x1 = dataset[old * 3 + 0]; 93 | float y1 = dataset[old * 3 + 1]; 94 | float z1 = dataset[old * 3 + 2]; 95 | for (int k = tid; k < n; k += stride) { 96 | float x2, y2, z2; 97 | x2 = dataset[k * 3 + 0]; 98 | y2 = dataset[k * 3 + 1]; 99 | z2 = dataset[k * 3 + 2]; 100 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 101 | if (mag <= 1e-3) continue; 102 | 103 | float d = 104 | (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 105 | 106 | float d2 = min(d, temp[k]); 107 | temp[k] = d2; 108 | besti = d2 > best ? k : besti; 109 | best = d2 > best ? d2 : best; 110 | } 111 | dists[tid] = best; 112 | dists_i[tid] = besti; 113 | __syncthreads(); 114 | 115 | if (block_size >= 512) { 116 | if (tid < 256) { 117 | __update(dists, dists_i, tid, tid + 256); 118 | } 119 | __syncthreads(); 120 | } 121 | if (block_size >= 256) { 122 | if (tid < 128) { 123 | __update(dists, dists_i, tid, tid + 128); 124 | } 125 | __syncthreads(); 126 | } 127 | if (block_size >= 128) { 128 | if (tid < 64) { 129 | __update(dists, dists_i, tid, tid + 64); 130 | } 131 | __syncthreads(); 132 | } 133 | if (block_size >= 64) { 134 | if (tid < 32) { 135 | __update(dists, dists_i, tid, tid + 32); 136 | } 137 | __syncthreads(); 138 | } 139 | if (block_size >= 32) { 140 | if (tid < 16) { 141 | __update(dists, dists_i, tid, tid + 16); 142 | } 143 | __syncthreads(); 144 | } 145 | if (block_size >= 16) { 146 | if (tid < 8) { 147 | __update(dists, dists_i, tid, tid + 8); 148 | } 149 | __syncthreads(); 150 | } 151 | if (block_size >= 8) { 152 | if (tid < 4) { 153 | __update(dists, dists_i, tid, tid + 4); 154 | } 155 | __syncthreads(); 156 | } 157 | if (block_size >= 4) { 158 | if (tid < 2) { 159 | __update(dists, dists_i, tid, tid + 2); 160 | } 161 | __syncthreads(); 162 | } 163 | if (block_size >= 2) { 164 | if (tid < 1) { 165 | __update(dists, dists_i, tid, tid + 1); 166 | } 167 | __syncthreads(); 168 | } 169 | 170 | old = dists_i[0]; 171 | if (tid == 0) idxs[j] = old; 172 | } 173 | } 174 | 175 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 176 | const float *dataset, float *temp, 177 | int *idxs) { 178 | unsigned int n_threads = opt_n_threads(n); 179 | 180 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 181 | 182 | switch (n_threads) { 183 | case 512: 184 | furthest_point_sampling_kernel<512> 185 | <<>>(b, n, m, dataset, temp, idxs); 186 | break; 187 | case 256: 188 | furthest_point_sampling_kernel<256> 189 | <<>>(b, n, m, dataset, temp, idxs); 190 | break; 191 | case 128: 192 | furthest_point_sampling_kernel<128> 193 | <<>>(b, n, m, dataset, temp, idxs); 194 | break; 195 | case 64: 196 | furthest_point_sampling_kernel<64> 197 | <<>>(b, n, m, dataset, temp, idxs); 198 | break; 199 | case 32: 200 | furthest_point_sampling_kernel<32> 201 | <<>>(b, n, m, dataset, temp, idxs); 202 | break; 203 | case 16: 204 | furthest_point_sampling_kernel<16> 205 | <<>>(b, n, m, dataset, temp, idxs); 206 | break; 207 | case 8: 208 | furthest_point_sampling_kernel<8> 209 | <<>>(b, n, m, dataset, temp, idxs); 210 | break; 211 | case 4: 212 | furthest_point_sampling_kernel<4> 213 | <<>>(b, n, m, dataset, temp, idxs); 214 | break; 215 | case 2: 216 | furthest_point_sampling_kernel<2> 217 | <<>>(b, n, m, dataset, temp, idxs); 218 | break; 219 | case 1: 220 | furthest_point_sampling_kernel<1> 221 | <<>>(b, n, m, dataset, temp, idxs); 222 | break; 223 | default: 224 | furthest_point_sampling_kernel<512> 225 | <<>>(b, n, m, dataset, temp, idxs); 226 | } 227 | 228 | CUDA_CHECK_ERRORS(); 229 | } 230 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.0.0" 2 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/pointnet2_ops/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from pointnet2_ops import pointnet2_utils 7 | 8 | 9 | def build_shared_mlp(mlp_spec: List[int], bn: bool = True): 10 | layers = [] 11 | for i in range(1, len(mlp_spec)): 12 | layers.append( 13 | nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn) 14 | ) 15 | if bn: 16 | layers.append(nn.BatchNorm2d(mlp_spec[i])) 17 | layers.append(nn.ReLU(True)) 18 | 19 | return nn.Sequential(*layers) 20 | 21 | 22 | class _PointnetSAModuleBase(nn.Module): 23 | def __init__(self): 24 | super(_PointnetSAModuleBase, self).__init__() 25 | self.npoint = None 26 | self.groupers = None 27 | self.mlps = None 28 | 29 | def forward( 30 | self, xyz: torch.Tensor, features: Optional[torch.Tensor] 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | r""" 33 | Parameters 34 | ---------- 35 | xyz : torch.Tensor 36 | (B, N, 3) tensor of the xyz coordinates of the features 37 | features : torch.Tensor 38 | (B, C, N) tensor of the descriptors of the the features 39 | 40 | Returns 41 | ------- 42 | new_xyz : torch.Tensor 43 | (B, npoint, 3) tensor of the new features' xyz 44 | new_features : torch.Tensor 45 | (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors 46 | """ 47 | 48 | new_features_list = [] 49 | 50 | xyz_flipped = xyz.transpose(1, 2).contiguous() 51 | new_xyz = ( 52 | pointnet2_utils.gather_operation( 53 | xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint) 54 | ) 55 | .transpose(1, 2) 56 | .contiguous() 57 | if self.npoint is not None 58 | else None 59 | ) 60 | 61 | for i in range(len(self.groupers)): 62 | new_features = self.groupers[i]( 63 | xyz, new_xyz, features 64 | ) # (B, C, npoint, nsample) 65 | 66 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) 67 | new_features = F.max_pool2d( 68 | new_features, kernel_size=[1, new_features.size(3)] 69 | ) # (B, mlp[-1], npoint, 1) 70 | new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) 71 | 72 | new_features_list.append(new_features) 73 | 74 | return new_xyz, torch.cat(new_features_list, dim=1) 75 | 76 | 77 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 78 | r"""Pointnet set abstrction layer with multiscale grouping 79 | 80 | Parameters 81 | ---------- 82 | npoint : int 83 | Number of features 84 | radii : list of float32 85 | list of radii to group with 86 | nsamples : list of int32 87 | Number of samples in each ball query 88 | mlps : list of list of int32 89 | Spec of the pointnet before the global max_pool for each scale 90 | bn : bool 91 | Use batchnorm 92 | """ 93 | 94 | def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True): 95 | # type: (PointnetSAModuleMSG, int, List[float], List[int], List[List[int]], bool, bool) -> None 96 | super(PointnetSAModuleMSG, self).__init__() 97 | 98 | assert len(radii) == len(nsamples) == len(mlps) 99 | 100 | self.npoint = npoint 101 | self.groupers = nn.ModuleList() 102 | self.mlps = nn.ModuleList() 103 | for i in range(len(radii)): 104 | radius = radii[i] 105 | nsample = nsamples[i] 106 | self.groupers.append( 107 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 108 | if npoint is not None 109 | else pointnet2_utils.GroupAll(use_xyz) 110 | ) 111 | mlp_spec = mlps[i] 112 | if use_xyz: 113 | mlp_spec[0] += 3 114 | 115 | self.mlps.append(build_shared_mlp(mlp_spec, bn)) 116 | 117 | 118 | class PointnetSAModule(PointnetSAModuleMSG): 119 | r"""Pointnet set abstrction layer 120 | 121 | Parameters 122 | ---------- 123 | npoint : int 124 | Number of features 125 | radius : float 126 | Radius of ball 127 | nsample : int 128 | Number of samples in the ball query 129 | mlp : list 130 | Spec of the pointnet before the global max_pool 131 | bn : bool 132 | Use batchnorm 133 | """ 134 | 135 | def __init__( 136 | self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True 137 | ): 138 | # type: (PointnetSAModule, List[int], int, float, int, bool, bool) -> None 139 | super(PointnetSAModule, self).__init__( 140 | mlps=[mlp], 141 | npoint=npoint, 142 | radii=[radius], 143 | nsamples=[nsample], 144 | bn=bn, 145 | use_xyz=use_xyz, 146 | ) 147 | 148 | 149 | class PointnetFPModule(nn.Module): 150 | r"""Propigates the features of one set to another 151 | 152 | Parameters 153 | ---------- 154 | mlp : list 155 | Pointnet module parameters 156 | bn : bool 157 | Use batchnorm 158 | """ 159 | 160 | def __init__(self, mlp, bn=True): 161 | # type: (PointnetFPModule, List[int], bool) -> None 162 | super(PointnetFPModule, self).__init__() 163 | self.mlp = build_shared_mlp(mlp, bn=bn) 164 | 165 | def forward(self, unknown, known, unknow_feats, known_feats): 166 | # type: (PointnetFPModule, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor 167 | r""" 168 | Parameters 169 | ---------- 170 | unknown : torch.Tensor 171 | (B, n, 3) tensor of the xyz positions of the unknown features 172 | known : torch.Tensor 173 | (B, m, 3) tensor of the xyz positions of the known features 174 | unknow_feats : torch.Tensor 175 | (B, C1, n) tensor of the features to be propigated to 176 | known_feats : torch.Tensor 177 | (B, C2, m) tensor of features to be propigated 178 | 179 | Returns 180 | ------- 181 | new_features : torch.Tensor 182 | (B, mlp[-1], n) tensor of the features of the unknown features 183 | """ 184 | 185 | if known is not None: 186 | dist, idx = pointnet2_utils.three_nn(unknown, known) 187 | dist_recip = 1.0 / (dist + 1e-8) 188 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 189 | weight = dist_recip / norm 190 | 191 | interpolated_feats = pointnet2_utils.three_interpolate( 192 | known_feats, idx, weight 193 | ) 194 | else: 195 | interpolated_feats = known_feats.expand( 196 | *(known_feats.size()[0:2] + [unknown.size(1)]) 197 | ) 198 | 199 | if unknow_feats is not None: 200 | new_features = torch.cat( 201 | [interpolated_feats, unknow_feats], dim=1 202 | ) # (B, C2 + C1, n) 203 | else: 204 | new_features = interpolated_feats 205 | 206 | new_features = new_features.unsqueeze(-1) 207 | new_features = self.mlp(new_features) 208 | 209 | return new_features.squeeze(-1) 210 | -------------------------------------------------------------------------------- /pointnet2_ops_lib/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | 5 | from setuptools import find_packages, setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | this_dir = osp.dirname(osp.abspath(__file__)) 9 | _ext_src_root = osp.join("pointnet2_ops", "_ext-src") 10 | _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob( 11 | osp.join(_ext_src_root, "src", "*.cu") 12 | ) 13 | _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*")) 14 | 15 | requirements = ["torch>=1.4"] 16 | 17 | exec(open(osp.join("pointnet2_ops", "_version.py")).read()) 18 | 19 | os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5" 20 | setup( 21 | name="pointnet2_ops", 22 | version=__version__, 23 | author="Erik Wijmans", 24 | packages=find_packages(), 25 | install_requires=requirements, 26 | ext_modules=[ 27 | CUDAExtension( 28 | name="pointnet2_ops._ext", 29 | sources=_ext_sources, 30 | extra_compile_args={ 31 | "cxx": ["-O3"], 32 | "nvcc": ["-O3", "-Xfatbin", "-compress-all"], 33 | }, 34 | include_dirs=[osp.join(this_dir, _ext_src_root, "include")], 35 | ) 36 | ], 37 | cmdclass={"build_ext": BuildExtension}, 38 | include_package_data=True, 39 | ) 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | easydict 3 | h5py 4 | matplotlib 5 | numpy 6 | open3d==0.9.0.0 7 | opencv-python 8 | pyexr 9 | scipy 10 | tensorboardX==1.2 11 | torch>=1.4.0 12 | transforms3d 13 | tqdm -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/average_meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/average_meter.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_loaders.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/data_loaders.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/data_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/io.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/diviswen/PMP-Net/c524b93187978302239616237a19dbd6bc857721/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /utils/average_meter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-06 22:50:12 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-03 21:50:38 6 | # @Email: cshzxie@gmail.com 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | def __init__(self, items=None): 12 | self.items = items 13 | self.n_items = 1 if items is None else len(items) 14 | self.reset() 15 | 16 | def reset(self): 17 | self._val = [0] * self.n_items 18 | self._sum = [0] * self.n_items 19 | self._count = [0] * self.n_items 20 | 21 | def update(self, values): 22 | if type(values).__name__ == 'list': 23 | for idx, v in enumerate(values): 24 | self._val[idx] = v 25 | self._sum[idx] += v 26 | self._count[idx] += 1 27 | else: 28 | self._val[0] = values 29 | self._sum[0] += values 30 | self._count[0] += 1 31 | 32 | def val(self, idx=None): 33 | if idx is None: 34 | return self._val[0] if self.items is None else [self._val[i] for i in range(self.n_items)] 35 | else: 36 | return self._val[idx] 37 | 38 | def count(self, idx=None): 39 | if idx is None: 40 | return self._count[0] if self.items is None else [self._count[i] for i in range(self.n_items)] 41 | else: 42 | return self._count[idx] 43 | 44 | def avg(self, idx=None): 45 | if idx is None: 46 | return self._sum[0] / self._count[0] if self.items is None else [ 47 | self._sum[i] / self._count[i] for i in range(self.n_items) 48 | ] 49 | else: 50 | return self._sum[idx] / self._count[idx] 51 | -------------------------------------------------------------------------------- /utils/data_transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-02 14:38:36 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-07-03 09:23:07 6 | # @Email: cshzxie@gmail.com 7 | 8 | import cv2 9 | import math 10 | import numpy as np 11 | import torch 12 | import transforms3d 13 | 14 | class Compose(object): 15 | def __init__(self, transforms): 16 | self.transformers = [] 17 | for tr in transforms: 18 | transformer = eval(tr['callback']) 19 | parameters = tr['parameters'] if 'parameters' in tr else None 20 | self.transformers.append({ 21 | 'callback': transformer(parameters), 22 | 'objects': tr['objects'] 23 | }) # yapf: disable 24 | 25 | def __call__(self, data): 26 | for tr in self.transformers: 27 | transform = tr['callback'] 28 | objects = tr['objects'] 29 | rnd_value = np.random.uniform(0, 1) 30 | if transform.__class__ in [NormalizeObjectPose]: 31 | data = transform(data) 32 | else: 33 | for k, v in data.items(): 34 | if k in objects and k in data: 35 | if transform.__class__ in [ 36 | RandomCrop, RandomFlip, RandomRotatePoints, ScalePoints, RandomMirrorPoints 37 | ]: 38 | data[k] = transform(v, rnd_value) 39 | else: 40 | data[k] = transform(v) 41 | 42 | return data 43 | 44 | 45 | class ToTensor(object): 46 | def __init__(self, parameters): 47 | pass 48 | 49 | def __call__(self, arr): 50 | shape = arr.shape 51 | if len(shape) == 3: # RGB/Depth Images 52 | arr = arr.transpose(2, 0, 1) 53 | 54 | # Ref: https://discuss.pytorch.org/t/torch-from-numpy-not-support-negative-strides/3663/2 55 | return torch.from_numpy(arr.copy()).float() 56 | 57 | 58 | class Normalize(object): 59 | def __init__(self, parameters): 60 | self.mean = parameters['mean'] 61 | self.std = parameters['std'] 62 | 63 | def __call__(self, arr): 64 | arr = arr.astype(np.float32) 65 | arr /= self.std 66 | arr -= self.mean 67 | 68 | return arr 69 | 70 | 71 | class CenterCrop(object): 72 | def __init__(self, parameters): 73 | self.img_size_h = parameters['img_size'][0] 74 | self.img_size_w = parameters['img_size'][1] 75 | self.crop_size_h = parameters['crop_size'][0] 76 | self.crop_size_w = parameters['crop_size'][1] 77 | 78 | def __call__(self, img): 79 | img_w, img_h, _ = img.shape 80 | x_left = (img_w - self.crop_size_w) * .5 81 | x_right = x_left + self.crop_size_w 82 | y_top = (img_h - self.crop_size_h) * .5 83 | y_bottom = y_top + self.crop_size_h 84 | 85 | # Crop the image 86 | img = cv2.resize(img[int(y_top):int(y_bottom), int(x_left):int(x_right)], (self.img_size_w, self.img_size_h)) 87 | img = img[..., np.newaxis] if len(img.shape) == 2 else img 88 | 89 | return img 90 | 91 | 92 | class RandomCrop(object): 93 | def __init__(self, parameters): 94 | self.img_size_h = parameters['img_size'][0] 95 | self.img_size_w = parameters['img_size'][1] 96 | self.crop_size_h = parameters['crop_size'][0] 97 | self.crop_size_w = parameters['crop_size'][1] 98 | 99 | def __call__(self, img, rnd_value): 100 | img_w, img_h, _ = img.shape 101 | x_left = (img_w - self.crop_size_w) * rnd_value 102 | x_right = x_left + self.crop_size_w 103 | y_top = (img_h - self.crop_size_h) * rnd_value 104 | y_bottom = y_top + self.crop_size_h 105 | 106 | # Crop the image 107 | img = cv2.resize(img[int(y_top):int(y_bottom), int(x_left):int(x_right)], (self.img_size_w, self.img_size_h)) 108 | img = img[..., np.newaxis] if len(img.shape) == 2 else img 109 | 110 | return img 111 | 112 | class ScalePoints(object): 113 | def __init__(self, parameters): 114 | self.scale = None 115 | if 'scale' in parameters: 116 | self.scale = parameters['scale'] 117 | 118 | def __call__(self, ptcloud, rnd_value): 119 | if self.scale is not None: 120 | scale = self.scale 121 | else: 122 | scale = np.random.randint(85, 95) * 0.01 123 | ptcloud = ptcloud * scale 124 | return ptcloud 125 | 126 | 127 | class RandomFlip(object): 128 | def __init__(self, parameters): 129 | pass 130 | 131 | def __call__(self, img, rnd_value): 132 | if rnd_value > 0.5: 133 | img = np.fliplr(img) 134 | 135 | return img 136 | 137 | 138 | class RandomPermuteRGB(object): 139 | def __init__(self, parameters): 140 | pass 141 | 142 | def __call__(self, img): 143 | rgb_permutation = np.random.permutation(3) 144 | return img[..., rgb_permutation] 145 | 146 | 147 | class RandomBackground(object): 148 | def __init__(self, parameters): 149 | self.random_bg_color_range = parameters['bg_color'] 150 | 151 | def __call__(self, img): 152 | img_h, img_w, img_c = img.shape 153 | if not img_c == 4: 154 | return img 155 | 156 | r, g, b = [ 157 | np.random.randint(self.random_bg_color_range[i][0], self.random_bg_color_range[i][1] + 1) for i in range(3) 158 | ] 159 | alpha = (np.expand_dims(img[:, :, 3], axis=2) == 0).astype(np.float32) 160 | img = img[:, :, :3] 161 | bg_color = np.array([[[r, g, b]]]) / 255. 162 | img = alpha * bg_color + (1 - alpha) * img 163 | 164 | return img 165 | 166 | 167 | class UpSamplePoints(object): 168 | def __init__(self, parameters): 169 | self.n_points = parameters['n_points'] 170 | 171 | def __call__(self, ptcloud): 172 | curr = ptcloud.shape[0] 173 | need = self.n_points - curr 174 | 175 | if need < 0: 176 | return ptcloud[np.random.permutation(self.n_points)] 177 | 178 | while curr <= need: 179 | ptcloud = np.tile(ptcloud, (2, 1)) 180 | need -= curr 181 | curr *= 2 182 | 183 | choice = np.random.permutation(need) 184 | ptcloud = np.concatenate((ptcloud, ptcloud[choice])) 185 | 186 | return ptcloud 187 | 188 | 189 | class RandomSamplePoints(object): 190 | def __init__(self, parameters): 191 | self.n_points = parameters['n_points'] 192 | 193 | def __call__(self, ptcloud): 194 | choice = np.random.permutation(ptcloud.shape[0]) 195 | ptcloud = ptcloud[choice[:self.n_points]] 196 | 197 | if ptcloud.shape[0] < self.n_points: 198 | zeros = np.zeros((self.n_points - ptcloud.shape[0], 3)) 199 | ptcloud = np.concatenate([ptcloud, zeros]) 200 | 201 | return ptcloud 202 | 203 | 204 | class RandomClipPoints(object): 205 | def __init__(self, parameters): 206 | self.sigma = parameters['sigma'] if 'sigma' in parameters else 0.01 207 | self.clip = parameters['clip'] if 'clip' in parameters else 0.05 208 | 209 | def __call__(self, ptcloud): 210 | ptcloud += np.clip(self.sigma * np.random.randn(*ptcloud.shape), -self.clip, self.clip).astype(np.float32) 211 | return ptcloud 212 | 213 | 214 | class RandomRotatePoints(object): 215 | def __init__(self, parameters): 216 | pass 217 | 218 | def __call__(self, ptcloud, rnd_value): 219 | trfm_mat = transforms3d.zooms.zfdir2mat(1) 220 | angle = 2 * math.pi * rnd_value 221 | trfm_mat = np.dot(transforms3d.axangles.axangle2mat([0, 1, 0], angle), trfm_mat) 222 | 223 | ptcloud[:, :3] = np.dot(ptcloud[:, :3], trfm_mat.T) 224 | return ptcloud 225 | 226 | 227 | class RandomScalePoints(object): 228 | def __init__(self, parameters): 229 | self.scale = parameters['scale'] 230 | 231 | def __call__(self, ptcloud, rnd_value): 232 | trfm_mat = transforms3d.zooms.zfdir2mat(1) 233 | scale = np.random.uniform(1.0 / self.scale * rnd_value, self.scale * rnd_value) 234 | trfm_mat = np.dot(transforms3d.zooms.zfdir2mat(scale), trfm_mat) 235 | 236 | ptcloud[:, :3] = np.dot(ptcloud[:, :3], trfm_mat.T) 237 | return ptcloud 238 | 239 | 240 | class RandomMirrorPoints(object): 241 | def __init__(self, parameters): 242 | pass 243 | 244 | def __call__(self, ptcloud, rnd_value): 245 | trfm_mat = transforms3d.zooms.zfdir2mat(1) 246 | trfm_mat_x = np.dot(transforms3d.zooms.zfdir2mat(-1, [1, 0, 0]), trfm_mat) 247 | trfm_mat_z = np.dot(transforms3d.zooms.zfdir2mat(-1, [0, 0, 1]), trfm_mat) 248 | if rnd_value <= 0.25: 249 | trfm_mat = np.dot(trfm_mat_x, trfm_mat) 250 | trfm_mat = np.dot(trfm_mat_z, trfm_mat) 251 | elif rnd_value > 0.25 and rnd_value <= 0.5: # lgtm [py/redundant-comparison] 252 | trfm_mat = np.dot(trfm_mat_x, trfm_mat) 253 | elif rnd_value > 0.5 and rnd_value <= 0.75: 254 | trfm_mat = np.dot(trfm_mat_z, trfm_mat) 255 | 256 | ptcloud[:, :3] = np.dot(ptcloud[:, :3], trfm_mat.T) 257 | return ptcloud 258 | 259 | 260 | 261 | 262 | class NormalizeObjectPose(object): 263 | def __init__(self, parameters): 264 | input_keys = parameters['input_keys'] 265 | self.ptcloud_key = input_keys['ptcloud'] 266 | self.bbox_key = input_keys['bbox'] 267 | 268 | def __call__(self, data): 269 | ptcloud = data[self.ptcloud_key] 270 | bbox = data[self.bbox_key] 271 | 272 | # Calculate center, rotation and scale 273 | # References: 274 | # - https://github.com/wentaoyuan/pcn/blob/master/test_kitti.py#L40-L52 275 | center = (bbox.min(0) + bbox.max(0)) / 2 276 | bbox -= center 277 | yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0]) 278 | rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0], [np.sin(yaw), np.cos(yaw), 0], [0, 0, 1]]) 279 | bbox = np.dot(bbox, rotation) 280 | scale = bbox[3, 0] - bbox[0, 0] 281 | bbox /= scale 282 | ptcloud = np.dot(ptcloud - center, rotation) / scale 283 | ptcloud = np.dot(ptcloud, [[1, 0, 0], [0, 0, 1], [0, 1, 0]]) 284 | 285 | data[self.ptcloud_key] = ptcloud 286 | return data 287 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-07-31 16:57:15 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-02-22 18:34:19 6 | # @Email: cshzxie@gmail.com 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | 14 | def var_or_cuda(x): 15 | if torch.cuda.is_available(): 16 | x = x.cuda(non_blocking=True) 17 | 18 | return x 19 | 20 | 21 | def init_weights(m): 22 | if type(m) == torch.nn.Conv2d or type(m) == torch.nn.ConvTranspose2d or \ 23 | type(m) == torch.nn.Conv3d or type(m) == torch.nn.ConvTranspose3d: 24 | torch.nn.init.kaiming_normal_(m.weight) 25 | if m.bias is not None: 26 | torch.nn.init.constant_(m.bias, 0) 27 | elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.BatchNorm3d: 28 | torch.nn.init.constant_(m.weight, 1) 29 | torch.nn.init.constant_(m.bias, 0) 30 | elif type(m) == torch.nn.Linear: 31 | torch.nn.init.normal_(m.weight, 0, 0.01) 32 | torch.nn.init.constant_(m.bias, 0) 33 | 34 | 35 | def count_parameters(network): 36 | return sum(p.numel() for p in network.parameters()) 37 | 38 | 39 | def get_ptcloud_img(ptcloud): 40 | fig = plt.figure(figsize=(8, 8)) 41 | 42 | x, z, y = ptcloud.transpose(1, 0) 43 | ax = fig.gca(projection=Axes3D.name, adjustable='box') 44 | ax.axis('off') 45 | ax.axis('scaled') 46 | ax.view_init(30, 45) 47 | 48 | max, min = np.max(ptcloud), np.min(ptcloud) 49 | ax.set_xbound(min, max) 50 | ax.set_ybound(min, max) 51 | ax.set_zbound(min, max) 52 | ax.scatter(x, y, z, zdir='z', c=x, cmap='jet') 53 | 54 | fig.canvas.draw() 55 | img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 56 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, )) 57 | return img 58 | -------------------------------------------------------------------------------- /utils/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-02 10:22:03 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-02-22 19:13:01 6 | # @Email: cshzxie@gmail.com 7 | 8 | import cv2 9 | import h5py 10 | import numpy as np 11 | import pyexr 12 | import open3d 13 | import os 14 | import sys 15 | 16 | from io import BytesIO 17 | 18 | # References: http://confluence.sensetime.com/pages/viewpage.action?pageId=44650315 19 | from config_c3d import cfg 20 | sys.path.append(cfg.MEMCACHED.LIBRARY_PATH) 21 | 22 | mc_client = None 23 | if cfg.MEMCACHED.ENABLED: 24 | import mc 25 | mc_client = mc.MemcachedClient.GetInstance(cfg.MEMCACHED.SERVER_CONFIG, cfg.MEMCACHED.CLIENT_CONFIG) 26 | 27 | 28 | class IO: 29 | @classmethod 30 | def get(cls, file_path): 31 | _, file_extension = os.path.splitext(file_path) 32 | 33 | if file_extension in ['.png', '.jpg']: 34 | return cls._read_img(file_path) 35 | elif file_extension in ['.npy']: 36 | return cls._read_npy(file_path) 37 | elif file_extension in ['.exr']: 38 | return cls._read_exr(file_path) 39 | elif file_extension in ['.pcd']: 40 | return cls._read_pcd(file_path) 41 | elif file_extension in ['.h5']: 42 | return cls._read_h5(file_path) 43 | elif file_extension in ['.txt']: 44 | return cls._read_txt(file_path) 45 | else: 46 | raise Exception('Unsupported file extension: %s' % file_extension) 47 | 48 | @classmethod 49 | def put(cls, file_path, file_content): 50 | _, file_extension = os.path.splitext(file_path) 51 | 52 | if file_extension in ['.pcd']: 53 | return cls._write_pcd(file_path, file_content) 54 | elif file_extension in ['.h5']: 55 | return cls._write_h5(file_path, file_content) 56 | else: 57 | raise Exception('Unsupported file extension: %s' % file_extension) 58 | 59 | @classmethod 60 | def _read_img(cls, file_path): 61 | if mc_client is None: 62 | return cv2.imread(file_path, cv2.IMREAD_UNCHANGED) / 255. 63 | else: 64 | pyvector = mc.pyvector() 65 | mc_client.Get(file_path, pyvector) 66 | buf = mc.ConvertBuffer(pyvector) 67 | img_array = np.frombuffer(buf, np.uint8) 68 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) 69 | return img / 255. 70 | 71 | # References: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py 72 | @classmethod 73 | def _read_npy(cls, file_path): 74 | if mc_client is None: 75 | return np.load(file_path) 76 | else: 77 | pyvector = mc.pyvector() 78 | mc_client.Get(file_path, pyvector) 79 | buf = mc.ConvertBuffer(pyvector) 80 | buf_bytes = buf.tobytes() 81 | if not buf_bytes[:6] == b'\x93NUMPY': 82 | raise Exception('Invalid npy file format.') 83 | 84 | header_size = int.from_bytes(buf_bytes[8:10], byteorder='little') 85 | header = eval(buf_bytes[10:header_size + 10]) 86 | dtype = np.dtype(header['descr']) 87 | nd_array = np.frombuffer(buf[header_size + 10:], dtype).reshape(header['shape']) 88 | 89 | return nd_array 90 | 91 | @classmethod 92 | def _read_exr(cls, file_path): 93 | return 1.0 / pyexr.open(file_path).get("Depth.Z").astype(np.float32) 94 | 95 | # References: https://github.com/dimatura/pypcd/blob/master/pypcd/pypcd.py#L275 96 | # Support PCD files without compression ONLY! 97 | @classmethod 98 | def _read_pcd(cls, file_path): 99 | if mc_client is None: 100 | pc = open3d.io.read_point_cloud(file_path) 101 | ptcloud = np.array(pc.points) 102 | else: 103 | pyvector = mc.pyvector() 104 | mc_client.Get(file_path, pyvector) 105 | text = mc.ConvertString(pyvector).split('\n') 106 | start_line_idx = len(text) - 1 107 | for idx, line in enumerate(text): 108 | if line == 'DATA ascii': 109 | start_line_idx = idx + 1 110 | break 111 | 112 | ptcloud = text[start_line_idx:] 113 | ptcloud = np.genfromtxt(BytesIO('\n'.join(ptcloud).encode()), dtype=np.float32) 114 | 115 | # ptcloud = np.concatenate((ptcloud, np.array([[0, 0, 0]])), axis=0) 116 | return ptcloud 117 | 118 | @classmethod 119 | def _read_h5(cls, file_path): 120 | f = h5py.File(file_path, 'r') 121 | # Avoid overflow while gridding 122 | return f['data'][()] 123 | 124 | @classmethod 125 | def _read_txt(cls, file_path): 126 | return np.loadtxt(file_path) 127 | 128 | @classmethod 129 | def _write_pcd(cls, file_path, file_content): 130 | pc = open3d.geometry.PointCloud() 131 | pc.points = open3d.utility.Vector3dVector(file_content) 132 | open3d.io.write_point_cloud(file_path, pc) 133 | 134 | @classmethod 135 | def _write_h5(cls, file_path, file_content): 136 | with h5py.File(file_path, 'w') as f: 137 | f.create_dataset('data', data=file_content) 138 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-08 14:31:30 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2020-05-25 09:13:32 6 | # @Email: cshzxie@gmail.com 7 | 8 | import logging 9 | import open3d 10 | import torch 11 | 12 | from Chamfer3D.dist_chamfer_3D import chamfer_3DDist 13 | 14 | class Metrics(object): 15 | ITEMS = [{ 16 | 'name': 'pmd', 17 | 'enabled': True, 18 | 'eval_func': 'cls._get_emd_distance', 19 | 'eval_object': chamfer_3DDist(), 20 | 'is_greater_better': False, 21 | 'init_value': 32767 22 | },{ 23 | 'name': 'ChamferDistance', 24 | 'enabled': True, 25 | 'eval_func': 'cls._get_chamfer_distance', 26 | 'eval_object': chamfer_3DDist(), 27 | # 'eval_object': ChamferDistance(ignore_zeros=True), 28 | 'is_greater_better': False, 29 | 'init_value': 32767 30 | }] 31 | 32 | @classmethod 33 | def get(cls, pred, gt): 34 | _items = cls.items() 35 | _values = [0] * len(_items) 36 | for i, item in enumerate(_items): 37 | eval_func = eval(item['eval_func']) 38 | _values[i] = eval_func(pred, gt) 39 | 40 | return _values 41 | 42 | @classmethod 43 | def items(cls): 44 | return [i for i in cls.ITEMS if i['enabled']] 45 | 46 | @classmethod 47 | def names(cls): 48 | _items = cls.items() 49 | return [i['name'] for i in _items] 50 | 51 | @classmethod 52 | def _get_f_score(cls, pred, gt, th=0.01): 53 | """References: https://github.com/lmb-freiburg/what3d/blob/master/util.py""" 54 | pred = cls._get_open3d_ptcloud(pred) 55 | gt = cls._get_open3d_ptcloud(gt) 56 | 57 | dist1 = pred.compute_point_cloud_distance(gt) 58 | dist2 = gt.compute_point_cloud_distance(pred) 59 | 60 | recall = float(sum(d < th for d in dist2)) / float(len(dist2)) 61 | precision = float(sum(d < th for d in dist1)) / float(len(dist1)) 62 | return 2 * recall * precision / (recall + precision) if recall + precision else 0 63 | 64 | @classmethod 65 | def _get_open3d_ptcloud(cls, tensor): 66 | tensor = tensor.squeeze().cpu().numpy() 67 | ptcloud = open3d.geometry.PointCloud() 68 | ptcloud.points = open3d.utility.Vector3dVector(tensor) 69 | 70 | return ptcloud 71 | 72 | @classmethod 73 | def _get_chamfer_distance(cls, pred, gt): 74 | # chamfer_distance = cls.ITEMS[1]['eval_object'] 75 | chamfer_distance = cls.ITEMS[1]['eval_object'] 76 | d1, d2, _, _ = chamfer_distance(pred, gt) 77 | cd = torch.mean(d1) + torch.mean(d2) 78 | return cd.item() * 1000 79 | # return chamfer_distance(pred, gt).item() * 1000 80 | 81 | @classmethod 82 | def _get_emd_distance(cls, pred, gt): 83 | emd_distance = cls.ITEMS[0]['eval_object'] 84 | return torch.mean(emd_distance(pred, gt)).item() 85 | 86 | def __init__(self, metric_name, values): 87 | self._items = Metrics.items() 88 | self._values = [item['init_value'] for item in self._items] 89 | self.metric_name = metric_name 90 | 91 | if type(values).__name__ == 'list': 92 | self._values = values 93 | elif type(values).__name__ == 'dict': 94 | metric_indexes = {} 95 | for idx, item in enumerate(self._items): 96 | item_name = item['name'] 97 | metric_indexes[item_name] = idx 98 | for k, v in values.items(): 99 | if k not in metric_indexes: 100 | logging.warn('Ignore Metric[Name=%s] due to disability.' % k) 101 | continue 102 | self._values[metric_indexes[k]] = v 103 | else: 104 | raise Exception('Unsupported value type: %s' % type(values)) 105 | 106 | def state_dict(self): 107 | _dict = dict() 108 | for i in range(len(self._items)): 109 | item = self._items[i]['name'] 110 | value = self._values[i] 111 | _dict[item] = value 112 | 113 | return _dict 114 | 115 | def __repr__(self): 116 | return str(self.state_dict()) 117 | 118 | def better_than(self, other): 119 | if other is None: 120 | return True 121 | 122 | _index = -1 123 | for i, _item in enumerate(self._items): 124 | if _item['name'] == self.metric_name: 125 | _index = i 126 | break 127 | if _index == -1: 128 | raise Exception('Invalid metric name to compare.') 129 | 130 | _metric = self._items[i] 131 | _value = self._values[_index] 132 | other_value = other._values[_index] 133 | return _value > other_value if _metric['is_greater_better'] else _value < other_value 134 | --------------------------------------------------------------------------------