├── LICENSE ├── README.md ├── build.sh ├── datasets.py ├── download.sh ├── pointnet.py ├── prepare_graph.py ├── render_balls_so.cpp ├── show3d_balls.py ├── show_cls.py ├── show_pt_yw.py ├── show_seg.py ├── train_FoldingNet.py ├── train_FoldingNet_graph.py ├── train_classification.py └── train_segmentation.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Fei Xia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch implementation of FoldingNet encoder and decoder(with graph pooling, covariance), add quantization 2 | This repo is implementation for PointNet(https://arxiv.org/abs/1612.00593) in pytorch. The model is in `pointnet.py`. 3 | 4 | 5 | # Download data and running 6 | 7 | ``` 8 | bash build.sh #build C++ code for visualization 9 | bash download.sh #download dataset 10 | python train_classification.py #train 3D model classification 11 | python python train_segmentation.py # train 3D model segmentaion 12 | 13 | python show_seg.py --model seg/seg_model_20.pth # show segmentation results 14 | ``` 15 | 16 | # Performance 17 | Without heavy tuning, PointNet can achieve 80-90% performance in classification and segmentaion on this [dataset](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html). 18 | 19 | Sample segmentation result: 20 | ![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D) 21 | 22 | 23 | # Links 24 | 25 | - [Project Page](http://stanford.edu/~rqi/pointnet/) 26 | - [Tensorflow implementation](https://github.com/charlesq34/pointnet) 27 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 2 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import errno 7 | import torch 8 | import json 9 | import codecs 10 | import numpy as np 11 | import sys 12 | import torchvision.transforms as transforms 13 | import argparse 14 | import json 15 | 16 | 17 | class PartDataset(data.Dataset): 18 | def __init__(self, root, npoints = 2500, classification = False, class_choice = None, train = True): 19 | self.npoints = npoints 20 | self.root = root 21 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') 22 | self.cat = {} 23 | 24 | self.classification = classification 25 | 26 | with open(self.catfile, 'r') as f: 27 | for line in f: 28 | ls = line.strip().split() 29 | self.cat[ls[0]] = ls[1] 30 | #print(self.cat) 31 | if not class_choice is None: 32 | self.cat = {k:v for k,v in self.cat.items() if k in class_choice} 33 | 34 | self.meta = {} 35 | for item in self.cat: 36 | #print('category', item) 37 | self.meta[item] = [] 38 | dir_point = os.path.join(self.root, self.cat[item], 'points') 39 | dir_seg = os.path.join(self.root, self.cat[item], 'points_label') 40 | #print(dir_point, dir_seg) 41 | fns = sorted(os.listdir(dir_point)) 42 | if train: 43 | fns = fns[:int(len(fns) * 0.9)] 44 | else: 45 | fns = fns[int(len(fns) * 0.9):] 46 | 47 | #print(os.path.basename(fns)) 48 | for fn in fns: 49 | token = (os.path.splitext(os.path.basename(fn))[0]) 50 | self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'))) 51 | 52 | self.datapath = [] 53 | for item in self.cat: 54 | for fn in self.meta[item]: 55 | self.datapath.append((item, fn[0], fn[1])) 56 | 57 | 58 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) 59 | print(self.classes) 60 | self.num_seg_classes = 0 61 | if not self.classification: 62 | for i in range(len(self.datapath)//50): 63 | l = len(np.unique(np.loadtxt(self.datapath[i][-1]).astype(np.uint8))) 64 | if l > self.num_seg_classes: 65 | self.num_seg_classes = l 66 | #print(self.num_seg_classes) 67 | 68 | 69 | def __getitem__(self, index): 70 | fn = self.datapath[index] 71 | cls = self.classes[self.datapath[index][0]] 72 | point_set = np.loadtxt(fn[1]).astype(np.float32) 73 | seg = np.loadtxt(fn[2]).astype(np.int64) 74 | #print(point_set.shape, seg.shape) 75 | 76 | choice = np.random.choice(len(seg), self.npoints, replace=True) 77 | #resample 78 | point_set = point_set[choice, :] 79 | seg = seg[choice] 80 | point_set = torch.from_numpy(point_set) 81 | seg = torch.from_numpy(seg) 82 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 83 | if self.classification: 84 | return point_set, cls 85 | else: 86 | return point_set, seg 87 | 88 | def __len__(self): 89 | return len(self.datapath) 90 | 91 | 92 | if __name__ == '__main__': 93 | print('test') 94 | d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair']) 95 | print(len(d)) 96 | ps, seg = d[0] 97 | print(ps.size(), ps.type(), seg.size(),seg.type()) 98 | 99 | d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True) 100 | print(len(d)) 101 | ps, cls = d[0] 102 | print(ps.size(), ps.type(), cls.size(),cls.type()) 103 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate 2 | unzip shapenetcore_partanno_segmentation_benchmark_v0.zip 3 | rm shapenetcore_partanno_segmentation_benchmark_v0.zip 4 | -------------------------------------------------------------------------------- /pointnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from torch.autograd import Variable 14 | from torch.autograd import Function 15 | from PIL import Image 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import pdb 19 | import torch.nn.functional as F 20 | import scipy.sparse 21 | from prepare_graph import build_graph_core 22 | import sys 23 | 24 | 25 | class STN3d(nn.Module): 26 | def __init__(self): 27 | super(STN3d, self).__init__() 28 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 29 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 30 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 31 | self.fc1 = nn.Linear(1024, 512) 32 | self.fc2 = nn.Linear(512, 256) 33 | self.fc3 = nn.Linear(256, 9) 34 | self.relu = nn.ReLU() 35 | 36 | self.bn1 = nn.BatchNorm1d(64) 37 | self.bn2 = nn.BatchNorm1d(128) 38 | self.bn3 = nn.BatchNorm1d(1024) 39 | self.bn4 = nn.BatchNorm1d(512) 40 | self.bn5 = nn.BatchNorm1d(256) 41 | 42 | def forward(self, x): 43 | batchsize = x.size()[0] 44 | x = F.relu(self.bn1(self.conv1(x))) 45 | x = F.relu(self.bn2(self.conv2(x))) 46 | x = F.relu(self.bn3(self.conv3(x))) 47 | x = torch.max(x, 2, keepdim=True)[0] 48 | x = x.view(-1, 1024) 49 | 50 | x = F.relu(self.bn4(self.fc1(x))) 51 | x = F.relu(self.bn5(self.fc2(x))) 52 | x = self.fc3(x) 53 | 54 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 55 | batchsize, 1) 56 | if x.is_cuda: 57 | iden = iden.cuda() 58 | x = x + iden 59 | x = x.view(-1, 3, 3) 60 | return x 61 | 62 | 63 | class PointNetfeat(nn.Module): 64 | def __init__(self, global_feat=True): 65 | super(PointNetfeat, self).__init__() 66 | self.stn = STN3d() 67 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 68 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 69 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 70 | self.bn1 = nn.BatchNorm1d(64) 71 | self.bn2 = nn.BatchNorm1d(128) 72 | self.bn3 = nn.BatchNorm1d(1024) 73 | self.global_feat = global_feat 74 | 75 | def forward(self, x): 76 | batchsize = x.size()[0] 77 | n_pts = x.size()[2] 78 | trans = self.stn(x) 79 | x = x.transpose(2, 1) 80 | x = torch.bmm(x, trans) 81 | x = x.transpose(2, 1) 82 | x = F.relu(self.bn1(self.conv1(x))) 83 | pointfeat = x 84 | x = F.relu(self.bn2(self.conv2(x))) 85 | x = self.bn3(self.conv3(x)) # x = batch,1024,n(n=2048) 86 | x = torch.max(x, 2, keepdim=True)[0] # x = batch,1024,1 87 | x = x.view(-1, 1024) # x = batch,1024 88 | if self.global_feat: 89 | return x, trans 90 | else: 91 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 92 | return torch.cat([x, pointfeat], 1), trans 93 | 94 | 95 | class PointNetCls(nn.Module): 96 | def __init__(self, k=2): 97 | super(PointNetCls, self).__init__() 98 | self.feat = PointNetfeat(global_feat=True) 99 | self.fc1 = nn.Linear(1024, 512) 100 | self.fc2 = nn.Linear(512, 256) 101 | self.fc3 = nn.Linear(256, k) 102 | self.bn1 = nn.BatchNorm1d(512) 103 | self.bn2 = nn.BatchNorm1d(256) 104 | self.relu = nn.ReLU() 105 | 106 | def forward(self, x): 107 | x, trans = self.feat(x) 108 | x = F.relu(self.bn1(self.fc1(x))) 109 | x = F.relu(self.bn2(self.fc2(x))) 110 | x = self.fc3(x) 111 | return F.log_softmax(x, dim=0), trans 112 | 113 | 114 | # *************** YW test FoldingNet ************ 115 | class FoldingNetEnc(nn.Module): 116 | def __init__(self): 117 | super(FoldingNetEnc, self).__init__() 118 | self.feat = PointNetfeat(global_feat=True) 119 | self.fc1 = nn.Linear(1024, 512) 120 | self.fc2 = nn.Linear(512, 512) 121 | # self.fc2 = nn.Linear(512, 256) 122 | # self.fc3 = nn.Linear(256, k) 123 | self.bn1 = nn.BatchNorm1d(512) 124 | # self.bn2 = nn.BatchNorm1d(256) 125 | self.relu = nn.ReLU() 126 | 127 | def forward(self, x): 128 | x, trans = self.feat(x) # x = batch,1024 129 | x = F.relu(self.bn1(self.fc1(x))) # x = batch,512 130 | x = self.fc2(x) # x = batch,512 131 | 132 | return x, trans 133 | 134 | 135 | class Graph_Pooling(nn.Module): 136 | def __init__(self): 137 | super(Graph_Pooling,self).__init__() 138 | 139 | def forward(self, x, batch_graph): # x: batch, channel, num_of_point. batch_graph: batch 140 | num_points = x.size(2) 141 | batch_size = x.size(0) 142 | assert (x.size(0)==len(batch_graph)) 143 | A_x = torch.zeros(x.size()) 144 | aa = torch.zeros(num_points, 16) 145 | 146 | if x.is_cuda: 147 | A_x = A_x.cuda() 148 | for b in range(batch_size): 149 | bth_graph = batch_graph[b] 150 | index_b = bth_graph.nonzero() 151 | x_batch = x[b,:,:] # channel, num_of_point 152 | x_batch = x_batch.transpose(0,1) # num_of_point, channel 153 | 154 | for i in range(num_points): 155 | idx = index_b[0] == i 156 | ele = index_b[1][idx] 157 | rand_idx = np.random.choice(len(ele),16,replace=False) 158 | ele = ele[rand_idx] 159 | aa[i,:] = torch.from_numpy(ele) 160 | 161 | aa = aa.to(torch.int64) 162 | A_batch = x_batch[aa] # num_of_point,16,channel 163 | if x.is_cuda: 164 | A_batch = A_batch.cuda() 165 | A_batch = torch.max(A_batch,dim = 1, keepdim=False)[0] # num_of_point,channel 166 | A_x[b,:,:] = A_batch.transpose(0,1) 167 | 168 | 169 | # for i in range(num_points): 170 | # i_nb_index = bth_graph[i, :].nonzero()[1] # the ith point's neighbors' index 171 | # A_x[b, :, i] = torch.max(x[b:b+1, :, i_nb_index], dim=2, keepdim=True)[0].view(-1) # the output size should be 1,channels,1 172 | 173 | A_x = torch.max(A_x, x) # compare to itself 174 | 175 | return A_x # batch,channel,num of point 176 | 177 | class FoldingNetEnc_with_graph(nn.Module): 178 | def __init__(self): 179 | super(FoldingNetEnc_with_graph, self).__init__() 180 | self.conv1 = torch.nn.Conv1d(12, 64, 1) 181 | self.conv2 = torch.nn.Conv1d(64, 64, 1) 182 | self.conv3 = torch.nn.Conv1d(64, 64, 1) 183 | self.conv4 = torch.nn.Conv1d(64, 128, 1) 184 | self.conv5 = torch.nn.Conv1d(128, 1024, 1) 185 | 186 | self.fc1 = nn.Linear(1024, 512) 187 | self.fc2 = nn.Linear(512, 512) 188 | 189 | self.graph_pooling = Graph_Pooling() 190 | 191 | self.bn1 = nn.BatchNorm1d(64) 192 | self.bn2 = nn.BatchNorm1d(64) 193 | self.bn3 = nn.BatchNorm1d(64) 194 | self.bn4 = nn.BatchNorm1d(128) 195 | self.bn5 = nn.BatchNorm1d(1024) 196 | self.bn6 = nn.BatchNorm1d(512) 197 | 198 | 199 | def forward(self, x, Cov, batch_graph): # x: batch,3,n; Cov: batch,9,n; batch_graph: batch * scipy.sparse.csr_matrix 200 | 201 | x_cov = torch.cat((x, Cov), 1) # x_cov: batch,12,n 202 | x_cov = F.relu(self.bn1(self.conv1(x_cov))) 203 | x_cov = F.relu(self.bn2(self.conv2(x_cov))) 204 | x_cov = F.relu(self.bn3(self.conv3(x_cov))) # x_cov : batch,64,n 205 | 206 | # A_x = torch.zeros(x_cov.size()) 207 | # if x_cov.is_cuda: 208 | # A_x = A_x.cuda() 209 | # for b in range(batch_size): 210 | # bth_graph = batch_graph[b] 211 | # for i in range(num_points): 212 | # i_nb_index = bth_graph[i, :].nonzero()[0] # the ith point's neighbors' index 213 | # A_x[b,:,i] = torch.max(x_cov[b, :, i_nb_index], dim = 2, keepdim=True)[0] # the output size should be 1,64,1 214 | 215 | A_x = self.graph_pooling(x_cov, batch_graph) # A_x: batch,64,n 216 | A_x = F.relu(A_x) 217 | A_x = F.relu(self.bn4(self.conv4(A_x))) # batch,128,n 218 | A_x = F.relu(self.graph_pooling(A_x, batch_graph)) # batch,128,n 219 | A_x = self.bn5(self.conv5(A_x)) # batch,1024,n 220 | A_x = torch.max(A_x, dim=2, keepdim=True)[0] # batch,1024,1 221 | A_x = A_x.view(-1,1024) # batch,1024 222 | A_x = F.relu(self.bn6(self.fc1(A_x))) # batch,512 223 | A_x = self.fc2(A_x) # batch,512 224 | 225 | return A_x 226 | 227 | 228 | class FoldingNetDecFold1(nn.Module): 229 | def __init__(self): 230 | super(FoldingNetDecFold1, self).__init__() 231 | self.conv1 = nn.Conv1d(514, 512, 1) 232 | self.conv2 = nn.Conv1d(512, 512, 1) 233 | self.conv3 = nn.Conv1d(512, 3, 1) 234 | 235 | self.relu = nn.ReLU() 236 | 237 | def forward(self, x): # input x = batch,514,45^2 238 | x = self.relu(self.conv1(x)) # x = batch,512,45^2 239 | x = self.relu(self.conv2(x)) 240 | x = self.conv3(x) 241 | 242 | return x 243 | 244 | 245 | class FoldingNetDecFold2(nn.Module): 246 | def __init__(self): 247 | super(FoldingNetDecFold2, self).__init__() 248 | self.conv1 = nn.Conv1d(515, 512, 1) 249 | self.conv2 = nn.Conv1d(512, 512, 1) 250 | self.conv3 = nn.Conv1d(512, 3, 1) 251 | self.relu = nn.ReLU() 252 | 253 | def forward(self, x): # input x = batch,515,45^2 254 | x = self.relu(self.conv1(x)) 255 | x = self.relu(self.conv2(x)) 256 | x = self.conv3(x) 257 | return x 258 | 259 | 260 | def GridSamplingLayer(batch_size, meshgrid): 261 | ''' 262 | output Grid points as a NxD matrix 263 | 264 | params = { 265 | 'batch_size': 8 266 | 'meshgrid': [[-0.3,0.3,45],[-0.3,0.3,45]] 267 | } 268 | ''' 269 | 270 | ret = np.meshgrid(*[np.linspace(it[0], it[1], num=it[2]) for it in meshgrid]) 271 | ndim = len(meshgrid) 272 | grid = np.zeros((np.prod([it[2] for it in meshgrid]), ndim), dtype=np.float32) # MxD 273 | for d in range(ndim): 274 | grid[:, d] = np.reshape(ret[d], -1) 275 | g = np.repeat(grid[np.newaxis, ...], repeats=batch_size, axis=0) 276 | 277 | return g 278 | 279 | 280 | class FoldingNetDec(nn.Module): 281 | def __init__(self): 282 | super(FoldingNetDec, self).__init__() 283 | self.fold1 = FoldingNetDecFold1() 284 | self.fold2 = FoldingNetDecFold2() 285 | 286 | def forward(self, x): # input x = batch, 512 287 | batch_size = x.size(0) 288 | x = torch.unsqueeze(x, 1) # x = batch,1,512 289 | x = x.repeat(1, 45 ** 2, 1) # x = batch,45^2,512 290 | code = x 291 | code = x.transpose(2, 1) # x = batch,512,45^2 292 | 293 | meshgrid = [[-0.3, 0.3, 45], [-0.3, 0.3, 45]] 294 | grid = GridSamplingLayer(batch_size, meshgrid) # grid = batch,45^2,2 295 | grid = torch.from_numpy(grid) 296 | 297 | if x.is_cuda: 298 | grid = grid.cuda() 299 | 300 | x = torch.cat((x, grid), 2) # x = batch,45^2,514 301 | x = x.transpose(2, 1) # x = batch,514,45^2 302 | 303 | x = self.fold1(x) # x = batch,3,45^2 304 | p1 = x # to observe 305 | 306 | x = torch.cat((code, x), 1) # x = batch,515,45^2 307 | 308 | x = self.fold2(x) # x = batch,3,45^2 309 | 310 | return x, p1 311 | 312 | class Quantization(Function): 313 | #def __init__(self): 314 | # super(Quantization, self).__init__() 315 | 316 | @staticmethod 317 | def forward(ctx, input): 318 | output = torch.round(input) 319 | return output 320 | 321 | @staticmethod 322 | def backward(ctx,grad_output): 323 | return grad_output 324 | 325 | class Quantization_module(nn.Module): 326 | def __init__(self): 327 | super().__init__() 328 | 329 | def forward(self, input): 330 | return Quantization.apply(input) 331 | 332 | class FoldingNet(nn.Module): 333 | def __init__(self): 334 | super(FoldingNet, self).__init__() 335 | self.encoder = FoldingNetEnc() 336 | self.decoder = FoldingNetDec() 337 | self.quan = Quantization_module() 338 | 339 | def forward(self, x): # input x = batch,3,number of points 340 | code, tran = self.encoder(x) # code = batch,512 341 | code = self.quan(code) # quantization 342 | 343 | '''if self.training == 0: # if now is evaluation, save code 344 | try: 345 | os.makedirs('bin') 346 | except OSError: 347 | pass 348 | code_save = code.cpu().detach() 349 | code_save = code_save.numpy() 350 | code_save = code_save.astype(int) 351 | np.savetxt('./bin/test.bin', code_save) 352 | ''' 353 | 354 | x, x_middle = self.decoder(code) # x = batch,3,45^2 355 | 356 | return x, x_middle,code 357 | 358 | class FoldingNet_graph(nn.Module): 359 | def __init__(self): 360 | super(FoldingNet_graph, self).__init__() 361 | self.encoder = FoldingNetEnc_with_graph() 362 | self.decoder = FoldingNetDec() 363 | self.quan = Quantization_module() 364 | 365 | def forward(self, x, Cov, batch_graph): 366 | ''' 367 | x: batch,3,n; Cov: batch,9,n; batch_graph: batch * scipy.sparse.csr_matrix 368 | ''' 369 | code = self.encoder(x,Cov,batch_graph) 370 | code = self.quan(code) 371 | x, x_middle = self.decoder(code) # x = batch,3,45^2 372 | 373 | return x, x_middle, code 374 | 375 | 376 | 377 | 378 | 379 | def ChamferDistance(x, y): # for example, x = batch,2025,3 y = batch,2048,3 380 | # compute chamfer distance between tow point clouds x and y 381 | 382 | x_size = x.size() 383 | y_size = y.size() 384 | assert (x_size[0] == y_size[0]) 385 | assert (x_size[2] == y_size[2]) 386 | x = torch.unsqueeze(x, 1) # x = batch,1,2025,3 387 | y = torch.unsqueeze(y, 2) # y = batch,2048,1,3 388 | 389 | x = x.repeat(1, y_size[1], 1, 1) # x = batch,2048,2025,3 390 | y = y.repeat(1, 1, x_size[1], 1) # y = batch,2048,2025,3 391 | 392 | x_y = x - y 393 | x_y = torch.pow(x_y, 2) # x_y = batch,2048,2025,3 394 | x_y = torch.sum(x_y, 3, keepdim=True) # x_y = batch,2048,2025,1 395 | x_y = torch.squeeze(x_y, 3) # x_y = batch,2048,2025 396 | x_y_row, _ = torch.min(x_y, 1, keepdim=True) # x_y_row = batch,1,2025 397 | x_y_col, _ = torch.min(x_y, 2, keepdim=True) # x_y_col = batch,2048,1 398 | 399 | x_y_row = torch.mean(x_y_row, 2, keepdim=True) # x_y_row = batch,1,1 400 | x_y_col = torch.mean(x_y_col, 1, keepdim=True) # batch,1,1 401 | x_y_row_col = torch.cat((x_y_row, x_y_col), 2) # batch,1,2 402 | chamfer_distance, _ = torch.max(x_y_row_col, 2, keepdim=True) # batch,1,1 403 | # chamfer_distance = torch.reshape(chamfer_distance,(x_size[0],-1)) #batch,1 404 | # chamfer_distance = torch.squeeze(chamfer_distance,1) # batch 405 | chamfer_distance = torch.mean(chamfer_distance) 406 | return chamfer_distance 407 | 408 | 409 | class ChamferLoss(nn.Module): 410 | # chamfer distance loss 411 | def __init__(self): 412 | super(ChamferLoss, self).__init__() 413 | 414 | def forward(self, x, y): 415 | return ChamferDistance(x, y) 416 | 417 | 418 | 419 | 420 | 421 | class PointNetDenseCls(nn.Module): 422 | def __init__(self, k=2): 423 | super(PointNetDenseCls, self).__init__() 424 | self.k = k 425 | self.feat = PointNetfeat(global_feat=False) 426 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 427 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 428 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 429 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 430 | self.bn1 = nn.BatchNorm1d(512) 431 | self.bn2 = nn.BatchNorm1d(256) 432 | self.bn3 = nn.BatchNorm1d(128) 433 | 434 | def forward(self, x): 435 | batchsize = x.size()[0] 436 | n_pts = x.size()[2] 437 | x, trans = self.feat(x) 438 | x = F.relu(self.bn1(self.conv1(x))) 439 | x = F.relu(self.bn2(self.conv2(x))) 440 | x = F.relu(self.bn3(self.conv3(x))) 441 | x = self.conv4(x) 442 | x = x.transpose(2, 1).contiguous() 443 | x = F.log_softmax(x.view(-1, self.k), dim=-1) 444 | x = x.view(batchsize, n_pts, self.k) 445 | return x, trans 446 | 447 | 448 | if __name__ == '__main__': 449 | # sim_data = Variable(torch.rand(32, 3, 2500)) 450 | # trans = STN3d() 451 | # out = trans(sim_data) 452 | # print('stn', out.size()) 453 | # 454 | # pointfeat = PointNetfeat(global_feat=True) 455 | # out, _ = pointfeat(sim_data) 456 | # print('global feat', out.size()) 457 | # 458 | # pointfeat = PointNetfeat(global_feat=False) 459 | # out, _ = pointfeat(sim_data) 460 | # print('point feat', out.size()) 461 | # 462 | # cls = PointNetCls(k=5) 463 | # out, _ = cls(sim_data) 464 | # print('class', out.size()) 465 | # 466 | # seg = PointNetDenseCls(k=3) 467 | # out, _ = seg(sim_data) 468 | # print('seg', out.size()) 469 | 470 | # YW test 471 | 472 | 473 | # sim_data = torch.rand(32,515,45*45) 474 | # print('sim_data ',sim_data.size()) 475 | # 476 | # fold2 = FoldingNetDecFold2() 477 | # out = fold2(sim_data) 478 | # print('fold2 ',out.size()) 479 | # 480 | # meshgrid = [[-0.3,0.3,45],[-0.3,0.3,45]] 481 | # out = GridSamplingLayer(3,meshgrid) 482 | # print('meshgrid',out.shape) 483 | # 484 | # sim_data = torch.rand(32,512) 485 | # sim_data.cuda() 486 | # dec = FoldingNetDec() 487 | # if sim_data.is_cuda: 488 | # dec.cuda() 489 | # out,out2 = dec(sim_data) 490 | # print('dec',out.size()) 491 | # print('fold1 result',out2.size()) 492 | # 493 | # 494 | # sim_data = torch.rand(32,3,2500) 495 | # 496 | # enc = FoldingNetEnc() 497 | # out, _ = enc(sim_data) 498 | # print(out.size()) 499 | # 500 | # 501 | # 502 | # foldnet = FoldingNet() 503 | # foldnet.cuda() 504 | # 505 | # out , out2 = foldnet(sim_data) 506 | # print('reconsructed point cloud', out.size()) 507 | # print('middle result',out2.size()) 508 | 509 | 510 | # x = torch.rand(16, 2048, 3) 511 | # y = x 512 | # # y = torch.rand(16,2025,3) 513 | # 514 | # cs = ChamferDistance(x, y) 515 | # print('chamfer distance', cs) 516 | # 517 | # closs = ChamferLoss() 518 | # print('chamfer loss', closs(x, y)) 519 | 520 | # YW test graph encoder 521 | 522 | parser = argparse.ArgumentParser(sys.argv[0]) 523 | 524 | parser.add_argument('--pts_mn40', type=int, default=2048, 525 | help="number of points per modelNet40 object") 526 | parser.add_argument('--pts_shapenet_part', type=int, default=2048, 527 | help="number of points per shapenet_part object") 528 | parser.add_argument('--pts_shapenet', type=int, default=2048, 529 | help="number of points per shapenet object") 530 | parser.add_argument('-md', '--mode', type=str, default="M", 531 | help="mode used to compute graphs: M, P") 532 | parser.add_argument('-m', '--metric', type=str, default='euclidean', 533 | help="metric for distance calculation (manhattan/euclidean)") 534 | parser.add_argument('--no-shuffle', dest='shuffle', action='store_false', default=True, 535 | help="whether to shuffle data (1) or not (0) before saving") 536 | parser.add_argument('--regenerate', dest='regenerate', action='store_true', default=False, 537 | help='regenerate from raw pointnet data or not (default: False)') 538 | 539 | args = parser.parse_args(sys.argv[1:]) 540 | args.script_folder = os.path.dirname(os.path.abspath(__file__)) 541 | 542 | args.knn = 16 543 | args.mode = 'M' 544 | args.metric = 'euclidean' 545 | sim_data = torch.rand(10, 2048, 3) 546 | # ith, ith_graph, nbsd, cov = build_graph_core((0, sim_data), args) 547 | 548 | batch_size = sim_data.size(0) 549 | batch_graph = [] 550 | Cov = torch.zeros(10,2048,9) 551 | for i in range(batch_size): 552 | ith_graph, nbsd, cov_i = build_graph_core(sim_data[i].numpy(), args) 553 | batch_graph.append(ith_graph) 554 | Cov[i,:,:] = torch.from_numpy(cov_i) 555 | print('Cov.size: ', Cov.size()) 556 | print('len(batch_graph):', len(batch_graph)) 557 | 558 | sim_data = sim_data.transpose(2,1) 559 | Cov = Cov.transpose(2,1) 560 | fold = FoldingNet_graph() 561 | out, out_middle, code = fold(sim_data, Cov, batch_graph) 562 | print('out', out.size()) 563 | print('out_middle', out_middle.size()) 564 | print('code ', code.size()) 565 | 566 | -------------------------------------------------------------------------------- /prepare_graph.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import numpy as np 5 | import h5py 6 | import argparse 7 | import scipy.sparse 8 | from sklearn.neighbors import KDTree 9 | import multiprocessing as multiproc 10 | from functools import partial 11 | import glog as logger 12 | from copy import deepcopy 13 | import errno 14 | import gdown #https://github.com/wkentaro/gdown 15 | import pyxis 16 | import torch 17 | 18 | 19 | def edges2A(edges, n_nodes, mode='P', sparse_mat_type=scipy.sparse.csr_matrix): 20 | ''' 21 | note: assume no (i,i)-like edge 22 | edges: <2xE> 23 | ''' 24 | edges = np.array(edges).astype(int) 25 | 26 | data_D = np.zeros(n_nodes, dtype=np.float32) 27 | for d in range(n_nodes): 28 | data_D[ d ] = len(np.where(edges[0] == d)[0]) # compute the number of node which pick node_i as their neighbor 29 | 30 | if mode.upper() == 'M': # 'M' means max pooling, which use the same graph matrix as the adjacency matrix 31 | data = np.ones(edges[0].shape[0], dtype=np.int32) 32 | elif mode.upper() == 'P': 33 | data = 1. / data_D[ edges[0] ] 34 | else: 35 | raise NotImplementedError("edges2A with unknown mode=" + mode) 36 | 37 | return sparse_mat_type((data, edges), shape=(n_nodes, n_nodes)) 38 | 39 | def knn_search(data, knn=16, metric="euclidean", symmetric=True): 40 | """ 41 | Args: 42 | data: Nx3 43 | knn: default=16 44 | """ 45 | assert(knn>0) 46 | n_data_i = data.shape[0] 47 | kdt = KDTree(data, leaf_size=30, metric=metric) 48 | 49 | nbs = kdt.query(data, k=knn+1, return_distance=True) # nbs[0]:NN distance,N*17. nbs[1]:NN index,N*17 50 | cov = np.zeros((n_data_i,9), dtype=np.float32) 51 | adjdict = dict() 52 | # wadj = np.zeros((n_data_i, n_data_i), dtype=np.float32) 53 | for i in range(n_data_i): 54 | # nbsd = nbs[0][i] 55 | nbsi = nbs[1][i] #index i, N*17 YW comment 56 | cov[i] = np.cov(data[nbsi[1:]].T).reshape(-1) #compute local covariance matrix 57 | for j in range(knn): 58 | if symmetric: 59 | adjdict[(i, nbsi[j+1])] = 1 60 | adjdict[(nbsi[j+1], i)] = 1 61 | # wadj[i, nbsi[j + 1]] = 1.0 / nbsd[j + 1] 62 | # wadj[nbsi[j + 1], i] = 1.0 / nbsd[j + 1] 63 | else: 64 | adjdict[(i, nbsi[j+1])] = 1 65 | # wadj[i, nbsi[j + 1]] = 1.0 / nbsd[j + 1] 66 | edges = np.array(list(adjdict.keys()), dtype=int).T 67 | return edges, nbs[0], cov #, wadj 68 | 69 | def build_graph_core(ith_datai, args): 70 | try: 71 | #ith, xyi = ith_datai #xyi: 2048x3 72 | xyi = ith_datai # xyi: 2048x3 73 | n_data_i = xyi.shape[0] 74 | edges, nbsd, cov = knn_search(xyi, knn=args.knn, metric=args.metric) 75 | ith_graph = edges2A(edges, n_data_i, args.mode, sparse_mat_type=scipy.sparse.csr_matrix) 76 | nbsd=np.asarray(nbsd)[:, 1:] 77 | nbsd=np.reshape(nbsd, -1) 78 | 79 | #if ith % 500 == 0: 80 | #logger.info('{} processed: {}'.format(args.flag, ith)) 81 | 82 | #return ith, ith_graph, nbsd, cov 83 | return ith_graph, nbsd, cov 84 | except KeyboardInterrupt: 85 | exit(-1) 86 | 87 | def build_graph(points, args): # points: batch, num of points, 3 88 | 89 | points = points.numpy() 90 | batch_graph = [] 91 | Cov = torch.zeros(points.shape[0], args.num_points, 9) 92 | 93 | pool = multiproc.Pool(2) 94 | pool_func = partial(build_graph_core, args=args) 95 | rets = pool.map(pool_func, points) 96 | pool.close() 97 | count = 0 98 | for ret in rets: 99 | ith_graph, _, cov = ret 100 | batch_graph.append(ith_graph) 101 | Cov[count,:,:] = torch.from_numpy(cov) 102 | count = count+1 103 | del rets 104 | 105 | return batch_graph, Cov 106 | 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | # test YW 113 | parser = argparse.ArgumentParser(sys.argv[0]) 114 | 115 | parser.add_argument('--pts_mn40', type=int, default=2048, 116 | help="number of points per modelNet40 object") 117 | parser.add_argument('--pts_shapenet_part', type=int, default=2048, 118 | help="number of points per shapenet_part object") 119 | parser.add_argument('--pts_shapenet', type=int, default=2048, 120 | help="number of points per shapenet object") 121 | parser.add_argument('-md', '--mode', type=str, default="M", 122 | help="mode used to compute graphs: M, P") 123 | parser.add_argument('-m', '--metric', type=str, default='euclidean', 124 | help="metric for distance calculation (manhattan/euclidean)") 125 | parser.add_argument('--no-shuffle', dest='shuffle', action='store_false', default=True, 126 | help="whether to shuffle data (1) or not (0) before saving") 127 | parser.add_argument('--regenerate', dest='regenerate', action='store_true', default=False, 128 | help='regenerate from raw pointnet data or not (default: False)') 129 | 130 | args = parser.parse_args(sys.argv[1:]) 131 | args.script_folder = os.path.dirname(os.path.abspath(__file__)) 132 | 133 | args.knn = 16 134 | args.mode = 'M' 135 | args.batchSize = 10 136 | args.num_points = 2048 137 | args.metric = 'euclidean' 138 | sim_data = torch.rand(10,2048, 3) 139 | #ith, ith_graph, nbsd, cov = build_graph_core((0, sim_data), args) 140 | #ith_graph, nbsd, cov = build_graph_core(sim_data, args) 141 | 142 | batch_graph, Cov = build_graph(sim_data, args) 143 | 144 | print('done') 145 | -------------------------------------------------------------------------------- /render_balls_so.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2]0: 92 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=0)) 93 | if magnifyBlue>=2: 94 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=0)) 95 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=1)) 96 | if magnifyBlue>=2: 97 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=1)) 98 | if showrot: 99 | cv2.putText(show,'xangle %d'%(int(xangle/np.pi*180)),(30,showsz-30),0,0.5,cv2.cv.CV_RGB(255,0,0)) 100 | cv2.putText(show,'yangle %d'%(int(yangle/np.pi*180)),(30,showsz-50),0,0.5,cv2.cv.CV_RGB(255,0,0)) 101 | cv2.putText(show,'zoom %d%%'%(int(zoom*100)),(30,showsz-70),0,0.5,cv2.cv.CV_RGB(255,0,0)) 102 | changed=True 103 | while True: 104 | if changed: 105 | render() 106 | changed=False 107 | cv2.imshow('show3d',show) 108 | if waittime==0: 109 | cmd=cv2.waitKey(10)%256 110 | else: 111 | cmd=cv2.waitKey(waittime)%256 112 | if cmd==ord('q'): 113 | break 114 | elif cmd==ord('Q'): 115 | sys.exit(0) 116 | 117 | if cmd==ord('t') or cmd == ord('p'): 118 | if cmd == ord('t'): 119 | if c_gt is None: 120 | c0=np.zeros((len(xyz),),dtype='float32')+255 121 | c1=np.zeros((len(xyz),),dtype='float32')+255 122 | c2=np.zeros((len(xyz),),dtype='float32')+255 123 | else: 124 | c0=c_gt[:,0] 125 | c1=c_gt[:,1] 126 | c2=c_gt[:,2] 127 | else: 128 | if c_pred is None: 129 | c0=np.zeros((len(xyz),),dtype='float32')+255 130 | c1=np.zeros((len(xyz),),dtype='float32')+255 131 | c2=np.zeros((len(xyz),),dtype='float32')+255 132 | else: 133 | c0=c_pred[:,0] 134 | c1=c_pred[:,1] 135 | c2=c_pred[:,2] 136 | if normalizecolor: 137 | c0/=(c0.max()+1e-14)/255.0 138 | c1/=(c1.max()+1e-14)/255.0 139 | c2/=(c2.max()+1e-14)/255.0 140 | c0=np.require(c0,'float32','C') 141 | c1=np.require(c1,'float32','C') 142 | c2=np.require(c2,'float32','C') 143 | changed = True 144 | 145 | 146 | 147 | if cmd==ord('n'): 148 | zoom*=1.1 149 | changed=True 150 | elif cmd==ord('m'): 151 | zoom/=1.1 152 | changed=True 153 | elif cmd==ord('r'): 154 | zoom=1.0 155 | changed=True 156 | elif cmd==ord('s'): 157 | cv2.imwrite('show3d.png',show) 158 | if waittime!=0: 159 | break 160 | return cmd 161 | if __name__=='__main__': 162 | 163 | np.random.seed(100) 164 | showpoints(np.random.randn(2500,3)) 165 | 166 | -------------------------------------------------------------------------------- /show_cls.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetCls 18 | import torch.nn.functional as F 19 | import matplotlib.pyplot as plt 20 | 21 | 22 | #showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500))) 23 | 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument('--model', type=str, default = '', help='model path') 27 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 28 | 29 | 30 | opt = parser.parse_args() 31 | print (opt) 32 | 33 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0' , train = False, classification = True, npoints = opt.num_points) 34 | 35 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle = True) 36 | 37 | 38 | classifier = PointNetCls(k = len(test_dataset.classes), num_points = opt.num_points) 39 | classifier.cuda() 40 | classifier.load_state_dict(torch.load(opt.model)) 41 | classifier.eval() 42 | 43 | 44 | for i, data in enumerate(testdataloader, 0): 45 | points, target = data 46 | points, target = Variable(points), Variable(target[:, 0]) 47 | points = points.transpose(2, 1) 48 | points, target = points.cuda(), target.cuda() 49 | pred, _ = classifier(points) 50 | loss = F.nll_loss(pred, target) 51 | 52 | pred_choice = pred.data.max(1)[1] 53 | correct = pred_choice.eq(target.data).cpu().sum() 54 | print('i:%d loss: %f accuracy: %f' %(i, loss.data[0], correct/float(32))) 55 | -------------------------------------------------------------------------------- /show_pt_yw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mpl_toolkits.mplot3d import Axes3D 3 | import matplotlib.pyplot as plt 4 | import sys 5 | from PIL import Image 6 | import os 7 | import os.path 8 | import errno 9 | import torch 10 | import argparse 11 | import json 12 | import codecs 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | import torch.utils.data 19 | import torchvision.datasets as dset 20 | import torchvision.transforms as transforms 21 | import torchvision.utils as vutils 22 | from torch.autograd import Variable 23 | from datasets import PartDataset 24 | import torch.nn.functional as F 25 | from pointnet import FoldingNet,ChamferLoss,FoldingNet_graph 26 | from prepare_graph import build_graph 27 | 28 | 29 | 30 | if __name__=='__main__': 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--batchSize', type=int, default=8, help='input batch size') 34 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 35 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 36 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 37 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 38 | parser.add_argument('--model', type=str, default='', help='model path') 39 | parser.add_argument('-md', '--mode', type=str, default="M", help="mode used to compute graphs: M, P") 40 | parser.add_argument('-m', '--metric', type=str, default='euclidean', 41 | help="metric for distance calculation (manhattan/euclidean)") 42 | 43 | opt = parser.parse_args() 44 | opt.nepoch = 200 # yw add 45 | opt.knn = 16 # yw 46 | 47 | np.random.seed(100) 48 | pt = np.random.rand(250,3) 49 | # fig = plt.figure() 50 | # ax = fig.add_subplot(111,projection='3d') 51 | 52 | #ax.scatter(pt[:,0],pt[:,1],pt[:,2]) 53 | #plt.show() 54 | 55 | class_choice = 'Airplane' 56 | pt_root = 'shapenetcore_partanno_segmentation_benchmark_v0' 57 | npoints = 2500 58 | 59 | shapenet_dataset = PartDataset(root = pt_root, class_choice = class_choice, classification = True,train = True) 60 | print('len(shapenet_dataset) :',len(shapenet_dataset)) 61 | dataloader = torch.utils.data.DataLoader(shapenet_dataset,batch_size=1,shuffle=False) 62 | 63 | li = list(enumerate(dataloader)) 64 | print(len(li)) 65 | 66 | # ps,cls = shapenet_dataset[0] 67 | # print('ps.size:',ps.size()) 68 | # print('ps.type:',ps.type()) 69 | # print('cls.size',cls.size()) 70 | # print('cls.type',cls.type()) 71 | 72 | # ps2,cls2 = shapenet_dataset[1] 73 | 74 | # ax.scatter(ps[:,0],ps[:,1],ps[:,2]) 75 | # ax.set_xlabel('X label') 76 | # ax.set_ylabel('Y label') 77 | # ax.set_zlabel('Z label') 78 | 79 | # # fig2 = plt.figure() 80 | # # a2 = fig2.add_subplot(111,projection='3d') 81 | # # a2.scatter(ps2[:,0],ps2[:,1],ps2[:,2]) 82 | 83 | # plt.show() 84 | 85 | foldingnet = FoldingNet_graph() 86 | 87 | #foldingnet.load_state_dict(torch.load('cls_fold_512code_2500points/foldingnet_model_170.pth')) 88 | foldingnet.load_state_dict(torch.load('cls_fold_512code_2500points_170_restart/foldingnet_model_150.pth')) 89 | 90 | foldingnet.cuda() 91 | 92 | chamferloss = ChamferLoss() 93 | chamferloss = chamferloss.cuda() 94 | #print(foldingnet) 95 | 96 | foldingnet.eval() 97 | 98 | i, data = li[1] 99 | points, target = data 100 | 101 | batch_graph, Cov = build_graph(points, opt) 102 | 103 | Cov = Cov.transpose(2, 1) 104 | Cov = Cov.cuda() 105 | 106 | points = points.transpose(2,1) 107 | points = points.cuda() 108 | recon_pc, mid_pc, _ = foldingnet(points ,Cov, batch_graph) 109 | 110 | points_show = points.cpu().detach().numpy() 111 | re_show = recon_pc.cpu().detach().numpy() 112 | 113 | fig_ori = plt.figure() 114 | a1 = fig_ori.add_subplot(111,projection='3d') 115 | a1.scatter(points_show[0,0,:],points_show[0,1,:],points_show[0,2,:]) 116 | #plt.savefig('points_show.png') 117 | 118 | fig_re = plt.figure() 119 | a2 = fig_re.add_subplot(111,projection='3d') 120 | a2.scatter(re_show[0,0,:],re_show[0,1,:],re_show[0,2,:]) 121 | #plt.savefig('re_show.png') 122 | 123 | plt.show() 124 | 125 | print('points.size:', points.size()) 126 | print('recon_pc.size:', recon_pc.size()) 127 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 128 | print('loss',loss.item()) 129 | 130 | try: 131 | os.makedirs('bin') 132 | except OSError: 133 | pass 134 | 135 | for i,data in enumerate(dataloader): 136 | points, target = data 137 | points = points.transpose(2,1) 138 | points = points.cuda() 139 | recon_pc, _, code = foldingnet(points) 140 | points_show = points.cpu().detach().numpy() 141 | #print(points_show.shape) 142 | points_show = points_show.transpose(0,2,1) 143 | re_show = recon_pc.cpu().detach().numpy() 144 | re_show = re_show.transpose(0,2,1) 145 | 146 | #batch = points.size(0) 147 | 148 | np.savetxt('recon_pc/ori_%s_%d.pts'%(class_choice,i),points_show[0]) 149 | np.savetxt('recon_pc/rec_%s_%d.pts'%(class_choice,i),re_show[0]) 150 | 151 | code_save = code.cpu().detach().numpy().astype(int) 152 | np.savetxt('bin/%s_%d.bin'%(class_choice, i), code_save) 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /show_seg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from show3d_balls import * 3 | import argparse 4 | import os 5 | import random 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | import torchvision.utils as vutils 16 | from torch.autograd import Variable 17 | from datasets import PartDataset 18 | from pointnet import PointNetDenseCls 19 | import torch.nn.functional as F 20 | import matplotlib.pyplot as plt 21 | 22 | 23 | #showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500))) 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--model', type=str, default = '', help='model path') 28 | parser.add_argument('--idx', type=int, default = 0, help='model index') 29 | 30 | 31 | 32 | opt = parser.parse_args() 33 | print (opt) 34 | 35 | d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair'], train = False) 36 | 37 | idx = opt.idx 38 | 39 | print("model %d/%d" %( idx, len(d))) 40 | 41 | point, seg = d[idx] 42 | print(point.size(), seg.size()) 43 | 44 | point_np = point.numpy() 45 | 46 | 47 | 48 | cmap = plt.cm.get_cmap("hsv", 10) 49 | cmap = np.array([cmap(i) for i in range(10)])[:,:3] 50 | gt = cmap[seg.numpy() - 1, :] 51 | 52 | classifier = PointNetDenseCls(k = 4) 53 | classifier.load_state_dict(torch.load(opt.model)) 54 | classifier.eval() 55 | 56 | point = point.transpose(1,0).contiguous() 57 | 58 | point = Variable(point.view(1, point.size()[0], point.size()[1])) 59 | pred, _ = classifier(point) 60 | pred_choice = pred.data.max(2)[1] 61 | print(pred_choice) 62 | 63 | #print(pred_choice.size()) 64 | pred_color = cmap[pred_choice.numpy()[0], :] 65 | 66 | #print(pred_color.shape) 67 | showpoints(point_np, gt, pred_color) 68 | 69 | -------------------------------------------------------------------------------- /train_FoldingNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetCls 18 | from pointnet import FoldingNet 19 | from pointnet import ChamferLoss 20 | import torch.nn.functional as F 21 | from visdom import Visdom 22 | import time 23 | from mpl_toolkits.mplot3d import Axes3D 24 | import matplotlib.pyplot as plt 25 | 26 | 27 | vis = Visdom() 28 | line = vis.line(np.arange(10)) 29 | 30 | 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--batchSize', type=int, default=8, help='input batch size') 35 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 36 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 37 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 38 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 39 | parser.add_argument('--model', type=str, default = '', help='model path') 40 | 41 | opt = parser.parse_args() 42 | opt.nepoch = 200 # yw add 43 | print(opt) 44 | 45 | blue = lambda x:'\033[94m' + x + '\033[0m' 46 | 47 | opt.manualSeed = random.randint(1, 10000) # fix seed 48 | print("Random Seed: ", opt.manualSeed) 49 | random.seed(opt.manualSeed) 50 | torch.manual_seed(opt.manualSeed) 51 | 52 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points) 53 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 54 | shuffle=True, num_workers=int(opt.workers)) 55 | 56 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points) 57 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 58 | shuffle=True, num_workers=int(opt.workers)) 59 | 60 | print(len(dataset), len(test_dataset)) 61 | num_classes = len(dataset.classes) 62 | print('classes', num_classes) 63 | 64 | try: 65 | os.makedirs(opt.outf) 66 | except OSError: 67 | pass 68 | 69 | 70 | #classifier = PointNetCls(k = num_classes) 71 | foldingnet = FoldingNet() 72 | 73 | 74 | 75 | if opt.model != '': 76 | foldingnet.load_state_dict(torch.load(opt.model)) 77 | 78 | 79 | #optimizer = optim.SGD(foldingnet.parameters(), lr=0.01, momentum=0.9) 80 | optimizer = optim.Adam(foldingnet.parameters(),lr = 0.0001,weight_decay=1e-6) 81 | foldingnet.cuda() 82 | 83 | num_batch = len(dataset)/opt.batchSize 84 | 85 | chamferloss = ChamferLoss() 86 | chamferloss.cuda() 87 | 88 | start_time = time.time() 89 | time_p, loss_p, loss_m = [],[],[] 90 | 91 | for epoch in range(opt.nepoch): 92 | sum_loss = 0 93 | sum_step = 0 94 | sum_mid_loss = 0 95 | for i, data in enumerate(dataloader, 0): 96 | points, target = data 97 | 98 | #print(points.size()) 99 | 100 | points, target = Variable(points), Variable(target[:,0]) 101 | points = points.transpose(2,1) 102 | points, target = points.cuda(), target.cuda() 103 | optimizer.zero_grad() 104 | foldingnet = foldingnet.train() 105 | recon_pc, mid_pc, _ = foldingnet(points) 106 | 107 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 108 | loss.backward() 109 | optimizer.step() 110 | 111 | mid_loss = chamferloss(points.transpose(2,1),mid_pc.transpose(2,1)) 112 | 113 | # store loss and step 114 | sum_loss += loss.item()*points.size(0) 115 | sum_mid_loss += mid_loss.item()*points.size(0) 116 | sum_step += points.size(0) 117 | 118 | print('[%d: %d/%d] train loss: %f middle loss: %f' %(epoch, i, num_batch, loss.item(),mid_loss.item())) 119 | 120 | if i % 100 == 0: 121 | j, data = next(enumerate(testdataloader, 0)) 122 | points, target = data 123 | points, target = Variable(points), Variable(target[:,0]) 124 | points = points.transpose(2,1) 125 | points, target = points.cuda(), target.cuda() 126 | foldingnet = foldingnet.eval() 127 | recon_pc, mid_pc, _ = foldingnet(points) 128 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 129 | 130 | mid_loss = chamferloss(points.transpose(2,1),mid_pc.transpose(2,1)) 131 | 132 | # prepare show result 133 | points_show = points.cpu().detach().numpy() 134 | #points_show = points_show[0] 135 | re_show = recon_pc.cpu().detach().numpy() 136 | #re_show = re_show[0] 137 | 138 | 139 | fig_ori = plt.figure() 140 | a1 = fig_ori.add_subplot(111,projection='3d') 141 | a1.scatter(points_show[0,0,:],points_show[0,1,:],points_show[0,2,:]) 142 | plt.savefig('points_show.png') 143 | 144 | fig_re = plt.figure() 145 | a2 = fig_re.add_subplot(111,projection='3d') 146 | a2.scatter(re_show[0,0,:],re_show[0,1,:],re_show[0,2,:]) 147 | plt.savefig('re_show.png') 148 | 149 | 150 | # plot results 151 | time_p.append(time.time()-start_time) 152 | loss_p.append(sum_loss/sum_step) 153 | loss_m.append(sum_mid_loss/sum_step) 154 | vis.line(X=np.array(time_p), 155 | Y=np.array(loss_p), 156 | win=line, 157 | opts=dict(legend=["Loss"])) 158 | 159 | 160 | 161 | print('[%d: %d/%d] %s test loss: %f middle test loss: %f' %(epoch, i, num_batch, blue('test'), loss.item(), mid_loss.item())) 162 | sum_step = 0 163 | sum_loss = 0 164 | sum_mid_loss = 0 165 | 166 | torch.save(foldingnet.state_dict(), '%s/foldingnet_model_%d.pth' % (opt.outf, epoch)) 167 | -------------------------------------------------------------------------------- /train_FoldingNet_graph.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetCls 18 | from pointnet import FoldingNet, FoldingNet_graph 19 | from pointnet import ChamferLoss 20 | import torch.nn.functional as F 21 | from visdom import Visdom 22 | import time 23 | from mpl_toolkits.mplot3d import Axes3D 24 | import matplotlib.pyplot as plt 25 | from prepare_graph import build_graph 26 | 27 | 28 | 29 | vis = Visdom() 30 | line = vis.line(np.arange(10)) 31 | 32 | 33 | 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--batchSize', type=int, default=8, help='input batch size') 37 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 38 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 39 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 40 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 41 | parser.add_argument('--model', type=str, default = '', help='model path') 42 | parser.add_argument('-md', '--mode', type=str, default="M", help="mode used to compute graphs: M, P") 43 | parser.add_argument('-m', '--metric', type=str, default='euclidean', 44 | help="metric for distance calculation (manhattan/euclidean)") 45 | 46 | opt = parser.parse_args() 47 | opt.nepoch = 200 # yw add 48 | opt.knn = 16 # yw 49 | print(opt) 50 | 51 | blue = lambda x:'\033[94m' + x + '\033[0m' 52 | 53 | opt.manualSeed = random.randint(1, 10000) # fix seed 54 | print("Random Seed: ", opt.manualSeed) 55 | random.seed(opt.manualSeed) 56 | torch.manual_seed(opt.manualSeed) 57 | 58 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points) 59 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 60 | shuffle=True, num_workers=int(opt.workers)) 61 | 62 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points) 63 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 64 | shuffle=True, num_workers=int(opt.workers)) 65 | 66 | print(len(dataset), len(test_dataset)) 67 | num_classes = len(dataset.classes) 68 | print('classes', num_classes) 69 | 70 | try: 71 | os.makedirs(opt.outf) 72 | except OSError: 73 | pass 74 | 75 | 76 | #classifier = PointNetCls(k = num_classes) 77 | #foldingnet = FoldingNet() 78 | foldingnet = FoldingNet_graph() 79 | 80 | 81 | 82 | if opt.model != '': 83 | foldingnet.load_state_dict(torch.load(opt.model)) 84 | 85 | 86 | #optimizer = optim.SGD(foldingnet.parameters(), lr=0.01, momentum=0.9) 87 | optimizer = optim.Adam(foldingnet.parameters(),lr = 0.0001,weight_decay=1e-6) 88 | foldingnet.cuda() 89 | 90 | num_batch = len(dataset)/opt.batchSize 91 | 92 | chamferloss = ChamferLoss() 93 | chamferloss.cuda() 94 | 95 | start_time = time.time() 96 | time_p, loss_p, loss_m = [],[],[] 97 | 98 | for epoch in range(opt.nepoch): 99 | sum_loss = 0 100 | sum_step = 0 101 | sum_mid_loss = 0 102 | for i, data in enumerate(dataloader, 0): 103 | points, target = data 104 | # print(points.size()) 105 | 106 | # build graph 107 | batch_graph, Cov = build_graph(points, opt) 108 | 109 | Cov = Cov.transpose(2,1) 110 | Cov = Cov.cuda() 111 | #points, target = Variable(points), Variable(target[:,0]) 112 | points = points.transpose(2,1) 113 | points, target = points.cuda(), target.cuda() 114 | 115 | optimizer.zero_grad() 116 | foldingnet = foldingnet.train() 117 | recon_pc, mid_pc, _ = foldingnet(points, Cov, batch_graph) 118 | 119 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 120 | loss.backward() 121 | optimizer.step() 122 | 123 | mid_loss = chamferloss(points.transpose(2,1),mid_pc.transpose(2,1)) 124 | 125 | # store loss and step 126 | sum_loss += loss.item()*points.size(0) 127 | sum_mid_loss += mid_loss.item()*points.size(0) 128 | sum_step += points.size(0) 129 | 130 | print('[%d: %d/%d] train loss: %f middle loss: %f' %(epoch, i, num_batch, loss.item(),mid_loss.item())) 131 | 132 | if i % 100 == 0: 133 | j, data = next(enumerate(testdataloader, 0)) 134 | points, target = data 135 | # build graph 136 | batch_graph, Cov = build_graph(points, opt) 137 | Cov = Cov.transpose(2, 1) 138 | Cov = Cov.cuda() 139 | 140 | points, target = Variable(points), Variable(target[:,0]) 141 | points = points.transpose(2,1) 142 | points, target = points.cuda(), target.cuda() 143 | foldingnet = foldingnet.eval() 144 | recon_pc, mid_pc, _ = foldingnet(points, Cov, batch_graph) 145 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 146 | 147 | mid_loss = chamferloss(points.transpose(2,1),mid_pc.transpose(2,1)) 148 | 149 | # prepare show result 150 | points_show = points.cpu().detach().numpy() 151 | #points_show = points_show[0] 152 | re_show = recon_pc.cpu().detach().numpy() 153 | #re_show = re_show[0] 154 | 155 | 156 | fig_ori = plt.figure() 157 | a1 = fig_ori.add_subplot(111,projection='3d') 158 | a1.scatter(points_show[0,0,:],points_show[0,1,:],points_show[0,2,:]) 159 | plt.savefig('points_show.png') 160 | 161 | fig_re = plt.figure() 162 | a2 = fig_re.add_subplot(111,projection='3d') 163 | a2.scatter(re_show[0,0,:],re_show[0,1,:],re_show[0,2,:]) 164 | plt.savefig('re_show.png') 165 | 166 | 167 | # plot results 168 | time_p.append(time.time()-start_time) 169 | loss_p.append(sum_loss/sum_step) 170 | loss_m.append(sum_mid_loss/sum_step) 171 | vis.line(X=np.array(time_p), 172 | Y=np.array(loss_p), 173 | win=line, 174 | opts=dict(legend=["Loss"])) 175 | 176 | 177 | 178 | print('[%d: %d/%d] %s test loss: %f middle test loss: %f' %(epoch, i, num_batch, blue('test'), loss.item(), mid_loss.item())) 179 | sum_step = 0 180 | sum_loss = 0 181 | sum_mid_loss = 0 182 | 183 | torch.save(foldingnet.state_dict(), '%s/foldingnet_model_%d.pth' % (opt.outf, epoch)) 184 | -------------------------------------------------------------------------------- /train_classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetCls 18 | import torch.nn.functional as F 19 | 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 24 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 25 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 26 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 27 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 28 | parser.add_argument('--model', type=str, default = '', help='model path') 29 | 30 | opt = parser.parse_args() 31 | print (opt) 32 | 33 | blue = lambda x:'\033[94m' + x + '\033[0m' 34 | 35 | opt.manualSeed = random.randint(1, 10000) # fix seed 36 | print("Random Seed: ", opt.manualSeed) 37 | random.seed(opt.manualSeed) 38 | torch.manual_seed(opt.manualSeed) 39 | 40 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points) 41 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 42 | shuffle=True, num_workers=int(opt.workers)) 43 | 44 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points) 45 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 46 | shuffle=True, num_workers=int(opt.workers)) 47 | 48 | print(len(dataset), len(test_dataset)) 49 | num_classes = len(dataset.classes) 50 | print('classes', num_classes) 51 | 52 | try: 53 | os.makedirs(opt.outf) 54 | except OSError: 55 | pass 56 | 57 | 58 | classifier = PointNetCls(k = num_classes) 59 | 60 | 61 | if opt.model != '': 62 | classifier.load_state_dict(torch.load(opt.model)) 63 | 64 | 65 | optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 66 | classifier.cuda() 67 | 68 | num_batch = len(dataset)/opt.batchSize 69 | 70 | for epoch in range(opt.nepoch): 71 | for i, data in enumerate(dataloader, 0): 72 | points, target = data 73 | points, target = Variable(points), Variable(target[:,0]) 74 | points = points.transpose(2,1) 75 | points, target = points.cuda(), target.cuda() 76 | optimizer.zero_grad() 77 | classifier = classifier.train() 78 | pred, _ = classifier(points) 79 | loss = F.nll_loss(pred, target) 80 | loss.backward() 81 | optimizer.step() 82 | pred_choice = pred.data.max(1)[1] 83 | correct = pred_choice.eq(target.data).cpu().sum() 84 | print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(),correct.item() / float(opt.batchSize))) 85 | 86 | if i % 10 == 0: 87 | j, data = next(enumerate(testdataloader, 0)) 88 | points, target = data 89 | points, target = Variable(points), Variable(target[:,0]) 90 | points = points.transpose(2,1) 91 | points, target = points.cuda(), target.cuda() 92 | classifier = classifier.eval() 93 | pred, _ = classifier(points) 94 | loss = F.nll_loss(pred, target) 95 | pred_choice = pred.data.max(1)[1] 96 | correct = pred_choice.eq(target.data).cpu().sum() 97 | print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 98 | 99 | torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 100 | -------------------------------------------------------------------------------- /train_segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetDenseCls 18 | import torch.nn.functional as F 19 | 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 24 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 25 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 26 | parser.add_argument('--outf', type=str, default='seg', help='output folder') 27 | parser.add_argument('--model', type=str, default = '', help='model path') 28 | 29 | 30 | opt = parser.parse_args() 31 | print (opt) 32 | 33 | opt.manualSeed = random.randint(1, 10000) # fix seed 34 | print("Random Seed: ", opt.manualSeed) 35 | random.seed(opt.manualSeed) 36 | torch.manual_seed(opt.manualSeed) 37 | 38 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = False, class_choice = ['Chair']) 39 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 40 | shuffle=True, num_workers=int(opt.workers)) 41 | 42 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = False, class_choice = ['Chair'], train = False) 43 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 44 | shuffle=True, num_workers=int(opt.workers)) 45 | 46 | print(len(dataset), len(test_dataset)) 47 | num_classes = dataset.num_seg_classes 48 | print('classes', num_classes) 49 | try: 50 | os.makedirs(opt.outf) 51 | except OSError: 52 | pass 53 | 54 | blue = lambda x:'\033[94m' + x + '\033[0m' 55 | 56 | 57 | classifier = PointNetDenseCls(k = num_classes) 58 | 59 | if opt.model != '': 60 | classifier.load_state_dict(torch.load(opt.model)) 61 | 62 | optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 63 | classifier.cuda() 64 | 65 | num_batch = len(dataset)/opt.batchSize 66 | 67 | for epoch in range(opt.nepoch): 68 | for i, data in enumerate(dataloader, 0): 69 | points, target = data 70 | points, target = Variable(points), Variable(target) 71 | points = points.transpose(2,1) 72 | points, target = points.cuda(), target.cuda() 73 | optimizer.zero_grad() 74 | classifier = classifier.train() 75 | pred, _ = classifier(points) 76 | pred = pred.view(-1, num_classes) 77 | target = target.view(-1,1)[:,0] - 1 78 | #print(pred.size(), target.size()) 79 | loss = F.nll_loss(pred, target) 80 | loss.backward() 81 | optimizer.step() 82 | pred_choice = pred.data.max(1)[1] 83 | correct = pred_choice.eq(target.data).cpu().sum() 84 | print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500))) 85 | 86 | if i % 10 == 0: 87 | j, data = next(enumerate(testdataloader, 0)) 88 | points, target = data 89 | points, target = Variable(points), Variable(target) 90 | points = points.transpose(2,1) 91 | points, target = points.cuda(), target.cuda() 92 | classifier = classifier.eval() 93 | pred, _ = classifier(points) 94 | pred = pred.view(-1, num_classes) 95 | target = target.view(-1,1)[:,0] - 1 96 | 97 | loss = F.nll_loss(pred, target) 98 | pred_choice = pred.data.max(1)[1] 99 | correct = pred_choice.eq(target.data).cpu().sum() 100 | print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500))) 101 | 102 | torch.save(classifier.state_dict(), '%s/seg_model_%d.pth' % (opt.outf, epoch)) --------------------------------------------------------------------------------