├── README.md ├── dataset └── shapenet.py ├── figs └── segmentation_result.png ├── main.py ├── model └── pointnet2_part_seg.py └── vis ├── show_seg_res.py └── view.py /README.md: -------------------------------------------------------------------------------- 1 | # Pointnet++ Part segmentation 2 | This repo is implementation for [PointNet++](https://arxiv.org/abs/1706.02413) part segmentation model based on [PyTorch](https://pytorch.org) and [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric). 3 | 4 | **The model has been mergered into [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) as a point cloud segmentation [example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py), you can try it.** 5 | 6 | # Performance 7 | Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html). 8 | 9 | | Method | mcIoU|Airplane|Bag|Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table 10 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 11 | | PointNet++ | 81.9| 82.4| 79.0| 87.7| 77.3 |90.8| 71.8| 91.0| 85.9| 83.7| 95.3| 71.6| 94.1| 81.3| 58.7| 76.4| 82.6| 12 | | PointNet++(this repo) || 82.5| 76.1| 87.8| 77.5| 89.89| 73.7| | | | 95.3| 70.5 13 | 14 | 15 | Note, 16 | - mcIOU: mean per-class pIoU 17 | - The model uses single-scale grouping with raw points as input. 18 | - All experiments are trained with same default configration: npoints=2500, batchsize=8, num_epoches=30. The recorded accuracy above is the test accuracy of the final epoch. 19 | 20 | 21 | # Requirements 22 | - python 3.6.8 23 | - [PyTorch 1.1.0](https://pytorch.org) 24 | - [pytorch_geometric 1.3.0](https://github.com/rusty1s/pytorch_geometric) 25 | - torch-cluster 1.4.2 26 | - torch-scatter 1.2.0 27 | - torch-sparse 0.4.0 28 | - [Open3D 0.6.0](https://github.com/intel-isl/Open3D)(optional, for visualization of segmentation result) 29 | 30 | # Usage 31 | Training 32 | ``` 33 | python main.py 34 | ``` 35 | 36 | Show segmentation result 37 | ``` 38 | python vis/show_seg_res.py 39 | ``` 40 | 41 | # Sample segmentation result 42 | ![segmentation_result](figs/segmentation_result.png) 43 | 44 | 45 | # Links 46 | - [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) by fxia22. This repo's tranining code is heavily borrowed from fxia22's repo. 47 | - Official [PointNet](https://github.com/charlesq34/pointnet) and [PointNet++](https://github.com/charlesq34/pointnet2) tensorflow implementations 48 | - [PointNet++ classification example](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet2_classification.py) of pytorch_geometric library 49 | -------------------------------------------------------------------------------- /dataset/shapenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torch_geometric.datasets import ShapeNet 5 | 6 | 7 | class ShapeNetPartSegDataset(Dataset): 8 | ''' 9 | Resample raw point cloud to fixed number of points. 10 | Map raw label from range [1, N] to [0, N-1]. 11 | ''' 12 | def __init__(self, root_dir, category, train=True, transform=None, npoints=2500): 13 | categories = ['Airplane', 'Bag', 'Cap', 'Car', 'Chair', 'Earphone', 'Guitar', 14 | 'Knife', 'Lamp', 'Laptop', 'Motorbike', 'Mug', 'Pistol', 'Rocket', 'Skateboard', 'Table'] 15 | # assert os.path.exists(root_dir) 16 | assert category in categories 17 | 18 | self.npoints = npoints 19 | self.dataset = ShapeNet(root_dir, category, train, transform) 20 | 21 | def __getitem__(self, index): 22 | data = self.dataset[index] 23 | points, labels = data.pos, data.y 24 | assert labels.min() >= 0 25 | 26 | # Resample to fixed number of points 27 | choice = np.random.choice(points.shape[0], self.npoints, replace=True) 28 | points, labels = points[choice, :], labels[choice] 29 | 30 | sample = { 31 | 'points': points, # torch.Tensor (n, 3) 32 | 'labels': labels # torch.Tensor (n,) 33 | } 34 | 35 | return sample 36 | 37 | def __len__(self): 38 | return len(self.dataset) 39 | 40 | def num_classes(self): 41 | return self.dataset.num_classes 42 | -------------------------------------------------------------------------------- /figs/segmentation_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/pointnet2-pytorch/5df6f3491c6ce826c0e94ca7357921ce0669f414/figs/segmentation_result.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from https://github.com/fxia22/pointnet.pytorch/blob/master/utils/train_segmentation.py 3 | ''' 4 | 5 | import os 6 | import random 7 | import numpy as np 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | from torch import autograd 14 | import torch.backends.cudnn as cudnn 15 | 16 | 17 | from dataset.shapenet import ShapeNetPartSegDataset 18 | from model.pointnet2_part_seg import PointNet2PartSegmentNet 19 | import torch_geometric.transforms as GT 20 | 21 | import time 22 | 23 | 24 | ## Argument parser 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--dataset', type=str, default='shapenet', help='dataset path') 27 | parser.add_argument('--category', type=str, default='Airplane', help='select category') 28 | parser.add_argument('--npoints', type=int, default=2500, help='resample points number') 29 | parser.add_argument('--model', type=str, default='', help='model path') 30 | parser.add_argument('--nepoch', type=int, default=30, help='number of epochs to train for') 31 | parser.add_argument('--outf', type=str, default='checkpoint', help='output folder') 32 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 33 | parser.add_argument('--test_per_batches', type=int, default=10, help='run a test batch per training batches number') 34 | parser.add_argument('--num_workers', type=int, default=6, help='number of data loading workers') 35 | 36 | opt = parser.parse_args() 37 | print(opt) 38 | 39 | 40 | ## Random seed 41 | # opt.manual_seed = np.random.randint(1, 10000) # fix seed 42 | # TODO: Still cannot get determinstic result 43 | opt.manual_seed = 123 44 | print('Random seed: ', opt.manual_seed) 45 | random.seed(opt.manual_seed) 46 | np.random.seed(opt.manual_seed) 47 | torch.manual_seed(opt.manual_seed) 48 | torch.cuda.manual_seed(opt.manual_seed) 49 | 50 | 51 | ## Dataset and transform 52 | print('Construct dataset ..') 53 | rot_max_angle = 15 54 | trans_max_distance = 0.01 55 | 56 | RotTransform = GT.Compose([GT.RandomRotate(rot_max_angle, 0), GT.RandomRotate(rot_max_angle, 1), GT.RandomRotate(rot_max_angle, 2)]) 57 | TransTransform = GT.RandomTranslate(trans_max_distance) 58 | 59 | train_transform = GT.Compose([GT.NormalizeScale(), RotTransform, TransTransform]) 60 | test_transform = GT.Compose([GT.NormalizeScale(), ]) 61 | 62 | dataset = ShapeNetPartSegDataset( 63 | root_dir=opt.dataset, category=opt.category, train=True, transform=train_transform, npoints=opt.npoints) 64 | dataloader = torch.utils.data.DataLoader( 65 | dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 66 | 67 | test_dataset = ShapeNetPartSegDataset( 68 | root_dir=opt.dataset, category=opt.category, train=False, transform=test_transform, npoints=opt.npoints) 69 | # Note, set shuffle=True for peridodic running a random test batch during training 70 | test_dataloader = torch.utils.data.DataLoader( 71 | test_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers) 72 | 73 | num_classes = dataset.num_classes() 74 | 75 | print('dataset size: ', len(dataset)) 76 | print('test_dataset size: ', len(test_dataset)) 77 | print('num_classes: ', num_classes) 78 | 79 | try: 80 | os.mkdir(opt.outf) 81 | except OSError: 82 | pass 83 | 84 | 85 | ## Model, criterion and optimizer 86 | print('Construct model ..') 87 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 88 | dtype = torch.float 89 | print('cudnn.enabled: ', torch.backends.cudnn.enabled) 90 | 91 | 92 | net = PointNet2PartSegmentNet(num_classes) 93 | 94 | if opt.model != '': 95 | net.load_state_dict(torch.load(opt.model)) 96 | net = net.to(device, dtype) 97 | 98 | criterion = nn.NLLLoss() 99 | optimizer = optim.Adam(net.parameters()) 100 | 101 | 102 | ## Train 103 | print('Training ..') 104 | blue = lambda x: '\033[94m' + x + '\033[0m' 105 | num_batch = len(dataset) // opt.batch_size 106 | test_per_batches = opt.test_per_batches 107 | 108 | print('number of epoches: ', opt.nepoch) 109 | print('number of batches per epoch: ', num_batch) 110 | print('run test per batches: ', test_per_batches) 111 | 112 | for epoch in range(opt.nepoch): 113 | print('Epoch {}, total epoches {}'.format(epoch+1, opt.nepoch)) 114 | 115 | net.train() 116 | 117 | for batch_idx, sample in enumerate(dataloader): 118 | # points: (batch_size, n, 3) 119 | # labels: (batch_size, n) 120 | points, labels = sample['points'], sample['labels'] 121 | points = points.transpose(1, 2).contiguous() # (batch_size, 3, n) 122 | points, labels = points.to(device, dtype), labels.to(device, torch.long) 123 | 124 | optimizer.zero_grad() 125 | 126 | pred = net(points) # (batch_size, n, num_classes) 127 | pred = pred.view(-1, num_classes) # (batch_size * n, num_classes) 128 | target = labels.view(-1, 1)[:, 0] 129 | 130 | loss = F.nll_loss(pred, target) 131 | loss.backward() 132 | 133 | optimizer.step() 134 | 135 | ## 136 | pred_label = pred.detach().max(1)[1] 137 | correct = pred_label.eq(target.detach()).cpu().sum() 138 | total = pred_label.shape[0] 139 | 140 | print('[{}: {}/{}] train loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, loss.item(), float(correct.item())/total)) 141 | 142 | ## 143 | if batch_idx % test_per_batches == 0: 144 | print('Run a test batch') 145 | net.eval() 146 | 147 | with torch.no_grad(): 148 | batch_idx, sample = next(enumerate(test_dataloader)) 149 | 150 | points, labels = sample['points'], sample['labels'] 151 | points = points.transpose(1, 2).contiguous() 152 | points, labels = points.to(device, dtype), labels.to(device, torch.long) 153 | 154 | pred = net(points) 155 | pred = pred.view(-1, num_classes) 156 | target = labels.view(-1, 1)[:, 0] 157 | 158 | loss = F.nll_loss(pred, target) 159 | 160 | pred_label = pred.detach().max(1)[1] 161 | correct = pred_label.eq(target.detach()).cpu().sum() 162 | total = pred_label.shape[0] 163 | print('[{}: {}/{}] {} loss: {} accuracy: {}'.format(epoch, batch_idx, num_batch, blue('test'), loss.item(), float(correct.item())/total)) 164 | 165 | # Back to training mode 166 | net.train() 167 | 168 | torch.save(net.state_dict(), '{}/seg_model_{}_{}.pth'.format(opt.outf, opt.category, epoch)) 169 | 170 | 171 | ## Benchmarm mIOU 172 | net.eval() 173 | shape_ious = [] 174 | 175 | with torch.no_grad(): 176 | for batch_idx, sample in enumerate(test_dataloader): 177 | points, labels = sample['points'], sample['labels'] 178 | points = points.transpose(1, 2).contiguous() 179 | points = points.to(device, dtype) 180 | 181 | # start_t = time.time() 182 | pred = net(points) # (batch_size, n, num_classes) 183 | # print('batch inference forward time used: {} ms'.format(time.time() - start_t)) 184 | 185 | pred_label = pred.max(2)[1] 186 | pred_label = pred_label.cpu().numpy() 187 | target_label = labels.numpy() 188 | 189 | batch_size = target_label.shape[0] 190 | for shape_idx in range(batch_size): 191 | parts = range(num_classes) # np.unique(target_label[shape_idx]) 192 | part_ious = [] 193 | for part in parts: 194 | I = np.sum(np.logical_and(pred_label[shape_idx] == part, target_label[shape_idx] == part)) 195 | U = np.sum(np.logical_or(pred_label[shape_idx] == part, target_label[shape_idx] == part)) 196 | if U == 0: iou = 1 197 | else: iou = float(I) / U 198 | part_ious.append(iou) 199 | shape_ious.append(np.mean(part_ious)) 200 | 201 | print('mIOU for category {}: {}'.format(opt.category, np.mean(shape_ious))) 202 | 203 | print('Done.') 204 | -------------------------------------------------------------------------------- /model/pointnet2_part_seg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Sequential as Seq, Linear as Lin, ReLU, Dropout, BatchNorm1d 4 | from torch_geometric.nn import PointConv, fps, radius, knn 5 | from torch_geometric.nn.conv import MessagePassing 6 | from torch_geometric.nn.inits import reset 7 | from torch_geometric.utils.num_nodes import maybe_num_nodes 8 | from torch_geometric.data.data import Data 9 | from torch_scatter import scatter_add, scatter_max 10 | 11 | 12 | class PointNet2SAModule(torch.nn.Module): 13 | def __init__(self, sample_radio, radius, max_num_neighbors, mlp): 14 | super(PointNet2SAModule, self).__init__() 15 | self.sample_ratio = sample_radio 16 | self.radius = radius 17 | self.max_num_neighbors = max_num_neighbors 18 | self.point_conv = PointConv(mlp) 19 | 20 | def forward(self, data): 21 | x, pos, batch = data 22 | 23 | # Sample 24 | idx = fps(pos, batch, ratio=self.sample_ratio) 25 | 26 | # Group(Build graph) 27 | row, col = radius(pos, pos[idx], self.radius, batch, batch[idx], max_num_neighbors=self.max_num_neighbors) 28 | edge_index = torch.stack([col, row], dim=0) 29 | 30 | # Apply pointnet 31 | x1 = self.point_conv(x, (pos, pos[idx]), edge_index) 32 | pos1, batch1 = pos[idx], batch[idx] 33 | 34 | return x1, pos1, batch1 35 | 36 | 37 | class PointNet2GlobalSAModule(torch.nn.Module): 38 | ''' 39 | One group with all input points, can be viewed as a simple PointNet module. 40 | It also return the only one output point(set as origin point). 41 | ''' 42 | def __init__(self, mlp): 43 | super(PointNet2GlobalSAModule, self).__init__() 44 | self.mlp = mlp 45 | 46 | def forward(self, data): 47 | x, pos, batch = data 48 | if x is not None: x = torch.cat([x, pos], dim=1) 49 | x1 = self.mlp(x) 50 | 51 | x1 = scatter_max(x1, batch, dim=0)[0] # (batch_size, C1) 52 | 53 | batch_size = x1.shape[0] 54 | pos1 = x1.new_zeros((batch_size, 3)) # set the output point as origin 55 | batch1 = torch.arange(batch_size).to(batch.device, batch.dtype) 56 | 57 | return x1, pos1, batch1 58 | 59 | 60 | class PointConvFP(MessagePassing): 61 | ''' 62 | Core layer of Feature propagtaion module. 63 | ''' 64 | def __init__(self, mlp=None): 65 | super(PointConvFP, self).__init__('add', 'source_to_target') 66 | self.mlp = mlp 67 | self.aggr = 'add' 68 | self.flow = 'source_to_target' 69 | 70 | self.reset_parameters() 71 | 72 | def reset_parameters(self): 73 | reset(self.mlp) 74 | 75 | def forward(self, x, pos, edge_index): 76 | r""" 77 | Args: 78 | x (tuple), (tensor, tensor) or (tensor, NoneType) 79 | pos (tuple): The node position matrix. Either given as 80 | tensor for use in general message passing or as tuple for use 81 | in message passing in bipartite graphs. 82 | edge_index (LongTensor): The edge indices. 83 | """ 84 | # Do not pass (tensor, None) directly into propagate(), sice it will check each item's size() inside. 85 | x_tmp = x[0] if x[1] is None else x 86 | aggr_out = self.propagate(edge_index, x=x_tmp, pos=pos) 87 | 88 | # 89 | i, j = (0, 1) if self.flow == 'target_to_source' else (1, 0) 90 | x_target, pos_target = x[i], pos[i] 91 | 92 | add = [pos_target,] if x_target is None else [x_target, pos_target] 93 | aggr_out = torch.cat([aggr_out, *add], dim=1) 94 | 95 | if self.mlp is not None: aggr_out = self.mlp(aggr_out) 96 | 97 | return aggr_out 98 | 99 | def message(self, x_j, pos_j, pos_i, edge_index): 100 | ''' 101 | x_j: (E, in_channels) 102 | pos_j: (E, 3) 103 | pos_i: (E, 3) 104 | ''' 105 | dist = (pos_j - pos_i).pow(2).sum(dim=1).pow(0.5) 106 | dist = torch.max(dist, torch.Tensor([1e-10]).to(dist.device, dist.dtype)) 107 | weight = 1.0 / dist # (E,) 108 | 109 | row, col = edge_index 110 | index = col 111 | num_nodes = maybe_num_nodes(index, None) 112 | wsum = scatter_add(weight, col, dim=0, dim_size=num_nodes)[index] + 1e-16 # (E,) 113 | weight /= wsum 114 | 115 | return weight.view(-1, 1) * x_j 116 | 117 | def update(self, aggr_out): 118 | return aggr_out 119 | 120 | 121 | class PointNet2FPModule(torch.nn.Module): 122 | def __init__(self, knn_num, mlp): 123 | super(PointNet2FPModule, self).__init__() 124 | self.knn_num = knn_num 125 | self.point_conv = PointConvFP(mlp) 126 | 127 | def forward(self, in_layer_data, skip_layer_data): 128 | in_x, in_pos, in_batch = in_layer_data 129 | skip_x, skip_pos, skip_batch = skip_layer_data 130 | 131 | row, col = knn(in_pos, skip_pos, self.knn_num, in_batch, skip_batch) 132 | edge_index = torch.stack([col, row], dim=0) 133 | 134 | x1 = self.point_conv((in_x, skip_x), (in_pos, skip_pos), edge_index) 135 | pos1, batch1 = skip_pos, skip_batch 136 | 137 | return x1, pos1, batch1 138 | 139 | 140 | def make_mlp(in_channels, mlp_channels, batch_norm=True): 141 | assert len(mlp_channels) >= 1 142 | layers = [] 143 | 144 | for c in mlp_channels: 145 | layers += [Lin(in_channels, c)] 146 | if batch_norm: layers += [BatchNorm1d(c)] 147 | layers += [ReLU()] 148 | 149 | in_channels = c 150 | 151 | return Seq(*layers) 152 | 153 | 154 | class PointNet2PartSegmentNet(torch.nn.Module): 155 | ''' 156 | ref: 157 | - https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_part_seg.py 158 | - https://github.com/rusty1s/pytorch_geometric/blob/master/examples/pointnet++.py 159 | ''' 160 | def __init__(self, num_classes): 161 | super(PointNet2PartSegmentNet, self).__init__() 162 | self.num_classes = num_classes 163 | 164 | # SA1 165 | sa1_sample_ratio = 0.5 166 | sa1_radius = 0.2 167 | sa1_max_num_neighbours = 64 168 | sa1_mlp = make_mlp(3, [64, 64, 128]) 169 | self.sa1_module = PointNet2SAModule(sa1_sample_ratio, sa1_radius, sa1_max_num_neighbours, sa1_mlp) 170 | 171 | # SA2 172 | sa2_sample_ratio = 0.25 173 | sa2_radius = 0.4 174 | sa2_max_num_neighbours = 64 175 | sa2_mlp = make_mlp(128+3, [128, 128, 256]) 176 | self.sa2_module = PointNet2SAModule(sa2_sample_ratio, sa2_radius, sa2_max_num_neighbours, sa2_mlp) 177 | 178 | # SA3 179 | sa3_mlp = make_mlp(256+3, [256, 512, 1024]) 180 | self.sa3_module = PointNet2GlobalSAModule(sa3_mlp) 181 | 182 | ## 183 | knn_num = 3 184 | 185 | # FP3, reverse of sa3 186 | fp3_knn_num = 1 # After global sa module, there is only one point in point cloud 187 | fp3_mlp = make_mlp(1024+256+3, [256, 256]) 188 | self.fp3_module = PointNet2FPModule(fp3_knn_num, fp3_mlp) 189 | 190 | # FP2, reverse of sa2 191 | fp2_knn_num = knn_num 192 | fp2_mlp = make_mlp(256+128+3, [256, 128]) 193 | self.fp2_module = PointNet2FPModule(fp2_knn_num, fp2_mlp) 194 | 195 | # FP1, reverse of sa1 196 | fp1_knn_num = knn_num 197 | fp1_mlp = make_mlp(128+3, [128, 128, 128]) 198 | self.fp1_module = PointNet2FPModule(fp1_knn_num, fp1_mlp) 199 | 200 | self.fc1 = Lin(128, 128) 201 | self.dropout1 = Dropout(p=0.5) 202 | self.fc2 = Lin(128, self.num_classes) 203 | 204 | def forward(self, data): 205 | ''' 206 | data: a batch of input, torch.Tensor or torch_geometric.data.Data type 207 | - torch.Tensor: (batch_size, 3, num_points), as common batch input 208 | 209 | - torch_geometric.data.Data, as torch_geometric batch input: 210 | data.x: (batch_size * ~num_points, C), batch nodes/points feature, 211 | ~num_points means each sample can have different number of points/nodes 212 | 213 | data.pos: (batch_size * ~num_points, 3) 214 | 215 | data.batch: (batch_size * ~num_points,), a column vector of graph/pointcloud 216 | idendifiers for all nodes of all graphs/pointclouds in the batch. See 217 | pytorch_gemometric documentation for more information 218 | ''' 219 | dense_input = True if isinstance(data, torch.Tensor) else False 220 | 221 | if dense_input: 222 | # Convert to torch_geometric.data.Data type 223 | data = data.transpose(1, 2).contiguous() 224 | batch_size, N, _ = data.shape # (batch_size, num_points, 3) 225 | pos = data.view(batch_size*N, -1) 226 | batch = torch.zeros((batch_size, N), device=pos.device, dtype=torch.long) 227 | for i in range(batch_size): batch[i] = i 228 | batch = batch.view(-1) 229 | 230 | data = Data() 231 | data.pos, data.batch = pos, batch 232 | 233 | if not hasattr(data, 'x'): data.x = None 234 | data_in = data.x, data.pos, data.batch 235 | 236 | sa1_out = self.sa1_module(data_in) 237 | sa2_out = self.sa2_module(sa1_out) 238 | sa3_out = self.sa3_module(sa2_out) 239 | 240 | fp3_out = self.fp3_module(sa3_out, sa2_out) 241 | fp2_out = self.fp2_module(fp3_out, sa1_out) 242 | fp1_out = self.fp1_module(fp2_out, data_in) 243 | 244 | fp1_out_x, fp1_out_pos, fp1_out_batch = fp1_out 245 | x = self.fc2(self.dropout1(self.fc1(fp1_out_x))) 246 | x = F.log_softmax(x, dim=-1) 247 | 248 | if dense_input: return x.view(batch_size, N, self.num_classes) 249 | else: return x, fp1_out_batch 250 | 251 | 252 | if __name__ == '__main__': 253 | num_classes = 10 254 | net = PointNet2PartSegmentNet(num_classes) 255 | 256 | # 257 | print('Test dense input ..') 258 | data1 = torch.rand((2, 3, 1024)) # (batch_size, 3, num_points) 259 | print('data1: ', data1.shape) 260 | 261 | out1 = net(data1) 262 | print('out1: ', out1.shape) 263 | 264 | # 265 | print('Test torch_geometric.data.Data input ..') 266 | def make_data_batch(): 267 | # batch_size = 2 268 | pos_num1 = 1000 269 | pos_num2 = 1024 270 | 271 | data_batch = Data() 272 | 273 | # data_batch.x = None 274 | data_batch.pos = torch.cat([torch.rand(pos_num1, 3), torch.rand(pos_num2, 3)], dim=0) 275 | data_batch.batch = torch.cat([torch.zeros(pos_num1, dtype=torch.long), torch.ones(pos_num2, dtype=torch.long)]) 276 | 277 | return data_batch 278 | 279 | data2 = make_data_batch() 280 | # print('data.x: ', data.x) 281 | print('data2.pos: ', data2.pos.shape) 282 | print('data2.batch: ', data2.batch.shape) 283 | 284 | out2_x, out2_batch = net(data2) 285 | print('out2_x: ', out2_x.shape) 286 | print('out2_batch: ', out2_batch.shape) 287 | -------------------------------------------------------------------------------- /vis/show_seg_res.py: -------------------------------------------------------------------------------- 1 | # Warning: import open3d may lead crash, try to import open3d first here 2 | from view import view_points_labels 3 | 4 | import sys 5 | import os 6 | sys.path.append(os.path.dirname(os.path.abspath(__file__)) + '/../') # add project root directory 7 | 8 | from dataset.shapenet import ShapeNetPartSegDataset 9 | from model.pointnet2_part_seg import PointNet2PartSegmentNet 10 | import torch_geometric.transforms as GT 11 | import torch 12 | import numpy as np 13 | import argparse 14 | 15 | 16 | ## 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--dataset', type=str, default='shapenet', help='dataset path') 19 | parser.add_argument('--category', type=str, default='Airplane', help='select category') 20 | parser.add_argument('--npoints', type=int, default=2500, help='resample points number') 21 | parser.add_argument('--model', type=str, default='./checkpoint/seg_model_Airplane_24.pth', help='model path') 22 | parser.add_argument('--sample_idx', type=int, default=0, help='select a sample to segment and view result') 23 | 24 | opt = parser.parse_args() 25 | print(opt) 26 | 27 | 28 | ## Load dataset 29 | print('Construct dataset ..') 30 | test_transform = GT.Compose([GT.NormalizeScale(),]) 31 | 32 | test_dataset = ShapeNetPartSegDataset( 33 | root_dir=opt.dataset, 34 | category=opt.category, 35 | train=False, 36 | transform=test_transform, 37 | npoints=opt.npoints 38 | ) 39 | num_classes = test_dataset.num_classes() 40 | 41 | print('test dataset size: ', len(test_dataset)) 42 | print('num_classes: ', num_classes) 43 | 44 | 45 | # Load model 46 | print('Construct model ..') 47 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 48 | dtype = torch.float 49 | 50 | # net = PointNetPartSegmentNet(num_classes) 51 | net = PointNet2PartSegmentNet(num_classes) 52 | 53 | net.load_state_dict(torch.load(opt.model)) 54 | net = net.to(device, dtype) 55 | net.eval() 56 | 57 | 58 | ## 59 | def eval_sample(net, sample): 60 | ''' 61 | sample: { 'points': tensor(n, 3), 'labels': tensor(n,) } 62 | return: (pred_label, gt_label) with labels shape (n,) 63 | ''' 64 | net.eval() 65 | with torch.no_grad(): 66 | # points: (n, 3) 67 | points, gt_label = sample['points'], sample['labels'] 68 | n = points.shape[0] 69 | 70 | points = points.view(1, n, 3) # make a batch 71 | points = points.transpose(1, 2).contiguous() 72 | points = points.to(device, dtype) 73 | 74 | pred = net(points) # (batch_size, n, num_classes) 75 | pred_label = pred.max(2)[1] 76 | pred_label = pred_label.view(-1).cpu() # (n,) 77 | 78 | assert pred_label.shape == gt_label.shape 79 | return (pred_label, gt_label) 80 | 81 | 82 | def compute_mIoU(pred_label, gt_label): 83 | minl, maxl = np.min(gt_label), np.max(gt_label) 84 | ious = [] 85 | for l in range(minl, maxl+1): 86 | I = np.sum(np.logical_and(pred_label == l, gt_label == l)) 87 | U = np.sum(np.logical_or(pred_label == l, gt_label == l)) 88 | if U == 0: iou = 1 89 | else: iou = float(I) / U 90 | ious.append(iou) 91 | return np.mean(ious) 92 | 93 | 94 | def label_diff(pred_label, gt_label): 95 | ''' 96 | Assign 1 if different label, or 0 if same label 97 | ''' 98 | diff = pred_label - gt_label 99 | diff_mask = (diff != 0) 100 | 101 | diff_label = np.zeros((pred_label.shape[0]), dtype=np.int32) 102 | diff_label[diff_mask] = 1 103 | 104 | return diff_label 105 | 106 | 107 | # Get one sample and eval 108 | sample = test_dataset[opt.sample_idx] 109 | 110 | print('Eval test sample ..') 111 | pred_label, gt_label = eval_sample(net, sample) 112 | print('Eval done ..') 113 | 114 | 115 | # Get sample result 116 | print('Compute mIoU ..') 117 | points = sample['points'].numpy() 118 | pred_labels = pred_label.numpy() 119 | gt_labels = gt_label.numpy() 120 | diff_labels = label_diff(pred_labels, gt_labels) 121 | 122 | print('mIoU: ', compute_mIoU(pred_labels, gt_labels)) 123 | 124 | 125 | # View result 126 | 127 | # print('View gt labels ..') 128 | # view_points_labels(points, gt_labels) 129 | 130 | # print('View diff labels ..') 131 | # view_points_labels(points, diff_labels) 132 | 133 | print('View pred labels ..') 134 | view_points_labels(points, pred_labels) 135 | -------------------------------------------------------------------------------- /vis/view.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | 5 | def mini_color_table(index, norm=True): 6 | colors = [ 7 | [0.5000, 0.5400, 0.5300], [0.8900, 0.1500, 0.2100], [0.6400, 0.5800, 0.5000], 8 | [1.0000, 0.3800, 0.0100], [1.0000, 0.6600, 0.1400], [0.4980, 1.0000, 0.0000], 9 | [0.4980, 1.0000, 0.8314], [0.9412, 0.9725, 1.0000], [0.5412, 0.1686, 0.8863], 10 | [0.5765, 0.4392, 0.8588], [0.3600, 0.1400, 0.4300], [0.5600, 0.3700, 0.6000], 11 | ] 12 | 13 | assert index >= 0 and index < len(colors) 14 | color = colors[index] 15 | 16 | if not norm: 17 | color[0] *= 255 18 | color[1] *= 255 19 | color[2] *= 255 20 | 21 | return color 22 | 23 | 24 | def view_points(points, colors=None): 25 | ''' 26 | points: np.ndarray with shape (n, 3) 27 | colors: [r, g, b] or np.array with shape (n, 3) 28 | ''' 29 | cloud = o3d.PointCloud() 30 | cloud.points = o3d.Vector3dVector(points) 31 | 32 | if colors is not None: 33 | if isinstance(colors, np.ndarray): 34 | cloud.colors = o3d.Vector3dVector(colors) 35 | else: cloud.paint_uniform_color(colors) 36 | 37 | o3d.draw_geometries([cloud]) 38 | 39 | 40 | def label2color(labels): 41 | ''' 42 | labels: np.ndarray with shape (n, ) 43 | colors(return): np.ndarray with shape (n, 3) 44 | ''' 45 | num = labels.shape[0] 46 | colors = np.zeros((num, 3)) 47 | 48 | minl, maxl = np.min(labels), np.max(labels) 49 | for l in range(minl, maxl + 1): 50 | colors[labels==l, :] = mini_color_table(l) 51 | 52 | return colors 53 | 54 | 55 | def view_points_labels(points, labels): 56 | ''' 57 | Assign points with colors by labels and view colored points. 58 | points: np.ndarray with shape (n, 3) 59 | labels: np.ndarray with shape (n, 1), dtype=np.int32 60 | ''' 61 | assert points.shape[0] == labels.shape[0] 62 | colors = label2color(labels) 63 | view_points(points, colors) 64 | --------------------------------------------------------------------------------