├── .gitignore ├── LICENSE ├── README.md ├── build.sh ├── datasets.py ├── pointnet.py ├── render_balls_so.cpp ├── requirements.txt ├── show3d_balls.py ├── show_seg.py ├── show_seg_s3d.py ├── train_cls.py ├── train_seg.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./data/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Yunxiao Shi 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | © 2020 GitHub, Inc. 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## PointNet-PyTorch 2 | 3 | [![Python 3.6+](https://img.shields.io/badge/Python-3.6%2B-blue)](https://www.python.org/) 4 | [![MIT License](https://img.shields.io/badge/MIT-License-brightgreen)](./LICENSE) 5 | 6 | This is a PyTorch implementation of [PointNet (CVPR 2017)](https://arxiv.org/abs/1612.00593 "PointNet"), with comprehensive experiments. 7 | 8 | ## Installation 9 | 10 | It is recommended to use [conda](https://docs.conda.io/en/latest/) to manage your env. For example do 11 | ``` 12 | conda create -n pointnet python=3.6 13 | conda activate pointnet 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | You may also need to install [PyMesh](https://github.com/PyMesh/PyMesh "PyMesh"). See [here](https://github.com/PyMesh/PyMesh#Build) for instructions to install. 18 | 19 | ## Usage 20 | 21 | This code implements object classification on ModelNet, shape part segmentation on ShapeNet and indoor scene semantic segmentation on the Stanford 3D dataset. 22 | 23 | For the missing ```s3d_cat2num.txt``` when training on S3DIS, follow [#3](https://github.com/kentsyx/pointnet-pytorch/issues/3#issuecomment-643061963) to generate it once you have the dataset downloaded. 24 | 25 | ### ModelNet Classification 26 | 27 | Download the ModelNet10 dataset from [here](http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip) or the ModelNet40 dataset from [here](https://lmb.informatik.uni-freiburg.de/resources/datasets/ORION/modelnet40_manually_aligned.tar). Unzip and run 28 | ``` 29 | python train_cls.py -dset modelnet40 -r modelnet_root_dir -np number_of_points_to_sample 30 | ``` 31 | 32 | ### ShapeNet Part Segmentation 33 | 34 | Download the ShapeNet dataset from [here](https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip). Unzip and run 35 | ``` 36 | python train_seg.py -dset shapenet16 -r shapenet_root_dir -np number_of_points_to_sample 37 | ``` 38 | 39 | ### Indoor Scene Semantic Segmentation 40 | 41 | Download the S3DIS dataset from [here](http://buildingparser.stanford.edu/dataset.html#Download) (you need to submit a request). Unzip and do 42 | ``` 43 | cd Stanford3dDataset_v1.2 44 | mkdir train test 45 | mv Area_1 Area_2 Area_3 Area_4 Area_6 train 46 | mv Area_5 test 47 | ``` 48 | to create train/test split. Then set ```gen_labels=True``` in the class ```S3dDataset``` in datasets.py and do 49 | ``` 50 | python datasets.py 51 | ``` 52 | to generate labels for the train and test set respectively. __After that always set ```gen_labels=False```__. With labels generated do 53 | ``` 54 | python train_seg.py -dset s3dis -r s3dis_root_dir -np number_of_points_to_sample 55 | ``` 56 | to start training. 57 | 58 | ## Visualization 59 | 60 | First do ```sh build.sh```, then use ```show_seg.py``` to visualize segmented object parts. Below are some example results. 61 | 62 |

63 | 65 | 66 | For S3DIS, you have to combine scene components along with their labels into one text file (```cat``` and ```paste``` seems to be an easy way to do this) and then pass it to ```show_seg_s3dis.py```. Below are some example results (removed some clutter classes for better visualization). 67 | 68 |

69 | 71 | 72 | ## Results 73 | 74 | Certain design choices in the original paper are not implemented here for simplicity. There is some performance gap on ModelNet classification, for ShapeNet and S3DIS seems to be on par with the original paper. 75 | 76 |

