├── .gitattributes ├── LICENSE ├── README.md ├── checkpoints └── checkpoint.pth ├── data_utils └── ModelNetDataLoader.py ├── eval_cls_conv.py ├── model └── pointconv.py ├── provider.py ├── train_cls_conv.py └── utils ├── pointconv_util.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | checkpoints/checkpoint.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Wenxuan Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointConv 2 | **PointConv: Deep Convolutional Networks on 3D Point Clouds.** CVPR 2019 3 | Wenxuan Wu, Zhongang Qi, Li Fuxin. 4 | 5 | ## Introduction 6 | This project is based on our CVPR2019 paper. You can find the [arXiv](https://arxiv.org/abs/1811.07246) version here. 7 | 8 | ``` 9 | @inproceedings{wu2019pointconv, 10 | title={Pointconv: Deep convolutional networks on 3d point clouds}, 11 | author={Wu, Wenxuan and Qi, Zhongang and Fuxin, Li}, 12 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 13 | pages={9621--9630}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | Unlike images which are represented in regular dense grids, 3D point clouds are irregular and unordered, hence applying convolution on them can be difficult. In this paper, we extend the dynamic filter to a new convolution operation, named PointConv. PointConv can be applied on point clouds to build deep convolutional networks. We treat convolution kernels as nonlinear functions of the local coordinates of 3D points comprised of weight and density functions. With respect to a given point, the weight functions are learned with multi-layer perceptron networks and the density functions through kernel density estimation. A novel reformulation is proposed for efficiently computing the weight functions, which allowed us to dramatically scale up the network and significantly improve its performance. The learned convolution kernel can be used to compute translation-invariant and permutation-invariant convolution on any point set in the 3D space. Besides, PointConv can also be used as deconvolution operators to propagate features from a subsampled point cloud back to its original resolution. Experiments on ModelNet40, ShapeNet, and ScanNet show that deep convolutional neural networks built on PointConv are able to achieve state-of-the-art on challenging semantic segmentation benchmarks on 3D point clouds. Besides, our experiments converting CIFAR-10 into a point cloud showed that networks built on PointConv can match the performance of convolutional networks in 2D images of a similar structure. 19 | 20 | ## Installation 21 | The code is modified from repo [Pointnet_Pointnet2_pytorch](https://github.com/yanx27/Pointnet_Pointnet2_pytorch). Please install [PyTorch](https://pytorch.org/), [pandas](https://pandas.pydata.org/), and [sklearn](https://scikit-learn.org/). 22 | The code has been tested with Python 3.5, pytorch 1.2, CUDA 10.0 and cuDNN 7.6 on Ubuntu 16.04. 23 | 24 | ## Usage 25 | ### ModelNet40 Classification 26 | 27 | Download the ModelNet40 dataset from [here](https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip). This dataset is the same one used in [PointNet](https://arxiv.org/abs/1612.00593), thanks to [Charles Qi](https://github.com/charlesq34/pointnet). Copy the unziped dataset to ```./data/modelnet40_normal_resampled```. 28 | 29 | To train the model, 30 | ``` 31 | python train_cls_conv.py --model pointconv_modelnet40 --normal 32 | ``` 33 | 34 | To evaluate the model, 35 | ``` 36 | python eval_cls_conv.py --checkpoint ./checkpoints/checkpoint.pth --normal 37 | ``` 38 | 39 | ## License 40 | This repository is released under MIT License (see LICENSE file for details). 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /checkpoints/checkpoint.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4a526dd7dae44ebd6b9c2b6a9a7eac6bb6608b88d04c74ed11b93adb909e9bb4 3 | size 156655599 4 | -------------------------------------------------------------------------------- /data_utils/ModelNetDataLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | import os 4 | from torch.utils.data import Dataset 5 | warnings.filterwarnings('ignore') 6 | 7 | def pc_normalize(pc): 8 | centroid = np.mean(pc, axis=0) 9 | pc = pc - centroid 10 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 11 | pc = pc / m 12 | return pc 13 | 14 | def farthest_point_sample(point, npoint): 15 | """ 16 | Input: 17 | xyz: pointcloud data, [N, D] 18 | npoint: number of samples 19 | Return: 20 | centroids: sampled pointcloud index, [npoint, D] 21 | """ 22 | N, D = point.shape 23 | xyz = point[:,:3] 24 | centroids = np.zeros((npoint,)) 25 | distance = np.ones((N,)) * 1e10 26 | farthest = np.random.randint(0, N) 27 | for i in range(npoint): 28 | centroids[i] = farthest 29 | centroid = xyz[farthest, :] 30 | dist = np.sum((xyz - centroid) ** 2, -1) 31 | mask = dist < distance 32 | distance[mask] = dist[mask] 33 | farthest = np.argmax(distance, -1) 34 | point = point[centroids.astype(np.int32)] 35 | return point 36 | 37 | class ModelNetDataLoader(Dataset): 38 | def __init__(self, root, npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000): 39 | self.root = root 40 | self.npoints = npoint 41 | self.uniform = uniform 42 | self.split = split 43 | self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt') 44 | 45 | self.cat = [line.rstrip() for line in open(self.catfile)] 46 | self.classes = dict(zip(self.cat, range(len(self.cat)))) 47 | self.normal_channel = normal_channel 48 | 49 | shape_ids = {} 50 | shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))] 51 | shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))] 52 | 53 | assert (split == 'train' or split == 'test') 54 | shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]] 55 | # list of (shape_name, shape_txt_file_path) tuple 56 | self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i 57 | in range(len(shape_ids[split]))] 58 | print('The size of %s data is %d'%(split,len(self.datapath))) 59 | 60 | self.cache_size = cache_size # how many data points to cache in memory 61 | self.cache = {} # from index to (point_set, cls) tuple 62 | 63 | def __len__(self): 64 | return len(self.datapath) 65 | 66 | def _get_item(self, index): 67 | if index in self.cache: 68 | point_set, cls = self.cache[index] 69 | else: 70 | fn = self.datapath[index] 71 | cls = self.classes[self.datapath[index][0]] 72 | cls = np.array([cls]).astype(np.int32) 73 | point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32) 74 | 75 | if len(self.cache) < self.cache_size: 76 | self.cache[index] = (point_set, cls) 77 | 78 | if self.uniform: 79 | point_set = farthest_point_sample(point_set, self.npoints) 80 | else: 81 | if self.split == 'train': 82 | train_idx = np.array(range(point_set.shape[0])) 83 | point_set = point_set[train_idx[:self.npoints],:] 84 | else: 85 | point_set = point_set[0:self.npoints,:] 86 | 87 | point_set[:, 0:3] = pc_normalize(point_set[:, 0:3]) 88 | 89 | if not self.normal_channel: 90 | point_set = point_set[:, 0:3] 91 | 92 | return point_set, cls 93 | 94 | def __getitem__(self, index): 95 | return self._get_item(index) 96 | 97 | 98 | if __name__ == '__main__': 99 | import torch 100 | 101 | data = ModelNetDataLoader('./data/modelnet40_normal_resampled/',split='train', uniform=False, normal_channel=True,) 102 | DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True) 103 | for point,label in DataLoader: 104 | import ipdb; ipdb.set_trace() 105 | print(point.shape) 106 | print(label.shape) 107 | 108 | 109 | -------------------------------------------------------------------------------- /eval_cls_conv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import numpy as np 5 | import torch 6 | import torch.nn.parallel 7 | import torch.utils.data 8 | import torch.nn.functional as F 9 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 10 | import datetime 11 | import logging 12 | from pathlib import Path 13 | from tqdm import tqdm 14 | from utils.utils import test, save_checkpoint 15 | from model.pointconv import PointConvDensityClsSsg as PointConvClsSsg 16 | 17 | 18 | def parse_args(): 19 | '''PARAMETERS''' 20 | parser = argparse.ArgumentParser('PointConv') 21 | parser.add_argument('--batchsize', type=int, default=32, help='batch size') 22 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 23 | parser.add_argument('--checkpoint', type=str, default=None, help='checkpoint') 24 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]') 25 | parser.add_argument('--num_workers', type=int, default=16, help='Worker Number [default: 16]') 26 | parser.add_argument('--model_name', default='pointconv', help='model name') 27 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 28 | return parser.parse_args() 29 | 30 | def main(args): 31 | '''HYPER PARAMETER''' 32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 33 | 34 | '''CREATE DIR''' 35 | experiment_dir = Path('./eval_experiment/') 36 | experiment_dir.mkdir(exist_ok=True) 37 | file_dir = Path(str(experiment_dir) + '/%s_ModelNet40-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))) 38 | file_dir.mkdir(exist_ok=True) 39 | checkpoints_dir = file_dir.joinpath('checkpoints/') 40 | checkpoints_dir.mkdir(exist_ok=True) 41 | os.system('cp %s %s' % (args.checkpoint, checkpoints_dir)) 42 | log_dir = file_dir.joinpath('logs/') 43 | log_dir.mkdir(exist_ok=True) 44 | 45 | '''LOG''' 46 | args = parse_args() 47 | logger = logging.getLogger(args.model_name) 48 | logger.setLevel(logging.INFO) 49 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 50 | file_handler = logging.FileHandler(str(log_dir) + 'eval_%s_cls.txt'%args.model_name) 51 | file_handler.setLevel(logging.INFO) 52 | file_handler.setFormatter(formatter) 53 | logger.addHandler(file_handler) 54 | logger.info('---------------------------------------------------EVAL---------------------------------------------------') 55 | logger.info('PARAMETER ...') 56 | logger.info(args) 57 | 58 | '''DATA LOADING''' 59 | logger.info('Load dataset ...') 60 | DATA_PATH = './data/modelnet40_normal_resampled/' 61 | 62 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal) 63 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers) 64 | logger.info("The number of test data is: %d", len(TEST_DATASET)) 65 | 66 | seed = 3 67 | torch.manual_seed(seed) 68 | if torch.cuda.is_available(): 69 | torch.cuda.manual_seed_all(seed) 70 | 71 | '''MODEL LOADING''' 72 | num_class = 40 73 | classifier = PointConvClsSsg(num_class).cuda() 74 | if args.checkpoint is not None: 75 | print('Load CheckPoint...') 76 | logger.info('Load CheckPoint') 77 | checkpoint = torch.load(args.checkpoint) 78 | start_epoch = checkpoint['epoch'] 79 | classifier.load_state_dict(checkpoint['model_state_dict']) 80 | else: 81 | print('Please load Checkpoint to eval...') 82 | sys.exit(0) 83 | start_epoch = 0 84 | 85 | blue = lambda x: '\033[94m' + x + '\033[0m' 86 | 87 | '''EVAL''' 88 | logger.info('Start evaluating...') 89 | print('Start evaluating...') 90 | 91 | classifier = classifier.eval() 92 | mean_correct = [] 93 | for batch_id, data in tqdm(enumerate(testDataLoader, 0), total=len(testDataLoader), smoothing=0.9): 94 | pointcloud, target = data 95 | target = target[:, 0] 96 | 97 | points = pointcloud.permute(0, 2, 1) 98 | points, target = points.cuda(), target.cuda() 99 | with torch.no_grad(): 100 | pred = classifier(points[:, :3, :], points[:, 3:, :]) 101 | pred_choice = pred.data.max(1)[1] 102 | correct = pred_choice.eq(target.long().data).cpu().sum() 103 | 104 | mean_correct.append(correct.item()/float(points.size()[0])) 105 | 106 | accuracy = np.mean(mean_correct) 107 | print('Total Accuracy: %f'%accuracy) 108 | 109 | logger.info('Total Accuracy: %f'%accuracy) 110 | logger.info('End of evaluation...') 111 | 112 | if __name__ == '__main__': 113 | args = parse_args() 114 | main(args) 115 | -------------------------------------------------------------------------------- /model/pointconv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classification Model 3 | Author: Wenxuan Wu 4 | Date: September 2019 5 | """ 6 | import torch.nn as nn 7 | import torch 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from utils.pointconv_util import PointConvDensitySetAbstraction 11 | 12 | class PointConvDensityClsSsg(nn.Module): 13 | def __init__(self, num_classes = 40): 14 | super(PointConvDensityClsSsg, self).__init__() 15 | feature_dim = 3 16 | self.sa1 = PointConvDensitySetAbstraction(npoint=512, nsample=32, in_channel=feature_dim + 3, mlp=[64, 64, 128], bandwidth = 0.1, group_all=False) 17 | self.sa2 = PointConvDensitySetAbstraction(npoint=128, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], bandwidth = 0.2, group_all=False) 18 | self.sa3 = PointConvDensitySetAbstraction(npoint=1, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], bandwidth = 0.4, group_all=True) 19 | self.fc1 = nn.Linear(1024, 512) 20 | self.bn1 = nn.BatchNorm1d(512) 21 | self.drop1 = nn.Dropout(0.7) 22 | self.fc2 = nn.Linear(512, 256) 23 | self.bn2 = nn.BatchNorm1d(256) 24 | self.drop2 = nn.Dropout(0.7) 25 | self.fc3 = nn.Linear(256, num_classes) 26 | 27 | def forward(self, xyz, feat): 28 | B, _, _ = xyz.shape 29 | l1_xyz, l1_points = self.sa1(xyz, feat) 30 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 31 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 32 | x = l3_points.view(B, 1024) 33 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 34 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 35 | x = self.fc3(x) 36 | x = F.log_softmax(x, -1) 37 | return x 38 | 39 | if __name__ == '__main__': 40 | import os 41 | import torch 42 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 43 | input = torch.randn((8,3,2048)) 44 | label = torch.randn(8,16) 45 | model = PointConvDensityClsSsg(num_classes=40) 46 | output= model(input) 47 | print(output.size()) 48 | 49 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud(batch_data): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | for k in range(batch_data.shape[0]): 56 | rotation_angle = np.random.uniform() * 2 * np.pi 57 | cosval = np.cos(rotation_angle) 58 | sinval = np.sin(rotation_angle) 59 | rotation_matrix = np.array([[cosval, 0, sinval], 60 | [0, 1, 0], 61 | [-sinval, 0, cosval]]) 62 | shape_pc = batch_data[k, ...] 63 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 64 | return rotated_data 65 | 66 | def rotate_point_cloud_z(batch_data): 67 | """ Randomly rotate the point clouds to augument the dataset 68 | rotation is per shape based along up direction 69 | Input: 70 | BxNx3 array, original batch of point clouds 71 | Return: 72 | BxNx3 array, rotated batch of point clouds 73 | """ 74 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 75 | for k in range(batch_data.shape[0]): 76 | rotation_angle = np.random.uniform() * 2 * np.pi 77 | cosval = np.cos(rotation_angle) 78 | sinval = np.sin(rotation_angle) 79 | rotation_matrix = np.array([[cosval, sinval, 0], 80 | [-sinval, cosval, 0], 81 | [0, 0, 1]]) 82 | shape_pc = batch_data[k, ...] 83 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 84 | return rotated_data 85 | 86 | def rotate_point_cloud_with_normal(batch_xyz_normal): 87 | ''' Randomly rotate XYZ, normal point cloud. 88 | Input: 89 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 90 | Output: 91 | B,N,6, rotated XYZ, normal point cloud 92 | ''' 93 | for k in range(batch_xyz_normal.shape[0]): 94 | rotation_angle = np.random.uniform() * 2 * np.pi 95 | cosval = np.cos(rotation_angle) 96 | sinval = np.sin(rotation_angle) 97 | rotation_matrix = np.array([[cosval, 0, sinval], 98 | [0, 1, 0], 99 | [-sinval, 0, cosval]]) 100 | shape_pc = batch_xyz_normal[k,:,0:3] 101 | shape_normal = batch_xyz_normal[k,:,3:6] 102 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 103 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 104 | return batch_xyz_normal 105 | 106 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 107 | """ Randomly perturb the point clouds by small rotations 108 | Input: 109 | BxNx6 array, original batch of point clouds and point normals 110 | Return: 111 | BxNx3 array, rotated batch of point clouds 112 | """ 113 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 114 | for k in range(batch_data.shape[0]): 115 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 116 | Rx = np.array([[1,0,0], 117 | [0,np.cos(angles[0]),-np.sin(angles[0])], 118 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 119 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 120 | [0,1,0], 121 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 122 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 123 | [np.sin(angles[2]),np.cos(angles[2]),0], 124 | [0,0,1]]) 125 | R = np.dot(Rz, np.dot(Ry,Rx)) 126 | shape_pc = batch_data[k,:,0:3] 127 | shape_normal = batch_data[k,:,3:6] 128 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 129 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 130 | return rotated_data 131 | 132 | 133 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 134 | """ Rotate the point cloud along up direction with certain angle. 135 | Input: 136 | BxNx3 array, original batch of point clouds 137 | Return: 138 | BxNx3 array, rotated batch of point clouds 139 | """ 140 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 141 | for k in range(batch_data.shape[0]): 142 | #rotation_angle = np.random.uniform() * 2 * np.pi 143 | cosval = np.cos(rotation_angle) 144 | sinval = np.sin(rotation_angle) 145 | rotation_matrix = np.array([[cosval, 0, sinval], 146 | [0, 1, 0], 147 | [-sinval, 0, cosval]]) 148 | shape_pc = batch_data[k,:,0:3] 149 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 150 | return rotated_data 151 | 152 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 153 | """ Rotate the point cloud along up direction with certain angle. 154 | Input: 155 | BxNx6 array, original batch of point clouds with normal 156 | scalar, angle of rotation 157 | Return: 158 | BxNx6 array, rotated batch of point clouds iwth normal 159 | """ 160 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 161 | for k in range(batch_data.shape[0]): 162 | #rotation_angle = np.random.uniform() * 2 * np.pi 163 | cosval = np.cos(rotation_angle) 164 | sinval = np.sin(rotation_angle) 165 | rotation_matrix = np.array([[cosval, 0, sinval], 166 | [0, 1, 0], 167 | [-sinval, 0, cosval]]) 168 | shape_pc = batch_data[k,:,0:3] 169 | shape_normal = batch_data[k,:,3:6] 170 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 171 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 172 | return rotated_data 173 | 174 | 175 | 176 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 177 | """ Randomly perturb the point clouds by small rotations 178 | Input: 179 | BxNx3 array, original batch of point clouds 180 | Return: 181 | BxNx3 array, rotated batch of point clouds 182 | """ 183 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 184 | for k in range(batch_data.shape[0]): 185 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 186 | Rx = np.array([[1,0,0], 187 | [0,np.cos(angles[0]),-np.sin(angles[0])], 188 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 189 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 190 | [0,1,0], 191 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 192 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 193 | [np.sin(angles[2]),np.cos(angles[2]),0], 194 | [0,0,1]]) 195 | R = np.dot(Rz, np.dot(Ry,Rx)) 196 | shape_pc = batch_data[k, ...] 197 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 198 | return rotated_data 199 | 200 | 201 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 202 | """ Randomly jitter points. jittering is per point. 203 | Input: 204 | BxNx3 array, original batch of point clouds 205 | Return: 206 | BxNx3 array, jittered batch of point clouds 207 | """ 208 | B, N, C = batch_data.shape 209 | assert(clip > 0) 210 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 211 | jittered_data += batch_data 212 | return jittered_data 213 | 214 | def shift_point_cloud(batch_data, shift_range=0.1): 215 | """ Randomly shift point cloud. Shift is per point cloud. 216 | Input: 217 | BxNx3 array, original batch of point clouds 218 | Return: 219 | BxNx3 array, shifted batch of point clouds 220 | """ 221 | B, N, C = batch_data.shape 222 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 223 | for batch_index in range(B): 224 | batch_data[batch_index,:,:] += shifts[batch_index,:] 225 | return batch_data 226 | 227 | 228 | def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): 229 | """ Randomly scale the point cloud. Scale is per point cloud. 230 | Input: 231 | BxNx3 array, original batch of point clouds 232 | Return: 233 | BxNx3 array, scaled batch of point clouds 234 | """ 235 | B, N, C = batch_data.shape 236 | scales = np.random.uniform(scale_low, scale_high, B) 237 | for batch_index in range(B): 238 | batch_data[batch_index,:,:] *= scales[batch_index] 239 | return batch_data 240 | 241 | def random_point_dropout(batch_pc, max_dropout_ratio=0.875): 242 | ''' batch_pc: BxNx3 ''' 243 | for b in range(batch_pc.shape[0]): 244 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 245 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 246 | if len(drop_idx)>0: 247 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 248 | return batch_pc 249 | 250 | def random_point_dropout_v2(batch_pc, max_dropout_ratio=0.875): 251 | ''' batch_pc: BxNx3 ''' 252 | for b in range(batch_pc.shape[0]): 253 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 254 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 255 | keep_idx = np.where(np.random.random((batch_pc.shape[1]))>dropout_ratio)[0] 256 | if len(keep_idx) > len(drop_idx): 257 | batch_pc[b, drop_idx, :] = batch_pc[b, keep_idx[:len(drop_idx)], :] 258 | else: 259 | batch_pc[b, drop_idx, :] = batch_pc[b, :len(drop_idx), :] 260 | 261 | return batch_pc 262 | 263 | 264 | 265 | 266 | -------------------------------------------------------------------------------- /train_cls_conv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | from data_utils.ModelNetDataLoader import ModelNetDataLoader 8 | import datetime 9 | import logging 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | from utils.utils import test, save_checkpoint 13 | from model.pointconv import PointConvDensityClsSsg as PointConvClsSsg 14 | import provider 15 | import numpy as np 16 | 17 | 18 | def parse_args(): 19 | '''PARAMETERS''' 20 | parser = argparse.ArgumentParser('PointConv') 21 | parser.add_argument('--batchsize', type=int, default=32, help='batch size in training') 22 | parser.add_argument('--epoch', default=400, type=int, help='number of epoch in training') 23 | parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training') 24 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 25 | parser.add_argument('--num_point', type=int, default=1024, help='Point Number [default: 1024]') 26 | parser.add_argument('--num_workers', type=int, default=16, help='Worker Number [default: 16]') 27 | parser.add_argument('--optimizer', type=str, default='SGD', help='optimizer for training') 28 | parser.add_argument('--pretrain', type=str, default=None,help='whether use pretrain model') 29 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate of learning rate') 30 | parser.add_argument('--model_name', default='pointconv', help='model name') 31 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 32 | return parser.parse_args() 33 | 34 | def main(args): 35 | '''HYPER PARAMETER''' 36 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 37 | 38 | '''CREATE DIR''' 39 | experiment_dir = Path('./experiment/') 40 | experiment_dir.mkdir(exist_ok=True) 41 | file_dir = Path(str(experiment_dir) + '/%s_ModelNet40-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))) 42 | file_dir.mkdir(exist_ok=True) 43 | checkpoints_dir = file_dir.joinpath('checkpoints/') 44 | checkpoints_dir.mkdir(exist_ok=True) 45 | log_dir = file_dir.joinpath('logs/') 46 | log_dir.mkdir(exist_ok=True) 47 | 48 | '''LOG''' 49 | args = parse_args() 50 | logger = logging.getLogger(args.model_name) 51 | logger.setLevel(logging.INFO) 52 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 53 | file_handler = logging.FileHandler(str(log_dir) + 'train_%s_cls.txt'%args.model_name) 54 | file_handler.setLevel(logging.INFO) 55 | file_handler.setFormatter(formatter) 56 | logger.addHandler(file_handler) 57 | logger.info('---------------------------------------------------TRANING---------------------------------------------------') 58 | logger.info('PARAMETER ...') 59 | logger.info(args) 60 | 61 | '''DATA LOADING''' 62 | logger.info('Load dataset ...') 63 | DATA_PATH = './data/modelnet40_normal_resampled/' 64 | 65 | TRAIN_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='train', normal_channel=args.normal) 66 | TEST_DATASET = ModelNetDataLoader(root=DATA_PATH, npoint=args.num_point, split='test', normal_channel=args.normal) 67 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchsize, shuffle=True, num_workers=args.num_workers) 68 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers) 69 | 70 | logger.info("The number of training data is: %d", len(TRAIN_DATASET)) 71 | logger.info("The number of test data is: %d", len(TEST_DATASET)) 72 | 73 | seed = 3 74 | torch.manual_seed(seed) 75 | if torch.cuda.is_available(): 76 | torch.cuda.manual_seed_all(seed) 77 | 78 | '''MODEL LOADING''' 79 | num_class = 40 80 | classifier = PointConvClsSsg(num_class).cuda() 81 | if args.pretrain is not None: 82 | print('Use pretrain model...') 83 | logger.info('Use pretrain model') 84 | checkpoint = torch.load(args.pretrain) 85 | start_epoch = checkpoint['epoch'] 86 | classifier.load_state_dict(checkpoint['model_state_dict']) 87 | else: 88 | print('No existing model, starting training from scratch...') 89 | start_epoch = 0 90 | 91 | 92 | if args.optimizer == 'SGD': 93 | optimizer = torch.optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 94 | elif args.optimizer == 'Adam': 95 | optimizer = torch.optim.Adam( 96 | classifier.parameters(), 97 | lr=args.learning_rate, 98 | betas=(0.9, 0.999), 99 | eps=1e-08, 100 | weight_decay=args.decay_rate 101 | ) 102 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.7) 103 | global_epoch = 0 104 | global_step = 0 105 | best_tst_accuracy = 0.0 106 | blue = lambda x: '\033[94m' + x + '\033[0m' 107 | 108 | '''TRANING''' 109 | logger.info('Start training...') 110 | for epoch in range(start_epoch,args.epoch): 111 | print('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch)) 112 | logger.info('Epoch %d (%d/%s):' ,global_epoch + 1, epoch + 1, args.epoch) 113 | mean_correct = [] 114 | 115 | scheduler.step() 116 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): 117 | points, target = data 118 | points = points.data.numpy() 119 | jittered_data = provider.random_scale_point_cloud(points[:,:, 0:3], scale_low=2.0/3, scale_high=3/2.0) 120 | jittered_data = provider.shift_point_cloud(jittered_data, shift_range=0.2) 121 | points[:, :, 0:3] = jittered_data 122 | points = provider.random_point_dropout_v2(points) 123 | provider.shuffle_points(points) 124 | points = torch.Tensor(points) 125 | target = target[:, 0] 126 | 127 | points = points.transpose(2, 1) 128 | points, target = points.cuda(), target.cuda() 129 | optimizer.zero_grad() 130 | 131 | classifier = classifier.train() 132 | pred = classifier(points[:, :3, :], points[:, 3:, :]) 133 | loss = F.nll_loss(pred, target.long()) 134 | pred_choice = pred.data.max(1)[1] 135 | correct = pred_choice.eq(target.long().data).cpu().sum() 136 | mean_correct.append(correct.item() / float(points.size()[0])) 137 | loss.backward() 138 | optimizer.step() 139 | global_step += 1 140 | 141 | train_acc = np.mean(mean_correct) 142 | print('Train Accuracy: %f' % train_acc) 143 | logger.info('Train Accuracy: %f' % train_acc) 144 | 145 | acc = test(classifier, testDataLoader) 146 | 147 | if (acc >= best_tst_accuracy) and epoch > 5: 148 | best_tst_accuracy = acc 149 | logger.info('Save model...') 150 | save_checkpoint( 151 | global_epoch + 1, 152 | train_acc, 153 | acc, 154 | classifier, 155 | optimizer, 156 | str(checkpoints_dir), 157 | args.model_name) 158 | print('Saving model....') 159 | 160 | print('\r Loss: %f' % loss.data) 161 | logger.info('Loss: %.2f', loss.data) 162 | print('\r Test %s: %f *** %s: %f' % (blue('Accuracy'),acc, blue('Best Accuracy'),best_tst_accuracy)) 163 | logger.info('Test Accuracy: %f *** Best Test Accuracy: %f', acc, best_tst_accuracy) 164 | 165 | 166 | global_epoch += 1 167 | print('Best Accuracy: %f'%best_tst_accuracy) 168 | 169 | logger.info('End of training...') 170 | 171 | if __name__ == '__main__': 172 | args = parse_args() 173 | main(args) 174 | -------------------------------------------------------------------------------- /utils/pointconv_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility function for PointConv 3 | Originally from : https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/utils.py 4 | Modify by Wenxuan Wu 5 | Date: September 2019 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from time import time 11 | import numpy as np 12 | from sklearn.neighbors.kde import KernelDensity 13 | 14 | def timeit(tag, t): 15 | print("{}: {}s".format(tag, time() - t)) 16 | return time() 17 | 18 | def square_distance(src, dst): 19 | """ 20 | Calculate Euclid distance between each two points. 21 | 22 | src^T * dst = xn * xm + yn * ym + zn * zm; 23 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 24 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 25 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 26 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 27 | 28 | Input: 29 | src: source points, [B, N, C] 30 | dst: target points, [B, M, C] 31 | Output: 32 | dist: per-point square distance, [B, N, M] 33 | """ 34 | B, N, _ = src.shape 35 | _, M, _ = dst.shape 36 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 37 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 38 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 39 | return dist 40 | 41 | def index_points(points, idx): 42 | """ 43 | 44 | Input: 45 | points: input points data, [B, N, C] 46 | idx: sample index data, [B, S] 47 | Return: 48 | new_points:, indexed points data, [B, S, C] 49 | """ 50 | device = points.device 51 | B = points.shape[0] 52 | view_shape = list(idx.shape) 53 | view_shape[1:] = [1] * (len(view_shape) - 1) 54 | repeat_shape = list(idx.shape) 55 | repeat_shape[0] = 1 56 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 57 | new_points = points[batch_indices, idx, :] 58 | return new_points 59 | 60 | def farthest_point_sample(xyz, npoint): 61 | """ 62 | Input: 63 | xyz: pointcloud data, [B, N, C] 64 | npoint: number of samples 65 | Return: 66 | centroids: sampled pointcloud index, [B, npoint] 67 | """ 68 | #import ipdb; ipdb.set_trace() 69 | device = xyz.device 70 | B, N, C = xyz.shape 71 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 72 | distance = torch.ones(B, N).to(device) * 1e10 73 | #farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 74 | farthest = torch.zeros(B, dtype=torch.long).to(device) 75 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 76 | for i in range(npoint): 77 | centroids[:, i] = farthest 78 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 79 | dist = torch.sum((xyz - centroid) ** 2, -1) 80 | mask = dist < distance 81 | distance[mask] = dist[mask] 82 | farthest = torch.max(distance, -1)[1] 83 | return centroids 84 | 85 | def query_ball_point(radius, nsample, xyz, new_xyz): 86 | """ 87 | Input: 88 | radius: local region radius 89 | nsample: max sample number in local region 90 | xyz: all points, [B, N, C] 91 | new_xyz: query points, [B, S, C] 92 | Return: 93 | group_idx: grouped points index, [B, S, nsample] 94 | """ 95 | device = xyz.device 96 | B, N, C = xyz.shape 97 | _, S, _ = new_xyz.shape 98 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 99 | sqrdists = square_distance(new_xyz, xyz) 100 | group_idx[sqrdists > radius ** 2] = N 101 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 102 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 103 | mask = group_idx == N 104 | group_idx[mask] = group_first[mask] 105 | return group_idx 106 | 107 | def knn_point(nsample, xyz, new_xyz): 108 | """ 109 | Input: 110 | nsample: max sample number in local region 111 | xyz: all points, [B, N, C] 112 | new_xyz: query points, [B, S, C] 113 | Return: 114 | group_idx: grouped points index, [B, S, nsample] 115 | """ 116 | sqrdists = square_distance(new_xyz, xyz) 117 | _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False) 118 | return group_idx 119 | 120 | def sample_and_group(npoint, nsample, xyz, points, density_scale = None): 121 | """ 122 | Input: 123 | npoint: 124 | nsample: 125 | xyz: input points position data, [B, N, C] 126 | points: input points data, [B, N, D] 127 | Return: 128 | new_xyz: sampled points position data, [B, 1, C] 129 | new_points: sampled points data, [B, 1, N, C+D] 130 | """ 131 | B, N, C = xyz.shape 132 | S = npoint 133 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 134 | new_xyz = index_points(xyz, fps_idx) 135 | idx = knn_point(nsample, xyz, new_xyz) 136 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 137 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 138 | if points is not None: 139 | grouped_points = index_points(points, idx) 140 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 141 | else: 142 | new_points = grouped_xyz_norm 143 | 144 | if density_scale is None: 145 | return new_xyz, new_points, grouped_xyz_norm, idx 146 | else: 147 | grouped_density = index_points(density_scale, idx) 148 | return new_xyz, new_points, grouped_xyz_norm, idx, grouped_density 149 | 150 | def sample_and_group_all(xyz, points, density_scale = None): 151 | """ 152 | Input: 153 | xyz: input points position data, [B, N, C] 154 | points: input points data, [B, N, D] 155 | Return: 156 | new_xyz: sampled points position data, [B, 1, C] 157 | new_points: sampled points data, [B, 1, N, C+D] 158 | """ 159 | device = xyz.device 160 | B, N, C = xyz.shape 161 | #new_xyz = torch.zeros(B, 1, C).to(device) 162 | new_xyz = xyz.mean(dim = 1, keepdim = True) 163 | grouped_xyz = xyz.view(B, 1, N, C) - new_xyz.view(B, 1, 1, C) 164 | if points is not None: 165 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 166 | else: 167 | new_points = grouped_xyz 168 | if density_scale is None: 169 | return new_xyz, new_points, grouped_xyz 170 | else: 171 | grouped_density = density_scale.view(B, 1, N, 1) 172 | return new_xyz, new_points, grouped_xyz, grouped_density 173 | 174 | def group(nsample, xyz, points): 175 | """ 176 | Input: 177 | npoint: 178 | nsample: 179 | xyz: input points position data, [B, N, C] 180 | points: input points data, [B, N, D] 181 | Return: 182 | new_xyz: sampled points position data, [B, 1, C] 183 | new_points: sampled points data, [B, 1, N, C+D] 184 | """ 185 | B, N, C = xyz.shape 186 | S = N 187 | new_xyz = xyz 188 | idx = knn_point(nsample, xyz, new_xyz) 189 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 190 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 191 | if points is not None: 192 | grouped_points = index_points(points, idx) 193 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 194 | else: 195 | new_points = grouped_xyz_norm 196 | 197 | return new_points, grouped_xyz_norm 198 | 199 | def compute_density(xyz, bandwidth): 200 | ''' 201 | xyz: input points position data, [B, N, C] 202 | ''' 203 | #import ipdb; ipdb.set_trace() 204 | B, N, C = xyz.shape 205 | sqrdists = square_distance(xyz, xyz) 206 | gaussion_density = torch.exp(- sqrdists / (2.0 * bandwidth * bandwidth)) / (2.5 * bandwidth) 207 | xyz_density = gaussion_density.mean(dim = -1) 208 | 209 | return xyz_density 210 | 211 | class DensityNet(nn.Module): 212 | def __init__(self, hidden_unit = [16, 8]): 213 | super(DensityNet, self).__init__() 214 | self.mlp_convs = nn.ModuleList() 215 | self.mlp_bns = nn.ModuleList() 216 | 217 | self.mlp_convs.append(nn.Conv2d(1, hidden_unit[0], 1)) 218 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0])) 219 | for i in range(1, len(hidden_unit)): 220 | self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1)) 221 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i])) 222 | self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], 1, 1)) 223 | self.mlp_bns.append(nn.BatchNorm2d(1)) 224 | 225 | def forward(self, density_scale): 226 | for i, conv in enumerate(self.mlp_convs): 227 | bn = self.mlp_bns[i] 228 | density_scale = bn(conv(density_scale)) 229 | if i == len(self.mlp_convs): 230 | density_scale = F.sigmoid(density_scale) 231 | else: 232 | density_scale = F.relu(density_scale) 233 | 234 | return density_scale 235 | 236 | class WeightNet(nn.Module): 237 | 238 | def __init__(self, in_channel, out_channel, hidden_unit = [8, 8]): 239 | super(WeightNet, self).__init__() 240 | 241 | self.mlp_convs = nn.ModuleList() 242 | self.mlp_bns = nn.ModuleList() 243 | if hidden_unit is None or len(hidden_unit) == 0: 244 | self.mlp_convs.append(nn.Conv2d(in_channel, out_channel, 1)) 245 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 246 | else: 247 | self.mlp_convs.append(nn.Conv2d(in_channel, hidden_unit[0], 1)) 248 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[0])) 249 | for i in range(1, len(hidden_unit)): 250 | self.mlp_convs.append(nn.Conv2d(hidden_unit[i - 1], hidden_unit[i], 1)) 251 | self.mlp_bns.append(nn.BatchNorm2d(hidden_unit[i])) 252 | self.mlp_convs.append(nn.Conv2d(hidden_unit[-1], out_channel, 1)) 253 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 254 | 255 | def forward(self, localized_xyz): 256 | #xyz : BxCxKxN 257 | 258 | weights = localized_xyz 259 | for i, conv in enumerate(self.mlp_convs): 260 | bn = self.mlp_bns[i] 261 | weights = F.relu(bn(conv(weights))) 262 | 263 | return weights 264 | 265 | class PointConvSetAbstraction(nn.Module): 266 | def __init__(self, npoint, nsample, in_channel, mlp, group_all): 267 | super(PointConvSetAbstraction, self).__init__() 268 | self.npoint = npoint 269 | self.nsample = nsample 270 | self.mlp_convs = nn.ModuleList() 271 | self.mlp_bns = nn.ModuleList() 272 | last_channel = in_channel 273 | for out_channel in mlp: 274 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 275 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 276 | last_channel = out_channel 277 | 278 | self.weightnet = WeightNet(3, 16) 279 | self.linear = nn.Linear(16 * mlp[-1], mlp[-1]) 280 | self.bn_linear = nn.BatchNorm1d(mlp[-1]) 281 | self.group_all = group_all 282 | 283 | def forward(self, xyz, points): 284 | """ 285 | Input: 286 | xyz: input points position data, [B, C, N] 287 | points: input points data, [B, D, N] 288 | Return: 289 | new_xyz: sampled points position data, [B, C, S] 290 | new_points_concat: sample points feature data, [B, D', S] 291 | """ 292 | B = xyz.shape[0] 293 | xyz = xyz.permute(0, 2, 1) 294 | if points is not None: 295 | points = points.permute(0, 2, 1) 296 | 297 | if self.group_all: 298 | new_xyz, new_points, grouped_xyz_norm = sample_and_group_all(xyz, points) 299 | else: 300 | new_xyz, new_points, grouped_xyz_norm, _ = sample_and_group(self.npoint, self.nsample, xyz, points) 301 | # new_xyz: sampled points position data, [B, npoint, C] 302 | # new_points: sampled points data, [B, npoint, nsample, C+D] 303 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 304 | for i, conv in enumerate(self.mlp_convs): 305 | bn = self.mlp_bns[i] 306 | new_points = F.relu(bn(conv(new_points))) 307 | 308 | grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1) 309 | weights = self.weightnet(grouped_xyz) 310 | new_points = torch.matmul(input=new_points.permute(0, 3, 1, 2), other = weights.permute(0, 3, 2, 1)).view(B, self.npoint, -1) 311 | new_points = self.linear(new_points) 312 | new_points = self.bn_linear(new_points.permute(0, 2, 1)) 313 | new_points = F.relu(new_points) 314 | new_xyz = new_xyz.permute(0, 2, 1) 315 | 316 | return new_xyz, new_points 317 | 318 | class PointConvDensitySetAbstraction(nn.Module): 319 | def __init__(self, npoint, nsample, in_channel, mlp, bandwidth, group_all): 320 | super(PointConvDensitySetAbstraction, self).__init__() 321 | self.npoint = npoint 322 | self.nsample = nsample 323 | self.mlp_convs = nn.ModuleList() 324 | self.mlp_bns = nn.ModuleList() 325 | last_channel = in_channel 326 | for out_channel in mlp: 327 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 328 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 329 | last_channel = out_channel 330 | 331 | self.weightnet = WeightNet(3, 16) 332 | self.linear = nn.Linear(16 * mlp[-1], mlp[-1]) 333 | self.bn_linear = nn.BatchNorm1d(mlp[-1]) 334 | self.densitynet = DensityNet() 335 | self.group_all = group_all 336 | self.bandwidth = bandwidth 337 | 338 | def forward(self, xyz, points): 339 | """ 340 | Input: 341 | xyz: input points position data, [B, C, N] 342 | points: input points data, [B, D, N] 343 | Return: 344 | new_xyz: sampled points position data, [B, C, S] 345 | new_points_concat: sample points feature data, [B, D', S] 346 | """ 347 | B = xyz.shape[0] 348 | N = xyz.shape[2] 349 | xyz = xyz.permute(0, 2, 1) 350 | if points is not None: 351 | points = points.permute(0, 2, 1) 352 | 353 | xyz_density = compute_density(xyz, self.bandwidth) 354 | inverse_density = 1.0 / xyz_density 355 | 356 | if self.group_all: 357 | new_xyz, new_points, grouped_xyz_norm, grouped_density = sample_and_group_all(xyz, points, inverse_density.view(B, N, 1)) 358 | else: 359 | new_xyz, new_points, grouped_xyz_norm, _, grouped_density = sample_and_group(self.npoint, self.nsample, xyz, points, inverse_density.view(B, N, 1)) 360 | # new_xyz: sampled points position data, [B, npoint, C] 361 | # new_points: sampled points data, [B, npoint, nsample, C+D] 362 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 363 | for i, conv in enumerate(self.mlp_convs): 364 | bn = self.mlp_bns[i] 365 | new_points = F.relu(bn(conv(new_points))) 366 | 367 | inverse_max_density = grouped_density.max(dim = 2, keepdim=True)[0] 368 | density_scale = grouped_density / inverse_max_density 369 | density_scale = self.densitynet(density_scale.permute(0, 3, 2, 1)) 370 | new_points = new_points * density_scale 371 | 372 | grouped_xyz = grouped_xyz_norm.permute(0, 3, 2, 1) 373 | weights = self.weightnet(grouped_xyz) 374 | new_points = torch.matmul(input=new_points.permute(0, 3, 1, 2), other = weights.permute(0, 3, 2, 1)).view(B, self.npoint, -1) 375 | new_points = self.linear(new_points) 376 | new_points = self.bn_linear(new_points.permute(0, 2, 1)) 377 | new_points = F.relu(new_points) 378 | new_xyz = new_xyz.permute(0, 2, 1) 379 | 380 | return new_xyz, new_points 381 | 382 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from torch.autograd import Variable 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import datetime 10 | import pandas as pd 11 | import torch.nn.functional as F 12 | def to_categorical(y, num_classes): 13 | """ 1-hot encodes a tensor """ 14 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] 15 | if (y.is_cuda): 16 | return new_y.cuda() 17 | return new_y 18 | 19 | def show_example(x, y, x_reconstruction, y_pred,save_dir, figname): 20 | x = x.squeeze().cpu().data.numpy() 21 | x = x.permute(0,2,1) 22 | y = y.cpu().data.numpy() 23 | x_reconstruction = x_reconstruction.squeeze().cpu().data.numpy() 24 | _, y_pred = torch.max(y_pred, -1) 25 | y_pred = y_pred.cpu().data.numpy() 26 | 27 | fig, ax = plt.subplots(1, 2) 28 | ax[0].imshow(x, cmap='Greys') 29 | ax[0].set_title('Input: %d' % y) 30 | ax[1].imshow(x_reconstruction, cmap='Greys') 31 | ax[1].set_title('Output: %d' % y_pred) 32 | plt.savefig(save_dir + figname + '.png') 33 | 34 | def save_checkpoint(epoch, train_accuracy, test_accuracy, model, optimizer, path,modelnet='checkpoint'): 35 | savepath = path + '/%s-%f-%04d.pth' % (modelnet,test_accuracy, epoch) 36 | state = { 37 | 'epoch': epoch, 38 | 'train_accuracy': train_accuracy, 39 | 'test_accuracy': test_accuracy, 40 | 'model_state_dict': model.state_dict(), 41 | 'optimizer_state_dict': optimizer.state_dict(), 42 | } 43 | torch.save(state, savepath) 44 | 45 | def test(model, loader): 46 | total_correct = 0.0 47 | total_seen = 0.0 48 | for j, data in enumerate(loader, 0): 49 | points, target = data 50 | target = target[:, 0] 51 | points = points.transpose(2, 1) 52 | points, target = points.cuda(), target.cuda() 53 | classifier = model.eval() 54 | with torch.no_grad(): 55 | pred = classifier(points[:, :3, :], points[:, 3:, :]) 56 | pred_choice = pred.data.max(1)[1] 57 | correct = pred_choice.eq(target.long().data).cpu().sum() 58 | total_correct += correct.item() 59 | total_seen += float(points.size()[0]) 60 | 61 | accuracy = total_correct / total_seen 62 | return accuracy 63 | 64 | def compute_cat_iou(pred,target,iou_tabel): 65 | iou_list = [] 66 | target = target.cpu().data.numpy() 67 | for j in range(pred.size(0)): 68 | batch_pred = pred[j] 69 | batch_target = target[j] 70 | batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy() 71 | for cat in np.unique(batch_target): 72 | # intersection = np.sum((batch_target == cat) & (batch_choice == cat)) 73 | # union = float(np.sum((batch_target == cat) | (batch_choice == cat))) 74 | # iou = intersection/union if not union ==0 else 1 75 | I = np.sum(np.logical_and(batch_choice == cat, batch_target == cat)) 76 | U = np.sum(np.logical_or(batch_choice == cat, batch_target == cat)) 77 | if U == 0: 78 | iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1 79 | else: 80 | iou = I / float(U) 81 | iou_tabel[cat,0] += iou 82 | iou_tabel[cat,1] += 1 83 | iou_list.append(iou) 84 | return iou_tabel,iou_list 85 | 86 | def compute_overall_iou(pred, target, num_classes): 87 | shape_ious = [] 88 | pred_np = pred.cpu().data.numpy() 89 | target_np = target.cpu().data.numpy() 90 | for shape_idx in range(pred.size(0)): 91 | part_ious = [] 92 | for part in range(num_classes): 93 | I = np.sum(np.logical_and(pred_np[shape_idx].max(1) == part, target_np[shape_idx] == part)) 94 | U = np.sum(np.logical_or(pred_np[shape_idx].max(1) == part, target_np[shape_idx] == part)) 95 | if U == 0: 96 | iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1 97 | else: 98 | iou = I / float(U) 99 | part_ious.append(iou) 100 | shape_ious.append(np.mean(part_ious)) 101 | return shape_ious 102 | 103 | def test_partseg(model, loader, catdict, num_classes = 50,forpointnet2=False): 104 | ''' catdict = {0:Airplane, 1:Airplane, ...49:Table} ''' 105 | iou_tabel = np.zeros((len(catdict),3)) 106 | iou_list = [] 107 | metrics = defaultdict(lambda:list()) 108 | hist_acc = [] 109 | # mean_correct = [] 110 | for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9): 111 | batchsize, num_point,_= points.size() 112 | points, label, target, norm_plt = Variable(points.float()),Variable(label.long()), Variable(target.long()),Variable(norm_plt.float()) 113 | points = points.transpose(2, 1) 114 | norm_plt = norm_plt.transpose(2, 1) 115 | points, label, target, norm_plt = points.cuda(), label.squeeze().cuda(), target.cuda(), norm_plt.cuda() 116 | if forpointnet2: 117 | seg_pred = model(points, norm_plt, to_categorical(label, 16)) 118 | else: 119 | labels_pred, seg_pred, _ = model(points,to_categorical(label,16)) 120 | # labels_pred_choice = labels_pred.data.max(1)[1] 121 | # labels_correct = labels_pred_choice.eq(label.long().data).cpu().sum() 122 | # mean_correct.append(labels_correct.item() / float(points.size()[0])) 123 | # print(pred.size()) 124 | iou_tabel, iou = compute_cat_iou(seg_pred,target,iou_tabel) 125 | iou_list+=iou 126 | # shape_ious += compute_overall_iou(pred, target, num_classes) 127 | seg_pred = seg_pred.contiguous().view(-1, num_classes) 128 | target = target.view(-1, 1)[:, 0] 129 | pred_choice = seg_pred.data.max(1)[1] 130 | correct = pred_choice.eq(target.data).cpu().sum() 131 | metrics['accuracy'].append(correct.item()/ (batchsize * num_point)) 132 | iou_tabel[:,2] = iou_tabel[:,0] /iou_tabel[:,1] 133 | hist_acc += metrics['accuracy'] 134 | metrics['accuracy'] = np.mean(hist_acc) 135 | metrics['inctance_avg_iou'] = np.mean(iou_list) 136 | # metrics['label_accuracy'] = np.mean(mean_correct) 137 | iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou']) 138 | iou_tabel['Category_IOU'] = [catdict[i] for i in range(len(catdict)) ] 139 | cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean() 140 | metrics['class_avg_iou'] = np.mean(cat_iou) 141 | 142 | return metrics, hist_acc, cat_iou 143 | 144 | def test_semseg(model, loader, catdict, num_classes = 13, pointnet2=False): 145 | iou_tabel = np.zeros((len(catdict),3)) 146 | metrics = defaultdict(lambda:list()) 147 | hist_acc = [] 148 | for batch_id, (points, target) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9): 149 | batchsize, num_point, _ = points.size() 150 | points, target = Variable(points.float()), Variable(target.long()) 151 | points = points.transpose(2, 1) 152 | points, target = points.cuda(), target.cuda() 153 | if pointnet2: 154 | pred = model(points[:, :3, :], points[:, 3:, :]) 155 | else: 156 | pred, _ = model(points) 157 | # print(pred.size()) 158 | iou_tabel, iou_list = compute_cat_iou(pred,target,iou_tabel) 159 | # shape_ious += compute_overall_iou(pred, target, num_classes) 160 | pred = pred.contiguous().view(-1, num_classes) 161 | target = target.view(-1, 1)[:, 0] 162 | pred_choice = pred.data.max(1)[1] 163 | correct = pred_choice.eq(target.data).cpu().sum() 164 | metrics['accuracy'].append(correct.item()/ (batchsize * num_point)) 165 | iou_tabel[:,2] = iou_tabel[:,0] /iou_tabel[:,1] 166 | hist_acc += metrics['accuracy'] 167 | metrics['accuracy'] = np.mean(metrics['accuracy']) 168 | metrics['iou'] = np.mean(iou_tabel[:, 2]) 169 | iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou']) 170 | iou_tabel['Category_IOU'] = [catdict[i] for i in range(len(catdict)) ] 171 | # print(iou_tabel) 172 | cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean() 173 | 174 | return metrics, hist_acc, cat_iou 175 | 176 | 177 | def compute_avg_curve(y, n_points_avg): 178 | avg_kernel = np.ones((n_points_avg,)) / n_points_avg 179 | rolling_mean = np.convolve(y, avg_kernel, mode='valid') 180 | return rolling_mean 181 | 182 | def plot_loss_curve(history,n_points_avg,n_points_plot,save_dir): 183 | curve = np.asarray(history['loss'])[-n_points_plot:] 184 | avg_curve = compute_avg_curve(curve, n_points_avg) 185 | plt.plot(avg_curve, '-g') 186 | 187 | curve = np.asarray(history['margin_loss'])[-n_points_plot:] 188 | avg_curve = compute_avg_curve(curve, n_points_avg) 189 | plt.plot(avg_curve, '-b') 190 | 191 | curve = np.asarray(history['reconstruction_loss'])[-n_points_plot:] 192 | avg_curve = compute_avg_curve(curve, n_points_avg) 193 | plt.plot(avg_curve, '-r') 194 | 195 | plt.legend(['Total Loss', 'Margin Loss', 'Reconstruction Loss']) 196 | plt.savefig(save_dir + '/'+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M')) + '_total_result.png') 197 | plt.close() 198 | 199 | def plot_acc_curve(total_train_acc,total_test_acc,save_dir): 200 | plt.plot(total_train_acc, '-b',label = 'train_acc') 201 | plt.plot(total_test_acc, '-r',label = 'test_acc') 202 | plt.legend() 203 | plt.ylabel('acc') 204 | plt.xlabel('epoch') 205 | plt.title('Accuracy of training and test') 206 | plt.savefig(save_dir +'/'+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M'))+'_total_acc.png') 207 | plt.close() 208 | 209 | def show_point_cloud(tuple,seg_label=[],title=None): 210 | import matplotlib.pyplot as plt 211 | if seg_label == []: 212 | x = [x[0] for x in tuple] 213 | y = [y[1] for y in tuple] 214 | z = [z[2] for z in tuple] 215 | ax = plt.subplot(111, projection='3d') 216 | ax.scatter(x, y, z, c='b', cmap='spectral') 217 | ax.set_zlabel('Z') 218 | ax.set_ylabel('Y') 219 | ax.set_xlabel('X') 220 | else: 221 | category = list(np.unique(seg_label)) 222 | color = ['b','r','g','y','w','b','p'] 223 | ax = plt.subplot(111, projection='3d') 224 | for categ_index in range(len(category)): 225 | tuple_seg = tuple[seg_label == category[categ_index]] 226 | x = [x[0] for x in tuple_seg] 227 | y = [y[1] for y in tuple_seg] 228 | z = [z[2] for z in tuple_seg] 229 | ax.scatter(x, y, z, c=color[categ_index], cmap='spectral') 230 | ax.set_zlabel('Z') 231 | ax.set_ylabel('Y') 232 | ax.set_xlabel('X') 233 | plt.title(title) 234 | plt.show() --------------------------------------------------------------------------------