├── .gitignore ├── LICENSE ├── README.md ├── misc ├── modelnet_id.txt ├── num_seg_classes.txt └── show3d.png ├── pointnet ├── __init__.py ├── dataset.py └── model.py ├── scripts ├── build.sh └── download.sh ├── setup.py └── utils ├── render_balls_so.cpp ├── show3d_balls.py ├── show_cls.py ├── show_seg.py ├── train_classification.py └── train_segmentation.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | data 3 | *.pyc 4 | *.ipynb 5 | shapenetcore_partanno_segmentation_benchmark_v0/ 6 | *.so 7 | .idea* 8 | cls/ 9 | seg/ 10 | *.egg-info/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Fei Xia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointNet.pytorch 2 | This repo is implementation for PointNet(https://arxiv.org/abs/1612.00593) in pytorch. The model is in `pointnet/model.py`. 3 | 4 | It is tested with pytorch-1.0. 5 | 6 | # Download data and running 7 | 8 | ``` 9 | git clone https://github.com/fxia22/pointnet.pytorch 10 | cd pointnet.pytorch 11 | pip install -e . 12 | ``` 13 | 14 | Download and build visualization tool 15 | ``` 16 | cd script 17 | bash build.sh #build C++ code for visualization 18 | bash download.sh #download dataset 19 | ``` 20 | 21 | Training 22 | ``` 23 | cd utils 24 | python train_classification.py --dataset --nepoch= --dataset_type 25 | python train_segmentation.py --dataset --nepoch= 26 | ``` 27 | 28 | Use `--feature_transform` to use feature transform. 29 | 30 | # Performance 31 | 32 | ## Classification performance 33 | 34 | On ModelNet40: 35 | 36 | | | Overall Acc | 37 | | :---: | :---: | 38 | | Original implementation | 89.2 | 39 | | this implementation(w/o feature transform) | 86.4 | 40 | | this implementation(w/ feature transform) | 87.0 | 41 | 42 | On [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html) 43 | 44 | | | Overall Acc | 45 | | :---: | :---: | 46 | | Original implementation | N/A | 47 | | this implementation(w/o feature transform) | 98.1 | 48 | | this implementation(w/ feature transform) | 97.7 | 49 | 50 | ## Segmentation performance 51 | 52 | Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html). 53 | 54 | | Class(mIOU) | Airplane | Bag| Cap|Car|Chair|Earphone|Guitar|Knife|Lamp|Laptop|Motorbike|Mug|Pistol|Rocket|Skateboard|Table 55 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 56 | | Original implementation | 83.4 | 78.7 | 82.5| 74.9 |89.6| 73.0| 91.5| 85.9| 80.8| 95.3| 65.2| 93.0| 81.2| 57.9| 72.8| 80.6| 57 | | this implementation(w/o feature transform) | 73.5 | 71.3 | 64.3 | 61.1 | 87.2 | 69.5 | 86.1|81.6| 77.4|92.7|41.3|86.5|78.2|41.2|61.0|81.1| 58 | | this implementation(w/ feature transform) | | | | | 87.6 | | | | | | | | | | |81.0| 59 | 60 | Note that this implementation trains each class separately, so classes with fewer data will have slightly lower performance than reference implementation. 61 | 62 | Sample segmentation result: 63 | ![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D) 64 | 65 | # Links 66 | 67 | - [Project Page](http://stanford.edu/~rqi/pointnet/) 68 | - [Tensorflow implementation](https://github.com/charlesq34/pointnet) 69 | -------------------------------------------------------------------------------- /misc/modelnet_id.txt: -------------------------------------------------------------------------------- 1 | airplane 0 2 | bathtub 1 3 | bed 2 4 | bench 3 5 | bookshelf 4 6 | bottle 5 7 | bowl 6 8 | car 7 9 | chair 8 10 | cone 9 11 | cup 10 12 | curtain 11 13 | desk 12 14 | door 13 15 | dresser 14 16 | flower_pot 15 17 | glass_box 16 18 | guitar 17 19 | keyboard 18 20 | lamp 19 21 | laptop 20 22 | mantel 21 23 | monitor 22 24 | night_stand 23 25 | person 24 26 | piano 25 27 | plant 26 28 | radio 27 29 | range_hood 28 30 | sink 29 31 | sofa 30 32 | stairs 31 33 | stool 32 34 | table 33 35 | tent 34 36 | toilet 35 37 | tv_stand 36 38 | vase 37 39 | wardrobe 38 40 | xbox 39 41 | -------------------------------------------------------------------------------- /misc/num_seg_classes.txt: -------------------------------------------------------------------------------- 1 | Airplane 4 2 | Bag 2 3 | Cap 2 4 | Car 4 5 | Chair 4 6 | Earphone 3 7 | Guitar 3 8 | Knife 2 9 | Lamp 4 10 | Laptop 2 11 | Motorbike 6 12 | Mug 2 13 | Pistol 3 14 | Rocket 3 15 | Skateboard 3 16 | Table 3 17 | -------------------------------------------------------------------------------- /misc/show3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dir-b/PointNet/5d0df518732c094d4b75e43dea5fcec41e7fb562/misc/show3d.png -------------------------------------------------------------------------------- /pointnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dir-b/PointNet/5d0df518732c094d4b75e43dea5fcec41e7fb562/pointnet/__init__.py -------------------------------------------------------------------------------- /pointnet/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | import os 4 | import os.path 5 | import torch 6 | import numpy as np 7 | import sys 8 | from tqdm import tqdm 9 | import json 10 | from plyfile import PlyData, PlyElement 11 | 12 | 13 | def get_segmentation_classes(root): 14 | catfile = os.path.join(root, 'synsetoffset2category.txt') 15 | cat = {} 16 | meta = {} 17 | 18 | with open(catfile, 'r') as f: 19 | for line in f: 20 | ls = line.strip().split() 21 | cat[ls[0]] = ls[1] 22 | 23 | for item in cat: 24 | dir_seg = os.path.join(root, cat[item], 'points_label') 25 | dir_point = os.path.join(root, cat[item], 'points') 26 | fns = sorted(os.listdir(dir_point)) 27 | meta[item] = [] 28 | for fn in fns: 29 | token = (os.path.splitext(os.path.basename(fn))[0]) 30 | meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'))) 31 | 32 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f: 33 | for item in cat: 34 | datapath = [] 35 | num_seg_classes = 0 36 | for fn in meta[item]: 37 | datapath.append((item, fn[0], fn[1])) 38 | 39 | for i in tqdm(range(len(datapath))): 40 | l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8))) 41 | if l > num_seg_classes: 42 | num_seg_classes = l 43 | 44 | print("category {} num segmentation classes {}".format(item, num_seg_classes)) 45 | f.write("{}\t{}\n".format(item, num_seg_classes)) 46 | 47 | 48 | def gen_modelnet_id(root): 49 | classes = [] 50 | with open(os.path.join(root, 'train.txt'), 'r') as f: 51 | for line in f: 52 | classes.append(line.strip().split('/')[0]) 53 | classes = np.unique(classes) 54 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f: 55 | for i in range(len(classes)): 56 | f.write('{}\t{}\n'.format(classes[i], i)) 57 | 58 | 59 | class ShapeNetDataset(data.Dataset): 60 | def __init__(self, 61 | root, 62 | npoints=2500, 63 | classification=False, 64 | class_choice=None, 65 | split='train', 66 | data_augmentation=True): 67 | self.npoints = npoints 68 | self.root = root 69 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') # 路径拼接 70 | self.cat = {} 71 | self.data_augmentation = data_augmentation # 数据扩充 72 | self.classification = classification 73 | self.seg_classes = {} 74 | 75 | # with expression [as target]: expression-需要执行的表达式;target-变量或元祖,存储expression执行的结果 76 | with open(self.catfile, 'r') as f: # 打开目录txt文件,'r':open for reading 77 | for line in f: 78 | # strip():移除字符串头尾指定的字符(默认为空格或换行符) 79 | # split():指定分隔符对字符串进行切片,返回分割后的字符串列表(默认为所有的空字符,包括空格、换行\n、制表符\t等) 80 | ls = line.strip().split() 81 | self.cat[ls[0]] = ls[1] # cat为字典,通过[键]索引。键:类别;值:文件夹名称 82 | # print(self.cat) 83 | if not class_choice is None: # 类别选择,对那些种类物体进行分类 84 | self.cat = {k: v for k, v in self.cat.items() if k in class_choice} 85 | 86 | self.id2cat = {v: k for k, v in self.cat.items()} # key和value互换 87 | 88 | self.meta = {} 89 | # json文件类似xml文件,可存储键值对和数组等 90 | # split=train 91 | # format():字符串格式化函数,使用{}代替之前的% 92 | splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split)) 93 | # from IPython import embed; embed() 94 | filelist = json.load(open(splitfile, 'r')) 95 | # for item in self.cat:item为键 96 | # for item in self.cat.values():item为值 97 | # for item in self.cat.items():item为键值对(元组的形式) 98 | # for k, v in self.cat.items():更为规范的键值对读取方式 99 | for item in self.cat: 100 | self.meta[item] = [] # meta为字典,键为类别,键值为空 101 | 102 | for file in filelist: # 读取shuffled_train_file_list.json 103 | _, category, uuid = file.split('/') # category为某一类别所在文件夹,uuid为某一类别的某一个 104 | if category in self.cat.values(): 105 | # points_label路径生成,包括原始点云及用于分割的标签 106 | self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid + '.pts'), 107 | os.path.join(self.root, category, 'points_label', 108 | uuid + '.seg'))) 109 | 110 | self.datapath = [] 111 | for item in self.cat: # cat存储类别及其所在文件夹,item访问键,即类别 112 | for fn in self.meta[item]: # meta为字典,fn访问值,即路径 113 | self.datapath.append((item, fn[0], fn[1])) # item为类别,fn[0]为点云路径,fn[1]为用于分割的标签路径 114 | # sorted():对所有可迭代兑现进行排序,默认为升序;sorted(self.cat)对字典cat中的键(种类)进行排序 115 | # zip(): 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组 116 | # dict(): 创建字典。dict(zip(['one', 'two'], [1, 2])) -> {'two': 2, 'one': 1} 117 | # 下列操作实现了对类别进行数字编码表示 118 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) 119 | print(self.classes) 120 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f: 121 | for line in f: 122 | ls = line.strip().split() 123 | self.seg_classes[ls[0]] = int(ls[1]) 124 | self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]] 125 | print(self.seg_classes, self.num_seg_classes) 126 | 127 | # 该方法的实例对象可通过索引取值,自动调用该方法 128 | def __getitem__(self, index): 129 | fn = self.datapath[index] # 获取类别、点云路径、分割标签路径元组 130 | cls = self.classes[self.datapath[index][0]] # 获取数字编码的类别标签 131 | point_set = np.loadtxt(fn[1]).astype(np.float32) # 读取pts点云 132 | seg = np.loadtxt(fn[2]).astype(np.int64) # 读取分割标签 133 | # print(point_set.shape, seg.shape) 134 | 135 | choice = np.random.choice(len(seg), self.npoints, replace=True) 136 | # resample 137 | point_set = point_set[choice, :] 138 | 139 | point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center 140 | dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0) 141 | point_set = point_set / dist # scale 142 | 143 | if self.data_augmentation: # 数据增强 144 | theta = np.random.uniform(0, np.pi * 2) 145 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 146 | point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation 147 | point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter 148 | 149 | seg = seg[choice] 150 | point_set = torch.from_numpy(point_set) 151 | seg = torch.from_numpy(seg) 152 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 153 | 154 | if self.classification: 155 | return point_set, cls 156 | else: 157 | return point_set, seg 158 | 159 | def __len__(self): 160 | return len(self.datapath) 161 | 162 | 163 | class ModelNetDataset(data.Dataset): 164 | def __init__(self, 165 | root, 166 | npoints=2500, 167 | split='train', 168 | data_augmentation=True): 169 | self.npoints = npoints 170 | self.root = root 171 | self.split = split 172 | self.data_augmentation = data_augmentation 173 | self.fns = [] 174 | with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f: 175 | for line in f: 176 | self.fns.append(line.strip()) 177 | 178 | self.cat = {} 179 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f: 180 | for line in f: 181 | ls = line.strip().split() 182 | self.cat[ls[0]] = int(ls[1]) 183 | 184 | print(self.cat) 185 | self.classes = list(self.cat.keys()) 186 | 187 | def __getitem__(self, index): 188 | fn = self.fns[index] 189 | cls = self.cat[fn.split('/')[0]] 190 | with open(os.path.join(self.root, fn), 'rb') as f: 191 | plydata = PlyData.read(f) 192 | pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T 193 | choice = np.random.choice(len(pts), self.npoints, replace=True) 194 | point_set = pts[choice, :] 195 | 196 | point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center 197 | dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0) 198 | point_set = point_set / dist # scale 199 | 200 | if self.data_augmentation: 201 | theta = np.random.uniform(0, np.pi * 2) 202 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 203 | point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation 204 | point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter 205 | 206 | point_set = torch.from_numpy(point_set.astype(np.float32)) 207 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 208 | return point_set, cls 209 | 210 | def __len__(self): 211 | return len(self.fns) 212 | 213 | 214 | if __name__ == '__main__': 215 | dataset = sys.argv[1] 216 | datapath = sys.argv[2] 217 | 218 | if dataset == 'shapenet': 219 | d = ShapeNetDataset(root=datapath, class_choice=['Chair']) 220 | print(len(d)) 221 | ps, seg = d[0] 222 | print(ps.size(), ps.type(), seg.size(), seg.type()) 223 | 224 | d = ShapeNetDataset(root=datapath, classification=True) 225 | print(len(d)) 226 | ps, cls = d[0] 227 | print(ps.size(), ps.type(), cls.size(), cls.type()) 228 | # get_segmentation_classes(datapath) 229 | 230 | if dataset == 'modelnet': 231 | gen_modelnet_id(datapath) 232 | d = ModelNetDataset(root=datapath) 233 | print(len(d)) 234 | print(d[0]) 235 | -------------------------------------------------------------------------------- /pointnet/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | # T-Net: is a pointnet itself.获取3x3的变换矩阵,校正点云姿态;效果一般,后续的改进并没有再加入这部分 11 | # 经过全连接层映射到9个数据,最后调整为3x3矩阵 12 | class STN3d(nn.Module): 13 | def __init__(self): 14 | super(STN3d, self).__init__() 15 | # mlp 16 | self.conv1 = torch.nn.Conv1d(3, 64, 1) # in out kernel_size 17 | # mlp 18 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 19 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 20 | # fc 21 | self.fc1 = nn.Linear(1024, 512) 22 | self.fc2 = nn.Linear(512, 256) 23 | self.fc3 = nn.Linear(256, 9) 24 | # 25 | self.relu = nn.ReLU() 26 | # bn 27 | self.bn1 = nn.BatchNorm1d(64) 28 | self.bn2 = nn.BatchNorm1d(128) 29 | self.bn3 = nn.BatchNorm1d(1024) 30 | self.bn4 = nn.BatchNorm1d(512) 31 | self.bn5 = nn.BatchNorm1d(256) 32 | 33 | 34 | def forward(self, x): 35 | batchsize = x.size()[0] 36 | x = F.relu(self.bn1(self.conv1(x))) 37 | x = F.relu(self.bn2(self.conv2(x))) 38 | x = F.relu(self.bn3(self.conv3(x))) 39 | x = torch.max(x, 2, keepdim=True)[0] 40 | x = x.view(-1, 1024) 41 | 42 | x = F.relu(self.bn4(self.fc1(x))) 43 | x = F.relu(self.bn5(self.fc2(x))) 44 | x = self.fc3(x) 45 | # Variable已被弃用,之前的版本中,pytorch的tensor只能在CPU计算,Variable将tensor转换成variable,具有三个属性(data\grad\grad_fn) 46 | # 现在二者已经融合,Variable返回tensor 47 | # iden生成单位变换矩阵 48 | # repeat(batchsize, 1),重复batchsize次,生成batchsize x 9的tensor 49 | iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) 50 | if x.is_cuda: 51 | iden = iden.cuda() 52 | x = x + iden 53 | # view()相当于numpy中的resize(),重构tensor维度,-1表示缺省参数由系统自动计算(为batchsize大小) 54 | # 返回结果为 batchsize x 3 x 3 55 | x = x.view(-1, 3, 3) 56 | return x 57 | 58 | # 数据为k维,用于mlp之后的高维特征,同上 59 | class STNkd(nn.Module): 60 | def __init__(self, k=64): 61 | super(STNkd, self).__init__() 62 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 63 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 64 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 65 | self.fc1 = nn.Linear(1024, 512) 66 | self.fc2 = nn.Linear(512, 256) 67 | self.fc3 = nn.Linear(256, k*k) 68 | self.relu = nn.ReLU() 69 | 70 | self.bn1 = nn.BatchNorm1d(64) 71 | self.bn2 = nn.BatchNorm1d(128) 72 | self.bn3 = nn.BatchNorm1d(1024) 73 | self.bn4 = nn.BatchNorm1d(512) 74 | self.bn5 = nn.BatchNorm1d(256) 75 | 76 | self.k = k 77 | 78 | def forward(self, x): 79 | batchsize = x.size()[0] 80 | x = F.relu(self.bn1(self.conv1(x))) 81 | x = F.relu(self.bn2(self.conv2(x))) 82 | x = F.relu(self.bn3(self.conv3(x))) 83 | x = torch.max(x, 2, keepdim=True)[0] 84 | x = x.view(-1, 1024) 85 | 86 | x = F.relu(self.bn4(self.fc1(x))) 87 | x = F.relu(self.bn5(self.fc2(x))) 88 | x = self.fc3(x) 89 | 90 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) 91 | if x.is_cuda: 92 | iden = iden.cuda() 93 | x = x + iden 94 | x = x.view(-1, self.k, self.k) 95 | return x 96 | # backbone 97 | class PointNetfeat(nn.Module): 98 | def __init__(self, global_feat = True, feature_transform = False): 99 | super(PointNetfeat, self).__init__() 100 | self.stn = STN3d() 101 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 102 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 103 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 104 | self.bn1 = nn.BatchNorm1d(64) 105 | self.bn2 = nn.BatchNorm1d(128) 106 | self.bn3 = nn.BatchNorm1d(1024) 107 | self.global_feat = global_feat 108 | self.feature_transform = feature_transform 109 | # mlp之后的64高维数据,feature transform 110 | if self.feature_transform: 111 | self.fstn = STNkd(k=64) 112 | 113 | def forward(self, x): 114 | n_pts = x.size()[2] # size()返回张量各个维度的尺度 115 | trans = self.stn(x) # STN3网络 调整姿态 116 | x = x.transpose(2, 1) # 交换tensor的两个维度,将点云数据转换为3xn形式,便于和旋转矩阵计算 117 | x = torch.bmm(x, trans) # 两个batch矩阵乘法 118 | x = x.transpose(2, 1) # 计算完成,转换为原始形式 119 | x = F.relu(self.bn1(self.conv1(x))) # 第一次mlp,每个点由3维升为64维 120 | 121 | # 是否进行feature_transform 122 | if self.feature_transform: 123 | trans_feat = self.fstn(x) 124 | x = x.transpose(2,1) 125 | x = torch.bmm(x, trans_feat) 126 | x = x.transpose(2,1) 127 | else: 128 | trans_feat = None 129 | 130 | pointfeat = x # 保留经过第一次mlp的特征,便于后续分割进行特征拼接融合 131 | x = F.relu(self.bn2(self.conv2(x))) # 第二次mlp的第一层,64->128 132 | x = self.bn3(self.conv3(x)) # 第二次mlp的第二层,128->1024 133 | x = torch.max(x, 2, keepdim=True)[0] # pointnet的核心操作,最大池化操作保证了点云的置换不变性(最大池化操作为对称函数) 134 | x = x.view(-1, 1024) # resize池化结果的形状,获得全局1024维特征 135 | if self.global_feat: # 全局特征,true:不进行局部特征的连接,用于分类;false进行局部特征的连接,用于分割 136 | return x, trans, trans_feat 137 | else: 138 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 139 | return torch.cat([x, pointfeat], 1), trans, trans_feat 140 | 141 | # 分类网络 142 | class PointNetCls(nn.Module): 143 | def __init__(self, k=2, feature_transform=False): 144 | super(PointNetCls, self).__init__() 145 | self.feature_transform = feature_transform 146 | self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) 147 | self.fc1 = nn.Linear(1024, 512) 148 | self.fc2 = nn.Linear(512, 256) 149 | self.fc3 = nn.Linear(256, k) # k为类别数目 150 | self.dropout = nn.Dropout(p=0.3) 151 | self.bn1 = nn.BatchNorm1d(512) 152 | self.bn2 = nn.BatchNorm1d(256) 153 | self.relu = nn.ReLU() 154 | 155 | # 在执行model(data)时,forward()函数自动调用 156 | def forward(self, x): 157 | x, trans, trans_feat = self.feat(x) # backbone 158 | x = F.relu(self.bn1(self.fc1(x))) # 第三次mlp的第一层:1024->512 159 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) # 第三次mlp的第二层:512->256 160 | x = self.fc3(x) # 全连接得到k维 161 | return F.log_softmax(x, dim=1), trans, trans_feat # log_softmax分类,解决softmax在计算e的次方时容易造成的上溢出和下溢出问题 162 | 163 | # 分割网络 164 | class PointNetDenseCls(nn.Module): 165 | def __init__(self, k = 2, feature_transform=False): 166 | super(PointNetDenseCls, self).__init__() 167 | self.k = k 168 | self.feature_transform=feature_transform 169 | self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) 170 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 171 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 172 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 173 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 174 | self.bn1 = nn.BatchNorm1d(512) 175 | self.bn2 = nn.BatchNorm1d(256) 176 | self.bn3 = nn.BatchNorm1d(128) 177 | 178 | def forward(self, x): 179 | batchsize = x.size()[0] # size()返回张量各个维度的尺度 180 | n_pts = x.size()[2] # 每个物体的点数 181 | x, trans, trans_feat = self.feat(x) # backbone 182 | x = F.relu(self.bn1(self.conv1(x))) 183 | x = F.relu(self.bn2(self.conv2(x))) 184 | x = F.relu(self.bn3(self.conv3(x))) 185 | x = self.conv4(x) 186 | x = x.transpose(2,1).contiguous() 187 | x = F.log_softmax(x.view(-1,self.k), dim=-1) 188 | x = x.view(batchsize, n_pts, self.k) 189 | return x, trans, trans_feat 190 | 191 | def feature_transform_regularizer(trans): 192 | d = trans.size()[1] 193 | batchsize = trans.size()[0] 194 | I = torch.eye(d)[None, :, :] 195 | if trans.is_cuda: 196 | I = I.cuda() 197 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) 198 | return loss 199 | 200 | if __name__ == '__main__': 201 | sim_data = Variable(torch.rand(32,3,2500)) 202 | trans = STN3d() 203 | out = trans(sim_data) 204 | print('stn', out.size()) 205 | print('loss', feature_transform_regularizer(out)) 206 | 207 | sim_data_64d = Variable(torch.rand(32, 64, 2500)) 208 | trans = STNkd(k=64) 209 | out = trans(sim_data_64d) 210 | print('stn64d', out.size()) 211 | print('loss', feature_transform_regularizer(out)) 212 | 213 | pointfeat = PointNetfeat(global_feat=True) 214 | out, _, _ = pointfeat(sim_data) 215 | print('global feat', out.size()) 216 | 217 | pointfeat = PointNetfeat(global_feat=False) 218 | out, _, _ = pointfeat(sim_data) 219 | print('point feat', out.size()) 220 | 221 | cls = PointNetCls(k = 5) 222 | out, _, _ = cls(sim_data) 223 | print('class', out.size()) 224 | 225 | seg = PointNetDenseCls(k = 3) 226 | out, _, _ = seg(sim_data) 227 | print('seg', out.size()) 228 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | SCRIPT=`realpath $0` 2 | SCRIPTPATH=`dirname $SCRIPT` 3 | echo $SCRIPTPATH 4 | 5 | g++ -std=c++11 $SCRIPTPATH/../utils/render_balls_so.cpp -o $SCRIPTPATH/../utils/render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 6 | -------------------------------------------------------------------------------- /scripts/download.sh: -------------------------------------------------------------------------------- 1 | SCRIPT=`realpath $0` 2 | SCRIPTPATH=`dirname $SCRIPT` 3 | 4 | cd $SCRIPTPATH/.. 5 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate 6 | unzip shapenetcore_partanno_segmentation_benchmark_v0.zip 7 | rm shapenetcore_partanno_segmentation_benchmark_v0.zip 8 | cd - 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # install using 'pip install -e .' 2 | 3 | from setuptools import setup 4 | 5 | setup(name='pointnet', 6 | packages=['pointnet'], 7 | package_dir={'pointnet': 'pointnet'}, 8 | install_requires=['torch', 9 | 'tqdm', 10 | 'plyfile'], 11 | version='0.0.1') 12 | -------------------------------------------------------------------------------- /utils/render_balls_so.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] 0: 90 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll( 91 | show[:, :, 0], 1, axis=0)) 92 | if magnifyBlue >= 2: 93 | show[:, :, 0] = np.maximum(show[:, :, 0], 94 | np.roll(show[:, :, 0], -1, axis=0)) 95 | show[:, :, 0] = np.maximum(show[:, :, 0], np.roll( 96 | show[:, :, 0], 1, axis=1)) 97 | if magnifyBlue >= 2: 98 | show[:, :, 0] = np.maximum(show[:, :, 0], 99 | np.roll(show[:, :, 0], -1, axis=1)) 100 | if showrot: 101 | cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)), 102 | (30, showsz - 30), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0)) 103 | cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)), 104 | (30, showsz - 50), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0)) 105 | cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0, 106 | 0.5, cv2.cv.CV_RGB(255, 0, 0)) 107 | 108 | changed = True 109 | while True: 110 | if changed: 111 | render() 112 | changed = False 113 | cv2.imshow('show3d', show) 114 | if waittime == 0: 115 | cmd = cv2.waitKey(10) % 256 116 | else: 117 | cmd = cv2.waitKey(waittime) % 256 118 | if cmd == ord('q'): 119 | break 120 | elif cmd == ord('Q'): 121 | sys.exit(0) 122 | 123 | if cmd == ord('t') or cmd == ord('p'): 124 | if cmd == ord('t'): 125 | if c_gt is None: 126 | c0 = np.zeros((len(xyz),), dtype='float32') + 255 127 | c1 = np.zeros((len(xyz),), dtype='float32') + 255 128 | c2 = np.zeros((len(xyz),), dtype='float32') + 255 129 | else: 130 | c0 = c_gt[:, 0] 131 | c1 = c_gt[:, 1] 132 | c2 = c_gt[:, 2] 133 | else: 134 | if c_pred is None: 135 | c0 = np.zeros((len(xyz),), dtype='float32') + 255 136 | c1 = np.zeros((len(xyz),), dtype='float32') + 255 137 | c2 = np.zeros((len(xyz),), dtype='float32') + 255 138 | else: 139 | c0 = c_pred[:, 0] 140 | c1 = c_pred[:, 1] 141 | c2 = c_pred[:, 2] 142 | if normalizecolor: 143 | c0 /= (c0.max() + 1e-14) / 255.0 144 | c1 /= (c1.max() + 1e-14) / 255.0 145 | c2 /= (c2.max() + 1e-14) / 255.0 146 | c0 = np.require(c0, 'float32', 'C') 147 | c1 = np.require(c1, 'float32', 'C') 148 | c2 = np.require(c2, 'float32', 'C') 149 | changed = True 150 | 151 | if cmd == ord('n'): 152 | zoom *= 1.1 153 | changed = True 154 | elif cmd == ord('m'): 155 | zoom /= 1.1 156 | changed = True 157 | elif cmd == ord('r'): 158 | zoom = 1.0 159 | changed = True 160 | elif cmd == ord('s'): 161 | cv2.imwrite('show3d.png', show) 162 | if waittime != 0: 163 | break 164 | return cmd 165 | 166 | 167 | if __name__ == '__main__': 168 | # np.random.seed(100) 169 | # showpoints(np.random.randn(2500, 3)) 170 | 171 | # pts点云数据读取,法一: 172 | f = open('../shapenetcore_partanno_segmentation_benchmark_v0/02691156/points/1a04e3eab45ca15dd86060f189eb133.pts', 173 | 'r') 174 | data = f.readlines() 175 | f.close() 176 | 177 | pts = [] 178 | for line in data[0:]: 179 | line = line.strip('\n') 180 | point = line.split(' ') 181 | x, y, z = [eval(i) for i in point[:3]] 182 | pts.append([x, y, z]) 183 | res = np.random.randn(len(pts), len(pts[0])) 184 | for i in range(len(pts)): 185 | res[i] = pts[i] 186 | 187 | # pts点云数据读取,法二: 188 | point_set = np.loadtxt('../shapenetcore_partanno_segmentation_benchmark_v0/02691156/points/1a04e3eab45ca15dd86060f189eb133.pts').astype(np.float32) 189 | showpoints(point_set) 190 | -------------------------------------------------------------------------------- /utils/show_cls.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | from pointnet.dataset import ShapeNetDataset 8 | from pointnet.model import PointNetCls 9 | import torch.nn.functional as F 10 | 11 | 12 | #showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500))) 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument('--model', type=str, default = '', help='model path') 17 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 18 | 19 | 20 | opt = parser.parse_args() 21 | print(opt) 22 | 23 | test_dataset = ShapeNetDataset( 24 | root='../shapenetcore_partanno_segmentation_benchmark_v0', 25 | split='test', 26 | classification=True, 27 | npoints=opt.num_points, 28 | data_augmentation=False) 29 | 30 | testdataloader = torch.utils.data.DataLoader( 31 | test_dataset, batch_size=32, shuffle=True) 32 | 33 | classifier = PointNetCls(k=len(test_dataset.classes)) 34 | classifier.cuda() 35 | classifier.load_state_dict(torch.load(opt.model)) 36 | classifier.eval() 37 | 38 | 39 | for i, data in enumerate(testdataloader, 0): 40 | points, target = data 41 | points, target = Variable(points), Variable(target[:, 0]) 42 | points = points.transpose(2, 1) 43 | points, target = points.cuda(), target.cuda() 44 | pred, _, _ = classifier(points) 45 | loss = F.nll_loss(pred, target) 46 | 47 | pred_choice = pred.data.max(1)[1] 48 | correct = pred_choice.eq(target.data).cpu().sum() 49 | print('i:%d loss: %f accuracy: %f' % (i, loss.data.item(), correct / float(32))) 50 | -------------------------------------------------------------------------------- /utils/show_seg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from show3d_balls import showpoints 3 | import argparse 4 | import numpy as np 5 | import torch 6 | import torch.nn.parallel 7 | import torch.utils.data 8 | from torch.autograd import Variable 9 | from pointnet.dataset import ShapeNetDataset 10 | from pointnet.model import PointNetDenseCls 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | #showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500))) 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('--model', type=str, default='', help='model path') 19 | parser.add_argument('--idx', type=int, default=0, help='model index') 20 | parser.add_argument('--dataset', type=str, default='', help='dataset path') 21 | parser.add_argument('--class_choice', type=str, default='', help='class choice') 22 | 23 | opt = parser.parse_args() 24 | print(opt) 25 | 26 | d = ShapeNetDataset( 27 | root=opt.dataset, 28 | class_choice=[opt.class_choice], 29 | split='test', 30 | data_augmentation=False) 31 | 32 | idx = opt.idx 33 | 34 | print("model %d/%d" % (idx, len(d))) 35 | point, seg = d[idx] 36 | print(point.size(), seg.size()) 37 | point_np = point.numpy() 38 | 39 | cmap = plt.cm.get_cmap("hsv", 10) 40 | cmap = np.array([cmap(i) for i in range(10)])[:, :3] 41 | gt = cmap[seg.numpy() - 1, :] 42 | 43 | state_dict = torch.load(opt.model) 44 | classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0]) 45 | classifier.load_state_dict(state_dict) 46 | classifier.eval() 47 | 48 | point = point.transpose(1, 0).contiguous() 49 | 50 | point = Variable(point.view(1, point.size()[0], point.size()[1])) 51 | pred, _, _ = classifier(point) 52 | pred_choice = pred.data.max(2)[1] 53 | print(pred_choice) 54 | 55 | #print(pred_choice.size()) 56 | pred_color = cmap[pred_choice.numpy()[0], :] 57 | 58 | #print(pred_color.shape) 59 | showpoints(point_np, gt, pred_color) 60 | -------------------------------------------------------------------------------- /utils/train_classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from pointnet.dataset import ShapeNetDataset, ModelNetDataset 10 | from pointnet.model import PointNetCls, feature_transform_regularizer 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '--batchSize', type=int, default=32, help='input batch size') 18 | parser.add_argument( 19 | '--num_points', type=int, default=2500, help='input batch size') 20 | parser.add_argument( 21 | '--workers', type=int, help='number of data loading workers', default=4) 22 | parser.add_argument( 23 | '--nepoch', type=int, default=250, help='number of epochs to train for') 24 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 25 | parser.add_argument('--model', type=str, default='', help='model path') 26 | parser.add_argument('--dataset', type=str, required=True, help="dataset path") 27 | parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40") 28 | parser.add_argument('--feature_transform', action='store_true', help="use feature transform") 29 | 30 | opt = parser.parse_args() 31 | print(opt) 32 | 33 | blue = lambda x: '\033[94m' + x + '\033[0m' 34 | 35 | opt.manualSeed = random.randint(1, 10000) # fix seed 36 | print("Random Seed: ", opt.manualSeed) 37 | random.seed(opt.manualSeed) 38 | torch.manual_seed(opt.manualSeed) 39 | 40 | if opt.dataset_type == 'shapenet': 41 | dataset = ShapeNetDataset( 42 | root=opt.dataset, 43 | classification=True, 44 | npoints=opt.num_points) 45 | 46 | test_dataset = ShapeNetDataset( 47 | root=opt.dataset, 48 | classification=True, 49 | split='test', 50 | npoints=opt.num_points, 51 | data_augmentation=False) 52 | elif opt.dataset_type == 'modelnet40': 53 | dataset = ModelNetDataset( 54 | root=opt.dataset, 55 | npoints=opt.num_points, 56 | split='trainval') 57 | 58 | test_dataset = ModelNetDataset( 59 | root=opt.dataset, 60 | split='test', 61 | npoints=opt.num_points, 62 | data_augmentation=False) 63 | else: 64 | exit('wrong dataset type') 65 | 66 | 67 | dataloader = torch.utils.data.DataLoader( 68 | dataset, 69 | batch_size=opt.batchSize, 70 | shuffle=True, 71 | num_workers=int(opt.workers)) 72 | 73 | testdataloader = torch.utils.data.DataLoader( 74 | test_dataset, 75 | batch_size=opt.batchSize, 76 | shuffle=True, 77 | num_workers=int(opt.workers)) 78 | 79 | print(len(dataset), len(test_dataset)) 80 | num_classes = len(dataset.classes) 81 | print('classes', num_classes) 82 | 83 | try: 84 | os.makedirs(opt.outf) 85 | except OSError: 86 | pass 87 | 88 | classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform) 89 | 90 | if opt.model != '': 91 | classifier.load_state_dict(torch.load(opt.model)) 92 | 93 | # 优化器:adam-Adaptive Moment Estimation(自适应矩估计),利用梯度的一阶矩和二阶矩动态调整每个参数的学习率 94 | # betas:用于计算梯度一阶矩和二阶矩的系数 95 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) 96 | # 学习率调整:每个step_size次epoch后,学习率x0.5 97 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 98 | # Moves all model parameters and buffers to the GPU. 99 | classifier.cuda() 100 | 101 | num_batch = len(dataset) / opt.batchSize # 计算batch的数量 102 | 103 | for epoch in range(opt.nepoch): 104 | scheduler.step() 105 | # 将一个可遍历对象组合为一个索引序列,同时列出数据和数据下标,(0, seq[0])... 106 | # __init__(self, iterable, start=0),参数为可遍历对象及起始位置 107 | for i, data in enumerate(dataloader, 0): 108 | points, target = data 109 | target = target[:, 0] # 取所有行的第0列 110 | points = points.transpose(2, 1) # 维度交换 111 | points, target = points.cuda(), target.cuda() # tensor转到cuda上 112 | optimizer.zero_grad() # 梯度清除,避免backward时梯度累加 113 | classifier = classifier.train() # 训练模式,使能BN和dropout 114 | pred, trans, trans_feat = classifier(points) # 网络结果预测输出 115 | loss = F.nll_loss(pred, target) # 损失函数:负log似然损失,在分类网络中使用了log_softmax,二者结合其实就是交叉熵损失函数 116 | if opt.feature_transform: 117 | loss += feature_transform_regularizer(trans_feat) * 0.001 118 | loss.backward() # loss反向传播 119 | optimizer.step() # 梯度下降,参数优化 120 | pred_choice = pred.data.max(1)[1] # max(1)返回每一行中的最大值及索引,[1]取出索引(代表着类别) 121 | correct = pred_choice.eq(target.data).cpu().sum() # 判断和target是否匹配,并计算匹配的数量 122 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize))) 123 | 124 | # 每10次batch之后,进行一次测试 125 | if i % 10 == 0: 126 | j, data = next(enumerate(testdataloader, 0)) 127 | points, target = data 128 | target = target[:, 0] 129 | points = points.transpose(2, 1) 130 | points, target = points.cuda(), target.cuda() 131 | classifier = classifier.eval() # 测试模式,固定住BN和dropout 132 | pred, _, _ = classifier(points) 133 | loss = F.nll_loss(pred, target) 134 | pred_choice = pred.data.max(1)[1] 135 | correct = pred_choice.eq(target.data).cpu().sum() 136 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 137 | 138 | torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 139 | 140 | total_correct = 0 141 | total_testset = 0 142 | for i,data in tqdm(enumerate(testdataloader, 0)): 143 | points, target = data 144 | target = target[:, 0] 145 | points = points.transpose(2, 1) 146 | points, target = points.cuda(), target.cuda() 147 | classifier = classifier.eval() 148 | pred, _, _ = classifier(points) 149 | pred_choice = pred.data.max(1)[1] 150 | correct = pred_choice.eq(target.data).cpu().sum() 151 | total_correct += correct.item() 152 | total_testset += points.size()[0] 153 | 154 | print("final accuracy {}".format(total_correct / float(total_testset))) -------------------------------------------------------------------------------- /utils/train_segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from pointnet.dataset import ShapeNetDataset 10 | from pointnet.model import PointNetDenseCls, feature_transform_regularizer 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '--batchSize', type=int, default=32, help='input batch size') 19 | parser.add_argument( 20 | '--workers', type=int, help='number of data loading workers', default=4) 21 | parser.add_argument( 22 | '--nepoch', type=int, default=25, help='number of epochs to train for') 23 | parser.add_argument('--outf', type=str, default='seg', help='output folder') 24 | parser.add_argument('--model', type=str, default='', help='model path') 25 | parser.add_argument('--dataset', type=str, required=True, help="dataset path") 26 | parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice") 27 | parser.add_argument('--feature_transform', action='store_true', help="use feature transform") 28 | 29 | opt = parser.parse_args() 30 | print(opt) 31 | 32 | opt.manualSeed = random.randint(1, 10000) # fix seed 33 | print("Random Seed: ", opt.manualSeed) 34 | random.seed(opt.manualSeed) 35 | torch.manual_seed(opt.manualSeed) 36 | 37 | dataset = ShapeNetDataset( 38 | root=opt.dataset, 39 | classification=False, 40 | class_choice=[opt.class_choice]) 41 | dataloader = torch.utils.data.DataLoader( 42 | dataset, 43 | batch_size=opt.batchSize, 44 | shuffle=True, 45 | num_workers=int(opt.workers)) 46 | 47 | test_dataset = ShapeNetDataset( 48 | root=opt.dataset, 49 | classification=False, 50 | class_choice=[opt.class_choice], 51 | split='test', 52 | data_augmentation=False) 53 | testdataloader = torch.utils.data.DataLoader( 54 | test_dataset, 55 | batch_size=opt.batchSize, 56 | shuffle=True, 57 | num_workers=int(opt.workers)) 58 | 59 | print(len(dataset), len(test_dataset)) 60 | num_classes = dataset.num_seg_classes 61 | print('classes', num_classes) 62 | try: 63 | os.makedirs(opt.outf) 64 | except OSError: 65 | pass 66 | 67 | blue = lambda x: '\033[94m' + x + '\033[0m' 68 | 69 | classifier = PointNetDenseCls(k=num_classes, feature_transform=opt.feature_transform) 70 | 71 | if opt.model != '': 72 | classifier.load_state_dict(torch.load(opt.model)) 73 | 74 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) 75 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 76 | classifier.cuda() 77 | 78 | num_batch = len(dataset) / opt.batchSize 79 | 80 | for epoch in range(opt.nepoch): 81 | scheduler.step() 82 | for i, data in enumerate(dataloader, 0): 83 | points, target = data 84 | points = points.transpose(2, 1) 85 | points, target = points.cuda(), target.cuda() 86 | optimizer.zero_grad() 87 | classifier = classifier.train() 88 | pred, trans, trans_feat = classifier(points) 89 | pred = pred.view(-1, num_classes) 90 | target = target.view(-1, 1)[:, 0] - 1 91 | #print(pred.size(), target.size()) 92 | loss = F.nll_loss(pred, target) 93 | if opt.feature_transform: 94 | loss += feature_transform_regularizer(trans_feat) * 0.001 95 | loss.backward() 96 | optimizer.step() 97 | pred_choice = pred.data.max(1)[1] 98 | correct = pred_choice.eq(target.data).cpu().sum() 99 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500))) 100 | 101 | if i % 10 == 0: 102 | j, data = next(enumerate(testdataloader, 0)) 103 | points, target = data 104 | points = points.transpose(2, 1) 105 | points, target = points.cuda(), target.cuda() 106 | classifier = classifier.eval() 107 | pred, _, _ = classifier(points) 108 | pred = pred.view(-1, num_classes) 109 | target = target.view(-1, 1)[:, 0] - 1 110 | loss = F.nll_loss(pred, target) 111 | pred_choice = pred.data.max(1)[1] 112 | correct = pred_choice.eq(target.data).cpu().sum() 113 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500))) 114 | 115 | torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch)) 116 | 117 | ## benchmark mIOU 118 | shape_ious = [] 119 | for i,data in tqdm(enumerate(testdataloader, 0)): 120 | points, target = data 121 | points = points.transpose(2, 1) 122 | points, target = points.cuda(), target.cuda() 123 | classifier = classifier.eval() 124 | pred, _, _ = classifier(points) 125 | pred_choice = pred.data.max(2)[1] 126 | 127 | pred_np = pred_choice.cpu().data.numpy() 128 | target_np = target.cpu().data.numpy() - 1 129 | 130 | for shape_idx in range(target_np.shape[0]): 131 | parts = range(num_classes)#np.unique(target_np[shape_idx]) 132 | part_ious = [] 133 | for part in parts: 134 | I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) 135 | U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part)) 136 | if U == 0: 137 | iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1 138 | else: 139 | iou = I / float(U) 140 | part_ious.append(iou) 141 | shape_ious.append(np.mean(part_ious)) 142 | 143 | print("mIOU for class {}: {}".format(opt.class_choice, np.mean(shape_ious))) --------------------------------------------------------------------------------