├── CMakeLists.txt ├── LICENSE ├── README.md ├── cfgs └── config_cls.yaml ├── cls └── model_cls_L6_iter_36567_acc_0.923825.pth ├── data ├── ModelNet40Loader.py ├── __init__.py └── data_utils.py ├── models ├── __init__.py └── densepoint_cls_L6_k24_g2.py ├── train_cls.py ├── train_cls.sh ├── utils ├── __init__.py ├── _ext │ ├── __init__.py │ └── pointnet2 │ │ └── __init__.py ├── build_ffi.py ├── cinclude │ ├── ball_query_gpu.h │ ├── ball_query_wrapper.h │ ├── cuda_utils.h │ ├── group_points_gpu.h │ ├── group_points_wrapper.h │ ├── interpolate_gpu.h │ ├── interpolate_wrapper.h │ ├── sampling_gpu.h │ └── sampling_wrapper.h ├── csrc │ ├── ball_query.c │ ├── ball_query_gpu.cu │ ├── group_points.c │ ├── group_points_gpu.cu │ ├── interpolate.c │ ├── interpolate_gpu.cu │ ├── sampling.c │ └── sampling_gpu.cu ├── linalg_utils.py ├── pointnet2_modules.py ├── pointnet2_utils.py └── pytorch_utils │ ├── __init__.py │ └── pytorch_utils.py └── voting_evaluate_cls.py /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | project(PointNet2) 2 | cmake_minimum_required(VERSION 2.8) 3 | 4 | find_package(CUDA REQUIRED) 5 | 6 | include_directories("${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude") 7 | cuda_include_directories("${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude") 8 | file(GLOB cuda_kernels_src "${CMAKE_CURRENT_SOURCE_DIR}/utils/csrc/*.cu") 9 | cuda_compile(cuda_kernels SHARED ${cuda_kernels_src} OPTIONS -O3) 10 | 11 | set(BUILD_CMD python "${CMAKE_CURRENT_SOURCE_DIR}/utils/build_ffi.py") 12 | file(GLOB wrapper_headers "${CMAKE_CURRENT_SOURCE_DIR}/utils/cinclude/*wrapper.h") 13 | file(GLOB wrapper_sources "${CMAKE_CURRENT_SOURCE_DIR}/utils/csrs/*.c") 14 | add_custom_command(OUTPUT "${CMAKE_CURRENT_SOURCE_DIR}/utils/_ext/pointnet2/_pointnet2.so" 15 | WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/utils 16 | COMMAND ${BUILD_CMD} --build --objs ${cuda_kernels} 17 | DEPENDS ${cuda_kernels} 18 | DEPENDS ${wrapper_headers} 19 | DEPENDS ${wrapper_sources} 20 | VERBATIM) 21 | 22 | add_custom_target(pointnet2_ext ALL 23 | DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/utils/_ext/pointnet2/_pointnet2.so") 24 | 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yongcheng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | DensePoint 2 | === 3 | This repository contains the code in Pytorch for the paper: 4 | 5 | __DensePoint: Learning Densely Contextual Representation for Efficient Point Cloud Processing__ [[arXiv](https://arxiv.org/abs/1909.03669)] [[CVF](http://openaccess.thecvf.com/content_ICCV_2019/papers/Liu_DensePoint_Learning_Densely_Contextual_Representation_for_Efficient_Point_Cloud_Processing_ICCV_2019_paper.pdf)] 6 |
7 | [Yongcheng Liu](https://yochengliu.github.io/), [Bin Fan](http://www.nlpr.ia.ac.cn/fanbin/), [Gaofeng Meng](http://www.escience.cn/people/menggaofeng/index.html;jsessionid=EE2E193290F516D1BA8E2E35A09A9A08-n1), [Jiwen Lu](http://ivg.au.tsinghua.edu.cn/Jiwen_Lu/), [Shiming Xiang](https://scholar.google.com/citations?user=0ggsACEAAAAJ&hl=zh-CN) and [Chunhong Pan](http://people.ucas.ac.cn/~0005314) 8 |
9 | [__ICCV 2019__](http://iccv2019.thecvf.com/) 10 | 11 | ## Citation 12 | 13 | If our paper is helpful for your research, please consider citing: 14 | 15 | @inproceedings{liu2019densepoint, 16 | author = {Yongcheng Liu and 17 | Bin Fan and 18 | Gaofeng Meng and 19 | Jiwen Lu and 20 | Shiming Xiang and 21 | Chunhong Pan}, 22 | title = {DensePoint: Learning Densely Contextual Representation for Efficient Point Cloud Processing}, 23 | booktitle = {IEEE International Conference on Computer Vision (ICCV)}, 24 | pages = {5239--5248}, 25 | year = {2019} 26 | } 27 | 28 | ## Usage: Preparation 29 | 30 | - Requirement 31 | 32 | - Ubuntu 14.04 33 | - Python 3 (recommend Anaconda3) 34 | - Pytorch 0.3.\* 35 | - CMake > 2.8 36 | - CUDA 8.0 + cuDNN 5.1 37 | 38 | - Building Kernel 39 | 40 | git clone https://github.com/Yochengliu/DensePoint.git 41 | cd DensePoint 42 | mkdir build && cd build 43 | cmake .. && make 44 | 45 | - Dataset 46 | - Shape Classification: download and unzip [ModelNet40](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip) (415M). Replace `$data_root$` in `cfgs/config_cls.yaml` with the dataset parent path. 47 | 48 | ## Usage: Training 49 | - Shape Classification 50 | 51 | sh train_cls.sh 52 | 53 | We have trained a 6-layer classification model in `cls` folder, whose accuracy is 92.38%. 54 | 55 | ## Usage: Evaluation 56 | - Shape Classification 57 | 58 | Voting script: voting_evaluate_cls.py 59 | 60 | You can use our model `cls/model_cls_L6_iter_36567_acc_0.923825.pth` as the checkpoint in `config_cls.yaml`, and after this voting you will get an accuracy of 92.5% if all things go right. 61 | 62 | ## License 63 | 64 | The code is released under MIT License (see LICENSE file for details). 65 | 66 | ## Acknowledgement 67 | 68 | The code is heavily borrowed from [Pointnet2_PyTorch](https://github.com/erikwijmans/Pointnet2_PyTorch). 69 | 70 | ## Contact 71 | 72 | If you have some ideas or questions about our research to share with us, please contact 73 | -------------------------------------------------------------------------------- /cfgs/config_cls.yaml: -------------------------------------------------------------------------------- 1 | common: 2 | workers: 4 3 | 4 | num_points: 1024 5 | num_classes: 40 6 | batch_size: 32 7 | 8 | base_lr: 0.001 9 | lr_clip: 0.00001 10 | lr_decay: 0.7 11 | decay_step: 21 12 | epochs: 200 13 | 14 | weight_decay: 0 15 | bn_momentum: 0.9 16 | bnm_clip: 0.01 17 | bn_decay: 0.5 18 | 19 | evaluate: 1 20 | val_freq_epoch: 0.5 # frequency in epoch for validation, can be decimal 21 | print_freq_iter: 20 # frequency in iteration for printing infomation 22 | 23 | input_channels: 0 # feature channels except (x, y, z) 24 | 25 | checkpoint: '' # the model to start from 26 | save_path: cls 27 | data_root: $data_root$ -------------------------------------------------------------------------------- /cls/model_cls_L6_iter_36567_acc_0.923825.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yochengliu/DensePoint/2a9393402f9f60d05a1735e78c4eced9f10015d9/cls/model_cls_L6_iter_36567_acc_0.923825.pth -------------------------------------------------------------------------------- /data/ModelNet40Loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import os, sys, h5py 5 | 6 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 7 | sys.path.append(BASE_DIR) 8 | 9 | def _get_data_files(list_filename): 10 | with open(list_filename) as f: 11 | return [line.rstrip()[5:] for line in f] 12 | 13 | def _load_data_file(name): 14 | f = h5py.File(name) 15 | data = f['data'][:] 16 | label = f['label'][:] 17 | return data, label 18 | 19 | class ModelNet40Cls(data.Dataset): 20 | 21 | def __init__( 22 | self, num_points, root, transforms=None, train=True 23 | ): 24 | super().__init__() 25 | 26 | self.transforms = transforms 27 | 28 | root = os.path.abspath(root) 29 | self.folder = "modelnet40_ply_hdf5_2048" 30 | self.data_dir = os.path.join(root, self.folder) 31 | 32 | self.train, self.num_points = train, num_points 33 | if self.train: 34 | self.files = _get_data_files( \ 35 | os.path.join(self.data_dir, 'train_files.txt')) 36 | else: 37 | self.files = _get_data_files( \ 38 | os.path.join(self.data_dir, 'test_files.txt')) 39 | 40 | point_list, label_list = [], [] 41 | for f in self.files: 42 | points, labels = _load_data_file(os.path.join(root, f)) 43 | point_list.append(points) 44 | label_list.append(labels) 45 | 46 | self.points = np.concatenate(point_list, 0) 47 | self.labels = np.concatenate(label_list, 0) 48 | 49 | def __getitem__(self, idx): 50 | pt_idxs = np.arange(0, self.points.shape[1]) # 2048 51 | if self.train: 52 | np.random.shuffle(pt_idxs) 53 | 54 | current_points = self.points[idx, pt_idxs].copy() 55 | label = torch.from_numpy(self.labels[idx]).type(torch.LongTensor) 56 | 57 | if self.transforms is not None: 58 | current_points = self.transforms(current_points) 59 | 60 | return current_points, label 61 | 62 | def __len__(self): 63 | return self.points.shape[0] 64 | 65 | if __name__ == "__main__": 66 | from torchvision import transforms 67 | import data_utils as d_utils 68 | 69 | transforms = transforms.Compose([ 70 | d_utils.PointcloudToTensor(), 71 | d_utils.PointcloudRotate(axis=np.array([1,0,0])), 72 | d_utils.PointcloudScale(), 73 | d_utils.PointcloudTranslate(), 74 | d_utils.PointcloudJitter() 75 | ]) 76 | dset = ModelNet40Cls(16, "./", train=True, transforms=transforms) 77 | print(dset[0][0]) 78 | print(dset[0][1]) 79 | print(len(dset)) 80 | dloader = torch.utils.data.DataLoader(dset, batch_size=32, shuffle=True) 81 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .ModelNet40Loader import ModelNet40Cls -------------------------------------------------------------------------------- /data/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class PointcloudToTensor(object): 5 | def __call__(self, points): 6 | return torch.from_numpy(points).float() 7 | 8 | def angle_axis(angle: float, axis: np.ndarray): 9 | r"""Returns a 4x4 rotation matrix that performs a rotation around axis by angle 10 | 11 | Parameters 12 | ---------- 13 | angle : float 14 | Angle to rotate by 15 | axis: np.ndarray 16 | Axis to rotate about 17 | 18 | Returns 19 | ------- 20 | torch.Tensor 21 | 3x3 rotation matrix 22 | """ 23 | u = axis / np.linalg.norm(axis) 24 | cosval, sinval = np.cos(angle), np.sin(angle) 25 | 26 | # yapf: disable 27 | cross_prod_mat = np.array([[0.0, -u[2], u[1]], 28 | [u[2], 0.0, -u[0]], 29 | [-u[1], u[0], 0.0]]) 30 | 31 | R = torch.from_numpy( 32 | cosval * np.eye(3) 33 | + sinval * cross_prod_mat 34 | + (1.0 - cosval) * np.outer(u, u) 35 | ) 36 | # yapf: enable 37 | return R.float() 38 | 39 | class PointcloudRotatebyAngle(object): 40 | def __init__(self, rotation_angle = 0.0): 41 | self.rotation_angle = rotation_angle 42 | 43 | def __call__(self, pc): 44 | normals = pc.size(2) > 3 45 | bsize = pc.size(0) 46 | for i in range(bsize): 47 | cosval = np.cos(self.rotation_angle) 48 | sinval = np.sin(self.rotation_angle) 49 | rotation_matrix = np.array([[cosval, 0, sinval], 50 | [0, 1, 0], 51 | [-sinval, 0, cosval]]) 52 | rotation_matrix = torch.from_numpy(rotation_matrix).float().cuda() 53 | 54 | cur_pc = pc[i, :, :] 55 | if not normals: 56 | cur_pc = cur_pc @ rotation_matrix 57 | else: 58 | pc_xyz = cur_pc[:, 0:3] 59 | pc_normals = cur_pc[:, 3:] 60 | cur_pc[:, 0:3] = pc_xyz @ rotation_matrix 61 | cur_pc[:, 3:] = pc_normals @ rotation_matrix 62 | 63 | pc[i, :, :] = cur_pc 64 | 65 | return pc 66 | 67 | class PointcloudJitter(object): 68 | def __init__(self, std=0.01, clip=0.05): 69 | self.std, self.clip = std, clip 70 | 71 | def __call__(self, pc): 72 | bsize = pc.size(0) 73 | for i in range(bsize): 74 | jittered_data = pc.new(pc.size(1), 3).normal_( 75 | mean=0.0, std=self.std 76 | ).clamp_(-self.clip, self.clip) 77 | pc[i, :, 0:3] += jittered_data 78 | 79 | return pc 80 | 81 | class PointcloudScaleAndTranslate(object): 82 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., translate_range=0.2): 83 | self.scale_low = scale_low 84 | self.scale_high = scale_high 85 | self.translate_range = translate_range 86 | 87 | def __call__(self, pc): 88 | bsize = pc.size(0) 89 | for i in range(bsize): 90 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 91 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 92 | 93 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) + torch.from_numpy(xyz2).float().cuda() 94 | 95 | return pc 96 | 97 | class PointcloudScale(object): 98 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2.): 99 | self.scale_low = scale_low 100 | self.scale_high = scale_high 101 | 102 | def __call__(self, pc): 103 | bsize = pc.size(0) 104 | for i in range(bsize): 105 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 106 | 107 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], torch.from_numpy(xyz1).float().cuda()) 108 | 109 | return pc 110 | 111 | class PointcloudTranslate(object): 112 | def __init__(self, translate_range=0.2): 113 | self.translate_range = translate_range 114 | 115 | def __call__(self, pc): 116 | bsize = pc.size(0) 117 | for i in range(bsize): 118 | xyz2 = np.random.uniform(low=-self.translate_range, high=self.translate_range, size=[3]) 119 | 120 | pc[i, :, 0:3] = pc[i, :, 0:3] + torch.from_numpy(xyz2).float().cuda() 121 | 122 | return pc 123 | 124 | class PointcloudRandomInputDropout(object): 125 | def __init__(self, max_dropout_ratio=0.875): 126 | assert max_dropout_ratio >= 0 and max_dropout_ratio < 1 127 | self.max_dropout_ratio = max_dropout_ratio 128 | 129 | def __call__(self, pc): 130 | bsize = pc.size(0) 131 | for i in range(bsize): 132 | dropout_ratio = np.random.random() * self.max_dropout_ratio # 0~0.875 133 | drop_idx = np.where(np.random.random((pc.size()[1])) <= dropout_ratio)[0] 134 | if len(drop_idx) > 0: 135 | cur_pc = pc[i, :, :] 136 | cur_pc[drop_idx.tolist(), 0:3] = cur_pc[0, 0:3].repeat(len(drop_idx), 1) # set to the first point 137 | pc[i, :, :] = cur_pc 138 | 139 | return pc 140 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densepoint_cls_L6_k24_g2 import DensePoint as DensePointCls_L6 2 | 3 | -------------------------------------------------------------------------------- /models/densepoint_cls_L6_k24_g2.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 3 | sys.path.append(BASE_DIR) 4 | sys.path.append(os.path.join(BASE_DIR, "../utils")) 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import pytorch_utils as pt_utils 9 | from pointnet2_modules import PointnetSAModule, PointnetSAModuleMSG 10 | import numpy as np 11 | 12 | # DensePoint: 2 PPools + 3 PConvs + 1 global pool; narrowness k = 24; group number g = 2 13 | class DensePoint(nn.Module): 14 | r""" 15 | PointNet2 with multi-scale grouping 16 | Semantic segmentation network that uses feature propogation layers 17 | 18 | Parameters 19 | ---------- 20 | num_classes: int 21 | Number of semantics classes to predict over -- size of softmax classifier that run for each point 22 | input_channels: int = 6 23 | Number of input channels in the feature descriptor for each point. If the point cloud is Nx9, this 24 | value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors 25 | use_xyz: bool = True 26 | Whether or not to use the xyz position of a point as a feature 27 | """ 28 | 29 | def __init__(self, num_classes, input_channels=0, use_xyz=True): 30 | super().__init__() 31 | 32 | self.SA_modules = nn.ModuleList() 33 | 34 | # stage 1 begin 35 | self.SA_modules.append( 36 | PointnetSAModuleMSG( 37 | npoint=512, 38 | radii=[0.25], 39 | nsamples=[64], 40 | mlps=[[input_channels, 96]], 41 | use_xyz=use_xyz, 42 | pool=True 43 | ) 44 | ) 45 | # stage 1 end 46 | 47 | # stage 2 begin 48 | input_channels = 96 49 | self.SA_modules.append( 50 | PointnetSAModuleMSG( 51 | npoint=128, 52 | radii=[0.32], 53 | nsamples=[64], 54 | mlps=[[input_channels, 93]], 55 | use_xyz=use_xyz, 56 | pool=True 57 | ) 58 | ) 59 | 60 | input_channels = 93 61 | self.SA_modules.append( 62 | PointnetSAModuleMSG( 63 | npoint=128, 64 | radii=[0.39], 65 | nsamples=[16], 66 | mlps=[[input_channels, 96]], 67 | group_number=2, 68 | use_xyz=use_xyz, 69 | after_pool=True 70 | ) 71 | ) 72 | 73 | input_channels = 117 74 | self.SA_modules.append( 75 | PointnetSAModuleMSG( 76 | npoint=128, 77 | radii=[0.39], 78 | nsamples=[16], 79 | mlps=[[input_channels, 96]], 80 | group_number=2, 81 | use_xyz=use_xyz 82 | ) 83 | ) 84 | 85 | input_channels = 141 86 | self.SA_modules.append( 87 | PointnetSAModuleMSG( 88 | npoint=128, 89 | radii=[0.39], 90 | nsamples=[16], 91 | mlps=[[input_channels, 96]], 92 | group_number=2, 93 | use_xyz=use_xyz, 94 | before_pool=True 95 | ) 96 | ) 97 | # stage 2 end 98 | 99 | # global pooling 100 | input_channels = 165 101 | self.SA_modules.append( 102 | PointnetSAModule( 103 | mlp=[input_channels, 512], use_xyz=use_xyz 104 | ) 105 | ) 106 | 107 | self.FC_layer = nn.Sequential( 108 | pt_utils.FC(512, 512, activation=nn.ReLU(inplace=True), bn=True), 109 | nn.Dropout(p=0.5), 110 | pt_utils.FC(512, 256, activation=nn.ReLU(inplace=True), bn=True), 111 | nn.Dropout(p=0.5), 112 | pt_utils.FC(256, num_classes, activation=None) 113 | ) 114 | 115 | def _break_up_pc(self, pc): 116 | xyz = pc[..., 0:3].contiguous() 117 | features = ( 118 | pc[..., 3:].transpose(1, 2).contiguous() 119 | if pc.size(-1) > 3 else None 120 | ) 121 | return xyz, features 122 | 123 | def forward(self, pointcloud: torch.cuda.FloatTensor): 124 | r""" 125 | Forward pass of the network 126 | 127 | Parameters 128 | ---------- 129 | pointcloud: Variable(torch.cuda.FloatTensor) 130 | (B, N, 3 + input_channels) tensor 131 | Point cloud to run predicts on 132 | Each point in the point-cloud MUST 133 | be formated as (x, y, z, features...) 134 | """ 135 | xyz, features = self._break_up_pc(pointcloud) 136 | for module in self.SA_modules: 137 | xyz, features = module(xyz, features) 138 | 139 | return self.FC_layer(features.squeeze(-1)) 140 | 141 | if __name__ == "__main__": 142 | sim_data = Variable(torch.rand(32, 2048, 6)) 143 | sim_data = sim_data.cuda() 144 | sim_cls = Variable(torch.ones(32, 16)) 145 | sim_cls = sim_cls.cuda() 146 | 147 | seg = Pointnet2MSG(num_classes=50, input_channels=3, use_xyz=True) 148 | seg = seg.cuda() 149 | out = seg(sim_data, sim_cls) 150 | print('seg', out.size()) 151 | -------------------------------------------------------------------------------- /train_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.optim.lr_scheduler as lr_sched 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import os 9 | from torchvision import transforms 10 | from models import DensePointCls_L6 as DensePoint 11 | from data import ModelNet40Cls 12 | import utils.pytorch_utils as pt_utils 13 | import utils.pointnet2_utils as pointnet2_utils 14 | import data.data_utils as d_utils 15 | import argparse 16 | import random 17 | import yaml 18 | 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.benchmark = True 21 | torch.backends.cudnn.deterministic = True 22 | 23 | seed = 123 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | parser = argparse.ArgumentParser(description='DensePoint Shape Classification Training') 31 | parser.add_argument('--config', default='cfgs/config_cls.yaml', type=str) 32 | 33 | def main(): 34 | args = parser.parse_args() 35 | with open(args.config) as f: 36 | config = yaml.load(f) 37 | print("\n**************************") 38 | for k, v in config['common'].items(): 39 | setattr(args, k, v) 40 | print('\n[%s]:'%(k), v) 41 | print("\n**************************\n") 42 | 43 | try: 44 | os.makedirs(args.save_path) 45 | except OSError: 46 | pass 47 | 48 | train_transforms = transforms.Compose([ 49 | d_utils.PointcloudToTensor() 50 | ]) 51 | test_transforms = transforms.Compose([ 52 | d_utils.PointcloudToTensor() 53 | ]) 54 | 55 | train_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=train_transforms) 56 | train_dataloader = DataLoader( 57 | train_dataset, 58 | batch_size=args.batch_size, 59 | shuffle=True, 60 | num_workers=int(args.workers), 61 | pin_memory=True 62 | ) 63 | 64 | test_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=test_transforms, train=False) 65 | test_dataloader = DataLoader( 66 | test_dataset, 67 | batch_size=args.batch_size, 68 | shuffle=False, 69 | num_workers=int(args.workers), 70 | pin_memory=True 71 | ) 72 | 73 | model = DensePoint(num_classes = args.num_classes, input_channels = args.input_channels, use_xyz = True) 74 | model.cuda() 75 | optimizer = optim.Adam( 76 | model.parameters(), lr=args.base_lr, weight_decay=args.weight_decay) 77 | 78 | lr_lbmd = lambda e: max(args.lr_decay**(e // args.decay_step), args.lr_clip / args.base_lr) 79 | bnm_lmbd = lambda e: max(args.bn_momentum * args.bn_decay**(e // args.decay_step), args.bnm_clip) 80 | lr_scheduler = lr_sched.LambdaLR(optimizer, lr_lbmd) 81 | bnm_scheduler = pt_utils.BNMomentumScheduler(model, bnm_lmbd) 82 | 83 | if args.checkpoint is not '': 84 | model.load_state_dict(torch.load(args.checkpoint)) 85 | print('Load model successfully: %s' % (args.checkpoint)) 86 | 87 | criterion = nn.CrossEntropyLoss() 88 | num_batch = len(train_dataset)/args.batch_size 89 | 90 | # training 91 | train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch) 92 | 93 | 94 | def train(train_dataloader, test_dataloader, model, criterion, optimizer, lr_scheduler, bnm_scheduler, args, num_batch): 95 | PointcloudScaleAndTranslate = d_utils.PointcloudScaleAndTranslate() # initialize augmentation 96 | global g_acc 97 | g_acc = 0.91 # only save the model whose acc > 0.91 98 | batch_count = 0 99 | model.train() 100 | for epoch in range(args.epochs): 101 | for i, data in enumerate(train_dataloader, 0): 102 | if lr_scheduler is not None: 103 | lr_scheduler.step(epoch) 104 | if bnm_scheduler is not None: 105 | bnm_scheduler.step(epoch-1) 106 | points, target = data 107 | points, target = points.cuda(), target.cuda() 108 | points, target = Variable(points), Variable(target) 109 | 110 | # farthest point sampling 111 | fps_idx = pointnet2_utils.furthest_point_sample(points, 1200) # (B, npoint) 112 | fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)] 113 | points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() # (B, N, 3) 114 | 115 | # augmentation 116 | points.data = PointcloudScaleAndTranslate(points.data) 117 | 118 | optimizer.zero_grad() 119 | 120 | pred = model(points) 121 | target = target.view(-1) 122 | loss = criterion(pred, target) 123 | loss.backward() 124 | optimizer.step() 125 | if i % args.print_freq_iter == 0: 126 | print('[epoch %3d: %3d/%3d] \t train loss: %0.6f \t lr: %0.5f' %(epoch+1, i, num_batch, loss.data.clone(), lr_scheduler.get_lr()[0])) 127 | batch_count += 1 128 | 129 | # validation in between an epoch 130 | if args.evaluate and batch_count % int(args.val_freq_epoch * num_batch) == 0: 131 | validate(test_dataloader, model, criterion, args, batch_count) 132 | 133 | 134 | def validate(test_dataloader, model, criterion, args, iter): 135 | global g_acc 136 | model.eval() 137 | losses, preds, labels = [], [], [] 138 | for j, data in enumerate(test_dataloader, 0): 139 | points, target = data 140 | points, target = points.cuda(), target.cuda() 141 | points, target = Variable(points, volatile=True), Variable(target, volatile=True) 142 | 143 | # farthest point sampling 144 | fps_idx = pointnet2_utils.furthest_point_sample(points, args.num_points) # (B, npoint) 145 | # fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)] 146 | points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() 147 | 148 | pred = model(points) 149 | target = target.view(-1) 150 | loss = criterion(pred, target) 151 | losses.append(loss.data.clone()) 152 | _, pred_choice = torch.max(pred.data, -1) 153 | 154 | preds.append(pred_choice) 155 | labels.append(target.data) 156 | 157 | preds = torch.cat(preds, 0) 158 | labels = torch.cat(labels, 0) 159 | acc = (preds == labels).sum() / labels.numel() 160 | print('\nval loss: %0.6f \t acc: %0.6f\n' %(np.array(losses).mean(), acc)) 161 | if acc > g_acc: 162 | g_acc = acc 163 | torch.save(model.state_dict(), '%s/cls_iter_%d_acc_%0.6f.pth' % (args.save_path, iter, acc)) 164 | model.train() 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /train_cls.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | mkdir -p log 3 | now=$(date +"%Y%m%d_%H%M%S") 4 | log_name="Cls_LOG_"$now"" 5 | export CUDA_VISIBLE_DEVICES=0 6 | python -u train_cls.py \ 7 | --config cfgs/config_cls.yaml \ 8 | 2>&1|tee log/$log_name.log & 9 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yochengliu/DensePoint/2a9393402f9f60d05a1735e78c4eced9f10015d9/utils/__init__.py -------------------------------------------------------------------------------- /utils/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yochengliu/DensePoint/2a9393402f9f60d05a1735e78c4eced9f10015d9/utils/_ext/__init__.py -------------------------------------------------------------------------------- /utils/_ext/pointnet2/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._pointnet2 import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /utils/build_ffi.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | import os.path as osp 4 | from torch.utils.ffi import create_extension 5 | import sys, argparse, shutil 6 | 7 | base_dir = osp.dirname(osp.abspath(__file__)) 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser( 12 | description="Arguments for building pointnet2 ffi extension" 13 | ) 14 | parser.add_argument("--objs", nargs="*") 15 | clean_arg = parser.add_mutually_exclusive_group() 16 | clean_arg.add_argument("--build", dest='build', action="store_true") 17 | clean_arg.add_argument("--clean", dest='clean', action="store_true") 18 | parser.set_defaults(build=False, clean=False) 19 | 20 | args = parser.parse_args() 21 | assert args.build or args.clean 22 | 23 | return args 24 | 25 | 26 | def build(args): 27 | extra_objects = args.objs 28 | extra_objects += [a for a in glob.glob('/usr/local/cuda/lib64/*.a')] 29 | 30 | ffi = create_extension( 31 | '_ext.pointnet2', 32 | headers=[a for a in glob.glob("cinclude/*_wrapper.h")], 33 | sources=[a for a in glob.glob("csrc/*.c")], 34 | define_macros=[('WITH_CUDA', None)], 35 | relative_to=__file__, 36 | with_cuda=True, 37 | extra_objects=extra_objects, 38 | include_dirs=[osp.join(base_dir, 'cinclude')], 39 | verbose=False, 40 | package=False 41 | ) 42 | ffi.build() 43 | 44 | 45 | def clean(args): 46 | shutil.rmtree(osp.join(base_dir, "_ext")) 47 | 48 | 49 | if __name__ == "__main__": 50 | args = parse_args() 51 | if args.clean: 52 | clean(args) 53 | else: 54 | build(args) 55 | -------------------------------------------------------------------------------- /utils/cinclude/ball_query_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU 2 | #define _BALL_QUERY_GPU 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 9 | int nsample, const float *xyz, 10 | const float *new_xyz, int *idx, 11 | cudaStream_t stream); 12 | 13 | #ifdef __cplusplus 14 | } 15 | #endif 16 | #endif 17 | -------------------------------------------------------------------------------- /utils/cinclude/ball_query_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | int ball_query_wrapper(int b, int n, int m, float radius, int nsample, 3 | THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor, 4 | THCudaIntTensor *idx_tensor); 5 | -------------------------------------------------------------------------------- /utils/cinclude/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef _CUDA_UTILS_H 2 | #define _CUDA_UTILS_H 3 | 4 | #include 5 | 6 | #define TOTAL_THREADS 512 7 | 8 | inline int opt_n_threads(int work_size) { 9 | const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); 10 | 11 | return max(min(1 << pow_2, TOTAL_THREADS), 1); 12 | } 13 | 14 | inline dim3 opt_block_config(int x, int y) { 15 | const int x_threads = opt_n_threads(x); 16 | const int y_threads = 17 | max(min(opt_n_threads(y), TOTAL_THREADS / x_threads), 1); 18 | dim3 block_config(x_threads, y_threads, 1); 19 | 20 | return block_config; 21 | } 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /utils/cinclude/group_points_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _BALL_QUERY_GPU 2 | #define _BALL_QUERY_GPU 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 9 | const float *points, const int *idx, 10 | float *out, cudaStream_t stream); 11 | 12 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | int nsample, const float *grad_out, 14 | const int *idx, float *grad_points, 15 | cudaStream_t stream); 16 | #ifdef __cplusplus 17 | } 18 | #endif 19 | #endif 20 | -------------------------------------------------------------------------------- /utils/cinclude/group_points_wrapper.h: -------------------------------------------------------------------------------- 1 | int group_points_wrapper(int b, int c, int n, int npoints, int nsample, 2 | THCudaTensor *points_tensor, 3 | THCudaIntTensor *idx_tensor, THCudaTensor *out); 4 | int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, 5 | THCudaTensor *grad_out_tensor, 6 | THCudaIntTensor *idx_tensor, 7 | THCudaTensor *grad_points_tensor); 8 | -------------------------------------------------------------------------------- /utils/cinclude/interpolate_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _INTERPOLATE_GPU_H 2 | #define _INTERPOLATE_GPU_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 9 | const float *known, float *dist2, int *idx, 10 | cudaStream_t stream); 11 | 12 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 13 | const float *points, const int *idx, 14 | const float *weight, float *out, 15 | cudaStream_t stream); 16 | 17 | void three_interpolate_grad_kernel_wrapper(int b, int c, int n, int m, 18 | const float *grad_out, 19 | const int *idx, const float *weight, 20 | float *grad_points, 21 | cudaStream_t stream); 22 | 23 | #ifdef __cplusplus 24 | } 25 | #endif 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /utils/cinclude/interpolate_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor, 4 | THCudaTensor *known_tensor, THCudaTensor *dist2_tensor, 5 | THCudaIntTensor *idx_tensor); 6 | void three_interpolate_wrapper(int b, int c, int m, int n, 7 | THCudaTensor *points_tensor, 8 | THCudaIntTensor *idx_tensor, 9 | THCudaTensor *weight_tensor, 10 | THCudaTensor *out_tensor); 11 | 12 | void three_interpolate_grad_wrapper(int b, int c, int n, int m, 13 | THCudaTensor *grad_out_tensor, 14 | THCudaIntTensor *idx_tensor, 15 | THCudaTensor *weight_tensor, 16 | THCudaTensor *grad_points_tensor); 17 | -------------------------------------------------------------------------------- /utils/cinclude/sampling_gpu.h: -------------------------------------------------------------------------------- 1 | #ifndef _SAMPLING_GPU_H 2 | #define _SAMPLING_GPU_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 9 | const float *points, const int *idx, 10 | float *out, cudaStream_t stream); 11 | 12 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 13 | const float *grad_out, const int *idx, 14 | float *grad_points, cudaStream_t stream); 15 | 16 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 17 | const float *dataset, float *temp, 18 | int *idxs, cudaStream_t stream); 19 | 20 | #ifdef __cplusplus 21 | } 22 | #endif 23 | #endif 24 | -------------------------------------------------------------------------------- /utils/cinclude/sampling_wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | int gather_points_wrapper(int b, int c, int n, int npoints, 3 | THCudaTensor *points_tensor, 4 | THCudaIntTensor *idx_tensor, 5 | THCudaTensor *out_tensor); 6 | int gather_points_grad_wrapper(int b, int c, int n, int npoints, 7 | THCudaTensor *grad_out_tensor, 8 | THCudaIntTensor *idx_tensor, 9 | THCudaTensor *grad_points_tensor); 10 | 11 | int furthest_point_sampling_wrapper(int b, int n, int m, 12 | THCudaTensor *points_tensor, 13 | THCudaTensor *temp_tensor, 14 | THCudaIntTensor *idx_tensor); 15 | -------------------------------------------------------------------------------- /utils/csrc/ball_query.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "ball_query_gpu.h" 4 | 5 | extern THCState *state; 6 | 7 | int ball_query_wrapper(int b, int n, int m, float radius, int nsample, 8 | THCudaTensor *new_xyz_tensor, THCudaTensor *xyz_tensor, 9 | THCudaIntTensor *idx_tensor) { 10 | 11 | const float *new_xyz = THCudaTensor_data(state, new_xyz_tensor); 12 | const float *xyz = THCudaTensor_data(state, xyz_tensor); 13 | int *idx = THCudaIntTensor_data(state, idx_tensor); 14 | 15 | cudaStream_t stream = THCState_getCurrentStream(state); 16 | 17 | query_ball_point_kernel_wrapper(b, n, m, radius, nsample, new_xyz, xyz, idx, 18 | stream); 19 | return 1; 20 | } 21 | -------------------------------------------------------------------------------- /utils/csrc/ball_query_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "ball_query_gpu.h" 6 | #include "cuda_utils.h" 7 | 8 | // input: new_xyz(b, m, 3) xyz(b, n, 3) 9 | // output: idx(b, m, nsample) 10 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 11 | int nsample, 12 | const float *__restrict__ new_xyz, 13 | const float *__restrict__ xyz, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | xyz += batch_index * n * 3; 17 | new_xyz += batch_index * m * 3; 18 | idx += m * nsample * batch_index; 19 | 20 | int index = threadIdx.x; 21 | int stride = blockDim.x; 22 | 23 | float radius2 = radius * radius; 24 | for (int j = index; j < m; j += stride) { 25 | float new_x = new_xyz[j * 3 + 0]; 26 | float new_y = new_xyz[j * 3 + 1]; 27 | float new_z = new_xyz[j * 3 + 2]; 28 | for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) { 29 | float x = xyz[k * 3 + 0]; 30 | float y = xyz[k * 3 + 1]; 31 | float z = xyz[k * 3 + 2]; 32 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 33 | (new_z - z) * (new_z - z); 34 | if (d2 < radius2) { 35 | if (cnt == 0) { 36 | for (int l = 0; l < nsample; ++l) { 37 | idx[j * nsample + l] = k; 38 | } 39 | } 40 | idx[j * nsample + cnt] = k; 41 | ++cnt; 42 | } 43 | } 44 | } 45 | } 46 | 47 | void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, 48 | int nsample, const float *new_xyz, 49 | const float *xyz, int *idx, 50 | cudaStream_t stream) { 51 | 52 | cudaError_t err; 53 | query_ball_point_kernel<<>>( 54 | b, n, m, radius, nsample, new_xyz, xyz, idx); 55 | 56 | err = cudaGetLastError(); 57 | if (cudaSuccess != err) { 58 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 59 | exit(-1); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /utils/csrc/group_points.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "group_points_gpu.h" 4 | 5 | extern THCState *state; 6 | 7 | int group_points_wrapper(int b, int c, int n, int npoints, int nsample, 8 | THCudaTensor *points_tensor, 9 | THCudaIntTensor *idx_tensor, 10 | THCudaTensor *out_tensor) { 11 | 12 | const float *points = THCudaTensor_data(state, points_tensor); 13 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 14 | float *out = THCudaTensor_data(state, out_tensor); 15 | 16 | cudaStream_t stream = THCState_getCurrentStream(state); 17 | 18 | group_points_kernel_wrapper(b, c, n, npoints, nsample, points, idx, out, 19 | stream); 20 | return 1; 21 | } 22 | 23 | int group_points_grad_wrapper(int b, int c, int n, int npoints, int nsample, 24 | THCudaTensor *grad_out_tensor, 25 | THCudaIntTensor *idx_tensor, 26 | THCudaTensor *grad_points_tensor) { 27 | 28 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 29 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 30 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 31 | 32 | cudaStream_t stream = THCState_getCurrentStream(state); 33 | 34 | group_points_grad_kernel_wrapper(b, c, n, npoints, nsample, grad_out, idx, 35 | grad_points, stream); 36 | return 1; 37 | } 38 | -------------------------------------------------------------------------------- /utils/csrc/group_points_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "group_points_gpu.h" 6 | 7 | // input: points(b, c, n) idx(b, npoints, nsample) 8 | // output: out(b, c, npoints, nsample) 9 | __global__ void group_points_kernel(int b, int c, int n, int npoints, 10 | int nsample, 11 | const float *__restrict__ points, 12 | const int *__restrict__ idx, 13 | float *__restrict__ out) { 14 | int batch_index = blockIdx.x; 15 | points += batch_index * n * c; 16 | idx += batch_index * npoints * nsample; 17 | out += batch_index * npoints * nsample * c; 18 | 19 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 20 | const int stride = blockDim.y * blockDim.x; 21 | for (int i = index; i < c * npoints; i += stride) { 22 | const int l = i / npoints; 23 | const int j = i % npoints; 24 | for (int k = 0; k < nsample; ++k) { 25 | int ii = idx[j * nsample + k]; 26 | out[(l * npoints + j) * nsample + k] = points[l * n + ii]; 27 | } 28 | } 29 | } 30 | 31 | void group_points_kernel_wrapper(int b, int c, int n, int npoints, int nsample, 32 | const float *points, const int *idx, 33 | float *out, cudaStream_t stream) { 34 | 35 | cudaError_t err; 36 | group_points_kernel<<>>( 37 | b, c, n, npoints, nsample, points, idx, out); 38 | 39 | err = cudaGetLastError(); 40 | if (cudaSuccess != err) { 41 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 42 | exit(-1); 43 | } 44 | } 45 | 46 | // input: grad_out(b, c, npoints, nsample), idx(b, npoints, nsample) 47 | // output: grad_points(b, c, n) 48 | __global__ void group_points_grad_kernel(int b, int c, int n, int npoints, 49 | int nsample, 50 | const float *__restrict__ grad_out, 51 | const int *__restrict__ idx, 52 | float *__restrict__ grad_points) { 53 | int batch_index = blockIdx.x; 54 | grad_out += batch_index * npoints * nsample * c; 55 | idx += batch_index * npoints * nsample; 56 | grad_points += batch_index * n * c; 57 | 58 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 59 | const int stride = blockDim.y * blockDim.x; 60 | for (int i = index; i < c * npoints; i += stride) { 61 | const int l = i / npoints; 62 | const int j = i % npoints; 63 | for (int k = 0; k < nsample; ++k) { 64 | int ii = idx[j * nsample + k]; 65 | atomicAdd(grad_points + l * n + ii, 66 | grad_out[(l * npoints + j) * nsample + k]); 67 | } 68 | } 69 | } 70 | 71 | void group_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 72 | int nsample, const float *grad_out, 73 | const int *idx, float *grad_points, 74 | cudaStream_t stream) { 75 | cudaError_t err; 76 | group_points_grad_kernel<<>>( 77 | b, c, n, npoints, nsample, grad_out, idx, grad_points); 78 | 79 | err = cudaGetLastError(); 80 | if (cudaSuccess != err) { 81 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 82 | exit(-1); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /utils/csrc/interpolate.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "interpolate_gpu.h" 7 | 8 | extern THCState *state; 9 | 10 | void three_nn_wrapper(int b, int n, int m, THCudaTensor *unknown_tensor, 11 | THCudaTensor *known_tensor, THCudaTensor *dist2_tensor, 12 | THCudaIntTensor *idx_tensor) { 13 | const float *unknown = THCudaTensor_data(state, unknown_tensor); 14 | const float *known = THCudaTensor_data(state, known_tensor); 15 | float *dist2 = THCudaTensor_data(state, dist2_tensor); 16 | int *idx = THCudaIntTensor_data(state, idx_tensor); 17 | 18 | cudaStream_t stream = THCState_getCurrentStream(state); 19 | three_nn_kernel_wrapper(b, n, m, unknown, known, dist2, idx, stream); 20 | } 21 | 22 | void three_interpolate_wrapper(int b, int c, int m, int n, 23 | THCudaTensor *points_tensor, 24 | THCudaIntTensor *idx_tensor, 25 | THCudaTensor *weight_tensor, 26 | THCudaTensor *out_tensor) { 27 | 28 | const float *points = THCudaTensor_data(state, points_tensor); 29 | const float *weight = THCudaTensor_data(state, weight_tensor); 30 | float *out = THCudaTensor_data(state, out_tensor); 31 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 32 | 33 | cudaStream_t stream = THCState_getCurrentStream(state); 34 | three_interpolate_kernel_wrapper(b, c, m, n, points, idx, weight, out, 35 | stream); 36 | } 37 | 38 | void three_interpolate_grad_wrapper(int b, int c, int n, int m, 39 | THCudaTensor *grad_out_tensor, 40 | THCudaIntTensor *idx_tensor, 41 | THCudaTensor *weight_tensor, 42 | THCudaTensor *grad_points_tensor) { 43 | 44 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 45 | const float *weight = THCudaTensor_data(state, weight_tensor); 46 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 47 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 48 | 49 | cudaStream_t stream = THCState_getCurrentStream(state); 50 | three_interpolate_grad_kernel_wrapper(b, c, n, m, grad_out, idx, weight, 51 | grad_points, stream); 52 | } 53 | -------------------------------------------------------------------------------- /utils/csrc/interpolate_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "cuda_utils.h" 6 | #include "interpolate_gpu.h" 7 | 8 | // input: unknown(b, n, 3) known(b, m, 3) 9 | // output: dist2(b, n, 3), idx(b, n, 3) 10 | __global__ void three_nn_kernel(int b, int n, int m, 11 | const float *__restrict__ unknown, 12 | const float *__restrict__ known, 13 | float *__restrict__ dist2, 14 | int *__restrict__ idx) { 15 | int batch_index = blockIdx.x; 16 | unknown += batch_index * n * 3; 17 | known += batch_index * m * 3; 18 | dist2 += batch_index * n * 3; 19 | idx += batch_index * n * 3; 20 | 21 | int index = threadIdx.x; 22 | int stride = blockDim.x; 23 | for (int j = index; j < n; j += stride) { 24 | float ux = unknown[j * 3 + 0]; 25 | float uy = unknown[j * 3 + 1]; 26 | float uz = unknown[j * 3 + 2]; 27 | 28 | double best1 = 1e40, best2 = 1e40, best3 = 1e40; 29 | int besti1 = 0, besti2 = 0, besti3 = 0; 30 | for (int k = 0; k < m; ++k) { 31 | float x = known[k * 3 + 0]; 32 | float y = known[k * 3 + 1]; 33 | float z = known[k * 3 + 2]; 34 | float d = 35 | (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z); 36 | if (d < best1) { 37 | best3 = best2; 38 | besti3 = besti2; 39 | best2 = best1; 40 | besti2 = besti1; 41 | best1 = d; 42 | besti1 = k; 43 | } else if (d < best2) { 44 | best3 = best2; 45 | besti3 = besti2; 46 | best2 = d; 47 | besti2 = k; 48 | } else if (d < best3) { 49 | best3 = d; 50 | besti3 = k; 51 | } 52 | } 53 | dist2[j * 3 + 0] = best1; 54 | dist2[j * 3 + 1] = best2; 55 | dist2[j * 3 + 2] = best3; 56 | 57 | idx[j * 3 + 0] = besti1; 58 | idx[j * 3 + 1] = besti2; 59 | idx[j * 3 + 2] = besti3; 60 | } 61 | } 62 | 63 | void three_nn_kernel_wrapper(int b, int n, int m, const float *unknown, 64 | const float *known, float *dist2, int *idx, 65 | cudaStream_t stream) { 66 | 67 | cudaError_t err; 68 | three_nn_kernel<<>>(b, n, m, unknown, known, 69 | dist2, idx); 70 | 71 | err = cudaGetLastError(); 72 | if (cudaSuccess != err) { 73 | fprintf(stderr, "CUDA kernel " 74 | "failed : %s\n", 75 | cudaGetErrorString(err)); 76 | exit(-1); 77 | } 78 | } 79 | 80 | // input: points(b, c, m), idx(b, n, 3), weight(b, n, 3) 81 | // output: out(b, c, n) 82 | __global__ void three_interpolate_kernel(int b, int c, int m, int n, 83 | const float *__restrict__ points, 84 | const int *__restrict__ idx, 85 | const float *__restrict__ weight, 86 | float *__restrict__ out) { 87 | int batch_index = blockIdx.x; 88 | points += batch_index * m * c; 89 | 90 | idx += batch_index * n * 3; 91 | weight += batch_index * n * 3; 92 | 93 | out += batch_index * n * c; 94 | 95 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 96 | const int stride = blockDim.y * blockDim.x; 97 | for (int i = index; i < c * n; i += stride) { 98 | const int l = i / n; 99 | const int j = i % n; 100 | float w1 = weight[j * 3 + 0]; 101 | float w2 = weight[j * 3 + 1]; 102 | float w3 = weight[j * 3 + 2]; 103 | 104 | int i1 = idx[j * 3 + 0]; 105 | int i2 = idx[j * 3 + 1]; 106 | int i3 = idx[j * 3 + 2]; 107 | 108 | out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 + 109 | points[l * m + i3] * w3; 110 | } 111 | } 112 | 113 | void three_interpolate_kernel_wrapper(int b, int c, int m, int n, 114 | const float *points, const int *idx, 115 | const float *weight, float *out, 116 | cudaStream_t stream) { 117 | 118 | cudaError_t err; 119 | three_interpolate_kernel<<>>( 120 | b, c, m, n, points, idx, weight, out); 121 | 122 | err = cudaGetLastError(); 123 | if (cudaSuccess != err) { 124 | fprintf(stderr, "CUDA kernel " 125 | "failed : %s\n", 126 | cudaGetErrorString(err)); 127 | exit(-1); 128 | } 129 | } 130 | 131 | // input: grad_out(b, c, n), idx(b, n, 3), weight(b, n, 3) 132 | // output: grad_points(b, c, m) 133 | 134 | __global__ void three_interpolate_grad_kernel( 135 | int b, int c, int n, int m, const float *__restrict__ grad_out, 136 | const int *__restrict__ idx, const float *__restrict__ weight, 137 | float *__restrict__ grad_points) { 138 | int batch_index = blockIdx.x; 139 | grad_out += batch_index * n * c; 140 | idx += batch_index * n * 3; 141 | weight += batch_index * n * 3; 142 | grad_points += batch_index * m * c; 143 | 144 | const int index = threadIdx.y * blockDim.x + threadIdx.x; 145 | const int stride = blockDim.y * blockDim.x; 146 | for (int i = index; i < c * n; i += stride) { 147 | const int l = i / n; 148 | const int j = i % n; 149 | float w1 = weight[j * 3 + 0]; 150 | float w2 = weight[j * 3 + 1]; 151 | float w3 = weight[j * 3 + 2]; 152 | 153 | int i1 = idx[j * 3 + 0]; 154 | int i2 = idx[j * 3 + 1]; 155 | int i3 = idx[j * 3 + 2]; 156 | 157 | atomicAdd(grad_points + l * m + i1, grad_out[i] * w1); 158 | atomicAdd(grad_points + l * m + i2, grad_out[i] * w2); 159 | atomicAdd(grad_points + l * m + i3, grad_out[i] * w3); 160 | } 161 | } 162 | 163 | void three_interpolate_grad_kernel_wrapper(int b, int n, int c, int m, 164 | const float *grad_out, 165 | const int *idx, const float *weight, 166 | float *grad_points, 167 | cudaStream_t stream) { 168 | 169 | cudaError_t err; 170 | three_interpolate_grad_kernel<<>>( 171 | b, n, c, m, grad_out, idx, weight, grad_points); 172 | 173 | err = cudaGetLastError(); 174 | if (cudaSuccess != err) { 175 | fprintf(stderr, "CUDA kernel " 176 | "failed : %s\n", 177 | cudaGetErrorString(err)); 178 | exit(-1); 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /utils/csrc/sampling.c: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "sampling_gpu.h" 4 | 5 | extern THCState *state; 6 | 7 | int gather_points_wrapper(int b, int c, int n, int npoints, 8 | THCudaTensor *points_tensor, 9 | THCudaIntTensor *idx_tensor, 10 | THCudaTensor *out_tensor) { 11 | 12 | const float *points = THCudaTensor_data(state, points_tensor); 13 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 14 | float *out = THCudaTensor_data(state, out_tensor); 15 | 16 | cudaStream_t stream = THCState_getCurrentStream(state); 17 | 18 | gather_points_kernel_wrapper(b, c, n, npoints, points, idx, out, stream); 19 | return 1; 20 | } 21 | 22 | int gather_points_grad_wrapper(int b, int c, int n, int npoints, 23 | THCudaTensor *grad_out_tensor, 24 | THCudaIntTensor *idx_tensor, 25 | THCudaTensor *grad_points_tensor) { 26 | 27 | const float *grad_out = THCudaTensor_data(state, grad_out_tensor); 28 | const int *idx = THCudaIntTensor_data(state, idx_tensor); 29 | float *grad_points = THCudaTensor_data(state, grad_points_tensor); 30 | 31 | cudaStream_t stream = THCState_getCurrentStream(state); 32 | 33 | gather_points_grad_kernel_wrapper(b, c, n, npoints, grad_out, idx, 34 | grad_points, stream); 35 | return 1; 36 | } 37 | 38 | int furthest_point_sampling_wrapper(int b, int n, int m, 39 | THCudaTensor *points_tensor, 40 | THCudaTensor *temp_tensor, 41 | THCudaIntTensor *idx_tensor) { 42 | 43 | const float *points = THCudaTensor_data(state, points_tensor); 44 | float *temp = THCudaTensor_data(state, temp_tensor); 45 | int *idx = THCudaIntTensor_data(state, idx_tensor); 46 | 47 | cudaStream_t stream = THCState_getCurrentStream(state); 48 | 49 | furthest_point_sampling_kernel_wrapper(b, n, m, points, temp, idx, stream); 50 | return 1; 51 | } 52 | -------------------------------------------------------------------------------- /utils/csrc/sampling_gpu.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_utils.h" 5 | #include "sampling_gpu.h" 6 | 7 | // input: points(b, c, n) idx(b, m) 8 | // output: out(b, c, m) 9 | __global__ void gather_points_kernel(int b, int c, int n, int m, 10 | const float *__restrict__ points, 11 | const int *__restrict__ idx, 12 | float *__restrict__ out) { 13 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 14 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 15 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 16 | int a = idx[i * m + j]; 17 | out[(i * c + l) * m + j] = points[(i * c + l) * n + a]; 18 | } 19 | } 20 | } 21 | } 22 | 23 | void gather_points_kernel_wrapper(int b, int c, int n, int npoints, 24 | const float *points, const int *idx, 25 | float *out, cudaStream_t stream) { 26 | 27 | cudaError_t err; 28 | gather_points_kernel<<>>( 29 | b, c, n, npoints, points, idx, out); 30 | 31 | err = cudaGetLastError(); 32 | if (cudaSuccess != err) { 33 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 34 | exit(-1); 35 | } 36 | } 37 | 38 | // input: grad_out(b, c, m) idx(b, m) 39 | // output: grad_points(b, c, n) 40 | __global__ void gather_points_grad_kernel(int b, int c, int n, int m, 41 | const float *__restrict__ grad_out, 42 | const int *__restrict__ idx, 43 | float *__restrict__ grad_points) { 44 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 45 | for (int l = blockIdx.y; l < c; l += gridDim.y) { 46 | for (int j = threadIdx.x; j < m; j += blockDim.x) { 47 | int a = idx[i * m + j]; 48 | atomicAdd(grad_points + (i * c + l) * n + a, 49 | grad_out[(i * c + l) * m + j]); 50 | } 51 | } 52 | } 53 | } 54 | 55 | void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints, 56 | const float *grad_out, const int *idx, 57 | float *grad_points, 58 | cudaStream_t stream) { 59 | 60 | cudaError_t err; 61 | gather_points_grad_kernel<<>>(b, c, n, npoints, grad_out, idx, 63 | grad_points); 64 | 65 | err = cudaGetLastError(); 66 | if (cudaSuccess != err) { 67 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 68 | exit(-1); 69 | } 70 | } 71 | 72 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 73 | int idx1, int idx2) { 74 | const float v1 = dists[idx1], v2 = dists[idx2]; 75 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 76 | dists[idx1] = max(v1, v2); 77 | dists_i[idx1] = v2 > v1 ? i2 : i1; 78 | } 79 | 80 | // Input dataset: (b, n, 3), tmp: (b, n) 81 | // Ouput idxs (b, m) 82 | template 83 | __global__ void furthest_point_sampling_kernel( 84 | int b, int n, int m, const float *__restrict__ dataset, 85 | float *__restrict__ temp, int *__restrict__ idxs) { 86 | if (m <= 0) 87 | return; 88 | __shared__ float dists[block_size]; 89 | __shared__ int dists_i[block_size]; 90 | 91 | int batch_index = blockIdx.x; 92 | dataset += batch_index * n * 3; 93 | temp += batch_index * n; 94 | idxs += batch_index * m; 95 | 96 | int tid = threadIdx.x; 97 | const int stride = block_size; 98 | 99 | int old = 0; 100 | if (threadIdx.x == 0) 101 | idxs[0] = old; 102 | 103 | __syncthreads(); 104 | for (int j = 1; j < m; j++) { 105 | int besti = 0; 106 | float best = -1; 107 | float x1 = dataset[old * 3 + 0]; 108 | float y1 = dataset[old * 3 + 1]; 109 | float z1 = dataset[old * 3 + 2]; 110 | for (int k = tid; k < n; k += stride) { 111 | float x2, y2, z2; 112 | x2 = dataset[k * 3 + 0]; 113 | y2 = dataset[k * 3 + 1]; 114 | z2 = dataset[k * 3 + 2]; 115 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 116 | if (mag <= 1e-3) 117 | continue; 118 | 119 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + 120 | (z2 - z1) * (z2 - z1); 121 | 122 | float d2 = min(d, temp[k]); 123 | temp[k] = d2; 124 | besti = d2 > best ? k : besti; 125 | best = d2 > best ? d2 : best; 126 | } 127 | dists[tid] = best; 128 | dists_i[tid] = besti; 129 | __syncthreads(); 130 | 131 | if (block_size >= 512) { 132 | if (tid < 256) { 133 | __update(dists, dists_i, tid, tid + 256); 134 | } 135 | __syncthreads(); 136 | } 137 | if (block_size >= 256) { 138 | if (tid < 128) { 139 | __update(dists, dists_i, tid, tid + 128); 140 | } 141 | __syncthreads(); 142 | } 143 | if (block_size >= 128) { 144 | if (tid < 64) { 145 | __update(dists, dists_i, tid, tid + 64); 146 | } 147 | __syncthreads(); 148 | } 149 | if (block_size >= 64) { 150 | if (tid < 32) { 151 | __update(dists, dists_i, tid, tid + 32); 152 | } 153 | __syncthreads(); 154 | } 155 | if (block_size >= 32) { 156 | if (tid < 16) { 157 | __update(dists, dists_i, tid, tid + 16); 158 | } 159 | __syncthreads(); 160 | } 161 | if (block_size >= 16) { 162 | if (tid < 8) { 163 | __update(dists, dists_i, tid, tid + 8); 164 | } 165 | __syncthreads(); 166 | } 167 | if (block_size >= 8) { 168 | if (tid < 4) { 169 | __update(dists, dists_i, tid, tid + 4); 170 | } 171 | __syncthreads(); 172 | } 173 | if (block_size >= 4) { 174 | if (tid < 2) { 175 | __update(dists, dists_i, tid, tid + 2); 176 | } 177 | __syncthreads(); 178 | } 179 | if (block_size >= 2) { 180 | if (tid < 1) { 181 | __update(dists, dists_i, tid, tid + 1); 182 | } 183 | __syncthreads(); 184 | } 185 | 186 | old = dists_i[0]; 187 | if (tid == 0) 188 | idxs[j] = old; 189 | } 190 | } 191 | 192 | void furthest_point_sampling_kernel_wrapper(int b, int n, int m, 193 | const float *dataset, float *temp, 194 | int *idxs, cudaStream_t stream) { 195 | 196 | cudaError_t err; 197 | unsigned int n_threads = opt_n_threads(n); 198 | 199 | switch (n_threads) { 200 | case 512: 201 | furthest_point_sampling_kernel<512><<>>( 202 | b, n, m, dataset, temp, idxs); 203 | break; 204 | case 256: 205 | furthest_point_sampling_kernel<256><<>>( 206 | b, n, m, dataset, temp, idxs); 207 | break; 208 | case 128: 209 | furthest_point_sampling_kernel<128><<>>( 210 | b, n, m, dataset, temp, idxs); 211 | break; 212 | case 64: 213 | furthest_point_sampling_kernel<64><<>>( 214 | b, n, m, dataset, temp, idxs); 215 | break; 216 | case 32: 217 | furthest_point_sampling_kernel<32><<>>( 218 | b, n, m, dataset, temp, idxs); 219 | break; 220 | case 16: 221 | furthest_point_sampling_kernel<16><<>>( 222 | b, n, m, dataset, temp, idxs); 223 | break; 224 | case 8: 225 | furthest_point_sampling_kernel<8><<>>( 226 | b, n, m, dataset, temp, idxs); 227 | break; 228 | case 4: 229 | furthest_point_sampling_kernel<4><<>>( 230 | b, n, m, dataset, temp, idxs); 231 | break; 232 | case 2: 233 | furthest_point_sampling_kernel<2><<>>( 234 | b, n, m, dataset, temp, idxs); 235 | break; 236 | case 1: 237 | furthest_point_sampling_kernel<1><<>>( 238 | b, n, m, dataset, temp, idxs); 239 | break; 240 | default: 241 | furthest_point_sampling_kernel<512><<>>( 242 | b, n, m, dataset, temp, idxs); 243 | } 244 | 245 | err = cudaGetLastError(); 246 | if (cudaSuccess != err) { 247 | fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err)); 248 | exit(-1); 249 | } 250 | } 251 | -------------------------------------------------------------------------------- /utils/linalg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enum import Enum 3 | 4 | PDist2Order = Enum('PDist2Order', 'd_first d_second') 5 | 6 | 7 | def pdist2( 8 | X: torch.Tensor, 9 | Z: torch.Tensor = None, 10 | order: PDist2Order = PDist2Order.d_second 11 | ) -> torch.Tensor: 12 | r""" Calculates the pairwise distance between X and Z 13 | 14 | D[b, i, j] = l2 distance X[b, i] and Z[b, j] 15 | 16 | Parameters 17 | --------- 18 | X : torch.Tensor 19 | X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d 20 | Z: torch.Tensor 21 | Z is a (B, M, d) tensor. If Z is None, then Z = X 22 | 23 | Returns 24 | ------- 25 | torch.Tensor 26 | Distance matrix is size (B, N, M) 27 | """ 28 | 29 | if order == PDist2Order.d_second: 30 | if X.dim() == 2: 31 | X = X.unsqueeze(0) 32 | if Z is None: 33 | Z = X 34 | G = X @ Z.transpose(-2, -1) 35 | S = (X * X).sum(-1, keepdim=True) 36 | R = S.transpose(-2, -1) 37 | else: 38 | if Z.dim() == 2: 39 | Z = Z.unsqueeze(0) 40 | G = X @ Z.transpose(-2, -1) 41 | S = (X * X).sum(-1, keepdim=True) 42 | R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1) 43 | else: 44 | if X.dim() == 2: 45 | X = X.unsqueeze(0) 46 | if Z is None: 47 | Z = X 48 | G = X.transpose(-2, -1) @ Z 49 | R = (X * X).sum(-2, keepdim=True) 50 | S = R.transpose(-2, -1) 51 | else: 52 | if Z.dim() == 2: 53 | Z = Z.unsqueeze(0) 54 | G = X.transpose(-2, -1) @ Z 55 | S = (X * X).sum(-2, keepdim=True).transpose(-2, -1) 56 | R = (Z * Z).sum(-2, keepdim=True) 57 | 58 | return torch.abs(R + S - 2 * G).squeeze(0) 59 | 60 | 61 | def pdist2_slow(X, Z=None): 62 | if Z is None: Z = X 63 | D = torch.zeros(X.size(0), X.size(2), Z.size(2)) 64 | 65 | for b in range(D.size(0)): 66 | for i in range(D.size(1)): 67 | for j in range(D.size(2)): 68 | D[b, i, j] = torch.dist(X[b, :, i], Z[b, :, j]) 69 | return D 70 | 71 | 72 | if __name__ == "__main__": 73 | X = torch.randn(2, 3, 5) 74 | Z = torch.randn(2, 3, 3) 75 | 76 | print(pdist2(X, order=PDist2Order.d_first)) 77 | print(pdist2_slow(X)) 78 | print(torch.dist(pdist2(X, order=PDist2Order.d_first), pdist2_slow(X))) 79 | -------------------------------------------------------------------------------- /utils/pointnet2_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import pointnet2_utils 6 | import pytorch_utils as pt_utils 7 | from typing import List 8 | import numpy as np 9 | 10 | class _PointnetSAModuleBase(nn.Module): 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.npoint = None 15 | self.groupers = None 16 | self.mlps = None 17 | self.pool = False 18 | 19 | def forward(self, xyz: torch.Tensor, 20 | features: torch.Tensor = None) -> (torch.Tensor, torch.Tensor): 21 | r""" 22 | Parameters 23 | ---------- 24 | xyz : torch.Tensor 25 | (B, N, 3) tensor of the xyz coordinates of the points 26 | features : torch.Tensor 27 | (B, N, C) tensor of the descriptors of the the points 28 | 29 | Returns 30 | ------- 31 | new_xyz : torch.Tensor 32 | (B, npoint, 3) tensor of the new points' xyz 33 | new_features : torch.Tensor 34 | (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_points descriptors 35 | """ 36 | 37 | all_features = 0 38 | xyz_flipped = xyz.transpose(1, 2).contiguous() 39 | 40 | if self.npoint is not None: 41 | fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.npoint) \ 42 | if self.pool else torch.from_numpy(np.arange(xyz.size(1))).int().cuda().repeat(xyz.size(0), 1) 43 | new_xyz = pointnet2_utils.gather_operation(xyz_flipped, fps_idx).transpose(1, 2).contiguous() 44 | else: 45 | new_xyz = None 46 | 47 | for i in range(len(self.groupers)): 48 | new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) 49 | if not self.pool and self.npoint is not None: 50 | new_features = [new_features, features] 51 | new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint) 52 | all_features += new_features 53 | 54 | return new_xyz, all_features 55 | 56 | 57 | class PointnetSAModuleMSG(_PointnetSAModuleBase): 58 | r"""Pointnet set abstrction layer with multiscale grouping 59 | 60 | Parameters 61 | ---------- 62 | npoint : int 63 | Number of points 64 | radii : list of float32 65 | list of radii to group with 66 | nsamples : list of int32 67 | Number of samples in each ball query 68 | mlps : list of list of int32 69 | Spec of the pointnet before the global max_pool for each scale 70 | bn : bool 71 | Use batchnorm 72 | """ 73 | 74 | def __init__( 75 | self, 76 | *, 77 | npoint: int, 78 | radii: List[float], 79 | nsamples: List[int], 80 | mlps: List[List[int]], 81 | group_number = 1, 82 | use_xyz: bool = True, 83 | pool: bool = False, 84 | before_pool: bool = False, 85 | after_pool: bool = False, 86 | bias = True, 87 | init = nn.init.kaiming_normal 88 | ): 89 | super().__init__() 90 | 91 | assert len(radii) == len(nsamples) == len(mlps) 92 | self.pool = pool 93 | self.npoint = npoint 94 | self.groupers = nn.ModuleList() 95 | self.mlps = nn.ModuleList() 96 | 97 | if pool: 98 | C_in = (mlps[0][0] + 3) if use_xyz else mlps[0][0] 99 | C_out = mlps[0][1] 100 | pconv = nn.Conv2d(in_channels = C_in, out_channels = C_out, kernel_size = (1, 1), 101 | stride = (1, 1), bias = bias) 102 | init(pconv.weight) 103 | if bias: 104 | nn.init.constant(pconv.bias, 0) 105 | convs = [pconv] 106 | 107 | for i in range(len(radii)): 108 | radius = radii[i] 109 | nsample = nsamples[i] 110 | self.groupers.append( 111 | pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) 112 | if npoint is not None else pointnet2_utils.GroupAll(use_xyz) 113 | ) 114 | mlp_spec = mlps[i] 115 | if use_xyz: 116 | mlp_spec[0] += 3 117 | if npoint is None: 118 | self.mlps.append(pt_utils.GloAvgConv(C_in = mlp_spec[0], C_out = mlp_spec[1])) 119 | elif pool: 120 | self.mlps.append(pt_utils.PointConv(C_in = mlp_spec[0], C_out = mlp_spec[1], convs = convs)) 121 | else: 122 | self.mlps.append(pt_utils.EnhancedPointConv(C_in = mlp_spec[0], C_out = mlp_spec[1], group_number = group_number, before_pool = before_pool, after_pool = after_pool)) 123 | 124 | 125 | class PointnetSAModule(PointnetSAModuleMSG): 126 | r"""Pointnet set abstrction layer 127 | 128 | Parameters 129 | ---------- 130 | npoint : int 131 | Number of features 132 | radius : float 133 | Radius of ball 134 | nsample : int 135 | Number of samples in the ball query 136 | mlp : list 137 | Spec of the pointnet before the global max_pool 138 | bn : bool 139 | Use batchnorm 140 | """ 141 | 142 | def __init__( 143 | self, 144 | *, 145 | mlp: List[int], 146 | npoint: int = None, 147 | radius: float = None, 148 | nsample: int = None, 149 | use_xyz: bool = True 150 | ): 151 | super().__init__( 152 | mlps=[mlp], 153 | npoint=npoint, 154 | radii=[radius], 155 | nsamples=[nsample], 156 | use_xyz=use_xyz 157 | ) 158 | 159 | 160 | class PointnetFPModule(nn.Module): 161 | r"""Propigates the features of one set to another 162 | 163 | Parameters 164 | ---------- 165 | mlp : list 166 | Pointnet module parameters 167 | bn : bool 168 | Use batchnorm 169 | """ 170 | 171 | def __init__(self, *, mlp: List[int], bn: bool = True): 172 | super().__init__() 173 | self.mlp = pt_utils.SharedMLP(mlp, bn=bn) 174 | 175 | def forward( 176 | self, unknown: torch.Tensor, known: torch.Tensor, 177 | unknow_feats: torch.Tensor, known_feats: torch.Tensor 178 | ) -> torch.Tensor: 179 | r""" 180 | Parameters 181 | ---------- 182 | unknown : torch.Tensor 183 | (B, n, 3) tensor of the xyz positions of the unknown features 184 | known : torch.Tensor 185 | (B, m, 3) tensor of the xyz positions of the known features 186 | unknow_feats : torch.Tensor 187 | (B, C1, n) tensor of the features to be propigated to 188 | known_feats : torch.Tensor 189 | (B, C2, m) tensor of features to be propigated 190 | 191 | Returns 192 | ------- 193 | new_features : torch.Tensor 194 | (B, mlp[-1], n) tensor of the features of the unknown features 195 | """ 196 | 197 | dist, idx = pointnet2_utils.three_nn(unknown, known) 198 | dist_recip = 1.0 / (dist + 1e-8) 199 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 200 | weight = dist_recip / norm 201 | 202 | interpolated_feats = pointnet2_utils.three_interpolate( 203 | known_feats, idx, weight 204 | ) 205 | if unknow_feats is not None: 206 | new_features = torch.cat([interpolated_feats, unknow_feats], 207 | dim=1) #(B, C2 + C1, n) 208 | else: 209 | new_features = interpolated_feats 210 | 211 | new_features = new_features.unsqueeze(-1) 212 | new_features = self.mlp(new_features) 213 | 214 | return new_features.squeeze(-1) -------------------------------------------------------------------------------- /utils/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.autograd import Function 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from typing import List, Tuple 7 | 8 | from _ext import pointnet2 9 | 10 | class FurthestPointSampling(Function): 11 | 12 | @staticmethod 13 | def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: 14 | r""" 15 | Uses iterative furthest point sampling to select a set of npoint features that have the largest 16 | minimum distance 17 | 18 | Parameters 19 | ---------- 20 | xyz : torch.Tensor 21 | (B, N, 3) tensor where N > npoint 22 | npoint : int32 23 | number of features in the sampled set 24 | 25 | Returns 26 | ------- 27 | torch.Tensor 28 | (B, npoint) tensor containing the set 29 | """ 30 | assert xyz.is_contiguous() 31 | 32 | B, N, _ = xyz.size() 33 | 34 | output = torch.cuda.IntTensor(B, npoint) 35 | temp = torch.cuda.FloatTensor(B, N).fill_(1e10) 36 | pointnet2.furthest_point_sampling_wrapper( 37 | B, N, npoint, xyz, temp, output 38 | ) 39 | return output 40 | 41 | @staticmethod 42 | def backward(xyz, a=None): 43 | return None, None 44 | 45 | 46 | furthest_point_sample = FurthestPointSampling.apply 47 | 48 | 49 | class GatherOperation(Function): 50 | 51 | @staticmethod 52 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 53 | r""" 54 | 55 | Parameters 56 | ---------- 57 | features : torch.Tensor 58 | (B, C, N) tensor 59 | 60 | idx : torch.Tensor 61 | (B, npoint) tensor of the features to gather 62 | 63 | Returns 64 | ------- 65 | torch.Tensor 66 | (B, C, npoint) tensor 67 | """ 68 | assert features.is_contiguous() 69 | assert idx.is_contiguous() 70 | 71 | B, npoint = idx.size() 72 | _, C, N = features.size() 73 | 74 | output = torch.cuda.FloatTensor(B, C, npoint) 75 | 76 | pointnet2.gather_points_wrapper( 77 | B, C, N, npoint, features, idx, output 78 | ) 79 | 80 | ctx.for_backwards = (idx, C, N) 81 | 82 | return output 83 | 84 | @staticmethod 85 | def backward(ctx, grad_out): 86 | idx, C, N = ctx.for_backwards 87 | B, npoint = idx.size() 88 | 89 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 90 | grad_out_data = grad_out.data.contiguous() 91 | pointnet2.gather_points_grad_wrapper( 92 | B, C, N, npoint, grad_out_data, idx, grad_features.data 93 | ) 94 | 95 | return grad_features, None 96 | 97 | 98 | gather_operation = GatherOperation.apply 99 | 100 | 101 | class ThreeNN(Function): 102 | 103 | @staticmethod 104 | def forward(ctx, unknown: torch.Tensor, 105 | known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 106 | r""" 107 | Find the three nearest neighbors of unknown in known 108 | Parameters 109 | ---------- 110 | unknown : torch.Tensor 111 | (B, n, 3) tensor of known features 112 | known : torch.Tensor 113 | (B, m, 3) tensor of unknown features 114 | 115 | Returns 116 | ------- 117 | dist : torch.Tensor 118 | (B, n, 3) l2 distance to the three nearest neighbors 119 | idx : torch.Tensor 120 | (B, n, 3) index of 3 nearest neighbors 121 | """ 122 | assert unknown.is_contiguous() 123 | assert known.is_contiguous() 124 | 125 | B, N, _ = unknown.size() 126 | m = known.size(1) 127 | dist2 = torch.cuda.FloatTensor(B, N, 3) 128 | idx = torch.cuda.IntTensor(B, N, 3) 129 | 130 | pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) 131 | 132 | return torch.sqrt(dist2), idx 133 | 134 | @staticmethod 135 | def backward(ctx, a=None, b=None): 136 | return None, None 137 | 138 | 139 | three_nn = ThreeNN.apply 140 | 141 | 142 | class ThreeInterpolate(Function): 143 | 144 | @staticmethod 145 | def forward( 146 | ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor 147 | ) -> torch.Tensor: 148 | r""" 149 | Performs weight linear interpolation on 3 features 150 | Parameters 151 | ---------- 152 | features : torch.Tensor 153 | (B, c, m) Features descriptors to be interpolated from 154 | idx : torch.Tensor 155 | (B, n, 3) three nearest neighbors of the target features in features 156 | weight : torch.Tensor 157 | (B, n, 3) weights 158 | 159 | Returns 160 | ------- 161 | torch.Tensor 162 | (B, c, n) tensor of the interpolated features 163 | """ 164 | assert features.is_contiguous() 165 | assert idx.is_contiguous() 166 | assert weight.is_contiguous() 167 | 168 | B, c, m = features.size() 169 | n = idx.size(1) 170 | 171 | ctx.three_interpolate_for_backward = (idx, weight, m) 172 | 173 | output = torch.cuda.FloatTensor(B, c, n) 174 | 175 | pointnet2.three_interpolate_wrapper( 176 | B, c, m, n, features, idx, weight, output 177 | ) 178 | 179 | return output 180 | 181 | @staticmethod 182 | def backward(ctx, grad_out: torch.Tensor 183 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 184 | r""" 185 | Parameters 186 | ---------- 187 | grad_out : torch.Tensor 188 | (B, c, n) tensor with gradients of ouputs 189 | 190 | Returns 191 | ------- 192 | grad_features : torch.Tensor 193 | (B, c, m) tensor with gradients of features 194 | 195 | None 196 | 197 | None 198 | """ 199 | idx, weight, m = ctx.three_interpolate_for_backward 200 | B, c, n = grad_out.size() 201 | 202 | grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) 203 | 204 | grad_out_data = grad_out.data.contiguous() 205 | pointnet2.three_interpolate_grad_wrapper( 206 | B, c, n, m, grad_out_data, idx, weight, grad_features.data 207 | ) 208 | 209 | return grad_features, None, None 210 | 211 | 212 | three_interpolate = ThreeInterpolate.apply 213 | 214 | 215 | class GroupingOperation(Function): 216 | 217 | @staticmethod 218 | def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: 219 | r""" 220 | 221 | Parameters 222 | ---------- 223 | features : torch.Tensor 224 | (B, C, N) tensor of features to group 225 | idx : torch.Tensor 226 | (B, npoint, nsample) tensor containing the indicies of features to group with 227 | 228 | Returns 229 | ------- 230 | torch.Tensor 231 | (B, C, npoint, nsample) tensor 232 | """ 233 | assert features.is_contiguous() 234 | assert idx.is_contiguous() 235 | 236 | B, nfeatures, nsample = idx.size() 237 | _, C, N = features.size() 238 | 239 | output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) 240 | 241 | pointnet2.group_points_wrapper( 242 | B, C, N, nfeatures, nsample, features, idx, output 243 | ) 244 | 245 | ctx.for_backwards = (idx, N) 246 | return output 247 | 248 | @staticmethod 249 | def backward(ctx, 250 | grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 251 | r""" 252 | 253 | Parameters 254 | ---------- 255 | grad_out : torch.Tensor 256 | (B, C, npoint, nsample) tensor of the gradients of the output from forward 257 | 258 | Returns 259 | ------- 260 | torch.Tensor 261 | (B, C, N) gradient of the features 262 | None 263 | """ 264 | idx, N = ctx.for_backwards 265 | 266 | B, C, npoint, nsample = grad_out.size() 267 | grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) 268 | 269 | grad_out_data = grad_out.data.contiguous() 270 | pointnet2.group_points_grad_wrapper( 271 | B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data 272 | ) 273 | 274 | return grad_features, None 275 | 276 | 277 | grouping_operation = GroupingOperation.apply 278 | 279 | 280 | class BallQuery(Function): 281 | 282 | @staticmethod 283 | def forward( 284 | ctx, radius: float, nsample: int, xyz: torch.Tensor, 285 | new_xyz: torch.Tensor 286 | ) -> torch.Tensor: 287 | r""" 288 | 289 | Parameters 290 | ---------- 291 | radius : float 292 | radius of the balls 293 | nsample : int 294 | maximum number of features in the balls 295 | xyz : torch.Tensor 296 | (B, N, 3) xyz coordinates of the features 297 | new_xyz : torch.Tensor 298 | (B, npoint, 3) centers of the ball query 299 | 300 | Returns 301 | ------- 302 | torch.Tensor 303 | (B, npoint, nsample) tensor with the indicies of the features that form the query balls 304 | """ 305 | assert new_xyz.is_contiguous() 306 | assert xyz.is_contiguous() 307 | 308 | B, N, _ = xyz.size() 309 | npoint = new_xyz.size(1) 310 | idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() 311 | 312 | pointnet2.ball_query_wrapper( 313 | B, N, npoint, radius, nsample, new_xyz, xyz, idx 314 | ) 315 | 316 | return idx 317 | 318 | @staticmethod 319 | def backward(ctx, a=None): 320 | return None, None, None, None 321 | 322 | 323 | ball_query = BallQuery.apply 324 | 325 | 326 | class QueryAndGroup(nn.Module): 327 | r""" 328 | Groups with a ball query of radius 329 | 330 | Parameters 331 | --------- 332 | radius : float32 333 | Radius of ball 334 | nsample : int32 335 | Maximum number of features to gather in the ball 336 | """ 337 | 338 | def __init__(self, radius: float, nsample: int, use_xyz: bool = True): 339 | super().__init__() 340 | self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz 341 | 342 | def forward( 343 | self, 344 | xyz: torch.Tensor, 345 | new_xyz: torch.Tensor, 346 | features: torch.Tensor = None 347 | ) -> Tuple[torch.Tensor]: 348 | r""" 349 | Parameters 350 | ---------- 351 | xyz : torch.Tensor 352 | xyz coordinates of the features (B, N, 3) 353 | new_xyz : torch.Tensor 354 | centriods (B, npoint, 3) 355 | features : torch.Tensor 356 | Descriptors of the features (B, C, N) 357 | 358 | Returns 359 | ------- 360 | new_features : torch.Tensor 361 | (B, 3 + C, npoint, nsample) tensor 362 | """ 363 | 364 | idx = ball_query(self.radius, self.nsample, xyz, new_xyz) 365 | xyz_trans = xyz.transpose(1, 2).contiguous() 366 | grouped_xyz = grouping_operation( 367 | xyz_trans, idx 368 | ) # (B, 3, npoint, nsample) 369 | grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) 370 | 371 | if features is not None: 372 | grouped_features = grouping_operation(features, idx) 373 | if self.use_xyz: 374 | new_features = torch.cat([grouped_xyz, grouped_features], 375 | dim=1) # (B, C + 3, npoint, nsample) 376 | else: 377 | new_features = grouped_features 378 | else: 379 | assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" 380 | new_features = grouped_xyz 381 | 382 | return new_features 383 | 384 | 385 | class GroupAll(nn.Module): 386 | r""" 387 | Groups all features 388 | 389 | Parameters 390 | --------- 391 | """ 392 | 393 | def __init__(self, use_xyz: bool = True): 394 | super().__init__() 395 | self.use_xyz = use_xyz 396 | 397 | def forward( 398 | self, 399 | xyz: torch.Tensor, 400 | new_xyz: torch.Tensor, 401 | features: torch.Tensor = None 402 | ) -> Tuple[torch.Tensor]: 403 | r""" 404 | Parameters 405 | ---------- 406 | xyz : torch.Tensor 407 | xyz coordinates of the features (B, N, 3) 408 | new_xyz : torch.Tensor 409 | Ignored 410 | features : torch.Tensor 411 | Descriptors of the features (B, C, N) 412 | 413 | Returns 414 | ------- 415 | new_features : torch.Tensor 416 | (B, C + 3, 1, N) tensor 417 | """ 418 | 419 | grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) 420 | if features is not None: 421 | grouped_features = features.unsqueeze(2) 422 | if self.use_xyz: 423 | new_features = torch.cat([grouped_xyz, grouped_features], 424 | dim=1) # (B, 3 + C, 1, N) 425 | else: 426 | new_features = grouped_features 427 | else: 428 | new_features = grouped_xyz 429 | 430 | return new_features 431 | -------------------------------------------------------------------------------- /utils/pytorch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pytorch_utils import * 2 | -------------------------------------------------------------------------------- /utils/pytorch_utils/pytorch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from torch.autograd.function import InplaceFunction 6 | from itertools import repeat 7 | import numpy as np 8 | import os 9 | from typing import List, Tuple 10 | from scipy.stats import t as student_t 11 | import statistics as stats 12 | import math 13 | 14 | ########## PointConv begin ############ 15 | class PointConv(nn.Module): 16 | ''' 17 | Input shape: (B, C_in, npoint, nsample) 18 | Output shape: (B, C_out, npoint) 19 | ''' 20 | def __init__(self, C_in, C_out, convs=None): 21 | super(PointConv, self).__init__() 22 | self.bn = nn.BatchNorm2d(C_out) 23 | self.activation = nn.ReLU(inplace=True) 24 | self.pconv = convs[0] 25 | 26 | def forward(self, x): # x: (B, C_in, npoint, nsample) 27 | nsample = x.size(3) 28 | x = self.activation(self.bn(self.pconv(x))) 29 | return F.max_pool2d(x, kernel_size = (1, nsample)).squeeze(3) 30 | ########## PointConv end ############ 31 | 32 | ########## EnhancedPointConv begin ############ 33 | class EnhancedPointConv(nn.Module): 34 | ''' 35 | Input shape: (B, C_in, npoint, nsample) 36 | Output shape: (B, C_out, npoint) 37 | ''' 38 | def __init__(self, C_in, C_out, group_number=1, before_pool=False, after_pool=False, init=nn.init.kaiming_normal, bias=True): 39 | super(EnhancedPointConv, self).__init__() 40 | self.before_pool, self.after_pool = before_pool, after_pool 41 | C_small = math.floor(C_out/4) 42 | self.conv_phi = nn.Conv2d(in_channels = C_in, out_channels = C_out, groups = group_number, kernel_size = (1, 1), 43 | stride = (1, 1), bias = bias) # ~\phi function: grouped version 44 | self.conv_psi = nn.Conv1d(in_channels = C_out, out_channels = C_small, kernel_size = 1, 45 | stride = 1, bias = bias) # \psi function 46 | if not after_pool: 47 | self.bn_cin = nn.BatchNorm2d(C_in) 48 | self.bn_phi = nn.BatchNorm2d(C_out) 49 | if before_pool: 50 | self.bn_concat = nn.BatchNorm1d(C_in-3+C_small) 51 | self.activation = nn.ReLU(inplace=True) 52 | self.dropout = nn.Dropout(p=0.2) 53 | 54 | init(self.conv_phi.weight) 55 | init(self.conv_psi.weight) 56 | if bias: 57 | nn.init.constant(self.conv_phi.bias, 0) 58 | nn.init.constant(self.conv_psi.bias, 0) 59 | 60 | def forward(self, input): # x: (B, C_in, npoint, nsample) 61 | x, last_feat = input[0], input[1] 62 | nsample = x.size(3) 63 | if not self.after_pool: 64 | x = self.activation(self.bn_cin(x)) 65 | x = self.activation(self.bn_phi(self.conv_phi(x))) 66 | x = F.max_pool2d(x, kernel_size=(1, nsample)).squeeze(3) 67 | x = torch.cat((last_feat, self.dropout(self.conv_psi(x))), dim=1) 68 | 69 | if self.before_pool: 70 | x = self.activation(self.bn_concat(x)) 71 | return x 72 | 73 | ########## EnhancedPointConv end ############ 74 | 75 | 76 | ########## global convolutional pooling begin ############ 77 | class GloAvgConv(nn.Module): 78 | ''' 79 | Input shape: (B, C_in, 1, nsample) 80 | Output shape: (B, C_out, npoint) 81 | ''' 82 | def __init__( 83 | self, 84 | C_in, 85 | C_out, 86 | init=nn.init.kaiming_normal, 87 | bias = True, 88 | activation = nn.ReLU(inplace=True) 89 | ): 90 | super(GloAvgConv, self).__init__() 91 | 92 | self.conv_avg = nn.Conv2d(in_channels = C_in, out_channels = C_out, kernel_size = (1, 1), 93 | stride = (1, 1), bias = bias) 94 | self.bn_avg = nn.BatchNorm2d(C_out) 95 | self.activation = activation 96 | 97 | init(self.conv_avg.weight) 98 | if bias: 99 | nn.init.constant(self.conv_avg.bias, 0) 100 | 101 | def forward(self, x): 102 | nsample = x.size(3) 103 | x = self.activation(self.bn_avg(self.conv_avg(x))) 104 | x = F.max_pool2d(x, kernel_size = (1, nsample)).squeeze(3) 105 | return x 106 | ########## global convolutional pooling end ############ 107 | 108 | class SharedMLP(nn.Sequential): 109 | 110 | def __init__( 111 | self, 112 | args: List[int], 113 | *, 114 | bn: bool = False, 115 | activation=nn.ReLU(inplace=True), 116 | preact: bool = False, 117 | first: bool = False, 118 | name: str = "" 119 | ): 120 | super().__init__() 121 | 122 | for i in range(len(args) - 1): 123 | self.add_module( 124 | name + 'layer{}'.format(i), 125 | Conv2d( 126 | args[i], 127 | args[i + 1], 128 | bn=(not first or not preact or (i != 0)) and bn, 129 | activation=activation 130 | if (not first or not preact or (i != 0)) else None, 131 | preact=preact 132 | ) 133 | ) 134 | 135 | 136 | class _BNBase(nn.Sequential): 137 | 138 | def __init__(self, in_size, batch_norm=None, name=""): 139 | super().__init__() 140 | self.add_module(name + "bn", batch_norm(in_size)) 141 | 142 | nn.init.constant(self[0].weight, 1.0) 143 | nn.init.constant(self[0].bias, 0) 144 | 145 | 146 | class BatchNorm1d(_BNBase): 147 | 148 | def __init__(self, in_size: int, *, name: str = ""): 149 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 150 | 151 | 152 | class BatchNorm2d(_BNBase): 153 | 154 | def __init__(self, in_size: int, name: str = ""): 155 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 156 | 157 | 158 | class BatchNorm3d(_BNBase): 159 | 160 | def __init__(self, in_size: int, name: str = ""): 161 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 162 | 163 | 164 | class _ConvBase(nn.Sequential): 165 | 166 | def __init__( 167 | self, 168 | in_size, 169 | out_size, 170 | kernel_size, 171 | stride, 172 | padding, 173 | activation, 174 | bn, 175 | init, 176 | conv=None, 177 | batch_norm=None, 178 | bias=True, 179 | preact=False, 180 | name="" 181 | ): 182 | super().__init__() 183 | 184 | bias = bias and (not bn) 185 | conv_unit = conv( 186 | in_size, 187 | out_size, 188 | kernel_size=kernel_size, 189 | stride=stride, 190 | padding=padding, 191 | bias=bias 192 | ) 193 | init(conv_unit.weight) 194 | if bias: 195 | nn.init.constant(conv_unit.bias, 0) 196 | 197 | if bn: 198 | if not preact: 199 | bn_unit = batch_norm(out_size) 200 | else: 201 | bn_unit = batch_norm(in_size) 202 | 203 | if preact: 204 | if bn: 205 | self.add_module(name + 'bn', bn_unit) 206 | 207 | if activation is not None: 208 | self.add_module(name + 'activation', activation) 209 | 210 | self.add_module(name + 'conv', conv_unit) 211 | 212 | if not preact: 213 | if bn: 214 | self.add_module(name + 'bn', bn_unit) 215 | 216 | if activation is not None: 217 | self.add_module(name + 'activation', activation) 218 | 219 | 220 | class Conv1d(_ConvBase): 221 | 222 | def __init__( 223 | self, 224 | in_size: int, 225 | out_size: int, 226 | *, 227 | kernel_size: int = 1, 228 | stride: int = 1, 229 | padding: int = 0, 230 | activation=nn.ReLU(inplace=True), 231 | bn: bool = False, 232 | init=nn.init.kaiming_normal, 233 | bias: bool = True, 234 | preact: bool = False, 235 | name: str = "" 236 | ): 237 | super().__init__( 238 | in_size, 239 | out_size, 240 | kernel_size, 241 | stride, 242 | padding, 243 | activation, 244 | bn, 245 | init, 246 | conv=nn.Conv1d, 247 | batch_norm=BatchNorm1d, 248 | bias=bias, 249 | preact=preact, 250 | name=name 251 | ) 252 | 253 | 254 | class Conv2d(_ConvBase): 255 | 256 | def __init__( 257 | self, 258 | in_size: int, 259 | out_size: int, 260 | *, 261 | kernel_size: Tuple[int, int] = (1, 1), 262 | stride: Tuple[int, int] = (1, 1), 263 | padding: Tuple[int, int] = (0, 0), 264 | activation=nn.ReLU(inplace=True), 265 | bn: bool = False, 266 | init=nn.init.kaiming_normal, 267 | bias: bool = True, 268 | preact: bool = False, 269 | name: str = "" 270 | ): 271 | super().__init__( 272 | in_size, 273 | out_size, 274 | kernel_size, 275 | stride, 276 | padding, 277 | activation, 278 | bn, 279 | init, 280 | conv=nn.Conv2d, 281 | batch_norm=BatchNorm2d, 282 | bias=bias, 283 | preact=preact, 284 | name=name 285 | ) 286 | 287 | 288 | class Conv3d(_ConvBase): 289 | 290 | def __init__( 291 | self, 292 | in_size: int, 293 | out_size: int, 294 | *, 295 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 296 | stride: Tuple[int, int, int] = (1, 1, 1), 297 | padding: Tuple[int, int, int] = (0, 0, 0), 298 | activation=nn.ReLU(inplace=True), 299 | bn: bool = False, 300 | init=nn.init.kaiming_normal, 301 | bias: bool = True, 302 | preact: bool = False, 303 | name: str = "" 304 | ): 305 | super().__init__( 306 | in_size, 307 | out_size, 308 | kernel_size, 309 | stride, 310 | padding, 311 | activation, 312 | bn, 313 | init, 314 | conv=nn.Conv3d, 315 | batch_norm=BatchNorm3d, 316 | bias=bias, 317 | preact=preact, 318 | name=name 319 | ) 320 | 321 | 322 | class FC(nn.Sequential): 323 | 324 | def __init__( 325 | self, 326 | in_size: int, 327 | out_size: int, 328 | *, 329 | activation=nn.ReLU(inplace=True), 330 | bn: bool = False, 331 | init=None, 332 | preact: bool = False, 333 | name: str = "" 334 | ): 335 | super().__init__() 336 | 337 | fc = nn.Linear(in_size, out_size, bias=not bn) 338 | if init is not None: 339 | init(fc.weight) 340 | if not bn: 341 | nn.init.constant(fc.bias, 0) 342 | 343 | if preact: 344 | if bn: 345 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 346 | 347 | if activation is not None: 348 | self.add_module(name + 'activation', activation) 349 | 350 | self.add_module(name + 'fc', fc) 351 | 352 | if not preact: 353 | if bn: 354 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 355 | 356 | if activation is not None: 357 | self.add_module(name + 'activation', activation) 358 | 359 | 360 | class _DropoutNoScaling(InplaceFunction): 361 | 362 | @staticmethod 363 | def _make_noise(input): 364 | return input.new().resize_as_(input) 365 | 366 | @staticmethod 367 | def symbolic(g, input, p=0.5, train=False, inplace=False): 368 | if inplace: 369 | return None 370 | n = g.appendNode( 371 | g.create("Dropout", [input]).f_("ratio", 372 | p).i_("is_test", not train) 373 | ) 374 | real = g.appendNode(g.createSelect(n, 0)) 375 | g.appendNode(g.createSelect(n, 1)) 376 | return real 377 | 378 | @classmethod 379 | def forward(cls, ctx, input, p=0.5, train=False, inplace=False): 380 | if p < 0 or p > 1: 381 | raise ValueError( 382 | "dropout probability has to be between 0 and 1, " 383 | "but got {}".format(p) 384 | ) 385 | ctx.p = p 386 | ctx.train = train 387 | ctx.inplace = inplace 388 | 389 | if ctx.inplace: 390 | ctx.mark_dirty(input) 391 | output = input 392 | else: 393 | output = input.clone() 394 | 395 | if ctx.p > 0 and ctx.train: 396 | ctx.noise = cls._make_noise(input) 397 | if ctx.p == 1: 398 | ctx.noise.fill_(0) 399 | else: 400 | ctx.noise.bernoulli_(1 - ctx.p) 401 | ctx.noise = ctx.noise.expand_as(input) 402 | output.mul_(ctx.noise) 403 | 404 | return output 405 | 406 | @staticmethod 407 | def backward(ctx, grad_output): 408 | if ctx.p > 0 and ctx.train: 409 | return grad_output.mul(Variable(ctx.noise)), None, None, None 410 | else: 411 | return grad_output, None, None, None 412 | 413 | 414 | dropout_no_scaling = _DropoutNoScaling.apply 415 | 416 | 417 | class _FeatureDropoutNoScaling(_DropoutNoScaling): 418 | 419 | @staticmethod 420 | def symbolic(input, p=0.5, train=False, inplace=False): 421 | return None 422 | 423 | @staticmethod 424 | def _make_noise(input): 425 | return input.new().resize_( 426 | input.size(0), input.size(1), *repeat(1, 427 | input.dim() - 2) 428 | ) 429 | 430 | 431 | feature_dropout_no_scaling = _FeatureDropoutNoScaling.apply 432 | 433 | 434 | def group_model_params(model: nn.Module): 435 | decay_group = [] 436 | no_decay_group = [] 437 | 438 | for name, param in model.named_parameters(): 439 | if name.find("bn") != -1 or name.find("bias") != -1: 440 | no_decay_group.append(param) 441 | else: 442 | decay_group.append(param) 443 | 444 | assert len(list(model.parameters()) 445 | ) == len(decay_group) + len(no_decay_group) 446 | 447 | return [ 448 | dict(params=decay_group), 449 | dict(params=no_decay_group, weight_decay=0.0) 450 | ] 451 | 452 | 453 | def checkpoint_state(model=None, optimizer=None, best_prec=None, epoch=None): 454 | optim_state = optimizer.state_dict() if optimizer is not None else None 455 | if model is not None: 456 | if isinstance(model, torch.nn.DataParallel): 457 | model_state = model.module.state_dict() 458 | else: 459 | model_state = model.state_dict() 460 | else: 461 | model_state = None 462 | 463 | return { 464 | 'epoch': epoch, 465 | 'best_prec': best_prec, 466 | 'model_state': model_state, 467 | 'optimizer_state': optim_state 468 | } 469 | 470 | 471 | def save_checkpoint( 472 | state, is_best, filename='checkpoint', bestname='model_best' 473 | ): 474 | filename = '{}.pth.tar'.format(filename) 475 | torch.save(state, filename) 476 | if is_best: 477 | shutil.copyfile(filename, '{}.pth.tar'.format(bestname)) 478 | 479 | 480 | def load_checkpoint(model=None, optimizer=None, filename='checkpoint'): 481 | filename = "{}.pth.tar".format(filename) 482 | if os.path.isfile(filename): 483 | print("==> Loading from checkpoint '{}'".format(filename)) 484 | checkpoint = torch.load(filename) 485 | epoch = checkpoint['epoch'] 486 | best_prec = checkpoint['best_prec'] 487 | if model is not None and checkpoint['model_state'] is not None: 488 | model.load_state_dict(checkpoint['model_state']) 489 | if optimizer is not None and checkpoint['optimizer_state'] is not None: 490 | optimizer.load_state_dict(checkpoint['optimizer_state']) 491 | print("==> Done") 492 | else: 493 | print("==> Checkpoint '{}' not found".format(filename)) 494 | 495 | return epoch, best_prec 496 | 497 | 498 | def variable_size_collate(pad_val=0, use_shared_memory=True): 499 | import collections 500 | _numpy_type_map = { 501 | 'float64': torch.DoubleTensor, 502 | 'float32': torch.FloatTensor, 503 | 'float16': torch.HalfTensor, 504 | 'int64': torch.LongTensor, 505 | 'int32': torch.IntTensor, 506 | 'int16': torch.ShortTensor, 507 | 'int8': torch.CharTensor, 508 | 'uint8': torch.ByteTensor, 509 | } 510 | 511 | def wrapped(batch): 512 | "Puts each data field into a tensor with outer dimension batch size" 513 | 514 | error_msg = "batch must contain tensors, numbers, dicts or lists; found {}" 515 | elem_type = type(batch[0]) 516 | if torch.is_tensor(batch[0]): 517 | max_len = 0 518 | for b in batch: 519 | max_len = max(max_len, b.size(0)) 520 | 521 | numel = sum([int(b.numel() / b.size(0) * max_len) for b in batch]) 522 | if use_shared_memory: 523 | # If we're in a background process, concatenate directly into a 524 | # shared memory tensor to avoid an extra copy 525 | storage = batch[0].storage()._new_shared(numel) 526 | out = batch[0].new(storage) 527 | else: 528 | out = batch[0].new(numel) 529 | 530 | out = out.view( 531 | len(batch), max_len, 532 | *[batch[0].size(i) for i in range(1, batch[0].dim())] 533 | ) 534 | out.fill_(pad_val) 535 | for i in range(len(batch)): 536 | out[i, 0:batch[i].size(0)] = batch[i] 537 | 538 | return out 539 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 540 | and elem_type.__name__ != 'string_': 541 | elem = batch[0] 542 | if elem_type.__name__ == 'ndarray': 543 | # array of string classes and object 544 | if re.search('[SaUO]', elem.dtype.str) is not None: 545 | raise TypeError(error_msg.format(elem.dtype)) 546 | 547 | return wrapped([torch.from_numpy(b) for b in batch]) 548 | if elem.shape == (): # scalars 549 | py_type = float if elem.dtype.name.startswith('float') else int 550 | return _numpy_type_map[elem.dtype.name]( 551 | list(map(py_type, batch)) 552 | ) 553 | elif isinstance(batch[0], int): 554 | return torch.LongTensor(batch) 555 | elif isinstance(batch[0], float): 556 | return torch.DoubleTensor(batch) 557 | elif isinstance(batch[0], collections.Mapping): 558 | return {key: wrapped([d[key] for d in batch]) for key in batch[0]} 559 | elif isinstance(batch[0], collections.Sequence): 560 | transposed = zip(*batch) 561 | return [wrapped(samples) for samples in transposed] 562 | 563 | raise TypeError((error_msg.format(type(batch[0])))) 564 | 565 | return wrapped 566 | 567 | 568 | class TrainValSplitter(): 569 | r""" 570 | Creates a training and validation split to be used as the sampler in a pytorch DataLoader 571 | Parameters 572 | --------- 573 | numel : int 574 | Number of elements in the entire training dataset 575 | percent_train : float 576 | Percentage of data in the training split 577 | shuffled : bool 578 | Whether or not shuffle which data goes to which split 579 | """ 580 | 581 | def __init__( 582 | self, *, numel: int, percent_train: float, shuffled: bool = False 583 | ): 584 | indicies = np.array([i for i in range(numel)]) 585 | if shuffled: 586 | np.random.shuffle(indicies) 587 | 588 | self.train = torch.utils.data.sampler.SubsetRandomSampler( 589 | indicies[0:int(percent_train * numel)] 590 | ) 591 | self.val = torch.utils.data.sampler.SubsetRandomSampler( 592 | indicies[int(percent_train * numel):-1] 593 | ) 594 | 595 | 596 | class CrossValSplitter(): 597 | r""" 598 | Class that creates cross validation splits. The train and val splits can be used in pytorch DataLoaders. The splits can be updated 599 | by calling next(self) or using a loop: 600 | for _ in self: 601 | .... 602 | Parameters 603 | --------- 604 | numel : int 605 | Number of elements in the training set 606 | k_folds : int 607 | Number of folds 608 | shuffled : bool 609 | Whether or not to shuffle which data goes in which fold 610 | """ 611 | 612 | def __init__(self, *, numel: int, k_folds: int, shuffled: bool = False): 613 | inidicies = np.array([i for i in range(numel)]) 614 | if shuffled: 615 | np.random.shuffle(inidicies) 616 | 617 | self.folds = np.array(np.array_split(inidicies, k_folds), dtype=object) 618 | self.current_v_ind = -1 619 | 620 | self.val = torch.utils.data.sampler.SubsetRandomSampler(self.folds[0]) 621 | self.train = torch.utils.data.sampler.SubsetRandomSampler( 622 | np.concatenate(self.folds[1:], axis=0) 623 | ) 624 | 625 | self.metrics = {} 626 | 627 | def __iter__(self): 628 | self.current_v_ind = -1 629 | return self 630 | 631 | def __len__(self): 632 | return len(self.folds) 633 | 634 | def __getitem__(self, idx): 635 | assert idx >= 0 and idx < len(self) 636 | self.val.inidicies = self.folds[idx] 637 | self.train.inidicies = np.concatenate( 638 | self.folds[np.arange(len(self)) != idx], axis=0 639 | ) 640 | 641 | def __next__(self): 642 | self.current_v_ind += 1 643 | if self.current_v_ind >= len(self): 644 | raise StopIteration 645 | 646 | self[self.current_v_ind] 647 | 648 | def update_metrics(self, to_post: dict): 649 | for k, v in to_post.items(): 650 | if k in self.metrics: 651 | self.metrics[k].append(v) 652 | else: 653 | self.metrics[k] = [v] 654 | 655 | def print_metrics(self): 656 | for name, samples in self.metrics.items(): 657 | xbar = stats.mean(samples) 658 | sx = stats.stdev(samples, xbar) 659 | tstar = student_t.ppf(1.0 - 0.025, len(samples) - 1) 660 | margin_of_error = tstar * sx / sqrt(len(samples)) 661 | print("{}: {} +/- {}".format(name, xbar, margin_of_error)) 662 | 663 | 664 | def set_bn_momentum_default(bn_momentum): 665 | 666 | def fn(m): 667 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 668 | m.momentum = bn_momentum 669 | 670 | return fn 671 | 672 | 673 | class BNMomentumScheduler(object): 674 | 675 | def __init__( 676 | self, model, bn_lambda, last_epoch=-1, 677 | setter=set_bn_momentum_default 678 | ): 679 | if not isinstance(model, nn.Module): 680 | raise RuntimeError( 681 | "Class '{}' is not a PyTorch nn Module".format( 682 | type(model).__name__ 683 | ) 684 | ) 685 | 686 | self.model = model 687 | self.setter = setter 688 | self.lmbd = bn_lambda 689 | 690 | self.step(last_epoch + 1) 691 | self.last_epoch = last_epoch 692 | 693 | def step(self, epoch=None): 694 | if epoch is None: 695 | epoch = self.last_epoch + 1 696 | 697 | self.last_epoch = epoch 698 | self.model.apply(self.setter(self.lmbd(epoch))) 699 | 700 | def get_momentum(self, epoch=None): 701 | if epoch is None: 702 | epoch = self.last_epoch + 1 703 | return self.lmbd(epoch) -------------------------------------------------------------------------------- /voting_evaluate_cls.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import os 9 | from torchvision import transforms 10 | from models import DensePointCls_L6 as DensePoint 11 | from data import ModelNet40Cls 12 | import utils.pytorch_utils as pt_utils 13 | import utils.pointnet2_utils as pointnet2_utils 14 | import data.data_utils as d_utils 15 | import argparse 16 | import random 17 | import yaml 18 | 19 | torch.backends.cudnn.enabled = True 20 | torch.backends.cudnn.benchmark = True 21 | torch.backends.cudnn.deterministic = True 22 | 23 | seed = 123 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | 30 | parser = argparse.ArgumentParser(description='DensePoint Shape Classification Voting Evaluate') 31 | parser.add_argument('--config', default='cfgs/config_cls.yaml', type=str) 32 | 33 | NUM_REPEAT = 300 34 | NUM_VOTE = 10 35 | 36 | def main(): 37 | args = parser.parse_args() 38 | with open(args.config) as f: 39 | config = yaml.load(f) 40 | for k, v in config['common'].items(): 41 | setattr(args, k, v) 42 | 43 | test_transforms = transforms.Compose([ 44 | d_utils.PointcloudToTensor() 45 | ]) 46 | 47 | test_dataset = ModelNet40Cls(num_points = args.num_points, root = args.data_root, transforms=test_transforms, train=False) 48 | test_dataloader = DataLoader( 49 | test_dataset, 50 | batch_size=args.batch_size, 51 | shuffle=False, 52 | num_workers=int(args.workers), 53 | pin_memory=True 54 | ) 55 | 56 | model = DensePoint(num_classes = args.num_classes, input_channels = args.input_channels, use_xyz = True) 57 | model.cuda() 58 | 59 | if args.checkpoint is not '': 60 | model.load_state_dict(torch.load(args.checkpoint)) 61 | print('Load model successfully: %s' % (args.checkpoint)) 62 | 63 | # evaluate 64 | PointcloudScale = d_utils.PointcloudScale() # initialize random scaling 65 | model.eval() 66 | global_acc = 0 67 | for i in range(NUM_REPEAT): 68 | preds = [] 69 | labels = [] 70 | for j, data in enumerate(test_dataloader, 0): 71 | points, target = data 72 | points, target = points.cuda(), target.cuda() 73 | points, target = Variable(points, volatile=True), Variable(target, volatile=True) 74 | 75 | # fastest point sampling 76 | fps_idx = pointnet2_utils.furthest_point_sample(points, 1200) # (B, npoint) 77 | pred = 0 78 | for v in range(NUM_VOTE): 79 | new_fps_idx = fps_idx[:, np.random.choice(1200, args.num_points, False)] 80 | new_points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), new_fps_idx).transpose(1, 2).contiguous() 81 | if v > 0: 82 | new_points.data = PointcloudScale(new_points.data) 83 | pred += F.softmax(model(new_points), dim = 1) 84 | pred /= NUM_VOTE 85 | target = target.view(-1) 86 | _, pred_choice = torch.max(pred.data, -1) 87 | 88 | preds.append(pred_choice) 89 | labels.append(target.data) 90 | 91 | preds = torch.cat(preds, 0) 92 | labels = torch.cat(labels, 0) 93 | acc = (preds == labels).sum() / labels.numel() 94 | if acc > global_acc: 95 | global_acc = acc 96 | print('Repeat %3d \t Acc: %0.6f' % (i + 1, acc)) 97 | print('\nBest voting acc: %0.6f' % (global_acc)) 98 | 99 | if __name__ == '__main__': 100 | main() --------------------------------------------------------------------------------