├── download.sh ├── LICENSE ├── README.md ├── render_balls_so.cpp ├── show_cls.py ├── show_seg.py ├── datasets.py ├── train_classification.py ├── show_pt_yw.py ├── train_segmentation.py ├── show3d_balls.py ├── train_FoldingNet.py └── pointnet.py /download.sh: -------------------------------------------------------------------------------- 1 | wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate 2 | unzip shapenetcore_partanno_segmentation_benchmark_v0.zip 3 | rm shapenetcore_partanno_segmentation_benchmark_v0.zip 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Fei Xia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointNet.pytorch and FoldingNet decoder, add quantization, change latent code size from 512 to 1024 2 | This repo is implementation for PointNet(https://arxiv.org/abs/1612.00593) in pytorch. The model is in `pointnet.py`. 3 | 4 | 5 | # Download data and running 6 | 7 | ``` 8 | bash build.sh #build C++ code for visualization 9 | bash download.sh #download dataset 10 | python train_classification.py #train 3D model classification 11 | python python train_segmentation.py # train 3D model segmentaion 12 | 13 | python show_seg.py --model seg/seg_model_20.pth # show segmentation results 14 | ``` 15 | 16 | # Performance 17 | Without heavy tuning, PointNet can achieve 80-90% performance in classification and segmentaion on this [dataset](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html). 18 | 19 | Sample segmentation result: 20 | ![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D) 21 | 22 | 23 | # Links 24 | 25 | - [Project Page](http://stanford.edu/~rqi/pointnet/) 26 | - [Tensorflow implementation](https://github.com/charlesq34/pointnet) 27 | -------------------------------------------------------------------------------- /render_balls_so.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2] self.num_seg_classes: 65 | self.num_seg_classes = l 66 | #print(self.num_seg_classes) 67 | 68 | 69 | def __getitem__(self, index): 70 | fn = self.datapath[index] 71 | cls = self.classes[self.datapath[index][0]] 72 | point_set = np.loadtxt(fn[1]).astype(np.float32) 73 | seg = np.loadtxt(fn[2]).astype(np.int64) 74 | #print(point_set.shape, seg.shape) 75 | 76 | choice = np.random.choice(len(seg), self.npoints, replace=True) 77 | #resample 78 | point_set = point_set[choice, :] 79 | seg = seg[choice] 80 | point_set = torch.from_numpy(point_set) 81 | seg = torch.from_numpy(seg) 82 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 83 | if self.classification: 84 | return point_set, cls 85 | else: 86 | return point_set, seg 87 | 88 | def __len__(self): 89 | return len(self.datapath) 90 | 91 | 92 | if __name__ == '__main__': 93 | print('test') 94 | d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair']) 95 | print(len(d)) 96 | ps, seg = d[0] 97 | print(ps.size(), ps.type(), seg.size(),seg.type()) 98 | 99 | d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True) 100 | print(len(d)) 101 | ps, cls = d[0] 102 | print(ps.size(), ps.type(), cls.size(),cls.type()) 103 | -------------------------------------------------------------------------------- /train_classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetCls 18 | import torch.nn.functional as F 19 | 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 24 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 25 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 26 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 27 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 28 | parser.add_argument('--model', type=str, default = '', help='model path') 29 | 30 | opt = parser.parse_args() 31 | print (opt) 32 | 33 | blue = lambda x:'\033[94m' + x + '\033[0m' 34 | 35 | opt.manualSeed = random.randint(1, 10000) # fix seed 36 | print("Random Seed: ", opt.manualSeed) 37 | random.seed(opt.manualSeed) 38 | torch.manual_seed(opt.manualSeed) 39 | 40 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points) 41 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 42 | shuffle=True, num_workers=int(opt.workers)) 43 | 44 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points) 45 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 46 | shuffle=True, num_workers=int(opt.workers)) 47 | 48 | print(len(dataset), len(test_dataset)) 49 | num_classes = len(dataset.classes) 50 | print('classes', num_classes) 51 | 52 | try: 53 | os.makedirs(opt.outf) 54 | except OSError: 55 | pass 56 | 57 | 58 | classifier = PointNetCls(k = num_classes) 59 | 60 | 61 | if opt.model != '': 62 | classifier.load_state_dict(torch.load(opt.model)) 63 | 64 | 65 | optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 66 | classifier.cuda() 67 | 68 | num_batch = len(dataset)/opt.batchSize 69 | 70 | for epoch in range(opt.nepoch): 71 | for i, data in enumerate(dataloader, 0): 72 | points, target = data 73 | points, target = Variable(points), Variable(target[:,0]) 74 | points = points.transpose(2,1) 75 | points, target = points.cuda(), target.cuda() 76 | optimizer.zero_grad() 77 | classifier = classifier.train() 78 | pred, _ = classifier(points) 79 | loss = F.nll_loss(pred, target) 80 | loss.backward() 81 | optimizer.step() 82 | pred_choice = pred.data.max(1)[1] 83 | correct = pred_choice.eq(target.data).cpu().sum() 84 | print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(),correct.item() / float(opt.batchSize))) 85 | 86 | if i % 10 == 0: 87 | j, data = next(enumerate(testdataloader, 0)) 88 | points, target = data 89 | points, target = Variable(points), Variable(target[:,0]) 90 | points = points.transpose(2,1) 91 | points, target = points.cuda(), target.cuda() 92 | classifier = classifier.eval() 93 | pred, _ = classifier(points) 94 | loss = F.nll_loss(pred, target) 95 | pred_choice = pred.data.max(1)[1] 96 | correct = pred_choice.eq(target.data).cpu().sum() 97 | print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 98 | 99 | torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 100 | -------------------------------------------------------------------------------- /show_pt_yw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mpl_toolkits.mplot3d import Axes3D 3 | import matplotlib.pyplot as plt 4 | import sys 5 | from PIL import Image 6 | import os 7 | import os.path 8 | import errno 9 | import torch 10 | import argparse 11 | import json 12 | import codecs 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.parallel 16 | import torch.backends.cudnn as cudnn 17 | import torch.optim as optim 18 | import torch.utils.data 19 | import torchvision.datasets as dset 20 | import torchvision.transforms as transforms 21 | import torchvision.utils as vutils 22 | from torch.autograd import Variable 23 | from datasets import PartDataset 24 | import torch.nn.functional as F 25 | from pointnet import FoldingNet,ChamferLoss 26 | 27 | 28 | 29 | if __name__=='__main__': 30 | 31 | np.random.seed(100) 32 | pt = np.random.rand(250,3) 33 | # fig = plt.figure() 34 | # ax = fig.add_subplot(111,projection='3d') 35 | 36 | #ax.scatter(pt[:,0],pt[:,1],pt[:,2]) 37 | #plt.show() 38 | 39 | class_choice = 'Airplane' 40 | pt_root = 'shapenetcore_partanno_segmentation_benchmark_v0' 41 | npoints = 2500 42 | 43 | shapenet_dataset = PartDataset(root = pt_root, class_choice = class_choice, classification = True,train = False) 44 | print('len(shapenet_dataset) :',len(shapenet_dataset)) 45 | dataloader = torch.utils.data.DataLoader(shapenet_dataset,batch_size=1,shuffle=False) 46 | 47 | li = list(enumerate(dataloader)) 48 | print(len(li)) 49 | 50 | # ps,cls = shapenet_dataset[0] 51 | # print('ps.size:',ps.size()) 52 | # print('ps.type:',ps.type()) 53 | # print('cls.size',cls.size()) 54 | # print('cls.type',cls.type()) 55 | 56 | # ps2,cls2 = shapenet_dataset[1] 57 | 58 | # ax.scatter(ps[:,0],ps[:,1],ps[:,2]) 59 | # ax.set_xlabel('X label') 60 | # ax.set_ylabel('Y label') 61 | # ax.set_zlabel('Z label') 62 | 63 | # # fig2 = plt.figure() 64 | # # a2 = fig2.add_subplot(111,projection='3d') 65 | # # a2.scatter(ps2[:,0],ps2[:,1],ps2[:,2]) 66 | 67 | # plt.show() 68 | 69 | foldingnet = FoldingNet() 70 | 71 | foldingnet.load_state_dict(torch.load('cls/foldingnet_model_150.pth')) 72 | foldingnet.cuda() 73 | 74 | chamferloss = ChamferLoss() 75 | chamferloss = chamferloss.cuda() 76 | #print(foldingnet) 77 | 78 | foldingnet.eval() 79 | 80 | i, data = li[4] 81 | points, target = data 82 | 83 | points = points.transpose(2,1) 84 | points = points.cuda() 85 | recon_pc, mid_pc, _ = foldingnet(points) 86 | 87 | points_show = points.cpu().detach().numpy() 88 | re_show = recon_pc.cpu().detach().numpy() 89 | 90 | fig_ori = plt.figure() 91 | a1 = fig_ori.add_subplot(111,projection='3d') 92 | a1.scatter(points_show[0,0,:],points_show[0,1,:],points_show[0,2,:]) 93 | #plt.savefig('points_show.png') 94 | 95 | fig_re = plt.figure() 96 | a2 = fig_re.add_subplot(111,projection='3d') 97 | a2.scatter(re_show[0,0,:],re_show[0,1,:],re_show[0,2,:]) 98 | #plt.savefig('re_show.png') 99 | 100 | plt.show() 101 | 102 | print('points.size:', points.size()) 103 | print('recon_pc.size:', recon_pc.size()) 104 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 105 | print('loss',loss.item()) 106 | 107 | try: 108 | os.makedirs('bin') 109 | except OSError: 110 | pass 111 | 112 | for i,data in enumerate(dataloader): 113 | points, target = data 114 | points = points.transpose(2,1) 115 | points = points.cuda() 116 | recon_pc, _, code = foldingnet(points) 117 | points_show = points.cpu().detach().numpy() 118 | #print(points_show.shape) 119 | points_show = points_show.transpose(0,2,1) 120 | re_show = recon_pc.cpu().detach().numpy() 121 | re_show = re_show.transpose(0,2,1) 122 | 123 | #batch = points.size(0) 124 | 125 | np.savetxt('recon_pc/ori_%s_%d.pts'%(class_choice,i),points_show[0]) 126 | np.savetxt('recon_pc/rec_%s_%d.pts'%(class_choice,i),re_show[0]) 127 | 128 | code_save = code.cpu().detach().numpy().astype(int) 129 | np.savetxt('bin/%s_%d.bin'%(class_choice, i), code_save) 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /train_segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetDenseCls 18 | import torch.nn.functional as F 19 | 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 24 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 25 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 26 | parser.add_argument('--outf', type=str, default='seg', help='output folder') 27 | parser.add_argument('--model', type=str, default = '', help='model path') 28 | 29 | 30 | opt = parser.parse_args() 31 | print (opt) 32 | 33 | opt.manualSeed = random.randint(1, 10000) # fix seed 34 | print("Random Seed: ", opt.manualSeed) 35 | random.seed(opt.manualSeed) 36 | torch.manual_seed(opt.manualSeed) 37 | 38 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = False, class_choice = ['Chair']) 39 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 40 | shuffle=True, num_workers=int(opt.workers)) 41 | 42 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = False, class_choice = ['Chair'], train = False) 43 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 44 | shuffle=True, num_workers=int(opt.workers)) 45 | 46 | print(len(dataset), len(test_dataset)) 47 | num_classes = dataset.num_seg_classes 48 | print('classes', num_classes) 49 | try: 50 | os.makedirs(opt.outf) 51 | except OSError: 52 | pass 53 | 54 | blue = lambda x:'\033[94m' + x + '\033[0m' 55 | 56 | 57 | classifier = PointNetDenseCls(k = num_classes) 58 | 59 | if opt.model != '': 60 | classifier.load_state_dict(torch.load(opt.model)) 61 | 62 | optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9) 63 | classifier.cuda() 64 | 65 | num_batch = len(dataset)/opt.batchSize 66 | 67 | for epoch in range(opt.nepoch): 68 | for i, data in enumerate(dataloader, 0): 69 | points, target = data 70 | points, target = Variable(points), Variable(target) 71 | points = points.transpose(2,1) 72 | points, target = points.cuda(), target.cuda() 73 | optimizer.zero_grad() 74 | classifier = classifier.train() 75 | pred, _ = classifier(points) 76 | pred = pred.view(-1, num_classes) 77 | target = target.view(-1,1)[:,0] - 1 78 | #print(pred.size(), target.size()) 79 | loss = F.nll_loss(pred, target) 80 | loss.backward() 81 | optimizer.step() 82 | pred_choice = pred.data.max(1)[1] 83 | correct = pred_choice.eq(target.data).cpu().sum() 84 | print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500))) 85 | 86 | if i % 10 == 0: 87 | j, data = next(enumerate(testdataloader, 0)) 88 | points, target = data 89 | points, target = Variable(points), Variable(target) 90 | points = points.transpose(2,1) 91 | points, target = points.cuda(), target.cuda() 92 | classifier = classifier.eval() 93 | pred, _ = classifier(points) 94 | pred = pred.view(-1, num_classes) 95 | target = target.view(-1,1)[:,0] - 1 96 | 97 | loss = F.nll_loss(pred, target) 98 | pred_choice = pred.data.max(1)[1] 99 | correct = pred_choice.eq(target.data).cpu().sum() 100 | print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500))) 101 | 102 | torch.save(classifier.state_dict(), '%s/seg_model_%d.pth' % (opt.outf, epoch)) -------------------------------------------------------------------------------- /show3d_balls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ctypes as ct 3 | import cv2 4 | import sys 5 | 6 | showsz=800 7 | mousex,mousey=0.5,0.5 8 | zoom=1.0 9 | changed=True 10 | def onmouse(*args): 11 | global mousex,mousey,changed 12 | y=args[1] 13 | x=args[2] 14 | mousex=x/float(showsz) 15 | mousey=y/float(showsz) 16 | changed=True 17 | 18 | 19 | cv2.namedWindow('show3d') 20 | cv2.moveWindow('show3d',0,0) 21 | cv2.setMouseCallback('show3d',onmouse) 22 | 23 | dll=np.ctypeslib.load_library('render_balls_so','.') 24 | 25 | 26 | def showpoints(xyz,c_gt=None, c_pred = None ,waittime=0,showrot=False,magnifyBlue=0,freezerot=False,background=(0,0,0),normalizecolor=True,ballradius=10): 27 | global showsz,mousex,mousey,zoom,changed 28 | 29 | xyz=xyz-xyz.mean(axis=0) 30 | 31 | radius=((xyz**2).sum(axis=-1)**0.5).max() 32 | xyz/=(radius*2.2)/showsz 33 | if c_gt is None: 34 | c0=np.zeros((len(xyz),),dtype='float32')+255 35 | c1=np.zeros((len(xyz),),dtype='float32')+255 36 | c2=np.zeros((len(xyz),),dtype='float32')+255 37 | else: 38 | c0=c_gt[:,0] 39 | c1=c_gt[:,1] 40 | c2=c_gt[:,2] 41 | 42 | 43 | if normalizecolor: 44 | c0/=(c0.max()+1e-14)/255.0 45 | c1/=(c1.max()+1e-14)/255.0 46 | c2/=(c2.max()+1e-14)/255.0 47 | 48 | 49 | c0=np.require(c0,'float32','C') 50 | c1=np.require(c1,'float32','C') 51 | c2=np.require(c2,'float32','C') 52 | 53 | show=np.zeros((showsz,showsz,3),dtype='uint8') 54 | def render(): 55 | rotmat=np.eye(3) 56 | if not freezerot: 57 | xangle=(mousey-0.5)*np.pi*1.2 58 | else: 59 | xangle=0 60 | rotmat=rotmat.dot(np.array([ 61 | [1.0,0.0,0.0], 62 | [0.0,np.cos(xangle),-np.sin(xangle)], 63 | [0.0,np.sin(xangle),np.cos(xangle)], 64 | ])) 65 | if not freezerot: 66 | yangle=(mousex-0.5)*np.pi*1.2 67 | else: 68 | yangle=0 69 | rotmat=rotmat.dot(np.array([ 70 | [np.cos(yangle),0.0,-np.sin(yangle)], 71 | [0.0,1.0,0.0], 72 | [np.sin(yangle),0.0,np.cos(yangle)], 73 | ])) 74 | rotmat*=zoom 75 | nxyz=xyz.dot(rotmat)+[showsz/2,showsz/2,0] 76 | 77 | ixyz=nxyz.astype('int32') 78 | show[:]=background 79 | dll.render_ball( 80 | ct.c_int(show.shape[0]), 81 | ct.c_int(show.shape[1]), 82 | show.ctypes.data_as(ct.c_void_p), 83 | ct.c_int(ixyz.shape[0]), 84 | ixyz.ctypes.data_as(ct.c_void_p), 85 | c0.ctypes.data_as(ct.c_void_p), 86 | c1.ctypes.data_as(ct.c_void_p), 87 | c2.ctypes.data_as(ct.c_void_p), 88 | ct.c_int(ballradius) 89 | ) 90 | 91 | if magnifyBlue>0: 92 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=0)) 93 | if magnifyBlue>=2: 94 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=0)) 95 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=1)) 96 | if magnifyBlue>=2: 97 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=1)) 98 | if showrot: 99 | cv2.putText(show,'xangle %d'%(int(xangle/np.pi*180)),(30,showsz-30),0,0.5,cv2.cv.CV_RGB(255,0,0)) 100 | cv2.putText(show,'yangle %d'%(int(yangle/np.pi*180)),(30,showsz-50),0,0.5,cv2.cv.CV_RGB(255,0,0)) 101 | cv2.putText(show,'zoom %d%%'%(int(zoom*100)),(30,showsz-70),0,0.5,cv2.cv.CV_RGB(255,0,0)) 102 | changed=True 103 | while True: 104 | if changed: 105 | render() 106 | changed=False 107 | cv2.imshow('show3d',show) 108 | if waittime==0: 109 | cmd=cv2.waitKey(10)%256 110 | else: 111 | cmd=cv2.waitKey(waittime)%256 112 | if cmd==ord('q'): 113 | break 114 | elif cmd==ord('Q'): 115 | sys.exit(0) 116 | 117 | if cmd==ord('t') or cmd == ord('p'): 118 | if cmd == ord('t'): 119 | if c_gt is None: 120 | c0=np.zeros((len(xyz),),dtype='float32')+255 121 | c1=np.zeros((len(xyz),),dtype='float32')+255 122 | c2=np.zeros((len(xyz),),dtype='float32')+255 123 | else: 124 | c0=c_gt[:,0] 125 | c1=c_gt[:,1] 126 | c2=c_gt[:,2] 127 | else: 128 | if c_pred is None: 129 | c0=np.zeros((len(xyz),),dtype='float32')+255 130 | c1=np.zeros((len(xyz),),dtype='float32')+255 131 | c2=np.zeros((len(xyz),),dtype='float32')+255 132 | else: 133 | c0=c_pred[:,0] 134 | c1=c_pred[:,1] 135 | c2=c_pred[:,2] 136 | if normalizecolor: 137 | c0/=(c0.max()+1e-14)/255.0 138 | c1/=(c1.max()+1e-14)/255.0 139 | c2/=(c2.max()+1e-14)/255.0 140 | c0=np.require(c0,'float32','C') 141 | c1=np.require(c1,'float32','C') 142 | c2=np.require(c2,'float32','C') 143 | changed = True 144 | 145 | 146 | 147 | if cmd==ord('n'): 148 | zoom*=1.1 149 | changed=True 150 | elif cmd==ord('m'): 151 | zoom/=1.1 152 | changed=True 153 | elif cmd==ord('r'): 154 | zoom=1.0 155 | changed=True 156 | elif cmd==ord('s'): 157 | cv2.imwrite('show3d.png',show) 158 | if waittime!=0: 159 | break 160 | return cmd 161 | if __name__=='__main__': 162 | 163 | np.random.seed(100) 164 | showpoints(np.random.randn(2500,3)) 165 | 166 | -------------------------------------------------------------------------------- /train_FoldingNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.parallel 9 | import torch.backends.cudnn as cudnn 10 | import torch.optim as optim 11 | import torch.utils.data 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | import torchvision.utils as vutils 15 | from torch.autograd import Variable 16 | from datasets import PartDataset 17 | from pointnet import PointNetCls 18 | from pointnet import FoldingNet 19 | from pointnet import FoldingNet_1024 20 | from pointnet import ChamferLoss 21 | import torch.nn.functional as F 22 | from visdom import Visdom 23 | import time 24 | from mpl_toolkits.mplot3d import Axes3D 25 | import matplotlib.pyplot as plt 26 | 27 | vis = Visdom() 28 | line = vis.line(np.arange(10)) 29 | 30 | 31 | 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--batchSize', type=int, default=8, help='input batch size') 35 | parser.add_argument('--num_points', type=int, default=2500, help='input batch size') 36 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 37 | parser.add_argument('--nepoch', type=int, default=25, help='number of epochs to train for') 38 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 39 | parser.add_argument('--model', type=str, default = '', help='model path') 40 | 41 | opt = parser.parse_args() 42 | opt.nepoch = 200 # yw add 43 | print(opt) 44 | 45 | blue = lambda x:'\033[94m' + x + '\033[0m' 46 | 47 | opt.manualSeed = random.randint(1, 10000) # fix seed 48 | print("Random Seed: ", opt.manualSeed) 49 | random.seed(opt.manualSeed) 50 | torch.manual_seed(opt.manualSeed) 51 | 52 | dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, npoints = opt.num_points) 53 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, 54 | shuffle=True, num_workers=int(opt.workers)) 55 | 56 | test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True, train = False, npoints = opt.num_points) 57 | testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=opt.batchSize, 58 | shuffle=True, num_workers=int(opt.workers)) 59 | 60 | print(len(dataset), len(test_dataset)) 61 | num_classes = len(dataset.classes) 62 | print('classes', num_classes) 63 | 64 | try: 65 | os.makedirs(opt.outf) 66 | except OSError: 67 | pass 68 | 69 | 70 | #classifier = PointNetCls(k = num_classes) 71 | foldingnet = FoldingNet_1024() 72 | 73 | 74 | 75 | if opt.model != '': 76 | foldingnet.load_state_dict(torch.load(opt.model)) 77 | 78 | 79 | #optimizer = optim.SGD(foldingnet.parameters(), lr=0.01, momentum=0.9) 80 | optimizer = optim.Adam(foldingnet.parameters(),lr = 0.0001,weight_decay=1e-6) 81 | foldingnet.cuda() 82 | 83 | num_batch = len(dataset)/opt.batchSize 84 | 85 | chamferloss = ChamferLoss() 86 | chamferloss.cuda() 87 | 88 | start_time = time.time() 89 | time_p, loss_p, loss_m = [],[],[] 90 | 91 | for epoch in range(opt.nepoch): 92 | sum_loss = 0 93 | sum_step = 0 94 | sum_mid_loss = 0 95 | for i, data in enumerate(dataloader, 0): 96 | points, target = data 97 | 98 | #print(points.size()) 99 | 100 | points, target = Variable(points), Variable(target[:,0]) 101 | points = points.transpose(2,1) 102 | points, target = points.cuda(), target.cuda() 103 | optimizer.zero_grad() 104 | foldingnet = foldingnet.train() 105 | recon_pc, mid_pc, _ = foldingnet(points) 106 | 107 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 108 | loss.backward() 109 | optimizer.step() 110 | 111 | mid_loss = chamferloss(points.transpose(2,1),mid_pc.transpose(2,1)) 112 | 113 | # store loss and step 114 | sum_loss += loss.item()*points.size(0) 115 | sum_mid_loss += mid_loss.item()*points.size(0) 116 | sum_step += points.size(0) 117 | 118 | print('[%d: %d/%d] train loss: %f middle loss: %f' %(epoch, i, num_batch, loss.item(),mid_loss.item())) 119 | 120 | if i % 100 == 0: 121 | j, data = next(enumerate(testdataloader, 0)) 122 | points, target = data 123 | points, target = Variable(points), Variable(target[:,0]) 124 | points = points.transpose(2,1) 125 | points, target = points.cuda(), target.cuda() 126 | foldingnet = foldingnet.eval() 127 | recon_pc, mid_pc, _ = foldingnet(points) 128 | loss = chamferloss(points.transpose(2,1),recon_pc.transpose(2,1)) 129 | 130 | mid_loss = chamferloss(points.transpose(2,1),mid_pc.transpose(2,1)) 131 | 132 | # prepare show result 133 | points_show = points.cpu().detach().numpy() 134 | #points_show = points_show[0] 135 | re_show = recon_pc.cpu().detach().numpy() 136 | #re_show = re_show[0] 137 | 138 | 139 | fig_ori = plt.figure() 140 | a1 = fig_ori.add_subplot(111,projection='3d') 141 | a1.scatter(points_show[0,0,:],points_show[0,1,:],points_show[0,2,:]) 142 | plt.savefig('points_show.png') 143 | 144 | fig_re = plt.figure() 145 | a2 = fig_re.add_subplot(111,projection='3d') 146 | a2.scatter(re_show[0,0,:],re_show[0,1,:],re_show[0,2,:]) 147 | plt.savefig('re_show.png') 148 | 149 | 150 | # plot results 151 | time_p.append(time.time()-start_time) 152 | loss_p.append(sum_loss/sum_step) 153 | loss_m.append(sum_mid_loss/sum_step) 154 | vis.line(X=np.array(time_p), 155 | Y=np.array(loss_p), 156 | win=line, 157 | opts=dict(legend=["Loss"])) 158 | 159 | 160 | 161 | print('[%d: %d/%d] %s test loss: %f middle test loss: %f' %(epoch, i, num_batch, blue('test'), loss.item(), mid_loss.item())) 162 | sum_step = 0 163 | sum_loss = 0 164 | sum_mid_loss = 0 165 | 166 | torch.save(foldingnet.state_dict(), '%s/foldingnet_model_%d.pth' % (opt.outf, epoch)) 167 | -------------------------------------------------------------------------------- /pointnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import torchvision.utils as vutils 13 | from torch.autograd import Variable 14 | from torch.autograd import Function 15 | from PIL import Image 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import pdb 19 | import torch.nn.functional as F 20 | 21 | 22 | class STN3d(nn.Module): 23 | def __init__(self): 24 | super(STN3d, self).__init__() 25 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 26 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 27 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 28 | self.fc1 = nn.Linear(1024, 512) 29 | self.fc2 = nn.Linear(512, 256) 30 | self.fc3 = nn.Linear(256, 9) 31 | self.relu = nn.ReLU() 32 | 33 | self.bn1 = nn.BatchNorm1d(64) 34 | self.bn2 = nn.BatchNorm1d(128) 35 | self.bn3 = nn.BatchNorm1d(1024) 36 | self.bn4 = nn.BatchNorm1d(512) 37 | self.bn5 = nn.BatchNorm1d(256) 38 | 39 | def forward(self, x): 40 | batchsize = x.size()[0] 41 | x = F.relu(self.bn1(self.conv1(x))) 42 | x = F.relu(self.bn2(self.conv2(x))) 43 | x = F.relu(self.bn3(self.conv3(x))) 44 | x = torch.max(x, 2, keepdim=True)[0] 45 | x = x.view(-1, 1024) 46 | 47 | x = F.relu(self.bn4(self.fc1(x))) 48 | x = F.relu(self.bn5(self.fc2(x))) 49 | x = self.fc3(x) 50 | 51 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 52 | batchsize, 1) 53 | if x.is_cuda: 54 | iden = iden.cuda() 55 | x = x + iden 56 | x = x.view(-1, 3, 3) 57 | return x 58 | 59 | 60 | class PointNetfeat(nn.Module): 61 | def __init__(self, global_feat=True): 62 | super(PointNetfeat, self).__init__() 63 | self.stn = STN3d() 64 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 65 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 66 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 67 | self.bn1 = nn.BatchNorm1d(64) 68 | self.bn2 = nn.BatchNorm1d(128) 69 | self.bn3 = nn.BatchNorm1d(1024) 70 | self.global_feat = global_feat 71 | 72 | def forward(self, x): 73 | batchsize = x.size()[0] 74 | n_pts = x.size()[2] 75 | trans = self.stn(x) 76 | x = x.transpose(2, 1) 77 | x = torch.bmm(x, trans) 78 | x = x.transpose(2, 1) 79 | x = F.relu(self.bn1(self.conv1(x))) 80 | pointfeat = x 81 | x = F.relu(self.bn2(self.conv2(x))) 82 | x = self.bn3(self.conv3(x)) # x = batch,1024,n(n=2048) 83 | x = torch.max(x, 2, keepdim=True)[0] # x = batch,1024,1 84 | x = x.view(-1, 1024) # x = batch,1024 85 | if self.global_feat: 86 | return x, trans 87 | else: 88 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 89 | return torch.cat([x, pointfeat], 1), trans 90 | 91 | 92 | class PointNetCls(nn.Module): 93 | def __init__(self, k=2): 94 | super(PointNetCls, self).__init__() 95 | self.feat = PointNetfeat(global_feat=True) 96 | self.fc1 = nn.Linear(1024, 512) 97 | self.fc2 = nn.Linear(512, 256) 98 | self.fc3 = nn.Linear(256, k) 99 | self.bn1 = nn.BatchNorm1d(512) 100 | self.bn2 = nn.BatchNorm1d(256) 101 | self.relu = nn.ReLU() 102 | 103 | def forward(self, x): 104 | x, trans = self.feat(x) 105 | x = F.relu(self.bn1(self.fc1(x))) 106 | x = F.relu(self.bn2(self.fc2(x))) 107 | x = self.fc3(x) 108 | return F.log_softmax(x, dim=0), trans 109 | 110 | 111 | # *************** YW test FoldingNet ************ 112 | class FoldingNetEnc(nn.Module): 113 | def __init__(self): 114 | super(FoldingNetEnc, self).__init__() 115 | self.feat = PointNetfeat(global_feat=True) 116 | self.fc1 = nn.Linear(1024, 512) 117 | self.fc2 = nn.Linear(512, 512) 118 | # self.fc2 = nn.Linear(512, 256) 119 | # self.fc3 = nn.Linear(256, k) 120 | self.bn1 = nn.BatchNorm1d(512) 121 | # self.bn2 = nn.BatchNorm1d(256) 122 | self.relu = nn.ReLU() 123 | 124 | def forward(self, x): 125 | x, trans = self.feat(x) # x = batch,1024 126 | x = F.relu(self.bn1(self.fc1(x))) # x = batch,512 127 | x = self.fc2(x) # x = batch,512 128 | 129 | return x, trans 130 | 131 | class FoldingNetEnc_1024(nn.Module): 132 | def __init__(self): 133 | super(FoldingNetEnc_1024, self).__init__() 134 | self.feat = PointNetfeat(global_feat=True) 135 | self.fc1 = nn.Linear(1024, 1024) 136 | self.bn1 = nn.BatchNorm1d(1024) 137 | self.fc2 = nn.Linear(1024, 1024) 138 | self.relu = nn.ReLU() 139 | 140 | def forward(self, x): 141 | x, trans = self.feat(x) # x = batch,1024 142 | x = F.relu(self.bn1(self.fc1(x))) # x = batch,1024 143 | x = self.fc2(x) # x = batch,1024 144 | 145 | return x, trans 146 | 147 | 148 | class FoldingNetDecFold1(nn.Module): 149 | def __init__(self): 150 | super(FoldingNetDecFold1, self).__init__() 151 | self.conv1 = nn.Conv1d(514, 512, 1) 152 | self.conv2 = nn.Conv1d(512, 512, 1) 153 | self.conv3 = nn.Conv1d(512, 3, 1) 154 | 155 | self.relu = nn.ReLU() 156 | 157 | def forward(self, x): # input x = batch,514,45^2 158 | x = self.relu(self.conv1(x)) # x = batch,512,45^2 159 | x = self.relu(self.conv2(x)) 160 | x = self.conv3(x) 161 | 162 | return x 163 | 164 | 165 | 166 | class FoldingNetDecFold1_1024(nn.Module): 167 | def __init__(self): 168 | super(FoldingNetDecFold1_1024, self).__init__() 169 | self.conv1 = nn.Conv1d(1026, 1024, 1) 170 | self.conv2 = nn.Conv1d(1024, 512, 1) 171 | self.conv3 = nn.Conv1d(512, 3, 1) 172 | 173 | self.relu = nn.ReLU() 174 | 175 | def forward(self, x): # input x = batch,1026,45^2 176 | x = self.relu(self.conv1(x)) # x = batch,1024,45^2 177 | x = self.relu(self.conv2(x)) 178 | x = self.conv3(x) 179 | 180 | return x 181 | 182 | class FoldingNetDecFold2(nn.Module): 183 | def __init__(self): 184 | super(FoldingNetDecFold2, self).__init__() 185 | self.conv1 = nn.Conv1d(515, 512, 1) 186 | self.conv2 = nn.Conv1d(512, 512, 1) 187 | self.conv3 = nn.Conv1d(512, 3, 1) 188 | self.relu = nn.ReLU() 189 | 190 | def forward(self, x): # input x = batch,515,45^2 191 | x = self.relu(self.conv1(x)) 192 | x = self.relu(self.conv2(x)) 193 | x = self.conv3(x) 194 | return x 195 | 196 | class FoldingNetDecFold2_1024(nn.Module): 197 | def __init__(self): 198 | super(FoldingNetDecFold2_1024, self).__init__() 199 | self.conv1 = nn.Conv1d(1027, 1024, 1) 200 | self.conv2 = nn.Conv1d(1024, 512, 1) 201 | self.conv3 = nn.Conv1d(512, 3, 1) 202 | self.relu = nn.ReLU() 203 | 204 | def forward(self, x): # input x = batch,1027,45^2 205 | x = self.relu(self.conv1(x)) 206 | x = self.relu(self.conv2(x)) 207 | x = self.conv3(x) 208 | return x 209 | 210 | def GridSamplingLayer(batch_size, meshgrid): 211 | ''' 212 | output Grid points as a NxD matrix 213 | 214 | params = { 215 | 'batch_size': 8 216 | 'meshgrid': [[-0.3,0.3,45],[-0.3,0.3,45]] 217 | } 218 | ''' 219 | 220 | ret = np.meshgrid(*[np.linspace(it[0], it[1], num=it[2]) for it in meshgrid]) 221 | ndim = len(meshgrid) 222 | grid = np.zeros((np.prod([it[2] for it in meshgrid]), ndim), dtype=np.float32) # MxD 223 | for d in range(ndim): 224 | grid[:, d] = np.reshape(ret[d], -1) 225 | g = np.repeat(grid[np.newaxis, ...], repeats=batch_size, axis=0) 226 | 227 | return g 228 | 229 | 230 | class FoldingNetDec(nn.Module): 231 | def __init__(self): 232 | super(FoldingNetDec, self).__init__() 233 | self.fold1 = FoldingNetDecFold1() 234 | self.fold2 = FoldingNetDecFold2() 235 | 236 | def forward(self, x): # input x = batch, 512 237 | batch_size = x.size(0) 238 | x = torch.unsqueeze(x, 1) # x = batch,1,512 239 | x = x.repeat(1, 45 ** 2, 1) # x = batch,45^2,512 240 | code = x 241 | code = x.transpose(2, 1) # x = batch,512,45^2 242 | 243 | meshgrid = [[-0.3, 0.3, 45], [-0.3, 0.3, 45]] 244 | grid = GridSamplingLayer(batch_size, meshgrid) # grid = batch,45^2,2 245 | grid = torch.from_numpy(grid) 246 | 247 | if x.is_cuda: 248 | grid = grid.cuda() 249 | 250 | x = torch.cat((x, grid), 2) # x = batch,45^2,514 251 | x = x.transpose(2, 1) # x = batch,514,45^2 252 | 253 | x = self.fold1(x) # x = batch,3,45^2 254 | p1 = x # to observe 255 | 256 | x = torch.cat((code, x), 1) # x = batch,515,45^2 257 | 258 | x = self.fold2(x) # x = batch,3,45^2 259 | 260 | return x, p1 261 | 262 | 263 | class FoldingNetDec_1024(nn.Module): 264 | def __init__(self): 265 | super(FoldingNetDec_1024, self).__init__() 266 | self.fold1 = FoldingNetDecFold1_1024() 267 | self.fold2 = FoldingNetDecFold2_1024() 268 | 269 | def forward(self, x): # input x = batch, 1024 270 | batch_size = x.size(0) 271 | x = torch.unsqueeze(x, 1) # x = batch,1,1024 272 | x = x.repeat(1, 45 ** 2, 1) # x = batch,45^2,1024 273 | code = x 274 | code = x.transpose(2, 1) # x = batch,1024,45^2 275 | 276 | meshgrid = [[-0.3, 0.3, 45], [-0.3, 0.3, 45]] 277 | grid = GridSamplingLayer(batch_size, meshgrid) # grid = batch,45^2,2 278 | grid = torch.from_numpy(grid) 279 | 280 | if x.is_cuda: 281 | grid = grid.cuda() 282 | 283 | x = torch.cat((x, grid), 2) # x = batch,45^2,1026 284 | x = x.transpose(2, 1) # x = batch,1026,45^2 285 | 286 | x = self.fold1(x) # x = batch,3,45^2 287 | p1 = x # to observe 288 | 289 | x = torch.cat((code, x), 1) # x = batch,1027,45^2 290 | 291 | x = self.fold2(x) # x = batch,3,45^2 292 | 293 | return x, p1 294 | 295 | class Quantization(Function): 296 | #def __init__(self): 297 | # super(Quantization, self).__init__() 298 | 299 | @staticmethod 300 | def forward(ctx, input): 301 | output = torch.round(input) 302 | return output 303 | 304 | @staticmethod 305 | def backward(ctx,grad_output): 306 | return grad_output 307 | 308 | class Quantization_module(nn.Module): 309 | def __init__(self): 310 | super().__init__() 311 | 312 | def forward(self, input): 313 | return Quantization.apply(input) 314 | 315 | class FoldingNet(nn.Module): 316 | def __init__(self): 317 | super(FoldingNet, self).__init__() 318 | self.encoder = FoldingNetEnc() 319 | self.decoder = FoldingNetDec() 320 | self.quan = Quantization_module() 321 | 322 | def forward(self, x): # input x = batch,3,number of points 323 | code, tran = self.encoder(x) # code = batch,512 324 | code = self.quan(code) # quantization 325 | 326 | '''if self.training == 0: # if now is evaluation, save code 327 | try: 328 | os.makedirs('bin') 329 | except OSError: 330 | pass 331 | code_save = code.cpu().detach() 332 | code_save = code_save.numpy() 333 | code_save = code_save.astype(int) 334 | np.savetxt('./bin/test.bin', code_save) 335 | ''' 336 | 337 | x, x_middle = self.decoder(code) # x = batch,3,45^2 338 | 339 | return x, x_middle,code 340 | 341 | class FoldingNet_1024(nn.Module): 342 | def __init__(self): 343 | super(FoldingNet_1024, self).__init__() 344 | self.encoder = FoldingNetEnc_1024() 345 | self.decoder = FoldingNetDec_1024() 346 | self.quan = Quantization_module() 347 | 348 | def forward(self, x): # input x = batch,3,number of points 349 | code, tran = self.encoder(x) # code = batch,512 350 | code = self.quan(code) # quantization 351 | 352 | '''if self.training == 0: # if now is evaluation, save code 353 | try: 354 | os.makedirs('bin') 355 | except OSError: 356 | pass 357 | code_save = code.cpu().detach() 358 | code_save = code_save.numpy() 359 | code_save = code_save.astype(int) 360 | np.savetxt('./bin/test.bin', code_save) 361 | ''' 362 | 363 | x, x_middle = self.decoder(code) # x = batch,3,45^2 364 | 365 | return x, x_middle,code 366 | 367 | 368 | def ChamferDistance(x, y): # for example, x = batch,2025,3 y = batch,2048,3 369 | # compute chamfer distance between tow point clouds x and y 370 | 371 | x_size = x.size() 372 | y_size = y.size() 373 | assert (x_size[0] == y_size[0]) 374 | assert (x_size[2] == y_size[2]) 375 | x = torch.unsqueeze(x, 1) # x = batch,1,2025,3 376 | y = torch.unsqueeze(y, 2) # y = batch,2048,1,3 377 | 378 | x = x.repeat(1, y_size[1], 1, 1) # x = batch,2048,2025,3 379 | y = y.repeat(1, 1, x_size[1], 1) # y = batch,2048,2025,3 380 | 381 | x_y = x - y 382 | x_y = torch.pow(x_y, 2) # x_y = batch,2048,2025,3 383 | x_y = torch.sum(x_y, 3, keepdim=True) # x_y = batch,2048,2025,1 384 | x_y = torch.squeeze(x_y, 3) # x_y = batch,2048,2025 385 | x_y_row, _ = torch.min(x_y, 1, keepdim=True) # x_y_row = batch,1,2025 386 | x_y_col, _ = torch.min(x_y, 2, keepdim=True) # x_y_col = batch,2048,1 387 | 388 | x_y_row = torch.mean(x_y_row, 2, keepdim=True) # x_y_row = batch,1,1 389 | x_y_col = torch.mean(x_y_col, 1, keepdim=True) # batch,1,1 390 | x_y_row_col = torch.cat((x_y_row, x_y_col), 2) # batch,1,2 391 | chamfer_distance, _ = torch.max(x_y_row_col, 2, keepdim=True) # batch,1,1 392 | # chamfer_distance = torch.reshape(chamfer_distance,(x_size[0],-1)) #batch,1 393 | # chamfer_distance = torch.squeeze(chamfer_distance,1) # batch 394 | chamfer_distance = torch.mean(chamfer_distance) 395 | return chamfer_distance 396 | 397 | 398 | class ChamferLoss(nn.Module): 399 | # chamfer distance loss 400 | def __init__(self): 401 | super(ChamferLoss, self).__init__() 402 | 403 | def forward(self, x, y): 404 | return ChamferDistance(x, y) 405 | 406 | 407 | 408 | 409 | 410 | class PointNetDenseCls(nn.Module): 411 | def __init__(self, k=2): 412 | super(PointNetDenseCls, self).__init__() 413 | self.k = k 414 | self.feat = PointNetfeat(global_feat=False) 415 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 416 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 417 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 418 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 419 | self.bn1 = nn.BatchNorm1d(512) 420 | self.bn2 = nn.BatchNorm1d(256) 421 | self.bn3 = nn.BatchNorm1d(128) 422 | 423 | def forward(self, x): 424 | batchsize = x.size()[0] 425 | n_pts = x.size()[2] 426 | x, trans = self.feat(x) 427 | x = F.relu(self.bn1(self.conv1(x))) 428 | x = F.relu(self.bn2(self.conv2(x))) 429 | x = F.relu(self.bn3(self.conv3(x))) 430 | x = self.conv4(x) 431 | x = x.transpose(2, 1).contiguous() 432 | x = F.log_softmax(x.view(-1, self.k), dim=-1) 433 | x = x.view(batchsize, n_pts, self.k) 434 | return x, trans 435 | 436 | 437 | if __name__ == '__main__': 438 | sim_data = Variable(torch.rand(32, 3, 2500)) 439 | trans = STN3d() 440 | out = trans(sim_data) 441 | print('stn', out.size()) 442 | 443 | pointfeat = PointNetfeat(global_feat=True) 444 | out, _ = pointfeat(sim_data) 445 | print('global feat', out.size()) 446 | 447 | pointfeat = PointNetfeat(global_feat=False) 448 | out, _ = pointfeat(sim_data) 449 | print('point feat', out.size()) 450 | 451 | cls = PointNetCls(k=5) 452 | out, _ = cls(sim_data) 453 | print('class', out.size()) 454 | 455 | seg = PointNetDenseCls(k=3) 456 | out, _ = seg(sim_data) 457 | print('seg', out.size()) 458 | 459 | # YW test 460 | 461 | ''' 462 | sim_data = torch.rand(32,515,45*45) 463 | print('sim_data ',sim_data.size()) 464 | 465 | fold2 = FoldingNetDecFold2() 466 | out = fold2(sim_data) 467 | print('fold2 ',out.size()) 468 | 469 | meshgrid = [[-0.3,0.3,45],[-0.3,0.3,45]] 470 | out = GridSamplingLayer(3,meshgrid) 471 | print('meshgrid',out.shape) 472 | 473 | sim_data = torch.rand(32,512) 474 | sim_data.cuda() 475 | dec = FoldingNetDec() 476 | if sim_data.is_cuda: 477 | dec.cuda() 478 | out,out2 = dec(sim_data) 479 | print('dec',out.size()) 480 | print('fold1 result',out2.size()) 481 | 482 | 483 | sim_data = torch.rand(32,3,2500) 484 | 485 | enc = FoldingNetEnc() 486 | out, _ = enc(sim_data) 487 | print(out.size()) 488 | 489 | 490 | 491 | foldnet = FoldingNet() 492 | foldnet.cuda() 493 | 494 | out , out2 = foldnet(sim_data) 495 | print('reconsructed point cloud', out.size()) 496 | print('middle result',out2.size()) 497 | ''' 498 | 499 | sim_data = torch.rand(32,3,2500) 500 | print('sim_data', sim_data.size()) 501 | enc_1024 = FoldingNetEnc_1024() 502 | out, _ = enc_1024(sim_data) 503 | print('enc_1024', out.size()) 504 | 505 | sim_data = torch.rand(32,1024) 506 | dec_1024 = FoldingNetDec_1024() 507 | out, out2 = dec_1024(sim_data) 508 | print('dec ',out.size(),out2.size()) 509 | 510 | sim_data = torch.rand(32,3,2500) 511 | fold = FoldingNet_1024() 512 | out, out2, code = fold(sim_data) 513 | print('fold',out.size(),out2.size(),code.size()) 514 | 515 | 516 | 517 | x = torch.rand(16, 2048, 3) 518 | y = x 519 | # y = torch.rand(16,2025,3) 520 | 521 | cs = ChamferDistance(x, y) 522 | print('chamfer distance', cs) 523 | 524 | closs = ChamferLoss() 525 | print('chamfer loss', closs(x, y)) 526 | --------------------------------------------------------------------------------