77 | 78 | | | accuracy | class avg IoU 79 | | :------: | :------: | :------: | 80 | | ModelNet10 | 87.2% | - | 81 | | ModelNet40 | 85.4% | - | 82 | | ShapeNet | - | 82.9% | 83 | | S3DIS | 72.1% | 50.6% | 84 | 85 |
86 | 87 | ## Acknowledgements 88 | 89 | [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) (many thanks) 90 | 91 | [original tensorflow implementation](https://github.com/charlesq34/pointnet) 92 | 93 | ## LICENSE 94 | 95 | MIT 96 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 2 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | import pymesh 9 | from tqdm import tqdm 10 | from utils import shapenet_labels 11 | 12 | def scale_linear_bycolumn(rawdata, high=1.0, low=0.0): 13 | mins = np.min(rawdata, axis=0) 14 | maxs = np.max(rawdata, axis=0) 15 | rng = maxs - mins 16 | return high - (high-low)*(maxs-rawdata)/(rng+np.finfo(np.float32).eps) 17 | 18 | 19 | class ClsDataset(data.Dataset): 20 | '''Object classification on ModelNet''' 21 | def __init__(self, root, npoints=1024, train=True): 22 | self.root = root 23 | self.npoints = npoints 24 | self.catfile = os.path.join(self.root, 'modelnet_cat2num.txt') 25 | self.cat = {} 26 | 27 | with open(self.catfile, 'r') as f: 28 | for line in f.readlines(): 29 | lns = line.strip().split() 30 | self.cat[lns[0]] = lns[1] 31 | self.num_classes = len(self.cat) 32 | self.datapath = [] 33 | FLAG = 'train' if train else 'test' 34 | for item in os.listdir(self.root): 35 | if os.path.isdir(os.path.join(self.root, item)): 36 | for f in os.listdir(os.path.join(self.root, item, FLAG)): 37 | if f.endswith('.off'): 38 | self.datapath.append((os.path.join(self.root, item, FLAG, f), int(self.cat[item]))) 39 | 40 | 41 | def __getitem__(self, idx): 42 | fn = self.datapath[idx] 43 | points = pymesh.load_mesh(fn[0]).vertices 44 | label = fn[1] 45 | replace = True if points.shape[0] 2 | #include 3 | #include 4 | #include 5 | using namespace std; 6 | 7 | struct PointInfo{ 8 | int x,y,z; 9 | float r,g,b; 10 | }; 11 | 12 | extern "C"{ 13 | 14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){ 15 | r=max(r,1); 16 | vector depth(h*w,-2100000000); 17 | vector pattern; 18 | for (int dx=-r;dx<=r;dx++) 19 | for (int dy=-r;dy<=r;dy++) 20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2]=0.4.1 2 | torchvision 3 | numpy 4 | tqdm 5 | opencv-python 6 | lera # to monitor training 7 | 8 | -------------------------------------------------------------------------------- /show3d_balls.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ctypes as ct 3 | import cv2 4 | import sys 5 | showsz=800 6 | mousex,mousey=0.5,0.5 7 | zoom=1.0 8 | changed=True 9 | def onmouse(*args): 10 | global mousex,mousey,changed 11 | y=args[1] 12 | x=args[2] 13 | mousex=x/float(showsz) 14 | mousey=y/float(showsz) 15 | changed=True 16 | cv2.namedWindow('show3d') 17 | cv2.moveWindow('show3d',0,0) 18 | cv2.setMouseCallback('show3d',onmouse) 19 | 20 | dll=np.ctypeslib.load_library('render_balls_so','.') 21 | 22 | def showpoints(xyz,c_gt=None, c_pred = None ,waittime=0,showrot=False,magnifyBlue=0,freezerot=False,background=(0, 0, 0),normalizecolor=True,ballradius=2): 23 | global showsz,mousex,mousey,zoom,changed 24 | xyz=xyz-xyz.mean(axis=0) 25 | radius=((xyz**2).sum(axis=-1)**0.5).max() 26 | xyz/=(radius*2.2)/showsz 27 | if c_gt is None: 28 | c0=np.zeros((len(xyz),),dtype='float32')+255 29 | c1=np.zeros((len(xyz),),dtype='float32')+255 30 | c2=np.zeros((len(xyz),),dtype='float32')+255 31 | else: 32 | c0=c_gt[:,0] 33 | c1=c_gt[:,1] 34 | c2=c_gt[:,2] 35 | 36 | 37 | if normalizecolor: 38 | c0/=(c0.max()+1e-14)/255.0 39 | c1/=(c1.max()+1e-14)/255.0 40 | c2/=(c2.max()+1e-14)/255.0 41 | 42 | 43 | c0=np.require(c0,'float32','C') 44 | c1=np.require(c1,'float32','C') 45 | c2=np.require(c2,'float32','C') 46 | 47 | show=np.zeros((showsz,showsz,3),dtype='uint8') 48 | def render(): 49 | rotmat=np.eye(3) 50 | if not freezerot: 51 | xangle=(mousey-0.5)*np.pi*1.2 52 | else: 53 | xangle=0 54 | rotmat=rotmat.dot(np.array([ 55 | [1.0,0.0,0.0], 56 | [0.0,np.cos(xangle),-np.sin(xangle)], 57 | [0.0,np.sin(xangle),np.cos(xangle)], 58 | ])) 59 | if not freezerot: 60 | yangle=(mousex-0.5)*np.pi*1.2 61 | else: 62 | yangle=0 63 | rotmat=rotmat.dot(np.array([ 64 | [np.cos(yangle),0.0,-np.sin(yangle)], 65 | [0.0,1.0,0.0], 66 | [np.sin(yangle),0.0,np.cos(yangle)], 67 | ])) 68 | rotmat*=zoom 69 | nxyz=xyz.dot(rotmat)+[showsz/2,showsz/2,0] 70 | 71 | ixyz=nxyz.astype('int32') 72 | show[:]=background 73 | dll.render_ball( 74 | ct.c_int(show.shape[0]), 75 | ct.c_int(show.shape[1]), 76 | show.ctypes.data_as(ct.c_void_p), 77 | ct.c_int(ixyz.shape[0]), 78 | ixyz.ctypes.data_as(ct.c_void_p), 79 | c0.ctypes.data_as(ct.c_void_p), 80 | c1.ctypes.data_as(ct.c_void_p), 81 | c2.ctypes.data_as(ct.c_void_p), 82 | ct.c_int(ballradius) 83 | ) 84 | 85 | if magnifyBlue>0: 86 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=0)) 87 | if magnifyBlue>=2: 88 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=0)) 89 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=1)) 90 | if magnifyBlue>=2: 91 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=1)) 92 | if showrot: 93 | cv2.putText(show,'xangle %d'%(int(xangle/np.pi*180)),(30,showsz-30),0,0.5,cv2.cv.CV_RGB(255,0,0)) 94 | cv2.putText(show,'yangle %d'%(int(yangle/np.pi*180)),(30,showsz-50),0,0.5,cv2.cv.CV_RGB(255,0,0)) 95 | cv2.putText(show,'zoom %d%%'%(int(zoom*100)),(30,showsz-70),0,0.5,cv2.cv.CV_RGB(255,0,0)) 96 | changed=True 97 | while True: 98 | if changed: 99 | render() 100 | changed=False 101 | cv2.imshow('show3d',show) 102 | if waittime==0: 103 | cmd=cv2.waitKey(10)%256 104 | else: 105 | cmd=cv2.waitKey(waittime)%256 106 | if cmd==ord('q'): 107 | break 108 | elif cmd==ord('Q'): 109 | sys.exit(0) 110 | 111 | if cmd==ord('t') or cmd == ord('p'): 112 | if cmd == ord('t'): 113 | if c_gt is None: 114 | c0=np.zeros((len(xyz),),dtype='float32')+255 115 | c1=np.zeros((len(xyz),),dtype='float32')+255 116 | c2=np.zeros((len(xyz),),dtype='float32')+255 117 | else: 118 | c0=c_gt[:,0] 119 | c1=c_gt[:,1] 120 | c2=c_gt[:,2] 121 | else: 122 | if c_pred is None: 123 | c0=np.zeros((len(xyz),),dtype='float32')+255 124 | c1=np.zeros((len(xyz),),dtype='float32')+255 125 | c2=np.zeros((len(xyz),),dtype='float32')+255 126 | else: 127 | c0=c_pred[:,0] 128 | c1=c_pred[:,1] 129 | c2=c_pred[:,2] 130 | if normalizecolor: 131 | c0/=(c0.max()+1e-14)/255.0 132 | c1/=(c1.max()+1e-14)/255.0 133 | c2/=(c2.max()+1e-14)/255.0 134 | c0=np.require(c0,'float32','C') 135 | c1=np.require(c1,'float32','C') 136 | c2=np.require(c2,'float32','C') 137 | changed = True 138 | 139 | 140 | 141 | if cmd==ord('n'): 142 | zoom*=1.1 143 | changed=True 144 | elif cmd==ord('m'): 145 | zoom/=1.1 146 | changed=True 147 | elif cmd==ord('r'): 148 | zoom=1.0 149 | changed=True 150 | elif cmd==ord('s'): 151 | cv2.imwrite('show3d.png',show) 152 | if waittime!=0: 153 | break 154 | return cmd 155 | if __name__=='__main__': 156 | np.random.seed(100) 157 | showpoints(np.random.randn(2500,3)) 158 | 159 | -------------------------------------------------------------------------------- /show_seg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import torch.nn as nn 11 | 12 | from datasets import PartDataset 13 | from pointnet import PointNetSeg 14 | from utils import shapenet_labels 15 | from show3d_balls import * 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument('--model', type=str, default='', help='model path') 20 | parser.add_argument('--idx', type=int, default=0, help='model index') 21 | parser.add_argument('--className', type=str, default='Chair', help='number of classes') 22 | parser.add_argument('--radius', type=int, default=2, help='radius of ball for visualization') 23 | parser.add_argument('--cmap', type=str, default='hsv', help='valid matplotlib cmap') 24 | parser.add_argument('--npoints', type=int, default=2500, help='points to sample') 25 | 26 | opt = parser.parse_args() 27 | 28 | idx = opt.idx 29 | 30 | d = PartDataset(root='shapenetcore_partanno_segmentation_benchmark_v0', class_choice=[opt.className], train=False, npoints=2048) 31 | 32 | print('model %d/%d' % (idx, len(d))) 33 | 34 | num_class = d.num_classes 35 | print('number of classes', num_class) 36 | 37 | point, seg = d[idx] 38 | 39 | point_np = point.numpy() 40 | 41 | cmap = plt.cm.get_cmap(opt.cmap, 10) 42 | cmap = np.array([cmap(i) for i in range(10)])[:, :3] 43 | gt = cmap[seg.numpy()-1, :] 44 | 45 | classifier = PointNetSeg(k=shapenet_labels[opt.className]) 46 | classifier.load_state_dict(torch.load(opt.model)) 47 | classifier.eval() 48 | 49 | point = point.transpose(1, 0).contiguous() 50 | point = point.view(1, point.size()[0], point.size()[1]) 51 | 52 | 53 | pred, _ = classifier(point) 54 | 55 | pred_choice = pred.data.max(2)[1] 56 | 57 | pred_color = cmap[pred_choice.numpy()[0], :] 58 | 59 | showpoints(point_np, gt, gt, ballradius=opt.radius) 60 | 61 | 62 | -------------------------------------------------------------------------------- /show_seg_s3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import torch.nn as nn 11 | 12 | from pointnet import PointNetSeg 13 | from datasets import S3dDataset 14 | from datasets import scale_linear_bycolumn 15 | 16 | from show3d_balls import * 17 | 18 | 19 | def parse_whole_scene(scene_path, scene_num, npoints=4096): 20 | scene = np.loadtxt(scene_path).astype(np.float32) 21 | seg = scene[:, -1].astype(np.int64) 22 | scene = scene[:, :3] 23 | replace = False if (scene_num*npoints