├── .gitignore ├── LICENSE ├── README.md ├── misc ├── modelnet_id.txt ├── num_seg_classes.txt └── show3d.png ├── pointnet ├── __init__.py ├── dataset.py └── model.py ├── scripts ├── build.sh └── download.sh ├── setup.py └── utils ├── render_balls_so.cpp ├── show3d_balls.py ├── show_cls.py ├── show_seg.py ├── train_classification.py └── train_segmentation.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | data 3 | *.pyc 4 | *.ipynb 5 | shapenetcore_partanno_segmentation_benchmark_v0/ 6 | *.so 7 | .idea* 8 | cls/ 9 | seg/ 10 | *.egg-info/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Fei Xia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointNet.pytorch 2 | This repo is implementation for PointNet(https://arxiv.org/abs/1612.00593) in pytorch. The model is in `pointnet/model.py`. 3 | 4 | It is tested with pytorch-1.0. 5 | 6 | # Download data and running 7 | 8 | ``` 9 | git clone https://github.com/fxia22/pointnet.pytorch 10 | cd pointnet.pytorch 11 | pip install -e . 12 | ``` 13 | 14 | Download and build visualization tool 15 | ``` 16 | cd scripts 17 | bash build.sh #build C++ code for visualization 18 | bash download.sh #download dataset 19 | ``` 20 | 21 | Training 22 | ``` 23 | cd utils 24 | python train_classification.py --dataset --nepoch= --dataset_type 25 | python train_segmentation.py --dataset --nepoch= 26 | ``` 27 | 28 | Use `--feature_transform` to use feature transform. 29 | 30 | # Performance 31 | 32 | ## Classification performance 33 | 34 | On ModelNet40: 35 | 36 | | | Overall Acc | 37 | | :---: | :---: | 38 | | Original implementation | 89.2 | 39 | | this implementation(w/o feature transform) | 86.4 | 40 | | this implementation(w/ feature transform) | 87.0 | 41 | 42 | On [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html) 43 | 44 | | | Overall Acc | 45 | | :---: | :---: | 46 | | Original implementation | N/A | 47 | | this implementation(w/o feature transform) | 98.1 | 48 | | this implementation(w/ feature transform) | 97.7 | 49 | 50 | ## Segmentation performance 51 | 52 | Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html). 53 | 54 | | Class(mIOU) | Airplane | Bag| Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table 55 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 56 | | Original implementation | 83.4 | 78.7 | 82.5| 74.9 |89.6| 73.0| 91.5| 85.9| 80.8| 95.3| 65.2| 93.0| 81.2| 57.9| 72.8| 80.6| 57 | | this implementation(w/o feature transform) | 73.5 | 71.3 | 64.3 | 61.1 | 87.2 | 69.5 | 86.1|81.6| 77.4|92.7|41.3|86.5|78.2|41.2|61.0|81.1| 58 | | this implementation(w/ feature transform) | | | | | 87.6 | | | | | | | | | | |81.0| 59 | 60 | Note that this implementation trains each class separately, so classes with fewer data will have slightly lower performance than reference implementation. 61 | 62 | Sample segmentation result: 63 | ![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D) 64 | 65 | # Links 66 | 67 | - [Project Page](http://stanford.edu/~rqi/pointnet/) 68 | - [Tensorflow implementation](https://github.com/charlesq34/pointnet) 69 | -------------------------------------------------------------------------------- /misc/modelnet_id.txt: -------------------------------------------------------------------------------- 1 | airplane 0 2 | bathtub 1 3 | bed 2 4 | bench 3 5 | bookshelf 4 6 | bottle 5 7 | bowl 6 8 | car 7 9 | chair 8 10 | cone 9 11 | cup 10 12 | curtain 11 13 | desk 12 14 | door 13 15 | dresser 14 16 | flower_pot 15 17 | glass_box 16 18 | guitar 17 19 | keyboard 18 20 | lamp 19 21 | laptop 20 22 | mantel 21 23 | monitor 22 24 | night_stand 23 25 | person 24 26 | piano 25 27 | plant 26 28 | radio 27 29 | range_hood 28 30 | sink 29 31 | sofa 30 32 | stairs 31 33 | stool 32 34 | table 33 35 | tent 34 36 | toilet 35 37 | tv_stand 36 38 | vase 37 39 | wardrobe 38 40 | xbox 39 41 | -------------------------------------------------------------------------------- /misc/num_seg_classes.txt: -------------------------------------------------------------------------------- 1 | Airplane 4 2 | Bag 2 3 | Cap 2 4 | Car 4 5 | Chair 4 6 | Earphone 3 7 | Guitar 3 8 | Knife 2 9 | Lamp 4 10 | Laptop 2 11 | Motorbike 6 12 | Mug 2 13 | Pistol 3 14 | Rocket 3 15 | Skateboard 3 16 | Table 3 17 | -------------------------------------------------------------------------------- /misc/show3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxia22/pointnet.pytorch/bafbf401e0af19be3262e448c59313fd2be0e421/misc/show3d.png -------------------------------------------------------------------------------- /pointnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fxia22/pointnet.pytorch/bafbf401e0af19be3262e448c59313fd2be0e421/pointnet/__init__.py -------------------------------------------------------------------------------- /pointnet/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | import os 4 | import os.path 5 | import torch 6 | import numpy as np 7 | import sys 8 | from tqdm import tqdm 9 | import json 10 | from plyfile import PlyData, PlyElement 11 | 12 | def get_segmentation_classes(root): 13 | catfile = os.path.join(root, 'synsetoffset2category.txt') 14 | cat = {} 15 | meta = {} 16 | 17 | with open(catfile, 'r') as f: 18 | for line in f: 19 | ls = line.strip().split() 20 | cat[ls[0]] = ls[1] 21 | 22 | for item in cat: 23 | dir_seg = os.path.join(root, cat[item], 'points_label') 24 | dir_point = os.path.join(root, cat[item], 'points') 25 | fns = sorted(os.listdir(dir_point)) 26 | meta[item] = [] 27 | for fn in fns: 28 | token = (os.path.splitext(os.path.basename(fn))[0]) 29 | meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'))) 30 | 31 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f: 32 | for item in cat: 33 | datapath = [] 34 | num_seg_classes = 0 35 | for fn in meta[item]: 36 | datapath.append((item, fn[0], fn[1])) 37 | 38 | for i in tqdm(range(len(datapath))): 39 | l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8))) 40 | if l > num_seg_classes: 41 | num_seg_classes = l 42 | 43 | print("category {} num segmentation classes {}".format(item, num_seg_classes)) 44 | f.write("{}\t{}\n".format(item, num_seg_classes)) 45 | 46 | def gen_modelnet_id(root): 47 | classes = [] 48 | with open(os.path.join(root, 'train.txt'), 'r') as f: 49 | for line in f: 50 | classes.append(line.strip().split('/')[0]) 51 | classes = np.unique(classes) 52 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f: 53 | for i in range(len(classes)): 54 | f.write('{}\t{}\n'.format(classes[i], i)) 55 | 56 | class ShapeNetDataset(data.Dataset): 57 | def __init__(self, 58 | root, 59 | npoints=2500, 60 | classification=False, 61 | class_choice=None, 62 | split='train', 63 | data_augmentation=True): 64 | self.npoints = npoints 65 | self.root = root 66 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 67 | self.cat = {} 68 | self.data_augmentation = data_augmentation 69 | self.classification = classification 70 | self.seg_classes = {} 71 | 72 | with open(self.catfile, 'r') as f: 73 | for line in f: 74 | ls = line.strip().split() 75 | self.cat[ls[0]] = ls[1] 76 | #print(self.cat) 77 | if not class_choice is None: 78 | self.cat = {k: v for k, v in self.cat.items() if k in class_choice} 79 | 80 | self.id2cat = {v: k for k, v in self.cat.items()} 81 | 82 | self.meta = {} 83 | splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split)) 84 | #from IPython import embed; embed() 85 | filelist = json.load(open(splitfile, 'r')) 86 | for item in self.cat: 87 | self.meta[item] = [] 88 | 89 | for file in filelist: 90 | _, category, uuid = file.split('/') 91 | if category in self.cat.values(): 92 | self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'), 93 | os.path.join(self.root, category, 'points_label', uuid+'.seg'))) 94 | 95 | self.datapath = [] 96 | for item in self.cat: 97 | for fn in self.meta[item]: 98 | self.datapath.append((item, fn[0], fn[1])) 99 | 100 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) 101 | print(self.classes) 102 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f: 103 | for line in f: 104 | ls = line.strip().split() 105 | self.seg_classes[ls[0]] = int(ls[1]) 106 | self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]] 107 | print(self.seg_classes, self.num_seg_classes) 108 | 109 | def __getitem__(self, index): 110 | fn = self.datapath[index] 111 | cls = self.classes[self.datapath[index][0]] 112 | point_set = np.loadtxt(fn[1]).astype(np.float32) 113 | seg = np.loadtxt(fn[2]).astype(np.int64) 114 | #print(point_set.shape, seg.shape) 115 | 116 | choice = np.random.choice(len(seg), self.npoints, replace=True) 117 | #resample 118 | point_set = point_set[choice, :] 119 | 120 | point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center 121 | dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0) 122 | point_set = point_set / dist #scale 123 | 124 | if self.data_augmentation: 125 | theta = np.random.uniform(0,np.pi*2) 126 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) 127 | point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation 128 | point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter 129 | 130 | seg = seg[choice] 131 | point_set = torch.from_numpy(point_set) 132 | seg = torch.from_numpy(seg) 133 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 134 | 135 | if self.classification: 136 | return point_set, cls 137 | else: 138 | return point_set, seg 139 | 140 | def __len__(self): 141 | return len(self.datapath) 142 | 143 | class ModelNetDataset(data.Dataset): 144 | def __init__(self, 145 | root, 146 | npoints=2500, 147 | split='train', 148 | data_augmentation=True): 149 | self.npoints = npoints 150 | self.root = root 151 | self.split = split 152 | self.data_augmentation = data_augmentation 153 | self.fns = [] 154 | with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f: 155 | for line in f: 156 | self.fns.append(line.strip()) 157 | 158 | self.cat = {} 159 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f: 160 | for line in f: 161 | ls = line.strip().split() 162 | self.cat[ls[0]] = int(ls[1]) 163 | 164 | print(self.cat) 165 | self.classes = list(self.cat.keys()) 166 | 167 | def __getitem__(self, index): 168 | fn = self.fns[index] 169 | cls = self.cat[fn.split('/')[0]] 170 | with open(os.path.join(self.root, fn), 'rb') as f: 171 | plydata = PlyData.read(f) 172 | pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T 173 | choice = np.random.choice(len(pts), self.npoints, replace=True) 174 | point_set = pts[choice, :] 175 | 176 | point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center 177 | dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0) 178 | point_set = point_set / dist # scale 179 | 180 | if self.data_augmentation: 181 | theta = np.random.uniform(0, np.pi * 2) 182 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 183 | point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation 184 | point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter 185 | 186 | point_set = torch.from_numpy(point_set.astype(np.float32)) 187 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 188 | return point_set, cls 189 | 190 | 191 | def __len__(self): 192 | return len(self.fns) 193 | 194 | if __name__ == '__main__': 195 | dataset = sys.argv[1] 196 | datapath = sys.argv[2] 197 | 198 | if dataset == 'shapenet': 199 | d = ShapeNetDataset(root = datapath, class_choice = ['Chair']) 200 | print(len(d)) 201 | ps, seg = d[0] 202 | print(ps.size(), ps.type(), seg.size(),seg.type()) 203 | 204 | d = ShapeNetDataset(root = datapath, classification = True) 205 | print(len(d)) 206 | ps, cls = d[0] 207 | print(ps.size(), ps.type(), cls.size(),cls.type()) 208 | # get_segmentation_classes(datapath) 209 | 210 | if dataset == 'modelnet': 211 | gen_modelnet_id(datapath) 212 | d = ModelNetDataset(root=datapath) 213 | print(len(d)) 214 | print(d[0]) 215 | 216 | -------------------------------------------------------------------------------- /pointnet/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class STN3d(nn.Module): 12 | def __init__(self): 13 | super(STN3d, self).__init__() 14 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 15 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 16 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 17 | self.fc1 = nn.Linear(1024, 512) 18 | self.fc2 = nn.Linear(512, 256) 19 | self.fc3 = nn.Linear(256, 9) 20 | self.relu = nn.ReLU() 21 | 22 | self.bn1 = nn.BatchNorm1d(64) 23 | self.bn2 = nn.BatchNorm1d(128) 24 | self.bn3 = nn.BatchNorm1d(1024) 25 | self.bn4 = nn.BatchNorm1d(512) 26 | self.bn5 = nn.BatchNorm1d(256) 27 | 28 | 29 | def forward(self, x): 30 | batchsize = x.size()[0] 31 | x = F.relu(self.bn1(self.conv1(x))) 32 | x = F.relu(self.bn2(self.conv2(x))) 33 | x = F.relu(self.bn3(self.conv3(x))) 34 | x = torch.max(x, 2, keepdim=True)[0] 35 | x = x.view(-1, 1024) 36 | 37 | x = F.relu(self.bn4(self.fc1(x))) 38 | x = F.relu(self.bn5(self.fc2(x))) 39 | x = self.fc3(x) 40 | 41 | iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) 42 | if x.is_cuda: 43 | iden = iden.cuda() 44 | x = x + iden 45 | x = x.view(-1, 3, 3) 46 | return x 47 | 48 | 49 | class STNkd(nn.Module): 50 | def __init__(self, k=64): 51 | super(STNkd, self).__init__() 52 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 53 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 54 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 55 | self.fc1 = nn.Linear(1024, 512) 56 | self.fc2 = nn.Linear(512, 256) 57 | self.fc3 = nn.Linear(256, k*k) 58 | self.relu = nn.ReLU() 59 | 60 | self.bn1 = nn.BatchNorm1d(64) 61 | self.bn2 = nn.BatchNorm1d(128) 62 | self.bn3 = nn.BatchNorm1d(1024) 63 | self.bn4 = nn.BatchNorm1d(512) 64 | self.bn5 = nn.BatchNorm1d(256) 65 | 66 | self.k = k 67 | 68 | def forward(self, x): 69 | batchsize = x.size()[0] 70 | x = F.relu(self.bn1(self.conv1(x))) 71 | x = F.relu(self.bn2(self.conv2(x))) 72 | x = F.relu(self.bn3(self.conv3(x))) 73 | x = torch.max(x, 2, keepdim=True)[0] 74 | x = x.view(-1, 1024) 75 | 76 | x = F.relu(self.bn4(self.fc1(x))) 77 | x = F.relu(self.bn5(self.fc2(x))) 78 | x = self.fc3(x) 79 | 80 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) 81 | if x.is_cuda: 82 | iden = iden.cuda() 83 | x = x + iden 84 | x = x.view(-1, self.k, self.k) 85 | return x 86 | 87 | class PointNetfeat(nn.Module): 88 | def __init__(self, global_feat = True, feature_transform = False): 89 | super(PointNetfeat, self).__init__() 90 | self.stn = STN3d() 91 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 92 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 93 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 94 | self.bn1 = nn.BatchNorm1d(64) 95 | self.bn2 = nn.BatchNorm1d(128) 96 | self.bn3 = nn.BatchNorm1d(1024) 97 | self.global_feat = global_feat 98 | self.feature_transform = feature_transform 99 | if self.feature_transform: 100 | self.fstn = STNkd(k=64) 101 | 102 | def forward(self, x): 103 | n_pts = x.size()[2] 104 | trans = self.stn(x) 105 | x = x.transpose(2, 1) 106 | x = torch.bmm(x, trans) 107 | x = x.transpose(2, 1) 108 | x = F.relu(self.bn1(self.conv1(x))) 109 | 110 | if self.feature_transform: 111 | trans_feat = self.fstn(x) 112 | x = x.transpose(2,1) 113 | x = torch.bmm(x, trans_feat) 114 | x = x.transpose(2,1) 115 | else: 116 | trans_feat = None 117 | 118 | pointfeat = x 119 | x = F.relu(self.bn2(self.conv2(x))) 120 | x = self.bn3(self.conv3(x)) 121 | x = torch.max(x, 2, keepdim=True)[0] 122 | x = x.view(-1, 1024) 123 | if self.global_feat: 124 | return x, trans, trans_feat 125 | else: 126 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 127 | return torch.cat([x, pointfeat], 1), trans, trans_feat 128 | 129 | class PointNetCls(nn.Module): 130 | def __init__(self, k=2, feature_transform=False): 131 | super(PointNetCls, self).__init__() 132 | self.feature_transform = feature_transform 133 | self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) 134 | self.fc1 = nn.Linear(1024, 512) 135 | self.fc2 = nn.Linear(512, 256) 136 | self.fc3 = nn.Linear(256, k) 137 | self.dropout = nn.Dropout(p=0.3) 138 | self.bn1 = nn.BatchNorm1d(512) 139 | self.bn2 = nn.BatchNorm1d(256) 140 | self.relu = nn.ReLU() 141 | 142 | def forward(self, x): 143 | x, trans, trans_feat = self.feat(x) 144 | x = F.relu(self.bn1(self.fc1(x))) 145 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 146 | x = self.fc3(x) 147 | return F.log_softmax(x, dim=1), trans, trans_feat 148 | 149 | 150 | class PointNetDenseCls(nn.Module): 151 | def __init__(self, k = 2, feature_transform=False): 152 | super(PointNetDenseCls, self).__init__() 153 | self.k = k 154 | self.feature_transform=feature_transform 155 | self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) 156 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 157 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 158 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 159 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 160 | self.bn1 = nn.BatchNorm1d(512) 161 | self.bn2 = nn.BatchNorm1d(256) 162 | self.bn3 = nn.BatchNorm1d(128) 163 | 164 | def forward(self, x): 165 | batchsize = x.size()[0] 166 | n_pts = x.size()[2] 167 | x, trans, trans_feat = self.feat(x) 168 | x = F.relu(self.bn1(self.conv1(x))) 169 | x = F.relu(self.bn2(self.conv2(x))) 170 | x = F.relu(self.bn3(self.conv3(x))) 171 | x = self.conv4(x) 172 | x = x.transpose(2,1).contiguous() 173 | x = F.log_softmax(x.view(-1,self.k), dim=-1) 174 | x = x.view(batchsize, n_pts, self.k) 175 | return x, trans, trans_feat 176 | 177 | def feature_transform_regularizer(trans): 178 | d = trans.size()[1] 179 | batchsize = trans.size()[0] 180 | I = torch.eye(d)[None, :, :] 181 | if trans.is_cuda: 182 | I = I.cuda() 183 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) 184 | return loss 185 | 186 | if __name__ == '__main__': 187 | sim_data = Variable(torch.rand(32,3,2500)) 188 | trans = STN3d() 189 | out = trans(sim_data) 190 | print('stn', out.size()) 191 | print('loss', feature_transform_regularizer(out)) 192 | 193 | sim_data_64d = Variable(torch.rand(32, 64, 2500)) 194 | trans = STNkd(k=64) 195 | out = trans(sim_data_64d) 196 | print('stn64d', out.size()) 197 | print('loss', feature_transform_regularizer(out)) 198 | 199 | pointfeat = PointNetfeat(global_feat=True) 200 | out, _, _ = pointfeat(sim_data) 201 | print('global feat', out.size()) 202 | 203 | pointfeat = PointNetfeat(global_feat=False) 204 | out, _, _ = pointfeat(sim_data) 205 | print('point feat', out.size()) 206 | 207 | cls = PointNetCls(k = 5) 208 | out, _, _ = cls(sim_data) 209 | print('class', out.size()) 210 | 211 | seg = PointNetDenseCls(k = 3) 212 | out, _, _ = seg(sim_data) 213 | print('seg', out.size()) 214 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | SCRIPT=`realpath $0` 2 | SCRIPTPATH=`dirname $SCRIPT` 3 | echo $SCRIPTPATH 4 | 5 | g++ -std=c++11 $SCRIPTPATH/../utils/render_balls_so.cpp -o $SCRIPTPATH/../utils/render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 6 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | SCRIPT=`realpath $0` 2 | SCRIPTPATH=`dirname $SCRIPT` 3 | 4 | cd $SCRIPTPATH/.. 5 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate 6 | unzip shapenetcore_partanno_segmentation_benchmark_v0.zip 7 | rm shapenetcore_partanno_segmentation_benchmark_v0.zip 8 | cd - 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # install using 'pip install -e .' 2 | 3 | from setuptools import setup 4 | 5 | setup(name='pointnet', 6 | packages=['pointnet'], 7 | package_dir={'pointnet': 'pointnet'}, 8 | install_requires=['torch', 9 | 'tqdm', 10 | 'plyfile'], 11 | version='0.0.1') 12 | -------------------------------------------------------------------------------- /utils/render_balls_so.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] 0: 87 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll( 88 | show[:, :, 0], 1, axis=0)) 89 | if magnifyBlue >= 2: 90 | show[:, :, 0] = np.maximum(show[:, :, 0], 91 | np.roll(show[:, :, 0], -1, axis=0)) 92 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll( 93 | show[:, :, 0], 1, axis=1)) 94 | if magnifyBlue >= 2: 95 | show[:, :, 0] = np.maximum(show[:, :, 0], 96 | np.roll(show[:, :, 0], -1, axis=1)) 97 | if showrot: 98 | cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)), 99 | (30, showsz - 30), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0)) 100 | cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)), 101 | (30, showsz - 50), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0)) 102 | cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0, 103 | 0.5, cv2.cv.CV_RGB(255, 0, 0)) 104 | changed = True 105 | while True: 106 | if changed: 107 | render() 108 | changed = False 109 | cv2.imshow('show3d', show) 110 | if waittime == 0: 111 | cmd = cv2.waitKey(10) % 256 112 | else: 113 | cmd = cv2.waitKey(waittime) % 256 114 | if cmd == ord('q'): 115 | break 116 | elif cmd == ord('Q'): 117 | sys.exit(0) 118 | 119 | if cmd == ord('t') or cmd == ord('p'): 120 | if cmd == ord('t'): 121 | if c_gt is None: 122 | c0 = np.zeros((len(xyz), ), dtype='float32') + 255 123 | c1 = np.zeros((len(xyz), ), dtype='float32') + 255 124 | c2 = np.zeros((len(xyz), ), dtype='float32') + 255 125 | else: 126 | c0 = c_gt[:, 0] 127 | c1 = c_gt[:, 1] 128 | c2 = c_gt[:, 2] 129 | else: 130 | if c_pred is None: 131 | c0 = np.zeros((len(xyz), ), dtype='float32') + 255 132 | c1 = np.zeros((len(xyz), ), dtype='float32') + 255 133 | c2 = np.zeros((len(xyz), ), dtype='float32') + 255 134 | else: 135 | c0 = c_pred[:, 0] 136 | c1 = c_pred[:, 1] 137 | c2 = c_pred[:, 2] 138 | if normalizecolor: 139 | c0 /= (c0.max() + 1e-14) / 255.0 140 | c1 /= (c1.max() + 1e-14) / 255.0 141 | c2 /= (c2.max() + 1e-14) / 255.0 142 | c0 = np.require(c0, 'float32', 'C') 143 | c1 = np.require(c1, 'float32', 'C') 144 | c2 = np.require(c2, 'float32', 'C') 145 | changed = True 146 | 147 | if cmd==ord('n'): 148 | zoom*=1.1 149 | changed=True 150 | elif cmd==ord('m'): 151 | zoom/=1.1 152 | changed=True 153 | elif cmd==ord('r'): 154 | zoom=1.0 155 | changed=True 156 | elif cmd==ord('s'): 157 | cv2.imwrite('show3d.png',show) 158 | if waittime!=0: 159 | break 160 | return cmd 161 | 162 | if __name__ == '__main__': 163 | np.random.seed(100) 164 | showpoints(np.random.randn(2500, 3)) 165 | -------------------------------------------------------------------------------- /utils/show_cls.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | from pointnet.dataset import ShapeNetDataset 8 | from pointnet.model import PointNetCls 9 | import torch.nn.functional as F 10 | 11 | 12 | #showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500))) 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--model', type=str, default = '', help='model path') 17 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 18 | 19 | 20 | opt = parser.parse_args() 21 | print(opt) 22 | 23 | test_dataset = ShapeNetDataset( 24 | root='shapenetcore_partanno_segmentation_benchmark_v0', 25 | split='test', 26 | classification=True, 27 | npoints=opt.num_points, 28 | data_augmentation=False) 29 | 30 | testdataloader = torch.utils.data.DataLoader( 31 | test_dataset, batch_size=32, shuffle=True) 32 | 33 | classifier = PointNetCls(k=len(test_dataset.classes)) 34 | classifier.cuda() 35 | classifier.load_state_dict(torch.load(opt.model)) 36 | classifier.eval() 37 | 38 | 39 | for i, data in enumerate(testdataloader, 0): 40 | points, target = data 41 | points, target = Variable(points), Variable(target[:, 0]) 42 | points = points.transpose(2, 1) 43 | points, target = points.cuda(), target.cuda() 44 | pred, _, _ = classifier(points) 45 | loss = F.nll_loss(pred, target) 46 | 47 | pred_choice = pred.data.max(1)[1] 48 | correct = pred_choice.eq(target.data).cpu().sum() 49 | print('i:%d loss: %f accuracy: %f' % (i, loss.data.item(), correct / float(32))) 50 | -------------------------------------------------------------------------------- /utils/show_seg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from show3d_balls import showpoints 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn.parallel 7 | import torch.utils.data 8 | from torch.autograd import Variable 9 | from pointnet.dataset import ShapeNetDataset 10 | from pointnet.model import PointNetDenseCls 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | #showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500))) 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('--model', type=str, default='', help='model path') 19 | parser.add_argument('--idx', type=int, default=0, help='model index') 20 | parser.add_argument('--dataset', type=str, default='', help='dataset path') 21 | parser.add_argument('--class_choice', type=str, default='', help='class choice') 22 | 23 | opt = parser.parse_args() 24 | print(opt) 25 | 26 | d = ShapeNetDataset( 27 | root=opt.dataset, 28 | class_choice=[opt.class_choice], 29 | split='test', 30 | data_augmentation=False) 31 | 32 | idx = opt.idx 33 | 34 | print("model %d/%d" % (idx, len(d))) 35 | point, seg = d[idx] 36 | print(point.size(), seg.size()) 37 | point_np = point.numpy() 38 | 39 | cmap = plt.cm.get_cmap("hsv", 10) 40 | cmap = np.array([cmap(i) for i in range(10)])[:, :3] 41 | gt = cmap[seg.numpy() - 1, :] 42 | 43 | state_dict = torch.load(opt.model) 44 | classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0]) 45 | classifier.load_state_dict(state_dict) 46 | classifier.eval() 47 | 48 | point = point.transpose(1, 0).contiguous() 49 | 50 | point = Variable(point.view(1, point.size()[0], point.size()[1])) 51 | pred, _, _ = classifier(point) 52 | pred_choice = pred.data.max(2)[1] 53 | print(pred_choice) 54 | 55 | #print(pred_choice.size()) 56 | pred_color = cmap[pred_choice.numpy()[0], :] 57 | 58 | #print(pred_color.shape) 59 | showpoints(point_np, gt, pred_color) 60 | -------------------------------------------------------------------------------- /utils/train_classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from pointnet.dataset import ShapeNetDataset, ModelNetDataset 10 | from pointnet.model import PointNetCls, feature_transform_regularizer 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '--batchSize', type=int, default=32, help='input batch size') 18 | parser.add_argument( 19 | '--num_points', type=int, default=2500, help='input batch size') 20 | parser.add_argument( 21 | '--workers', type=int, help='number of data loading workers', default=4) 22 | parser.add_argument( 23 | '--nepoch', type=int, default=250, help='number of epochs to train for') 24 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 25 | parser.add_argument('--model', type=str, default='', help='model path') 26 | parser.add_argument('--dataset', type=str, required=True, help="dataset path") 27 | parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40") 28 | parser.add_argument('--feature_transform', action='store_true', help="use feature transform") 29 | 30 | opt = parser.parse_args() 31 | print(opt) 32 | 33 | blue = lambda x: '\033[94m' + x + '\033[0m' 34 | 35 | opt.manualSeed = random.randint(1, 10000) # fix seed 36 | print("Random Seed: ", opt.manualSeed) 37 | random.seed(opt.manualSeed) 38 | torch.manual_seed(opt.manualSeed) 39 | 40 | if opt.dataset_type == 'shapenet': 41 | dataset = ShapeNetDataset( 42 | root=opt.dataset, 43 | classification=True, 44 | npoints=opt.num_points) 45 | 46 | test_dataset = ShapeNetDataset( 47 | root=opt.dataset, 48 | classification=True, 49 | split='test', 50 | npoints=opt.num_points, 51 | data_augmentation=False) 52 | elif opt.dataset_type == 'modelnet40': 53 | dataset = ModelNetDataset( 54 | root=opt.dataset, 55 | npoints=opt.num_points, 56 | split='trainval') 57 | 58 | test_dataset = ModelNetDataset( 59 | root=opt.dataset, 60 | split='test', 61 | npoints=opt.num_points, 62 | data_augmentation=False) 63 | else: 64 | exit('wrong dataset type') 65 | 66 | 67 | dataloader = torch.utils.data.DataLoader( 68 | dataset, 69 | batch_size=opt.batchSize, 70 | shuffle=True, 71 | num_workers=int(opt.workers)) 72 | 73 | testdataloader = torch.utils.data.DataLoader( 74 | test_dataset, 75 | batch_size=opt.batchSize, 76 | shuffle=True, 77 | num_workers=int(opt.workers)) 78 | 79 | print(len(dataset), len(test_dataset)) 80 | num_classes = len(dataset.classes) 81 | print('classes', num_classes) 82 | 83 | try: 84 | os.makedirs(opt.outf) 85 | except OSError: 86 | pass 87 | 88 | classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform) 89 | 90 | if opt.model != '': 91 | classifier.load_state_dict(torch.load(opt.model)) 92 | 93 | 94 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) 95 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 96 | classifier.cuda() 97 | 98 | num_batch = len(dataset) / opt.batchSize 99 | 100 | for epoch in range(opt.nepoch): 101 | scheduler.step() 102 | for i, data in enumerate(dataloader, 0): 103 | points, target = data 104 | target = target[:, 0] 105 | points = points.transpose(2, 1) 106 | points, target = points.cuda(), target.cuda() 107 | optimizer.zero_grad() 108 | classifier = classifier.train() 109 | pred, trans, trans_feat = classifier(points) 110 | loss = F.nll_loss(pred, target) 111 | if opt.feature_transform: 112 | loss += feature_transform_regularizer(trans_feat) * 0.001 113 | loss.backward() 114 | optimizer.step() 115 | pred_choice = pred.data.max(1)[1] 116 | correct = pred_choice.eq(target.data).cpu().sum() 117 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize))) 118 | 119 | if i % 10 == 0: 120 | j, data = next(enumerate(testdataloader, 0)) 121 | points, target = data 122 | target = target[:, 0] 123 | points = points.transpose(2, 1) 124 | points, target = points.cuda(), target.cuda() 125 | classifier = classifier.eval() 126 | pred, _, _ = classifier(points) 127 | loss = F.nll_loss(pred, target) 128 | pred_choice = pred.data.max(1)[1] 129 | correct = pred_choice.eq(target.data).cpu().sum() 130 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 131 | 132 | torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 133 | 134 | total_correct = 0 135 | total_testset = 0 136 | for i,data in tqdm(enumerate(testdataloader, 0)): 137 | points, target = data 138 | target = target[:, 0] 139 | points = points.transpose(2, 1) 140 | points, target = points.cuda(), target.cuda() 141 | classifier = classifier.eval() 142 | pred, _, _ = classifier(points) 143 | pred_choice = pred.data.max(1)[1] 144 | correct = pred_choice.eq(target.data).cpu().sum() 145 | total_correct += correct.item() 146 | total_testset += points.size()[0] 147 | 148 | print("final accuracy {}".format(total_correct / float(total_testset))) -------------------------------------------------------------------------------- /utils/train_segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from pointnet.dataset import ShapeNetDataset 10 | from pointnet.model import PointNetDenseCls, feature_transform_regularizer 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '--batchSize', type=int, default=32, help='input batch size') 19 | parser.add_argument( 20 | '--workers', type=int, help='number of data loading workers', default=4) 21 | parser.add_argument( 22 | '--nepoch', type=int, default=25, help='number of epochs to train for') 23 | parser.add_argument('--outf', type=str, default='seg', help='output folder') 24 | parser.add_argument('--model', type=str, default='', help='model path') 25 | parser.add_argument('--dataset', type=str, required=True, help="dataset path") 26 | parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice") 27 | parser.add_argument('--feature_transform', action='store_true', help="use feature transform") 28 | 29 | opt = parser.parse_args() 30 | print(opt) 31 | 32 | opt.manualSeed = random.randint(1, 10000) # fix seed 33 | print("Random Seed: ", opt.manualSeed) 34 | random.seed(opt.manualSeed) 35 | torch.manual_seed(opt.manualSeed) 36 | 37 | dataset = ShapeNetDataset( 38 | root=opt.dataset, 39 | classification=False, 40 | class_choice=[opt.class_choice]) 41 | dataloader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=opt.batchSize, 44 | shuffle=True, 45 | num_workers=int(opt.workers)) 46 | 47 | test_dataset = ShapeNetDataset( 48 | root=opt.dataset, 49 | classification=False, 50 | class_choice=[opt.class_choice], 51 | split='test', 52 | data_augmentation=False) 53 | testdataloader = torch.utils.data.DataLoader( 54 | test_dataset, 55 | batch_size=opt.batchSize, 56 | shuffle=True, 57 | num_workers=int(opt.workers)) 58 | 59 | print(len(dataset), len(test_dataset)) 60 | num_classes = dataset.num_seg_classes 61 | print('classes', num_classes) 62 | try: 63 | os.makedirs(opt.outf) 64 | except OSError: 65 | pass 66 | 67 | blue = lambda x: '\033[94m' + x + '\033[0m' 68 | 69 | classifier = PointNetDenseCls(k=num_classes, feature_transform=opt.feature_transform) 70 | 71 | if opt.model != '': 72 | classifier.load_state_dict(torch.load(opt.model)) 73 | 74 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) 75 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 76 | classifier.cuda() 77 | 78 | num_batch = len(dataset) / opt.batchSize 79 | 80 | for epoch in range(opt.nepoch): 81 | scheduler.step() 82 | for i, data in enumerate(dataloader, 0): 83 | points, target = data 84 | points = points.transpose(2, 1) 85 | points, target = points.cuda(), target.cuda() 86 | optimizer.zero_grad() 87 | classifier = classifier.train() 88 | pred, trans, trans_feat = classifier(points) 89 | pred = pred.view(-1, num_classes) 90 | target = target.view(-1, 1)[:, 0] - 1 91 | #print(pred.size(), target.size()) 92 | loss = F.nll_loss(pred, target) 93 | if opt.feature_transform: 94 | loss += feature_transform_regularizer(trans_feat) * 0.001 95 | loss.backward() 96 | optimizer.step() 97 | pred_choice = pred.data.max(1)[1] 98 | correct = pred_choice.eq(target.data).cpu().sum() 99 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500))) 100 | 101 | if i % 10 == 0: 102 | j, data = next(enumerate(testdataloader, 0)) 103 | points, target = data 104 | points = points.transpose(2, 1) 105 | points, target = points.cuda(), target.cuda() 106 | classifier = classifier.eval() 107 | pred, _, _ = classifier(points) 108 | pred = pred.view(-1, num_classes) 109 | target = target.view(-1, 1)[:, 0] - 1 110 | loss = F.nll_loss(pred, target) 111 | pred_choice = pred.data.max(1)[1] 112 | correct = pred_choice.eq(target.data).cpu().sum() 113 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500))) 114 | 115 | torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch)) 116 | 117 | ## benchmark mIOU 118 | shape_ious = [] 119 | for i,data in tqdm(enumerate(testdataloader, 0)): 120 | points, target = data 121 | points = points.transpose(2, 1) 122 | points, target = points.cuda(), target.cuda() 123 | classifier = classifier.eval() 124 | pred, _, _ = classifier(points) 125 | pred_choice = pred.data.max(2)[1] 126 | 127 | pred_np = pred_choice.cpu().data.numpy() 128 | target_np = target.cpu().data.numpy() - 1 129 | 130 | for shape_idx in range(target_np.shape[0]): 131 | parts = range(num_classes)#np.unique(target_np[shape_idx]) 132 | part_ious = [] 133 | for part in parts: 134 | I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) 135 | U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part)) 136 | if U == 0: 137 | iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1 138 | else: 139 | iou = I / float(U) 140 | part_ious.append(iou) 141 | shape_ious.append(np.mean(part_ious)) 142 | 143 | print("mIOU for class {}: {}".format(opt.class_choice, np.mean(shape_ious))) --------------------------------------------------------------------------------