├── .gitignore
├── LICENSE
├── README.md
├── build.sh
├── datasets.py
├── pointnet.py
├── render_balls_so.cpp
├── requirements.txt
├── show3d_balls.py
├── show_seg.py
├── show_seg_s3d.py
├── train_cls.py
├── train_seg.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ./data/
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT) Copyright (c) 2020 Yunxiao Shi
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8 | © 2020 GitHub, Inc.
9 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## PointNet-PyTorch
2 |
3 | [](https://www.python.org/)
4 | [](./LICENSE)
5 |
6 | This is a PyTorch implementation of [PointNet (CVPR 2017)](https://arxiv.org/abs/1612.00593 "PointNet"), with comprehensive experiments.
7 |
8 | ## Installation
9 |
10 | It is recommended to use [conda](https://docs.conda.io/en/latest/) to manage your env. For example do
11 | ```
12 | conda create -n pointnet python=3.6
13 | conda activate pointnet
14 | pip install -r requirements.txt
15 | ```
16 |
17 | You may also need to install [PyMesh](https://github.com/PyMesh/PyMesh "PyMesh"). See [here](https://github.com/PyMesh/PyMesh#Build) for instructions to install.
18 |
19 | ## Usage
20 |
21 | This code implements object classification on ModelNet, shape part segmentation on ShapeNet and indoor scene semantic segmentation on the Stanford 3D dataset.
22 |
23 | For the missing ```s3d_cat2num.txt``` when training on S3DIS, follow [#3](https://github.com/kentsyx/pointnet-pytorch/issues/3#issuecomment-643061963) to generate it once you have the dataset downloaded.
24 |
25 | ### ModelNet Classification
26 |
27 | Download the ModelNet10 dataset from [here](http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip) or the ModelNet40 dataset from [here](https://lmb.informatik.uni-freiburg.de/resources/datasets/ORION/modelnet40_manually_aligned.tar). Unzip and run
28 | ```
29 | python train_cls.py -dset modelnet40 -r modelnet_root_dir -np number_of_points_to_sample
30 | ```
31 |
32 | ### ShapeNet Part Segmentation
33 |
34 | Download the ShapeNet dataset from [here](https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip). Unzip and run
35 | ```
36 | python train_seg.py -dset shapenet16 -r shapenet_root_dir -np number_of_points_to_sample
37 | ```
38 |
39 | ### Indoor Scene Semantic Segmentation
40 |
41 | Download the S3DIS dataset from [here](http://buildingparser.stanford.edu/dataset.html#Download) (you need to submit a request). Unzip and do
42 | ```
43 | cd Stanford3dDataset_v1.2
44 | mkdir train test
45 | mv Area_1 Area_2 Area_3 Area_4 Area_6 train
46 | mv Area_5 test
47 | ```
48 | to create train/test split. Then set ```gen_labels=True``` in the class ```S3dDataset``` in datasets.py and do
49 | ```
50 | python datasets.py
51 | ```
52 | to generate labels for the train and test set respectively. __After that always set ```gen_labels=False```__. With labels generated do
53 | ```
54 | python train_seg.py -dset s3dis -r s3dis_root_dir -np number_of_points_to_sample
55 | ```
56 | to start training.
57 |
58 | ## Visualization
59 |
60 | First do ```sh build.sh```, then use ```show_seg.py``` to visualize segmented object parts. Below are some example results.
61 |
62 |
63 |
65 |
66 | For S3DIS, you have to combine scene components along with their labels into one text file (```cat``` and ```paste``` seems to be an easy way to do this) and then pass it to ```show_seg_s3dis.py```. Below are some example results (removed some clutter classes for better visualization).
67 |
68 |
69 |
71 |
72 | ## Results
73 |
74 | Certain design choices in the original paper are not implemented here for simplicity. There is some performance gap on ModelNet classification, for ShapeNet and S3DIS seems to be on par with the original paper.
75 |
76 |
77 |
78 | | | accuracy | class avg IoU
79 | | :------: | :------: | :------: |
80 | | ModelNet10 | 87.2% | - |
81 | | ModelNet40 | 85.4% | - |
82 | | ShapeNet | - | 82.9% |
83 | | S3DIS | 72.1% | 50.6% |
84 |
85 |
86 |
87 | ## Acknowledgements
88 |
89 | [pointnet.pytorch](https://github.com/fxia22/pointnet.pytorch) (many thanks)
90 |
91 | [original tensorflow implementation](https://github.com/charlesq34/pointnet)
92 |
93 | ## LICENSE
94 |
95 | MIT
96 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
1 | g++ -std=c++11 render_balls_so.cpp -o render_balls_so.so -shared -fPIC -O2 -D_GLIBCXX_USE_CXX11_ABI=0
2 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import random
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as data
8 | import pymesh
9 | from tqdm import tqdm
10 | from utils import shapenet_labels
11 |
12 | def scale_linear_bycolumn(rawdata, high=1.0, low=0.0):
13 | mins = np.min(rawdata, axis=0)
14 | maxs = np.max(rawdata, axis=0)
15 | rng = maxs - mins
16 | return high - (high-low)*(maxs-rawdata)/(rng+np.finfo(np.float32).eps)
17 |
18 |
19 | class ClsDataset(data.Dataset):
20 | '''Object classification on ModelNet'''
21 | def __init__(self, root, npoints=1024, train=True):
22 | self.root = root
23 | self.npoints = npoints
24 | self.catfile = os.path.join(self.root, 'modelnet_cat2num.txt')
25 | self.cat = {}
26 |
27 | with open(self.catfile, 'r') as f:
28 | for line in f.readlines():
29 | lns = line.strip().split()
30 | self.cat[lns[0]] = lns[1]
31 | self.num_classes = len(self.cat)
32 | self.datapath = []
33 | FLAG = 'train' if train else 'test'
34 | for item in os.listdir(self.root):
35 | if os.path.isdir(os.path.join(self.root, item)):
36 | for f in os.listdir(os.path.join(self.root, item, FLAG)):
37 | if f.endswith('.off'):
38 | self.datapath.append((os.path.join(self.root, item, FLAG, f), int(self.cat[item])))
39 |
40 |
41 | def __getitem__(self, idx):
42 | fn = self.datapath[idx]
43 | points = pymesh.load_mesh(fn[0]).vertices
44 | label = fn[1]
45 | replace = True if points.shape[0]
2 | #include
3 | #include
4 | #include
5 | using namespace std;
6 |
7 | struct PointInfo{
8 | int x,y,z;
9 | float r,g,b;
10 | };
11 |
12 | extern "C"{
13 |
14 | void render_ball(int h,int w,unsigned char * show,int n,int * xyzs,float * c0,float * c1,float * c2,int r){
15 | r=max(r,1);
16 | vector depth(h*w,-2100000000);
17 | vector pattern;
18 | for (int dx=-r;dx<=r;dx++)
19 | for (int dy=-r;dy<=r;dy++)
20 | if (dx*dx+dy*dy=h || y2<0 || y2>=w) && depth[x2*w+y2]=0.4.1
2 | torchvision
3 | numpy
4 | tqdm
5 | opencv-python
6 | lera # to monitor training
7 |
8 |
--------------------------------------------------------------------------------
/show3d_balls.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import ctypes as ct
3 | import cv2
4 | import sys
5 | showsz=800
6 | mousex,mousey=0.5,0.5
7 | zoom=1.0
8 | changed=True
9 | def onmouse(*args):
10 | global mousex,mousey,changed
11 | y=args[1]
12 | x=args[2]
13 | mousex=x/float(showsz)
14 | mousey=y/float(showsz)
15 | changed=True
16 | cv2.namedWindow('show3d')
17 | cv2.moveWindow('show3d',0,0)
18 | cv2.setMouseCallback('show3d',onmouse)
19 |
20 | dll=np.ctypeslib.load_library('render_balls_so','.')
21 |
22 | def showpoints(xyz,c_gt=None, c_pred = None ,waittime=0,showrot=False,magnifyBlue=0,freezerot=False,background=(0, 0, 0),normalizecolor=True,ballradius=2):
23 | global showsz,mousex,mousey,zoom,changed
24 | xyz=xyz-xyz.mean(axis=0)
25 | radius=((xyz**2).sum(axis=-1)**0.5).max()
26 | xyz/=(radius*2.2)/showsz
27 | if c_gt is None:
28 | c0=np.zeros((len(xyz),),dtype='float32')+255
29 | c1=np.zeros((len(xyz),),dtype='float32')+255
30 | c2=np.zeros((len(xyz),),dtype='float32')+255
31 | else:
32 | c0=c_gt[:,0]
33 | c1=c_gt[:,1]
34 | c2=c_gt[:,2]
35 |
36 |
37 | if normalizecolor:
38 | c0/=(c0.max()+1e-14)/255.0
39 | c1/=(c1.max()+1e-14)/255.0
40 | c2/=(c2.max()+1e-14)/255.0
41 |
42 |
43 | c0=np.require(c0,'float32','C')
44 | c1=np.require(c1,'float32','C')
45 | c2=np.require(c2,'float32','C')
46 |
47 | show=np.zeros((showsz,showsz,3),dtype='uint8')
48 | def render():
49 | rotmat=np.eye(3)
50 | if not freezerot:
51 | xangle=(mousey-0.5)*np.pi*1.2
52 | else:
53 | xangle=0
54 | rotmat=rotmat.dot(np.array([
55 | [1.0,0.0,0.0],
56 | [0.0,np.cos(xangle),-np.sin(xangle)],
57 | [0.0,np.sin(xangle),np.cos(xangle)],
58 | ]))
59 | if not freezerot:
60 | yangle=(mousex-0.5)*np.pi*1.2
61 | else:
62 | yangle=0
63 | rotmat=rotmat.dot(np.array([
64 | [np.cos(yangle),0.0,-np.sin(yangle)],
65 | [0.0,1.0,0.0],
66 | [np.sin(yangle),0.0,np.cos(yangle)],
67 | ]))
68 | rotmat*=zoom
69 | nxyz=xyz.dot(rotmat)+[showsz/2,showsz/2,0]
70 |
71 | ixyz=nxyz.astype('int32')
72 | show[:]=background
73 | dll.render_ball(
74 | ct.c_int(show.shape[0]),
75 | ct.c_int(show.shape[1]),
76 | show.ctypes.data_as(ct.c_void_p),
77 | ct.c_int(ixyz.shape[0]),
78 | ixyz.ctypes.data_as(ct.c_void_p),
79 | c0.ctypes.data_as(ct.c_void_p),
80 | c1.ctypes.data_as(ct.c_void_p),
81 | c2.ctypes.data_as(ct.c_void_p),
82 | ct.c_int(ballradius)
83 | )
84 |
85 | if magnifyBlue>0:
86 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=0))
87 | if magnifyBlue>=2:
88 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=0))
89 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=1))
90 | if magnifyBlue>=2:
91 | show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=1))
92 | if showrot:
93 | cv2.putText(show,'xangle %d'%(int(xangle/np.pi*180)),(30,showsz-30),0,0.5,cv2.cv.CV_RGB(255,0,0))
94 | cv2.putText(show,'yangle %d'%(int(yangle/np.pi*180)),(30,showsz-50),0,0.5,cv2.cv.CV_RGB(255,0,0))
95 | cv2.putText(show,'zoom %d%%'%(int(zoom*100)),(30,showsz-70),0,0.5,cv2.cv.CV_RGB(255,0,0))
96 | changed=True
97 | while True:
98 | if changed:
99 | render()
100 | changed=False
101 | cv2.imshow('show3d',show)
102 | if waittime==0:
103 | cmd=cv2.waitKey(10)%256
104 | else:
105 | cmd=cv2.waitKey(waittime)%256
106 | if cmd==ord('q'):
107 | break
108 | elif cmd==ord('Q'):
109 | sys.exit(0)
110 |
111 | if cmd==ord('t') or cmd == ord('p'):
112 | if cmd == ord('t'):
113 | if c_gt is None:
114 | c0=np.zeros((len(xyz),),dtype='float32')+255
115 | c1=np.zeros((len(xyz),),dtype='float32')+255
116 | c2=np.zeros((len(xyz),),dtype='float32')+255
117 | else:
118 | c0=c_gt[:,0]
119 | c1=c_gt[:,1]
120 | c2=c_gt[:,2]
121 | else:
122 | if c_pred is None:
123 | c0=np.zeros((len(xyz),),dtype='float32')+255
124 | c1=np.zeros((len(xyz),),dtype='float32')+255
125 | c2=np.zeros((len(xyz),),dtype='float32')+255
126 | else:
127 | c0=c_pred[:,0]
128 | c1=c_pred[:,1]
129 | c2=c_pred[:,2]
130 | if normalizecolor:
131 | c0/=(c0.max()+1e-14)/255.0
132 | c1/=(c1.max()+1e-14)/255.0
133 | c2/=(c2.max()+1e-14)/255.0
134 | c0=np.require(c0,'float32','C')
135 | c1=np.require(c1,'float32','C')
136 | c2=np.require(c2,'float32','C')
137 | changed = True
138 |
139 |
140 |
141 | if cmd==ord('n'):
142 | zoom*=1.1
143 | changed=True
144 | elif cmd==ord('m'):
145 | zoom/=1.1
146 | changed=True
147 | elif cmd==ord('r'):
148 | zoom=1.0
149 | changed=True
150 | elif cmd==ord('s'):
151 | cv2.imwrite('show3d.png',show)
152 | if waittime!=0:
153 | break
154 | return cmd
155 | if __name__=='__main__':
156 | np.random.seed(100)
157 | showpoints(np.random.randn(2500,3))
158 |
159 |
--------------------------------------------------------------------------------
/show_seg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import torch
10 | import torch.nn as nn
11 |
12 | from datasets import PartDataset
13 | from pointnet import PointNetSeg
14 | from utils import shapenet_labels
15 | from show3d_balls import *
16 |
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument('--model', type=str, default='', help='model path')
20 | parser.add_argument('--idx', type=int, default=0, help='model index')
21 | parser.add_argument('--className', type=str, default='Chair', help='number of classes')
22 | parser.add_argument('--radius', type=int, default=2, help='radius of ball for visualization')
23 | parser.add_argument('--cmap', type=str, default='hsv', help='valid matplotlib cmap')
24 | parser.add_argument('--npoints', type=int, default=2500, help='points to sample')
25 |
26 | opt = parser.parse_args()
27 |
28 | idx = opt.idx
29 |
30 | d = PartDataset(root='shapenetcore_partanno_segmentation_benchmark_v0', class_choice=[opt.className], train=False, npoints=2048)
31 |
32 | print('model %d/%d' % (idx, len(d)))
33 |
34 | num_class = d.num_classes
35 | print('number of classes', num_class)
36 |
37 | point, seg = d[idx]
38 |
39 | point_np = point.numpy()
40 |
41 | cmap = plt.cm.get_cmap(opt.cmap, 10)
42 | cmap = np.array([cmap(i) for i in range(10)])[:, :3]
43 | gt = cmap[seg.numpy()-1, :]
44 |
45 | classifier = PointNetSeg(k=shapenet_labels[opt.className])
46 | classifier.load_state_dict(torch.load(opt.model))
47 | classifier.eval()
48 |
49 | point = point.transpose(1, 0).contiguous()
50 | point = point.view(1, point.size()[0], point.size()[1])
51 |
52 |
53 | pred, _ = classifier(point)
54 |
55 | pred_choice = pred.data.max(2)[1]
56 |
57 | pred_color = cmap[pred_choice.numpy()[0], :]
58 |
59 | showpoints(point_np, gt, gt, ballradius=opt.radius)
60 |
61 |
62 |
--------------------------------------------------------------------------------
/show_seg_s3d.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import matplotlib.pyplot as plt
9 | import torch
10 | import torch.nn as nn
11 |
12 | from pointnet import PointNetSeg
13 | from datasets import S3dDataset
14 | from datasets import scale_linear_bycolumn
15 |
16 | from show3d_balls import *
17 |
18 |
19 | def parse_whole_scene(scene_path, scene_num, npoints=4096):
20 | scene = np.loadtxt(scene_path).astype(np.float32)
21 | seg = scene[:, -1].astype(np.int64)
22 | scene = scene[:, :3]
23 | replace = False if (scene_num*npoints