├── .DS_Store
├── .idea
├── deployment.xml
├── misc.xml
├── modules.xml
├── pointhop-master.iml
├── vcs.xml
└── workspace.xml
├── README.md
├── data_utils.py
├── doc
└── intro.png
├── evaluate.py
├── modelnet_data.py
├── point_utils.py
├── pointhop.py
└── train.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minzhang-1/PointHop/4eb9618f7c37d59d2a6267288df1edfd6906c38e/.DS_Store
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/pointhop-master.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 | gather_ops
77 | pdb
78 | test_area
79 |
80 |
81 |
82 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 | 1563795570519
163 |
164 |
165 | 1563795570519
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PointHop: *An Explainable Machine Learning Method for Point Cloud Classification*
2 | Created by Min Zhang, Haoxuan You, Pranav Kadam, Shan Liu, C.-C. Jay Kuo from University of Southern California.
3 |
4 | 
5 |
6 | ### Introduction
7 | This work is an official implementation of our [arXiv tech report](https://arxiv.org/abs/1907.12766). We proposed a novel explainable machine learning method for point cloud, called the PointHop method.
8 |
9 | We address the problem of unordered point cloud data using a space partitioning procedure and developing a robust descriptor that characterizes the relationship between a point and its one-hop neighbor in a PointHop unit. Furthermore, we used the Saab transform to reduce the attribute dimension in each PointHop unit. In the classification stage, we fed the feature vector to a classifier and explored ensemble methods to improve the classification performance. It was shown by experimental results that the training complexity of the PointHop method is significantly lower than that of state-of-the-art deep learning-based methods with comparable classification performance.
10 |
11 | In this repository, we release code and data for training a PointHop classification network on point clouds sampled from 3D shapes.
12 |
13 | ### Spark version
14 | This implementation has a high requirement for memory. If you only have 16/32GB memory, please use our [new distributed version](https://github.com/minzhang-1/PointHop-PointHop2_Spark) which is built upon Apache Spark. The new version implements the baseline within 20 minutes, using less than 12GB memory.
15 |
16 | ### Citation
17 | If you find our work useful in your research, please consider citing:
18 |
19 | @article{zhang2020pointhop,
20 | title={PointHop: An Explainable Machine Learning Method for Point Cloud Classification},
21 | author={Zhang, Min and You, Haoxuan and Kadam, Pranav and Liu, Shan and Kuo, C-C Jay},
22 | journal={IEEE Transactions on Multimedia},
23 | year={2020},
24 | publisher={IEEE}
25 | }
26 |
27 | ### Installation
28 |
29 | The code has been tested with Python 3.5. You may need to install h5py, pytorch, sklearn, pickle and threading packages.
30 |
31 | To install h5py for Python:
32 | ```bash
33 | sudo apt-get install libhdf5-dev
34 | sudo pip install h5py
35 | ```
36 |
37 | ### Usage
38 | To train a single model to classify point clouds sampled from 3D shapes:
39 |
40 | python3 train.py
41 |
42 | After the above training, we can evaluate the single model.
43 |
44 | python3 evaluate.py
45 |
46 | If you would like to achieve better performance, you can change the argument `ensemble` from `False` to `True` in both `train.py` and `evaluate.py`.
47 |
48 | If you run the code on your laptop with small memory, you can change the argument `num_batch_train` or `num_batch_test` larger. To get the same speed and performance as the paper, set `num_batch_train` as 1 and `num_batch_test` as 1 and change `incremenntalpca` to `pca` in pointhop.py
49 |
50 | Log files and network parameters will be saved to `log` folder. Point clouds of ModelNet40 models in HDF5 files will be automatically downloaded (416MB) to the data folder. Each point cloud contains 2048 points uniformly sampled from a shape surface. Each cloud is zero-mean and normalized into an unit sphere. There are also text files in `data/modelnet40_ply_hdf5_2048` specifying the ids of shapes in h5 files.
51 |
52 |
53 |
--------------------------------------------------------------------------------
/data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def normal_pc(pc):
5 | """
6 | normalize point cloud in range L
7 | :param pc: type list
8 | :return: type list
9 | """
10 | pc_mean = pc.mean(axis=0)
11 | pc = pc - pc_mean
12 | pc_L_max = np.max(np.sqrt(np.sum(abs(pc ** 2), axis=-1)))
13 | pc = pc/pc_L_max
14 | return pc
15 |
16 |
17 | def rotation_point_cloud(pc):
18 | """
19 | Randomly rotate the point clouds to augment the dataset
20 | rotation is per shape based along up direction
21 | :param pc: B X N X 3 array, original batch of point clouds
22 | :return: BxNx3 array, rotated batch of point clouds
23 | """
24 | # rotated_data = np.zeros(pc.shape, dtype=np.float32)
25 |
26 | rotation_angle = np.random.uniform() * 2 * np.pi
27 | cosval = np.cos(rotation_angle)
28 | sinval = np.sin(rotation_angle)
29 | rotation_matrix = np.array([[cosval, 0, sinval],
30 | [0, 1, 0],
31 | [-sinval, 0, cosval]])
32 | rotated_data = np.dot(pc.reshape((-1, 3)), rotation_matrix)
33 |
34 | return rotated_data
35 |
36 |
37 | def rotate_point_cloud_by_angle(pc, rotation_angle):
38 | """
39 | Randomly rotate the point clouds to augment the dataset
40 | rotation is per shape based along up direction
41 | :param pc: B X N X 3 array, original batch of point clouds
42 | :param rotation_angle: angle of rotation
43 | :return: BxNx3 array, rotated batch of point clouds
44 | """
45 | # rotated_data = np.zeros(pc.shape, dtype=np.float32)
46 |
47 | # rotation_angle = np.random.uniform() * 2 * np.pi
48 | cosval = np.cos(rotation_angle)
49 | sinval = np.sin(rotation_angle)
50 | rotation_matrix = np.array([[cosval, 0, sinval],
51 | [0, 1, 0],
52 | [-sinval, 0, cosval]])
53 | rotated_data = np.dot(pc.reshape((-1, 3)), rotation_matrix)
54 |
55 | return rotated_data
56 |
57 |
58 | def jitter_point_cloud(pc, sigma=0.01, clip=0.05):
59 | """
60 | Randomly jitter points. jittering is per point.
61 | :param pc: B X N X 3 array, original batch of point clouds
62 | :param sigma:
63 | :param clip:
64 | :return:
65 | """
66 | jittered_data = np.clip(sigma * np.random.randn(*pc.shape), -1 * clip, clip)
67 | jittered_data += pc
68 | return jittered_data
69 |
70 |
71 | def shift_point_cloud(pc, shift_range=0.1):
72 | """ Randomly shift point cloud. Shift is per point cloud.
73 | Input:
74 | BxNx3 array, original batch of point clouds
75 | Return:
76 | BxNx3 array, shifted batch of point clouds
77 | """
78 | N, C = pc.shape
79 | shifts = np.random.uniform(-shift_range, shift_range, 3)
80 | pc += shifts
81 | return pc
82 |
83 |
84 | def random_scale_point_cloud(pc, scale_low=0.8, scale_high=1.25):
85 | """ Randomly scale the point cloud. Scale is per point cloud.
86 | Input:
87 | BxNx3 array, original batch of point clouds
88 | Return:
89 | BxNx3 array, scaled batch of point clouds
90 | """
91 | N, C = pc.shape
92 | scales = np.random.uniform(scale_low, scale_high, 1)
93 | pc *= scales
94 | return pc
95 |
96 |
97 | def rotate_perturbation_point_cloud(pc, angle_sigma=0.06, angle_clip=0.18):
98 | """ Randomly perturb the point clouds by small rotations
99 | Input:
100 | BxNx3 array, original batch of point clouds
101 | Return:
102 | BxNx3 array, rotated batch of point clouds
103 | """
104 | # rotated_data = np.zeros(pc.shape, dtype=np.float32)
105 | angles = np.clip(angle_sigma * np.random.randn(3), -angle_clip, angle_clip)
106 | Rx = np.array([[1, 0, 0],
107 | [0, np.cos(angles[0]), -np.sin(angles[0])],
108 | [0, np.sin(angles[0]), np.cos(angles[0])]])
109 | Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],
110 | [0, 1, 0],
111 | [-np.sin(angles[1]), 0, np.cos(angles[1])]])
112 | Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],
113 | [np.sin(angles[2]), np.cos(angles[2]), 0],
114 | [0, 0, 1]])
115 | R = np.dot(Rz, np.dot(Ry, Rx))
116 | shape_pc = pc
117 | rotated_data = np.dot(shape_pc.reshape((-1, 3)), R)
118 | return rotated_data
119 |
120 |
121 | def pc_augment(pc, angle):
122 | pc = rotate_point_cloud_by_angle(pc, angle)
123 | # pc = rotation_point_cloud(pc)
124 | # pc = jitter_point_cloud(pc)
125 | # pc = random_scale_point_cloud(pc)
126 | # pc = rotate_perturbation_point_cloud(pc)
127 | # pc = shift_point_cloud(pc)
128 | return pc
129 |
130 |
131 | def data_augment(train_data, angle):
132 | return pc_augment(train_data, angle).reshape(-1, train_data.shape[1], train_data.shape[2])
133 |
134 |
--------------------------------------------------------------------------------
/doc/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minzhang-1/PointHop/4eb9618f7c37d59d2a6267288df1edfd6906c38e/doc/intro.png
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import sklearn
4 | import modelnet_data
5 | import pointhop
6 | import numpy as np
7 | import data_utils
8 | import os
9 | import time
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--num_batch_test', type=int, default=1, help='Batch Number')
13 | parser.add_argument('--initial_point', type=int, default=1024, help='Point Number [256/512/1024/2048]')
14 | parser.add_argument('--ensemble', default=False, help='Ensemble or not')
15 | parser.add_argument('--rotation_angle', default=np.pi/4, help='Rotate angle')
16 | parser.add_argument('--rotation_freq', default=8, help='Rotate time')
17 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
18 | parser.add_argument('--num_point', default=[1024, 128, 128, 128], help='Point Number after down sampling')
19 | parser.add_argument('--num_sample', default=[64, 64, 64, 64], help='KNN query number')
20 | parser.add_argument('--num_filter', default=[15, 25, 40, 80], help='Filter Number ')
21 | parser.add_argument('--pooling_method', default=[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1],
22 | [1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1], [0, 1, 1, 0],
23 | [0, 1, 0, 1], [0, 0, 1, 1], [1, 1, 1, 0], [1, 1, 0, 1],
24 | [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]],
25 | help='Pooling methods [mean, max, l1, l2]')
26 | FLAGS = parser.parse_args()
27 |
28 | num_batch_test = FLAGS.num_batch_test
29 | initial_point = FLAGS.initial_point
30 | ENSEMBLE = FLAGS.ensemble
31 | angle_rotation = FLAGS.rotation_angle
32 | freq_rotation = FLAGS.rotation_freq
33 | num_point = FLAGS.num_point
34 | num_sample = FLAGS.num_sample
35 | num_filter = FLAGS.num_filter
36 | pooling = FLAGS.pooling_method
37 |
38 |
39 | LOG_DIR = FLAGS.log_dir
40 | if not os.path.exists(LOG_DIR):
41 | os.mkdir(LOG_DIR)
42 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_test.txt'), 'w')
43 | LOG_FOUT.write(str(FLAGS) + '\n')
44 |
45 |
46 | def log_string(out_str):
47 | LOG_FOUT.write(out_str+'\n')
48 | LOG_FOUT.flush()
49 | print(out_str)
50 |
51 |
52 | def main():
53 | time_start = time.time()
54 |
55 | # load data
56 | test_data, test_label = modelnet_data.data_load(num_point=initial_point, data_dir='modelnet40_ply_hdf5_2048', train=False)
57 |
58 | if ENSEMBLE:
59 | angle = np.repeat(angle_rotation, freq_rotation)
60 | else:
61 | angle = [0]
62 |
63 | with open(os.path.join(LOG_DIR, 'params.pkl'), 'rb') as f:
64 | params = pickle.load(f, encoding='latin')
65 |
66 | # get feature and pca parameter
67 | feature_test = []
68 | for i in range(len(angle)):
69 | print('------------Test ', i, '--------------')
70 |
71 | pca_params = params['stage %d pca_params' % i]
72 |
73 | final_feature, feature = pointhop.pointhop_pred(
74 | test_data, n_batch=num_batch_test, pca_params=pca_params, n_newpoint=num_point, n_sample=num_sample, layer_num=num_filter,
75 | idx_save=None, new_xyz_save=None)
76 |
77 | feature = pointhop.extract(feature)
78 |
79 | feature_test.append(feature)
80 | test_data = data_utils.data_augment(test_data, angle[i])
81 |
82 | feature_test = np.concatenate(feature_test, axis=-1)
83 |
84 | clf_tmp = params['clf']
85 | pred_test_tmp = []
86 | acc_test_tmp = []
87 | for i in range(len(pooling)):
88 | clf = clf_tmp['pooling method %d' % i]
89 | feature_test_tmp = pointhop.aggregate(feature_test, pooling[i])
90 | pred_test = clf.predict(feature_test_tmp)
91 | acc_test = sklearn.metrics.accuracy_score(test_label, pred_test)
92 | pred_test_tmp.append(pred_test)
93 | acc_test_tmp.append(acc_test)
94 | idx = np.argmax(acc_test_tmp)
95 | pred_test = pred_test_tmp[idx]
96 | acc = pointhop.average_acc(test_label, pred_test)
97 |
98 | time_end = time.time()
99 |
100 | log_string("test acc is {}".format(acc_test_tmp[idx]))
101 | log_string('test mean acc is {}'.format(np.mean(acc)))
102 | log_string('per-class acc is {}'.format(str(acc)))
103 | log_string('totally time cost is {} minutes'.format((time_end - time_start)//60))
104 |
105 | # with open(os.path.join(LOG_DIR, 'ensemble_pred_test.pkl'), 'wb') as f:
106 | # pickle.dump(pred_test, f)
107 |
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
112 |
--------------------------------------------------------------------------------
/modelnet_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import h5py
3 | import numpy as np
4 | from sklearn.model_selection import train_test_split
5 |
6 |
7 | def load_dir(data_dir, name='train_files.txt'):
8 | with open(os.path.join(data_dir,name),'r') as f:
9 | lines = f.readlines()
10 | return [os.path.join(data_dir, line.rstrip().split('/')[-1]) for line in lines]
11 |
12 |
13 | def shuffle_data(data):
14 | """ Shuffle data order.
15 | Input:
16 | data: B,N,... numpy array
17 | Return:
18 | shuffled data, shuffle indices
19 | """
20 | idx = np.arange(data.shape[0])
21 | np.random.shuffle(idx)
22 | return data[idx, ...], idx
23 |
24 |
25 | def shuffle_points(data):
26 | """ Shuffle orders of points in each point cloud -- changes FPS behavior.
27 | Input:
28 | BxNxC array
29 | Output:
30 | BxNxC array
31 | """
32 | idx = np.arange(data.shape[1])
33 | np.random.shuffle(idx)
34 | return data[:, idx, :], idx
35 |
36 |
37 | def xyz2sphere(data):
38 | """
39 | Input: data(B,N,3) xyz_coordinates
40 | Return: data(B,N,3) sphere_coordinates
41 | """
42 | r = np.sqrt(np.sum(data**2, axis=2, keepdims=False))
43 | theta = np.arccos(data[...,2]*1.0/r)
44 | phi = np.arctan(data[...,1]*1.0/data[...,0])
45 |
46 | if len(r.shape) == 2:
47 | r = np.expand_dims(r, 2)
48 | if len(theta.shape) == 2:
49 | theta = np.expand_dims(theta, 2)
50 | if len(phi.shape) == 2:
51 | phi = np.expand_dims(phi, 2)
52 |
53 | data_sphere = np.concatenate([r, theta, phi], axis=2)
54 | return data_sphere
55 |
56 |
57 | def xyz2cylind(data):
58 | """
59 | Input: data(B,N,3) xyz_coordinates
60 | Return: data(B,N,3) cylindrical_coordinates
61 | """
62 | r = np.sqrt(np.sum(data[...,:2]**2, axis=2, keepdims=False))
63 | phi = np.arctan(data[...,1]*1.0/data[...,0])
64 | z = data[...,2]
65 |
66 | if len(r.shape) == 2:
67 | r = np.expand_dims(r, 2)
68 | if len(z.shape) == 2:
69 | z = np.expand_dims(z, 2)
70 | if len(phi.shape) == 2:
71 | phi = np.expand_dims(phi, 2)
72 |
73 | data_sphere = np.concatenate([r, z, phi], axis=2)
74 | return data_sphere
75 |
76 |
77 | def data_load(num_point=None, data_dir='/modelnet40_ply_hdf5_2048', train=True):
78 | if not os.path.exists('modelnet40_ply_hdf5_2048'):
79 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
80 | zipfile = os.path.basename(www)
81 | os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile))
82 | os.system('rm %s' % (zipfile))
83 |
84 | if train:
85 | data_pth = load_dir(data_dir, name='train_files.txt')
86 | else:
87 | data_pth = load_dir(data_dir, name='test_files.txt')
88 |
89 | point_list = []
90 | label_list = []
91 | for pth in data_pth:
92 | data_file = h5py.File(pth, 'r')
93 | point = data_file['data'][:]
94 | label = data_file['label'][:]
95 | point_list.append(point)
96 | label_list.append(label)
97 | data = np.concatenate(point_list, axis=0)
98 | label = np.concatenate(label_list, axis=0)
99 | # data, idx = shuffle_data(data)
100 | # data, ind = shuffle_points(data)
101 |
102 | if not num_point:
103 | return data[:, :, :], label
104 | else:
105 | return data[:, :num_point, :], label
106 |
107 |
108 | def data_separate(data, label):
109 | seed = 7
110 | np.random.seed(seed)
111 | train_data, valid_data, train_label, valid_label = train_test_split(data, label, test_size=0.1, random_state=seed)
112 |
113 | return train_data, train_label, valid_data, valid_label
114 |
115 |
116 |
--------------------------------------------------------------------------------
/point_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import threading
4 |
5 |
6 | def calc_distances(tmp, pts):
7 | '''
8 |
9 | :param tmp:(B, k, 3)/(B, 3)
10 | :param pts:(B, N, 3)
11 | :return:(B, N, k)/(B, N)
12 | '''
13 | if len(tmp.shape) == 2:
14 | tmp = np.expand_dims(tmp, axis=1)
15 | tmp_trans = np.transpose(tmp, [0,2,1])
16 | xy = np.matmul(pts, tmp_trans)
17 | pts_square = (pts**2).sum(axis=2, keepdims=True)
18 | tmp_square_trans = (tmp_trans**2).sum(axis=1, keepdims=True)
19 | return np.squeeze(pts_square + tmp_square_trans - 2 * xy)
20 |
21 |
22 | def index_points(points, idx):
23 | """
24 | Input:
25 | points: input points data, [B, N, C]
26 | idx: sample index data, [B, S]
27 | Return:
28 | new_points:, indexed points data, [B, S, C]
29 | """
30 | B = points.shape[0]
31 | view_shape = list(idx.shape)
32 | view_shape[1:] = [1] * (len(view_shape) - 1)
33 | repeat_shape = list(idx.shape)
34 | repeat_shape[0] = 1
35 | batch_indices = np.tile(np.arange(B).reshape(view_shape),repeat_shape)
36 | new_points = points[batch_indices, idx, :]
37 | return new_points
38 |
39 |
40 | def furthest_point_sample(pts, K):
41 | """
42 | Input:
43 | pts: pointcloud data, [B, N, C]
44 | K: number of samples
45 | Return:
46 | (B, K, 3)
47 | """
48 | B, N, C = pts.shape
49 | centroids = np.zeros((B, K), dtype=int)
50 | distance = np.ones((B, N), dtype=int) * 1e10
51 | farthest = np.random.randint(0, N, (B,))
52 | batch_indices = np.arange(B)
53 | for i in range(K):
54 | centroids[:, i] = farthest
55 | centroid = pts[batch_indices, farthest, :].reshape(B, 1, 3)
56 | dist = np.sum((pts - centroid) ** 2, axis=-1)
57 | mask = dist < distance
58 | distance[mask] = dist[mask]
59 | farthest = np.argmax(distance, axis=-1)
60 | return index_points(pts, centroids)
61 |
62 |
63 | def knn_query(new_pts, pts, n_sample, idx):
64 | '''
65 | new_pts:(B, K, 3)
66 | pts:(B, N, 3)
67 | n_sample:int
68 | :return: nn_idx (B, n_sample, K)
69 | '''
70 | distance_matrix = calc_distances(new_pts, pts)
71 | # nn_idx = np.argsort(distance_matrix, axis=1, kind='stable')[:, :n_sample, :] # (B, n, K)
72 | nn_idx = np.argpartition(distance_matrix, (0, n_sample), axis=1)[:, :n_sample, :]
73 | idx.append(nn_idx)
74 |
75 |
76 | def knn(new_xyz, point_data, n_sample):
77 | idx1 = []
78 | idx2 = []
79 | idx3 = []
80 | idx4 = []
81 | idx5 = []
82 | idx6 = []
83 | idx7 = []
84 | idx8 = []
85 | threads = []
86 | batch_size = point_data.shape[0]//8
87 | t1 = threading.Thread(target=knn_query, args=(new_xyz[:batch_size], point_data[:batch_size], n_sample, idx1))
88 | threads.append(t1)
89 | t2 = threading.Thread(target=knn_query, args=(new_xyz[batch_size:2*batch_size], point_data[batch_size:2*batch_size], n_sample, idx2))
90 | threads.append(t2)
91 | t3 = threading.Thread(target=knn_query, args=(new_xyz[2*batch_size:3*batch_size], point_data[2*batch_size:3*batch_size], n_sample, idx3))
92 | threads.append(t3)
93 | t4 = threading.Thread(target=knn_query, args=(new_xyz[3*batch_size:4*batch_size], point_data[3*batch_size:4*batch_size], n_sample, idx4))
94 | threads.append(t4)
95 | t5 = threading.Thread(target=knn_query, args=(new_xyz[4*batch_size:5*batch_size], point_data[4*batch_size:5*batch_size], n_sample, idx5))
96 | threads.append(t5)
97 | t6 = threading.Thread(target=knn_query, args=(new_xyz[5*batch_size:6*batch_size], point_data[5*batch_size:6*batch_size], n_sample, idx6))
98 | threads.append(t6)
99 | t7 = threading.Thread(target=knn_query, args=(new_xyz[6*batch_size:7*batch_size], point_data[6*batch_size:7*batch_size], n_sample, idx7))
100 | threads.append(t7)
101 | t8 = threading.Thread(target=knn_query, args=(new_xyz[7*batch_size:], point_data[7*batch_size:], n_sample, idx8))
102 | threads.append(t8)
103 |
104 | for t in threads:
105 | t.setDaemon(False)
106 | t.start()
107 | for t in threads:
108 | if t.isAlive():
109 | t.join()
110 | idx = idx1 + idx2 + idx3 + idx4 + idx5 + idx6 + idx7 + idx8
111 | idx_tmp = np.concatenate(idx, axis=0)
112 |
113 | return idx_tmp
114 |
115 |
116 | def gather_ops(nn_idx, pts):
117 | """
118 | nn_idx:(B, n_sample, K)
119 | pts:(B, N, dim)
120 | :return: pc_n(B, n_sample, K, dim)
121 | """
122 | num_newpts = nn_idx.shape[2]
123 | num_dim = pts.shape[2]
124 | pts_expand = torch.from_numpy(pts).type(torch.FloatTensor).unsqueeze(2).expand(-1, -1, num_newpts, -1)
125 | nn_idx_expand = torch.from_numpy(nn_idx).type(torch.LongTensor).unsqueeze(3).expand(-1, -1, -1, num_dim)
126 | pc_n = torch.gather(pts_expand, 1, nn_idx_expand)
127 | return pc_n.numpy()
128 |
129 |
130 | def calc_feature(pc_temp, pc_bin, pc_gather):
131 | value = np.multiply(pc_temp, pc_bin)
132 | value = np.sum(value, axis=2, keepdims=True)
133 | num = np.sum(pc_bin, axis=2, keepdims=True)
134 | final = np.squeeze(value/num)
135 | pc_gather.append(final)
136 |
137 |
138 | def gather_fea(nn_idx, point_data, fea):
139 | """
140 | nn_idx:(B, n_sample, K)
141 | pts:(B, N, dim)
142 | :return: pc_n(B, K, dim_fea)
143 | """
144 | num_newpts = nn_idx.shape[2]
145 | assert point_data.shape[:-1] == fea.shape[:-1]
146 | pts_fea = np.concatenate([point_data, fea], axis=-1)
147 | num_dim = pts_fea.shape[2]
148 |
149 | pts_fea_expand = index_points(pts_fea, nn_idx)
150 | # print(pts_fea_expand.shape)
151 | pts_fea_expand = pts_fea_expand.transpose(0, 2, 1, 3) # (B, K, n_sample, dim)
152 | pc_n = pts_fea_expand[..., :3]
153 | pc_temp = pts_fea_expand[..., 3:]
154 |
155 | pc_n_center = np.expand_dims(pc_n[:, :, 0, :], axis=2)
156 | pc_n_uncentered = pc_n - pc_n_center
157 |
158 | pc_idx = []
159 | pc_idx.append(pc_n_uncentered[:, :, :, 0] >= 0)
160 | pc_idx.append(pc_n_uncentered[:, :, :, 0] <= 0)
161 | pc_idx.append(pc_n_uncentered[:, :, :, 1] >= 0)
162 | pc_idx.append(pc_n_uncentered[:, :, :, 1] <= 0)
163 | pc_idx.append(pc_n_uncentered[:, :, :, 2] >= 0)
164 | pc_idx.append(pc_n_uncentered[:, :, :, 2] <= 0)
165 |
166 | pc_bin = []
167 | pc_bin.append(np.expand_dims((pc_idx[0] * pc_idx[2] * pc_idx[4])*1.0, axis=3))
168 | pc_bin.append(np.expand_dims((pc_idx[0] * pc_idx[2] * pc_idx[5])*1.0, axis=3))
169 | pc_bin.append(np.expand_dims((pc_idx[0] * pc_idx[3] * pc_idx[4])*1.0, axis=3))
170 | pc_bin.append(np.expand_dims((pc_idx[0] * pc_idx[3] * pc_idx[5])*1.0, axis=3))
171 | pc_bin.append(np.expand_dims((pc_idx[1] * pc_idx[2] * pc_idx[4])*1.0, axis=3))
172 | pc_bin.append(np.expand_dims((pc_idx[1] * pc_idx[2] * pc_idx[5])*1.0, axis=3))
173 | pc_bin.append(np.expand_dims((pc_idx[1] * pc_idx[3] * pc_idx[4])*1.0, axis=3))
174 | pc_bin.append(np.expand_dims((pc_idx[1] * pc_idx[3] * pc_idx[5])*1.0, axis=3))
175 |
176 | pc_gather1 = []
177 | pc_gather2 = []
178 | pc_gather3 = []
179 | pc_gather4 = []
180 | pc_gather5 = []
181 | pc_gather6 = []
182 | pc_gather7 = []
183 | pc_gather8 = []
184 | threads = []
185 | t1 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[0], pc_gather1))
186 | threads.append(t1)
187 | t2 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[1], pc_gather2))
188 | threads.append(t2)
189 | t3 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[2], pc_gather3))
190 | threads.append(t3)
191 | t4 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[3], pc_gather4))
192 | threads.append(t4)
193 | t5 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[4], pc_gather5))
194 | threads.append(t5)
195 | t6 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[5], pc_gather6))
196 | threads.append(t6)
197 | t7 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[6], pc_gather7))
198 | threads.append(t7)
199 | t8 = threading.Thread(target=calc_feature, args=(pc_temp, pc_bin[7], pc_gather8))
200 | threads.append(t8)
201 | for t in threads:
202 | t.setDaemon(False)
203 | t.start()
204 | for t in threads:
205 | if t.isAlive():
206 | t.join()
207 | pc_gather = pc_gather1 + pc_gather2 + pc_gather3 + pc_gather4 + pc_gather5 + pc_gather6 + pc_gather7 + pc_gather8
208 | pc_fea = np.concatenate(pc_gather, axis=2)
209 |
210 | return pc_fea
211 |
212 |
213 | def gather_global_fea(feature, xyz, part=5):
214 | '''
215 |
216 | :param feature: (B, n_point, dim)
217 | :param xyz: (B, n_point, 3)
218 | :param part:int
219 | :return: (B, dim*part)
220 | '''
221 |
222 | pts_square = (xyz**2).sum(axis=2, keepdims=False)
223 | dis = np.sqrt(pts_square) # (B, n_point)
224 | total_fea = []
225 | for i in range(part):
226 | idx = (dis >= (i/float(part))) * (dis <= ((i+1)/float(part)))*1.0
227 | part_fea = (feature*np.expand_dims(idx, axis=2)).max(axis=1, keepdims=False)
228 | total_fea.append(part_fea)
229 | return np.concatenate(total_fea, axis=1)
--------------------------------------------------------------------------------
/pointhop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import sklearn
3 | from sklearn.decomposition import PCA, IncrementalPCA
4 | from numpy import linalg as LA
5 |
6 | from sklearn import svm
7 | from sklearn import ensemble
8 |
9 |
10 | import point_utils
11 |
12 |
13 | def pointhop_train(train_data, n_batch, n_newpoint, n_sample, layer_num, energy_percent):
14 | '''
15 | Train based on the provided samples.
16 | :param train_data: [num_samples, num_point, feature_dimension]
17 | :param n_newpoint: point numbers used in every stage
18 | :param n_sample: k nearest neighbors
19 | :param layer_num: num kernels to be preserved
20 | :param energy_percent: the percent of energy to be preserved
21 | :return: idx, new_idx, final stage feature, feature, pca_params
22 | '''
23 |
24 | num_data = train_data.shape[0]
25 | pca_params = {}
26 | idx_save = {}
27 | new_xyz_save = {}
28 |
29 | point_data = train_data
30 | batch_size = num_data//n_batch
31 | grouped_feature = None
32 | feature_train = []
33 |
34 | feature_data = train_data
35 |
36 | for i in range(len(n_newpoint)):
37 | print(i)
38 | point_num = point_data.shape[1]
39 | print('Start sampling-------------')
40 | if n_newpoint[i] == point_num:
41 | new_xyz = point_data
42 | else:
43 | new_xyz = point_utils.furthest_point_sample(point_data, n_newpoint[i])
44 |
45 | new_xyz_save['Layer_{:d}'.format(i)] = new_xyz
46 |
47 | print('Start query and gathering-------------')
48 | # time_start = time.time()
49 | if not grouped_feature is None:
50 | idx, grouped_feature = query_and_gather(new_xyz, n_batch, batch_size, point_data, grouped_feature, n_sample[i], None)
51 | else:
52 | idx, grouped_feature = query_and_gather(new_xyz, n_batch, batch_size, point_data, feature_data, n_sample[i], None)
53 |
54 | idx_save['Layer_%d' % (i)] = idx
55 | grouped_feature = grouped_feature.reshape(num_data*n_newpoint[i], -1)
56 | print('ok-------------')
57 |
58 | kernels, mean = find_kernels_pca(grouped_feature, layer_num[i], energy_percent, n_batch)
59 | if i == 0:
60 | transformed = np.matmul(grouped_feature, np.transpose(kernels))
61 | else:
62 | bias = LA.norm(grouped_feature, axis=1)
63 | bias = np.max(bias)
64 | pca_params['Layer_{:d}/bias'.format(i)] = bias
65 | grouped_feature = grouped_feature + bias
66 |
67 | transformed = np.matmul(grouped_feature, np.transpose(kernels))
68 | e = np.zeros((1, kernels.shape[0]))
69 | e[0, 0] = 1
70 | transformed -= bias*e
71 | grouped_feature = transformed.reshape(num_data, n_newpoint[i], -1)
72 | print(grouped_feature.shape)
73 | feature_train.append(grouped_feature)
74 | pca_params['Layer_{:d}/kernel'.format(i)] = kernels
75 | pca_params['Layer_{:d}/pca_mean'.format(i)] = mean
76 | point_data = new_xyz
77 | final_feature = grouped_feature.max(axis=1, keepdims=False)
78 |
79 | return idx_save, new_xyz_save, final_feature, feature_train, pca_params
80 |
81 |
82 | def pointhop_pred(test_data, n_batch, pca_params, n_newpoint, n_sample, layer_num, idx_save, new_xyz_save):
83 | '''
84 | Test based on the provided samples.
85 | :param test_data: [num_samples, num_point, feature_dimension]
86 | :param pca_params: pca kernel and mean
87 | :param n_newpoint: point numbers used in every stage
88 | :param n_sample: k nearest neighbors
89 | :param layer_num: num kernels to be preserved
90 | :param idx_save: knn index
91 | :param new_xyz_save: down sample index
92 | :return: final stage feature, feature, pca_params
93 | '''
94 |
95 | num_data = test_data.shape[0]
96 | point_data = test_data
97 | grouped_feature = None
98 | feature_test = []
99 | batch_size = num_data//n_batch
100 |
101 | feature_data = test_data
102 |
103 | for i in range(len(n_newpoint)):
104 | if not new_xyz_save:
105 | point_num = point_data.shape[1]
106 | if n_newpoint[i] == point_num:
107 | new_xyz = point_data
108 | else:
109 | new_xyz = point_utils.furthest_point_sample(point_data, n_newpoint[i])
110 | else:
111 | print('---------------loading idx--------------')
112 | new_xyz = new_xyz_save['Layer_{:d}'.format(i)]
113 |
114 | if not grouped_feature is None:
115 | idx, grouped_feature = query_and_gather(new_xyz, n_batch, batch_size, point_data, grouped_feature, n_sample[i], None)
116 | else:
117 | idx, grouped_feature = query_and_gather(new_xyz, n_batch, batch_size, point_data, feature_data, n_sample[i], None)
118 |
119 | grouped_feature = grouped_feature.reshape(num_data*n_newpoint[i], -1)
120 |
121 | kernels = pca_params['Layer_{:d}/kernel'.format(i)]
122 | mean = pca_params['Layer_{:d}/pca_mean'.format(i)]
123 |
124 | if i == 0:
125 | transformed = np.matmul(grouped_feature, np.transpose(kernels))
126 | else:
127 | bias = pca_params['Layer_{:d}/bias'.format(i)]
128 | grouped_feature = grouped_feature + bias
129 | transformed = np.matmul(grouped_feature, np.transpose(kernels))
130 | e = np.zeros((1, kernels.shape[0]))
131 | e[0, 0] = 1
132 | transformed -= bias*e
133 | grouped_feature = transformed.reshape(num_data, n_newpoint[i], -1)
134 | feature_test.append(grouped_feature)
135 | point_data = new_xyz
136 | final_feature = grouped_feature.max(axis=1, keepdims=False)
137 | return final_feature, feature_test
138 |
139 |
140 | def query_and_gather(new_xyz, n_batch, batch_size, pts_coor, pts_fea, n_sample, pooling):
141 | idx = []
142 | grouped_feature = []
143 | for j in range(n_batch):
144 | if j != n_batch - 1:
145 | idx_tmp = point_utils.knn(new_xyz[j * batch_size:(j + 1) * batch_size],
146 | pts_coor[j * batch_size:(j + 1) * batch_size]
147 | , n_sample)
148 | grouped_feature_tmp = point_utils.gather_fea(idx_tmp, pts_coor[j * batch_size:(j + 1) * batch_size],
149 | pts_fea[j * batch_size:(j + 1) * batch_size])
150 | else:
151 | idx_tmp = point_utils.knn(new_xyz[j * batch_size:], pts_coor[j * batch_size:], n_sample)
152 | grouped_feature_tmp = point_utils.gather_fea(idx_tmp, pts_coor[j * batch_size:],
153 | pts_fea[j * batch_size:])
154 | if pooling is not None:
155 | grouped_feature_tmp = grouped_feature_tmp.reshape(grouped_feature_tmp.shape[0], grouped_feature_tmp.shape[1], 8, -1)
156 | grouped_feature_tmp = extract(grouped_feature_tmp, pooling, 2)
157 | idx.append(idx_tmp)
158 | grouped_feature.append(grouped_feature_tmp)
159 | idx = np.concatenate(idx, axis=0)
160 | grouped_feature = np.concatenate(grouped_feature, axis=0)
161 | return idx, grouped_feature
162 |
163 |
164 | def remove_mean(features, axis):
165 | '''
166 | Remove the dataset mean.
167 | :param features [num_samples,...]
168 | :param axis the axis to compute mean
169 |
170 | '''
171 | feature_mean = np.mean(features, axis=axis, keepdims=True)
172 | feature_remove_mean = features-feature_mean
173 | return feature_remove_mean, feature_mean
174 |
175 |
176 | def remove_zero_patch(samples):
177 | std_var = (np.std(samples, axis=1)).reshape(-1, 1)
178 | ind_bool = (std_var == 0)
179 | ind = np.where(ind_bool==True)[0]
180 | # print('zero patch shape:',ind.shape)
181 | samples_new = np.delete(samples, ind, 0)
182 | return samples_new
183 |
184 |
185 | def find_kernels_pca(sample_patches, num_kernels, energy_percent, n_batch):
186 | '''
187 | Do the PCA based on the provided samples.
188 | If num_kernels is not set, will use energy_percent.
189 | If neither is set, will preserve all kernels.
190 | :param samples: [num_samples, feature_dimension]
191 | :param num_kernels: num kernels to be preserved
192 | :param energy_percent: the percent of energy to be preserved
193 | :return: kernels, sample_mean
194 | '''
195 | # Remove patch mean
196 | sample_patches_centered, dc = remove_mean(sample_patches, axis=1)
197 | sample_patches_centered = remove_zero_patch(sample_patches_centered)
198 | # Remove feature mean (Set E(X)=0 for each dimension)
199 | training_data, feature_expectation = remove_mean(sample_patches_centered, axis=0)
200 |
201 | # pca = PCA(n_components=training_data.shape[1], svd_solver='full', whiten=True)
202 | batch_size = training_data.shape[0]//n_batch
203 | pca = IncrementalPCA(n_components=training_data.shape[1], whiten=True, batch_size=batch_size, copy=False)
204 | pca.fit(training_data)
205 |
206 | # Compute the number of kernels corresponding to preserved energy
207 | if energy_percent:
208 | energy = np.cumsum(pca.explained_variance_ratio_)
209 | num_components = np.sum(energy < energy_percent)+1
210 | else:
211 | num_components = num_kernels
212 |
213 | kernels = pca.components_[:num_components, :]
214 | mean = pca.mean_
215 |
216 | num_channels = sample_patches.shape[-1]
217 | largest_ev = [np.var(dc*np.sqrt(num_channels))]
218 | dc_kernel = 1/np.sqrt(num_channels)*np.ones((1, num_channels))/np.sqrt(largest_ev)
219 | kernels = np.concatenate((dc_kernel, kernels), axis=0)
220 |
221 | print("Num of kernels: %d" % num_components)
222 | print("Energy percent: %f" % np.cumsum(pca.explained_variance_ratio_)[num_components-1])
223 | return kernels, mean
224 |
225 |
226 | def extract(feat):
227 | '''
228 | Do feature extraction based on the provided feature.
229 | :param feat: [num_layer, num_samples, feature_dimension]
230 | # :param pooling: pooling method to be used
231 | :return: feature
232 | '''
233 | mean = []
234 | maxi = []
235 | l1 = []
236 | l2 = []
237 |
238 | for i in range(len(feat)):
239 | mean.append(feat[i].mean(axis=1, keepdims=False))
240 | maxi.append(feat[i].max(axis=1, keepdims=False))
241 | l1.append(np.linalg.norm(feat[i], ord=1, axis=1, keepdims=False))
242 | l2.append(np.linalg.norm(feat[i], ord=2, axis=1, keepdims=False))
243 | mean = np.concatenate(mean, axis=-1)
244 | maxi = np.concatenate(maxi, axis=-1)
245 | l1 = np.concatenate(l1, axis=-1)
246 | l2 = np.concatenate(l2, axis=-1)
247 |
248 | return [mean, maxi, l1, l2]
249 |
250 |
251 | def aggregate(feat, pool):
252 | feature = []
253 | for j in range(len(feat)):
254 | feature.append(feat[j] * pool[j])
255 | feature = np.concatenate(feature, axis=-1)
256 | return feature
257 |
258 |
259 | def classify(feature_train, train_label, feature_valid, valid_label, pooling):
260 | '''
261 | Train classifier based on the provided feature.
262 | :param feature_train: [num_samples, feature_dimension]
263 | :param train_label: train label provided
264 | :param feature_valid: [num_samples, feature_dimension]
265 | :param valid_label: train label provided
266 | :param pooling: pooling methods provided
267 | :return: classifer, train accuracy, evaluate accuracy
268 | '''
269 |
270 | clf_tmp = {}
271 | acc_train = []
272 | acc_valid = []
273 | pred_valid = []
274 | for i in range(len(pooling)):
275 | feat_tmp_train = aggregate(feature_train, pooling[i])
276 | feat_tmp_valid = aggregate(feature_valid, pooling[i])
277 | clf = rf_classifier(feat_tmp_train, np.squeeze(train_label))
278 | pred_train = clf.predict(feat_tmp_train)
279 | acc_train.append(sklearn.metrics.accuracy_score(train_label, pred_train))
280 | pred_valid_tmp = clf.predict(feat_tmp_valid)
281 | pred_valid.append(pred_valid_tmp)
282 | acc_valid.append(sklearn.metrics.accuracy_score(valid_label, pred_valid_tmp))
283 | clf_tmp['pooling method %d' % i] = clf
284 | idx = np.argmax(acc_valid)
285 | acc = average_acc(valid_label, pred_valid[idx])
286 | # print(pooling[idx])
287 |
288 | feature = {}
289 | label = {}
290 | feature['train'] = feat_tmp_train
291 | feature['test'] = feat_tmp_valid
292 | label['train'] = train_label
293 | label['test'] = valid_label
294 | import os
295 | import pickle
296 | with open(os.path.join('/home/minzhang/pointhop-master/feat.pkl'), 'wb') as f:
297 | pickle.dump(feature, f)
298 | with open(os.path.join('/home/minzhang/pointhop-master/label.pkl'), 'wb') as f:
299 | pickle.dump(label, f)
300 | return clf_tmp, acc_train[idx], acc_valid[idx], acc
301 |
302 |
303 | def average_acc(label, pred_label):
304 |
305 | classes = np.arange(40)
306 | acc = np.zeros(len(classes))
307 | for i in range(len(classes)):
308 | ind = np.where(label == classes[i])[0]
309 | pred_test_special = pred_label[ind]
310 | acc[i] = len(np.where(pred_test_special == classes[i])[0])/float(len(ind))
311 | return acc
312 |
313 |
314 | def onehot_encoding(n_class, labels):
315 |
316 | targets = labels.reshape(-1)
317 | one_hot_targets = np.eye(n_class)[targets]
318 | return one_hot_targets
319 |
320 |
321 | # SVM
322 | def svm_classifier(feat, y):
323 | '''
324 | Train svm based on the provided feature.
325 | :param feat: [num_samples, feature_dimension]
326 | :param y: label provided
327 | :return: classifer
328 | '''
329 | clf = svm.SVC(probability=True,gamma='auto')
330 | clf.fit(feat, y)
331 | return clf
332 |
333 |
334 | # RF
335 | def rf_classifier(feat, y):
336 | '''
337 | Train svm based on the provided feature.
338 | :param feat: [num_samples, feature_dimension]
339 | :param y: label provided
340 | :return: classifer
341 | '''
342 | clf = ensemble.RandomForestClassifier(n_estimators=128, bootstrap=False,
343 | n_jobs=-1)
344 | clf.fit(feat, y)
345 | return clf
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import modelnet_data
4 | import pointhop
5 | import numpy as np
6 | import data_utils
7 | import os
8 | import time
9 |
10 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--num_batch_train', type=int, default=20, help='Batch Number')
14 | parser.add_argument('--num_batch_test', type=int, default=1, help='Batch Number')
15 | parser.add_argument('--initial_point', type=int, default=1024, help='Point Number [256/512/1024/2048]')
16 | parser.add_argument('--validation', default=False, help='Split train data or not')
17 | parser.add_argument('--ensemble', default=False, help='Ensemble or not')
18 | parser.add_argument('--rotation_angle', default=np.pi/4, help='Rotate angle')
19 | parser.add_argument('--rotation_freq', default=8, help='Rotate time')
20 | parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
21 | parser.add_argument('--num_point', default=[1024, 128, 128, 64], help='Point Number after down sampling')
22 | parser.add_argument('--num_sample', default=[64, 64, 64, 64], help='KNN query number')
23 | parser.add_argument('--num_filter', default=[15, 25, 40, 80], help='Filter Number ')
24 | parser.add_argument('--pooling_method', default=[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1],
25 | [1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1], [0, 1, 1, 0],
26 | [0, 1, 0, 1], [0, 0, 1, 1], [1, 1, 1, 0], [1, 1, 0, 1],
27 | [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]],
28 | help='Pooling methods [mean, max, l1, l2]')
29 | FLAGS = parser.parse_args()
30 |
31 | num_batch_train = FLAGS.num_batch_train
32 | num_batch_test = FLAGS.num_batch_test
33 | initial_point = FLAGS.initial_point
34 | VALID = FLAGS.validation
35 | ENSEMBLE = FLAGS.ensemble
36 | angle_rotation = FLAGS.rotation_angle
37 | freq_rotation = FLAGS.rotation_freq
38 | num_point = FLAGS.num_point
39 | num_sample = FLAGS.num_sample
40 | num_filter = FLAGS.num_filter
41 | pooling = FLAGS.pooling_method
42 |
43 |
44 | LOG_DIR = FLAGS.log_dir
45 | if not os.path.exists(LOG_DIR):
46 | os.mkdir(LOG_DIR)
47 | LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
48 | LOG_FOUT.write(str(FLAGS) + '\n')
49 |
50 |
51 | def log_string(out_str):
52 | LOG_FOUT.write(out_str+'\n')
53 | LOG_FOUT.flush()
54 | print(out_str)
55 |
56 |
57 | def main():
58 | time_start = time.time()
59 |
60 | # load data
61 | train_data, train_label = modelnet_data.data_load(num_point=initial_point, data_dir=os.path.join(BASE_DIR, 'modelnet40_ply_hdf5_2048'), train=True)
62 | test_data, test_label = modelnet_data.data_load(num_point=initial_point, data_dir=os.path.join(BASE_DIR, 'modelnet40_ply_hdf5_2048'), train=False)
63 |
64 | # validation set
65 | if VALID:
66 | train_data, train_label, valid_data, valid_label = modelnet_data.data_separate(train_data, train_label)
67 | else:
68 | valid_data = test_data
69 | valid_label = test_label
70 |
71 | print(train_data.shape)
72 | print(valid_data.shape)
73 |
74 | if ENSEMBLE:
75 | angle = np.repeat(angle_rotation, freq_rotation)
76 | else:
77 | angle = [0]
78 |
79 | params = {}
80 | feat_train = []
81 | feat_valid = []
82 | for i in range(len(angle)):
83 | print('------------Train ', i, '--------------')
84 | idx_save, new_xyz_save, final_feature_train, feature_train, pca_params = \
85 | pointhop.pointhop_train(train_data, n_batch=num_batch_train, n_newpoint=num_point, n_sample=num_sample, layer_num=num_filter,
86 | energy_percent=None)
87 | print('------------Validation ', i, '--------------')
88 |
89 | final_feature_valid, feature_valid = pointhop.pointhop_pred(
90 | valid_data, n_batch=num_batch_test, pca_params=pca_params, n_newpoint=num_point, n_sample=num_sample, layer_num=num_filter,
91 | idx_save=None, new_xyz_save=None)
92 |
93 | feature_train = pointhop.extract(feature_train)
94 | feature_valid = pointhop.extract(feature_valid)
95 | feat_train.append(feature_train)
96 | feat_valid.append(feature_valid)
97 | params['stage %d pca_params' % i] = pca_params
98 |
99 | train_data = data_utils.data_augment(train_data, angle[i])
100 | valid_data = data_utils.data_augment(valid_data, angle[i])
101 |
102 | feat_train = np.concatenate(feat_train, axis=-1)
103 | feat_valid = np.concatenate(feat_valid, axis=-1)
104 |
105 | clf, acc_train, acc_valid, acc = pointhop.classify(feat_train, train_label, feat_valid, valid_label, pooling)
106 | params['clf'] = clf
107 |
108 | time_end = time.time()
109 |
110 | log_string("train acc is {}".format(acc_train))
111 | log_string('eval acc is {}'.format(acc_valid))
112 | log_string('eval mean acc is {}'.format(np.mean(acc)))
113 | log_string('per-class acc is {}'.format(str(acc)))
114 | log_string('totally time cost is {} minutes'.format((time_end - time_start)//60))
115 |
116 | with open(os.path.join(LOG_DIR, 'params.pkl'), 'wb') as f:
117 | pickle.dump(params, f)
118 |
119 |
120 | if __name__ == '__main__':
121 | main()
122 |
123 |
--------------------------------------------------------------------------------